In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import IPython
from IPython.display import clear_output

import torch
from torch.utils import data
from torchvision import transforms
import numpy as np
from IPython.display import Audio
import librosa
import librosa.display
from tqdm import tqdm
import h5py
import os
from pathlib import Path

from GANsynth_pytorch.pytorch_nsynth_lib.nsynth import (
    NSynth, WavToSpectrogramDataLoader)

from GANsynth_pytorch import phase_operation
from GANsynth_pytorch.utils import plots
from GANsynth_pytorch import spec_ops
from GANsynth_pytorch import phase_operation as phase_op
from GANsynth_pytorch import spectrograms_helper as spec_helper

In [None]:
subset = 'valid'

base_path = Path('~/code/data/nsynth/').expanduser()
subsets = ['train', 'valid']

audio_directory_paths = [base_path / subset / 'json_wav/audio/'
                         for subset in subsets]

json_data_paths = {subset: base_path / subset / 'json_wav/examples.json'
                   for subset in subsets}

balanced_splits_base_path = Path(
    '/home/theis/code/data/nsynth-balanced-split-fixed_seed/').expanduser()
balanced_splits_json_data_paths = {
    subset: balanced_splits_base_path / subset / 'examples.json'
    for subset in subsets}

# use instrument_family and pitch as classification targets
dataset = NSynth(audio_directory_paths=audio_directory_paths,
                 json_data_path=json_data_paths[subset],
                 categorical_field_list=["instrument_family","pitch"],
                 valid_pitch_range=[24, 84]
                )
FS_HZ = dataset[0][3]['sample_rate']  # assumes constant sampling rate accross the dataset

In [None]:
HOP_LENGTH = 512

use_mel_scale = True
mel_break_frequency_hertz = 700

loader = WavToSpectrogramDataLoader(dataset, batch_size=1, shuffle=False,
                                    device='cuda',
                                    use_mel_scale=use_mel_scale,
                                    mel_break_frequency_hertz=mel_break_frequency_hertz,
                                    fs_hz=FS_HZ, hop_length=HOP_LENGTH)
shuffled_loader = WavToSpectrogramDataLoader(
    dataset, batch_size=1, shuffle=True,
    device='cuda',
    use_mel_scale=use_mel_scale,
    mel_break_frequency_hertz=mel_break_frequency_hertz,
    fs_hz=FS_HZ, hop_length=HOP_LENGTH)

In [None]:
print(dataset[0])

In [None]:
def expand(mat):
    """"Repeat the last column of the input matrix twice"""
    expand_vec = np.expand_dims(mat[:,125],axis=1)
    expanded = np.hstack((mat,expand_vec,expand_vec))
    return expanded

# Visualization of the computed representations

Re-run the cell to visualize representations on a different input!

In [None]:
shuffled_loader_iterator = iter(shuffled_loader)
samples, *_, targets = next(shuffled_loader_iterator)
sample_name = targets['note_str'][0]

pitch = targets['pitch'].data.numpy()[0]

sample = samples.data.cpu().numpy().squeeze()
plots.plot_mel_representations(sample[0], sample[1],
                              hop_length=HOP_LENGTH, fs_hz=FS_HZ)

In [None]:
import scipy.signal
window_length = 1023

window_scipy = scipy.signal.windows.hann(window_length)
window_torch = torch.hann_window(window_length)

print(torch.nn.functional.mse_loss(torch.from_numpy(window_scipy),
                                   window_torch))

In [None]:
import time
target_n_fft = 2048
mel_downscale = 1
n_fft = target_n_fft * mel_downscale
window_length = target_n_fft // 2
hop_length = window_length // 4

sample_index = np.random.randint(0, len(dataset))
sample_index = 7871
# sample_index = 5875
print(sample_index)
sample = dataset[sample_index]
sample_name = sample[3]['note_str']
print("sample_name", sample_name)
audio = sample[0].flatten()

# original audio
IPython.display.display(IPython.display.Audio(audio.cpu(), rate=FS_HZ))
librosa.display.waveplot(audio.numpy(), sr=FS_HZ)
plt.show()
    
for my_mel_break_frequency in [80, 120, 180, 200, 700, 2000]:
# for my_mel_break_frequency in [80]:
    my_loader = WavToSpectrogramDataLoader([], batch_size=1, shuffle=False,
                                           device='cuda',
                                           use_mel_scale=True,
                                           lower_edge_hertz=20.,
                                           upper_edge_hertz=8000.,
                                           mel_break_frequency_hertz=my_mel_break_frequency,
                                           n_fft=n_fft,
                                           hop_length=hop_length,
                                           window_length=window_length,
                                           mel_downscale=mel_downscale,
                                           fs_hz=FS_HZ,
                                           expand_resolution_factor=1.1)
    
    spectrogram = my_loader.to_spectrogram(audio.unsqueeze(0)
                                          ) # samples[0].data.cpu().numpy()
    spectrogram_np = spectrogram.squeeze(0).cpu().numpy()
    reconstructed_audio = my_loader.to_audio(spectrogram).cpu()
    
    print(spectrogram.shape)
    plots.plot_mel_representations(spectrogram_np[0], spectrogram_np[1],
                                   hop_length=hop_length, fs_hz=FS_HZ)
    plt.show()
    
    IPython.display.display(IPython.display.Audio(reconstructed_audio[0], rate=FS_HZ))
    librosa.display.waveplot(reconstructed_audio[0].numpy(), sr=FS_HZ)
    plt.show()
    print(torch.nn.functional.mse_loss(audio/audio.abs().max(),
                                       (reconstructed_audio[0][:audio.size(0)].cpu()
                                        /reconstructed_audio[0][:audio.size(0)].abs().max())))

In [None]:
import random
random_stored_sample = random.choice(list(save_path.glob('*.h5')))
with h5py.File(random_stored_sample, 'r') as sample_file:
    IF = sample_file['IF']
    logmelmag2 = sample_file['mel_Spec']
    mel_p = sample_file['mel_IF']
    
    plt.subplot(1, 3, 1)
    librosa.display.specshow(logmelmag2, sr=FS_HZ, hop_length=HOP_LENGTH,
                            y_axis='mel')
    plt.title("log-melspectrogram")

    plt.subplot(1, 3, 2)
    librosa.display.specshow(IF, sr=FS_HZ, hop_length=HOP_LENGTH,
                             y_axis='linear')
    plt.title("IF")

    plt.subplot(1, 3, 3)
    librosa.display.specshow(mel_p, sr=FS_HZ, hop_length=HOP_LENGTH,
                             y_axis='mel')
    plt.title("mel-IF")

    plt.tight_layout()

    plt.show()

In [None]:
SKIP_EXISTING = False
count = 0

for samples, pitch, targets in tqdm(loader):
    sample_name = targets['note_str'][0]
    h5_file_path = h5_save_path / f'{sample_name}.h5'
    
    if SKIP_EXISTING and h5_file_path.is_file(): 
        # skip already created file
        continue

    pitch = targets['pitch'].data.numpy()[0]

    if pitch < 24 or pitch > 84:
        # filter-out extreme pitches, as advocated by GANSynth
        continue
    
    sample = samples.data.numpy().squeeze()
    spec = librosa.stft(sample, n_fft=2048, hop_length = 512)
    
    magnitude = np.log(np.abs(spec) + 1.0e-6)[:1024]
    angle =np.angle(spec)

    IF = phase_operation.instantaneous_frequency(angle, time_axis=1)[:1024]
    
    magnitude = expand(magnitude)
    IF = expand(IF)
    logmelmag2, mel_p = spec_helper.specgrams_to_melspecgrams(magnitude, IF)
    
    assert magnitude.shape == (1024, 128)
    assert IF.shape == (1024, 128)
    with h5py.File(h5_file_path, 'w') as sample_file:
        sample_file.create_dataset("Spec", data=magnitude.astype(np.float32))
        sample_file.create_dataset("IF", data=IF.astype(np.float32))
        sample_file.create_dataset("mel_Spec", data=logmelmag2.astype(np.float32))
        sample_file.create_dataset("mel_IF", data=mel_p.astype(np.float32))
        sample_file.attrs.create("pitch", data=pitch)
    
    if count % 500 == 0:
        clear_output(wait=True)
        plot_representations(magnitude, angle[:1024], IF, logmelmag2, mel_p)
        IPython.display.display(IPython.display.Audio(sample, rate=16000))
    count +=1