In [1]:
import os
import re
from typing import List, Optional

import click
import dnnlib
import numpy as np
import torch

import legacy

import librosa
import librosa.display
import soundfile as sf

from tifresi.utils import load_signal
from tifresi.utils import preprocess_signal
from tifresi.stft import GaussTF, GaussTruncTF
from tifresi.transforms import log_spectrogram
from tifresi.transforms import inv_log_spectrogram

import matplotlib.pyplot as plt
np.seterr(divide='ignore', invalid='ignore')
import warnings
import matplotlib.cbook
warnings.filterwarnings("ignore",category=matplotlib.cbook.mplDeprecation)

from IPython.display import Audio 
import IPython

%load_ext autoreload
%autoreload 2

In [2]:
stft_channels = 512
n_frames = 256
hop_size = 128
sample_rate = 16000

def pghi_istft(x):
    use_truncated_window = True
    if use_truncated_window:
        stft_system = GaussTruncTF(hop_size=hop_size, stft_channels=stft_channels)
    else:
        stft_system = GaussTF(hop_size=hop_size, stft_channels=stft_channels)

    x = np.squeeze(x,axis=0)
    new_Y = inv_log_spectrogram(x)
    new_y = stft_system.invert_spectrogram(new_Y)
    return new_y

In [3]:
#GreatestHits
checkpoint_num = '2200'
network_pkl = 'training-runs/00041-vis-data-256-split-auto1-noaug/network-snapshot-00{checkpoint_num}.pkl'.format(checkpoint_num=checkpoint_num)
output_dir = 'sefa-results/greatesthits/sefa-00041-{checkpoint_num}'.format(checkpoint_num=checkpoint_num)

#TokWotel
# checkpoint_num = '0200'
# network_pkl = 'training-runs/00040-tokwotel-auto1-noaug/network-snapshot-00{checkpoint_num}.pkl'.format(checkpoint_num=checkpoint_num)
# output_dir = 'sefa-results/tokwotel/sefa-00040-{checkpoint_num}'.format(checkpoint_num=checkpoint_num)


In [4]:

num_samples = 7
num_semantics = 7
distance_value = 5
start_distance = -1.0*distance_value
end_distance = 1.0*distance_value
num_steps = 11
distances = np.linspace(start_distance, end_distance, num_steps)


layers = ['b4','b8','b16','b32','b64','b128','b256'] #layernames in Synthesis network
# layers = ['b4','b8','b16','b32'] #layernames in Synthesis network
# layers = ['b4','b8','b16']
# layers = ['b16','b32','b64']
# layers = ['b64','b128','b256'] #layernames in Synthesis network
# layers = ['b8','b16','b32','b64'] #layernames in Synthesis network
layers_identifier = '-'.join(layers)+'-'+str(distance_value)+'dist'+'-G'
layers.extend(layers) ## THIS IS SUPER IMPORTANT. Remember, the dimensionality of y is twice the number of feature maps (see first Style GAN paper)
layers.sort(key=lambda x: int(x.replace('b','')))
print(layers)

sefa_output_dir = os.path.join(output_dir, layers_identifier)
sefa_output_audio_dir = os.path.join(sefa_output_dir, 'audio')
os.makedirs(sefa_output_dir, exist_ok=True)
os.makedirs(sefa_output_audio_dir, exist_ok=True)


['b4', 'b4', 'b8', 'b8', 'b16', 'b16', 'b32', 'b32', 'b64', 'b64', 'b128', 'b128', 'b256', 'b256']


In [5]:
device = torch.device('cuda')
with dnnlib.util.open_url(network_pkl) as f:
    G = legacy.load_network_pkl(f)['G'].to(device).eval()
    
with dnnlib.util.open_url(network_pkl) as f:
    G_ema = legacy.load_network_pkl(f)['G_ema'].to(device).eval()
    
generator = G

In [6]:
weights = []
layer_ids = []
for layer_id, layer_name in enumerate(layers):
    weight = generator.synthesis.__getattr__(layer_name).__getattr__('torgb').affine.weight.T
    weights.append(weight.cpu().detach().numpy())
    layer_ids.append(layer_id)
    
weight = np.concatenate(weights, axis=1).astype(np.float32)
weight = weight / np.linalg.norm(weight, axis=0, keepdims=True)
eigen_values, eigen_vectors = np.linalg.eig(weight.dot(weight.T))
boundaries, values = eigen_vectors.T, eigen_values

print(boundaries.shape, values.shape)

#Sorting values
values_ind = np.array([a for a in range(len(values))])
temp = np.array(sorted(zip(values, values_ind), key=lambda x: x[0], reverse=True))
values, values_ind = temp[:, 0], temp[:, 1]

print(values, values_ind)

(128, 128) (128,)
[138.17625427  83.62195587  80.00395203  77.82159424  76.91971588
  75.88919067  74.14586639  73.17889404  70.00765991  68.87132263
  68.37974548  68.07419586  67.04872894  65.73930359  64.77319336
  64.22251892  63.7420311   62.47785568  62.1232872   61.59173965
  61.15961838  59.89933014  59.53933716  59.04004288  58.66625214
  58.45537186  57.56650925  56.87845612  56.6220932   56.16900253
  55.30134201  55.15444183  54.95201492  54.40515518  53.87666321
  53.28897858  52.83316422  52.18307495  52.0953331   51.51014328
  50.99672318  50.49062729  49.9239006   49.6997261   49.2090683
  48.56822205  48.34690857  48.03170776  47.24629211  46.78878784
  46.36758804  45.79691315  45.29172516  44.97393036  44.85342789
  43.9659462   43.59421539  42.97405243  42.86740112  42.49003601
  41.81612015  41.35634995  40.58515167  40.48667145  40.2219429
  39.52517319  38.50452805  38.13554001  37.80888367  37.39937592
  37.26199722  36.71719742  36.04849625  35.58002472  35.023

In [7]:
# np.random.seed(987348234)
# torch.manual_seed(987348234)

# Prepare codes.
codes = torch.randn(num_samples, generator.z_dim).cuda()
codes = generator.mapping(codes, None)#['w']
codes = codes.detach().cpu().numpy()
codes.shape

Setting up PyTorch plugin "bias_act_plugin"... Done.


(7, 14, 128)

In [None]:
for sample_id in range(num_samples):
    code = codes[sample_id:sample_id + 1]
    
    for semantic_id in range(num_semantics):
        
        val_sorted_ind = int(values_ind[semantic_id])
        
        boundary = boundaries[val_sorted_ind:val_sorted_ind + 1]
        
        for dist_id, dist in enumerate(distances, start=1):
            temp_code = code.copy()
            temp_code[:, layer_ids, :] += boundary * dist
            image = generator.synthesis(torch.from_numpy(temp_code).cuda())
            image = (image  * 127.5+ 128).clamp(0, 255).to(torch.uint8)
            image = image.detach().cpu().numpy()[0]

            filler = np.full((1, 1, image[0][0].shape[0]), np.min(image))
            image = np.append(image, filler, axis=1) # UNDOING THAT CODE!
            image = image/255
            image = -50+image*50
            
            audio = pghi_istft(image)
            
            filename = 'samplenum_'+str(sample_id)+'_semantic_'+str(semantic_id)+'_distance_'+str(dist)+'.wav'
            sf.write(os.path.join(sefa_output_audio_dir, filename), audio.astype(float), sample_rate)
            
            fig=plt.figure()
            filename_img = 'samplenum_'+str(sample_id)+'_semantic_'+str(semantic_id)+'_distance_'+str(dist)+'.png'
            _=librosa.display.specshow(image[0], y_axis='log', sr=sample_rate, x_axis='time')
            code_sematic_specpng_path = os.path.join(sefa_output_audio_dir, filename_img)
            plt.savefig(code_sematic_specpng_path)
            plt.close(fig)

Setting up PyTorch plugin "upfirdn2d_plugin"... Done.


In [None]:
overall_html_template_str = '''<html>
    <head>
        <meta http-equiv="Content-Type" content="text/html; charset=UTF-8">

        <meta charset="utf-8">
        <meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=yes">
        <link rel="preconnect" href="https://fonts.gstatic.com">
        <link
            href="https://fonts.googleapis.com/css2?family=Open+Sans:ital,wght@0,300;0,400;0,600;0,700;0,800;1,300;1,400;1,600;1,700;1,800&display=swap"
            rel="stylesheet">
        <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.0.0-beta2/dist/css/bootstrap.min.css" rel="stylesheet"
            integrity="sha384-BmbxuPwQa2lc/FVzBcNJ7UAyJxM6wuqIj61tLrc4wSX0szH/Ev+nYRRuWlolflfl" crossorigin="anonymous">
    </head>
    <body>

        {html_data}

    </body>
</html>'''

header_template_str = '''
<div class='row'>
{header_row}
</div>
'''

header_template_first_col_str = '''
<div class='col-1 border'>({semantic_id}) Semantic ({semantic_val})</div>
'''

header_template_remaining_col_str = '''
<div class='col-1 border'>Direction ({direction_val})</div>
'''


sample_template_str = '''
<div class='row text-wrap'>
{sample_row}
</div>
'''

sample_template_first_col_str = '''
<div class='col-1 border text-wrap'>Sample ({sample_id})</div>
'''


sample_template_remaining_col_str = '''
<div class='col-1 border text-wrap'>
    <img width='100%' src='audio/samplenum_{sample_id}_semantic_{semantic_id}_distance_{dist}.png'/><br/>
    <audio  style='width:100%' controls>
        <source src='audio/samplenum_{sample_id}_semantic_{semantic_id}_distance_{dist}.wav' type='audio/wav'/>
    </audio>
</div>
'''


In [None]:
body_html = ''
for sem_id in range(num_semantics):
    
    header_inner_html = header_template_first_col_str.format(semantic_val='{:.2f}'.format(values[sem_id]), semantic_id=int(sem_id)+1)
    
    for dist_id, dist in enumerate(distances, start=1):
        header_inner_html = header_inner_html + header_template_remaining_col_str.format(direction_val='{:.2f}'.format(dist))
    body_html = body_html + header_template_str.format(header_row=header_inner_html)
    
    sample_row_html = ''
    for sample_id in range(num_samples):
          
        sample_row_inner_html = sample_template_first_col_str.format(sample_id=sample_id)
        for dist_id, dist in enumerate(distances, start=1):
            sample_row_inner_html = sample_row_inner_html + sample_template_remaining_col_str.format(sample_id=sample_id, \
                                                                                                     semantic_id=sem_id, \
                                                                                                     dist=dist)
        sample_row_html = sample_row_html + sample_row_inner_html
    body_html = body_html + sample_template_str.format(sample_row=sample_row_html)
    
overall_html = overall_html_template_str.format(html_data=body_html)     
with open(os.path.join(sefa_output_dir,'sefa.html'), 'w') as sefaf:
    sefaf.write(overall_html)