# Before you start
Open this link and create a shortcut to indices_genres in your drive: https://drive.google.com/file/d/1-0CjAdc5ZJIw_pxu8ycVBmHnJabAIxEg/view?usp=sharing

# Code

In [None]:
try:
    from google.colab import drive
    drive.mount('/content/drive')
    colab = True
except:
    colab = False
print ("Running colab:", colab)
path = "/content/" if colab else ""
#abs path
import os
path = os.path.abspath(path) + "/"
path

In [None]:
#if not os.path.exists(path + "mel_specs_music"):
!mkdir mel_specs_music
!mkdir mel_specs_music/train
!mkdir mel_specs_music/val
!mkdir mel_specs_music/train/cl
!mkdir mel_specs_music/val/cl
!git clone https://github.com/nadavbh12/VQ-VAE.git
!mv ./VQ-VAE/ ./VQ_VAE/ 

In [None]:
import sklearn
from sklearn.model_selection import train_test_split
import glob
import pandas as pd
import pickle
import torch
from torchvision import datasets, transforms
from VQ_VAE.vq_vae.auto_encoder import VQ_CVAE
import tqdm
import numpy as np
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import plotly.express as px
import pandas as pd
import librosa
import cv2
import scipy
from gc import collect
import IPython.display as ipd
import shutil

In [None]:
#@title Unzip
# Make sure your processed mel spectrogram data is in the same directory as this notebook
if colab: 
  !unzip '/content/drive/MyDrive/Big Data Project/large_music_mel_spectrograms.zip'

In [None]:
class vq_vae_search:
    def __init__(self, bs=64, dict_size=128, epochs=15, h=64, w=256):
        self.bs = bs
        self.dict_size = dict_size
        self.epochs = epochs
        self.h = h
        self.w = w
        self.path = path
        self.vqvae_music_checkpoints_folder = path + 'vqvae_music_checkpoints/'
        self.filter_and_org_data()
        self.tune_aug_hyperperameters()
        self.train()
        #self.load_model()
        #self.predict_values() 
        #self.top_k(50)
  
    def filter_and_org_data(self):
        if not os.path.exists(self.vqvae_music_checkpoints_folder):
            os.mkdir(self.vqvae_music_checkpoints_folder)
        self.ds_dir = path + 'mel_specs_music/val/'
        self.img = glob.glob(path + 'music_mel_spectrograms/*.png')
        if colab:
            with open('/content/drive/MyDrive/indices_genres', 'rb') as f:
                self.ig = pickle.load(f)
        else:
            with open('indices_genres', 'rb') as f:
                self.ig = pickle.load(f)
        #filter, undersample, and reorganize data
        self.img_ids = list(map(lambda x: os.path.basename(x)[:-4], self.img))
        #print(self.img_ids)
        self.df_filter = pd.DataFrame.from_dict({'genre': [self.ig[el] for el in self.img_ids], 'id':self.img}).dropna()
        self.df_filter = pd.concat([self.df_filter[self.df_filter['genre'] == genre][:1000].reset_index(drop=True) for genre in list(self.df_filter['genre'].value_counts()[:-7].keys())]).reset_index(drop=True)
        self.img = list(self.df_filter['id'].values)

        self.img_train, self.img_test = train_test_split(self.img, test_size=0.2, random_state=42)
        self.f_train = lambda x: path + 'mel_specs_music/train/cl/' + os.path.basename(x)
        self.f_test = lambda x: path + 'mel_specs_music/val/cl/' + os.path.basename(x)
        self.out_train = list(map(self.f_train, self.img_train))
        self.out_test = list(map(self.f_test, self.img_test))

        for el1, el2 in zip(self.img_train, self.out_train):
            #shutil.move(el1, el2)
            shutil.copy(el1, el2)

        for el1, el2 in zip(self.img_test, self.out_test):
            #shutil.move(el1, el2)
            shutil.copy(el1, el2)

    def tune_aug_hyperperameters(self):
        #@title Tune augmentation hyperparameters
        self.size = f'transforms.Resize(({self.h}, {self.w}))'
        self.replace_main_with = '''dataset_transforms = {
            'custom': transforms.Compose([transforms.Grayscale(), transforms.Resize((h, w)), 
                                            transforms.ToTensor(),
                                            transforms.Normalize((0.5), (0.5))]),
            'imagenet': transforms.Compose([transforms.Grayscale(), transforms.Resize((h, w)), 
                                            transforms.ToTensor(),
                                            transforms.Normalize((0.5), (0.5))]),
            'cifar10': transforms.Compose([transforms.ToTensor(),
                                        transforms.Normalize((0.5), (0.5))]),
            'mnist': transforms.ToTensor()
        }'''.replace('transforms.Resize((h, w))', self.size)
        self.replace_main = '''dataset_transforms = {
            'custom': transforms.Compose([transforms.Resize(256), transforms.CenterCrop(256),
                                            transforms.ToTensor(),
                                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
            'imagenet': transforms.Compose([transforms.Resize(256), transforms.CenterCrop(256),
                                            transforms.ToTensor(),
                                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
            'cifar10': transforms.Compose([transforms.ToTensor(),
                                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
            'mnist': transforms.ToTensor()
        }'''

        with open(path + 'VQ_VAE/main.py', 'r') as f:
            self.data = f.read()

        self.data = self.data.replace(self.replace_main, self.replace_main_with)

        with open(path + 'VQ_VAE/main.py', 'w') as f:
            f.write(self.data)


        self.replace_main_with = '''dataset_n_channels = {
            'custom': 1,
            'imagenet': 1,
            'cifar10': 1,
            'mnist': 1,
        }'''
        self.replace_main = '''dataset_n_channels = {
            'custom': 3,
            'imagenet': 3,
            'cifar10': 3,
            'mnist': 1,
        }'''

        with open(path + 'VQ_VAE/main.py', 'r') as f:
            self.data = f.read()

        self.data = self.data.replace(self.replace_main, self.replace_main_with)

        with open(path + 'VQ_VAE/main.py', 'w') as f:
            f.write(self.data)

    def train(self):
        #@title Train and save to drive (be sure to save the previous trained model somewhere else as the checkpoint folder will be emptied)
        VQVAE_path = path + 'VQ_VAE'
        %cd $VQVAE_path
        if colab:
            !python3 main.py --dataset=custom --model=vqvae --data-dir=/content/mel_specs_music --epochs={self.epochs} --batch-size {self.bs} --dict-size {self.dict_size}
            self.checkpoint_path = sorted(glob.glob('/content/VQ_VAE/results/*/checkpoints/*.pth'))[-1]
        else:
            !python3 main.py --dataset=custom --model=vqvae --data-dir=../mel_specs_music --epochs={self.epochs} --batch-size {self.bs} --dict-size {self.dict_size}
            self.checkpoint_path = sorted(glob.glob('../VQ_VAE/results/*/checkpoints/*.pth'))[-1]

        !rm -rf {self.vqvae_music_checkpoints_folder}
        !mkdir {self.vqvae_music_checkpoints_folder}

        shutil.copyfile(self.checkpoint_path, os.path.join(self.vqvae_music_checkpoints_folder, os.path.basename(self.checkpoint_path)))

    def load_model(self, channels=1):
        #@title Load model
        self.model = VQ_CVAE(128, k = self.dict_size, num_channels=channels)
        self.model.load_state_dict(torch.load(self.checkpoint_path))

    def predict_values(self):
        #@title Predict values
        if not colab:
            !cd ~./VQ-VAE-SEARCH
        self.T = transforms.Compose([transforms.Grayscale(), transforms.Resize((self.h, self.w)), #transforms.CenterCrop(256),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5), (0.5))])

        self.test_dataset = torch.utils.data.DataLoader(datasets.ImageFolder(self.ds_dir, transform=self.T), batch_size=self.bs, shuffle=False)

        self.categories = []
        self.data_names = datasets.ImageFolder(self.ds_dir, transform=self.T)
        self.data_names = [el[0] for el in self.data_names.samples]
        self.test_ids = list(map(lambda x: os.path.basename(x)[:-4], self.data_names))
        self.categories = [self.ig[el] for el in self.test_ids]

        self.normalize = lambda x: (x - x.min()) / (x.max() - x.min())

        self.all_outputs = []
        i = 0
        for data, _ in tqdm.tqdm(self.test_dataset):
            with torch.no_grad():
                self.outputs = self.model(data)[2]
                self.outputs = self.outputs.reshape(self.outputs.shape[0], -1).detach().cpu().numpy() #outputs[1] = enc, outputs[2] = emb
            self.all_outputs.append(self.outputs)
            i += 1

        self.all_outputs = np.concatenate(self.all_outputs, 0)
        #self.all_outputs = self.normalize(all_outputs)

        with open(f'{self.vqvae_music_checkpoints_folder}/all_outputs.pickle', 'wb') as f:
            pickle.dump(self.all_outputs, f)

    def top_k(self, k=50):
        #@title Top K data
        self.n_comp=k
        with open(f'{self.vqvae_music_checkpoints_folder}/all_outputs.pickle', 'rb') as f:
            self.all_outputs = pickle.load(f)

        self.categories = []
        self.data_names = datasets.ImageFolder(self.ds_dir, transform=self.T)
        self.data_names = [el[0] for el in self.data_names.samples]
        self.test_ids = list(map(lambda x: os.path.basename(x)[:-4], self.data_names))
        self.categories = [self.ig[el] for el in self.test_ids]

        self.spec_paths = datasets.ImageFolder(self.ds_dir, transform=self.T)
        self.spec_paths = [el[0] for el in self.spec_paths.imgs]

        self.pca = PCA(n_components=self.n_comp)
        self.pca_result = self.pca.fit_transform(self.all_outputs)
        self.pca_df = pd.DataFrame(self.pca_result, columns=list(map(str,list(range(0, self.n_comp)))))
        self.pca_df['categories'] = self.categories
        self.pca_df['spec_paths'] = self.spec_paths

    def show_top_k(self, select_id=599, k=10):
        #@title Top K
        self.vecs = self.pca_df[list(map(str,list(range(0, self.n_comp))))].to_numpy()
        self.distances = sklearn.metrics.pairwise.cosine_similarity(self.vecs, self.vecs)
        self.top_k = np.flip(np.argsort(self.distances[select_id]))[:k]
        self.top_k_df = self.pca_df.iloc[self.top_k][['categories', 'spec_paths']]
        self.audios = [self.png_to_audio(spec) for spec in self.top_k_df['spec_paths'].values]
        display(self.top_k_df)

    def display_audio(self, audios):
        for audio in audios:
            display(audio)

    def pca(self):
        self.pca = PCA(n_components=3)
        self.pca_result = self.pca.fit_transform(self.all_outputs)
        self.pca_df = pd.DataFrame(self.pca_result, columns=['1', '2', '3'])
        self.pca_df['categories'] = self.categories

        fig = px.scatter_3d(self.pca_df, x='1', y='2', z='3', color='categories', width=1500)
        fig.show()

    def tsne(self):
        #@title TSNE (requires tuning)
        self.pca_tsne = PCA(n_components=50) # change the number of components as a hyperparameter
        self.pca_tsne_result = self.pca_tsne.fit_transform(self.all_outputs)

        self.tsne = TSNE(n_components=3, verbose=1, perplexity=25, n_iter=3000, learning_rate=200)
        self.tsne_results = self.tsne.fit_transform(self.pca_tsne_result)
        self.tsne_df = pd.DataFrame(self.tsne_results, columns=['1', '2', '3'])
        self.tsne_df['categories'] = self.categories

        fig = px.scatter_3d(self.tsne_df, x='1', y='2', z='3', color='categories', width=1200)
        fig.show()

    def png_to_audio(self, audio_file='spec.png', n_fft = 2000, hop_length = 150, win = 50, mi = -80.0, m = 0.0, sr = 22050, save=False):
        # read image
        spec = cv2.imread(audio_file, cv2.IMREAD_GRAYSCALE)

        # de_normalize
        spec = (spec * (m - mi) / 255) + mi
        spec = spec.astype(np.float32)

        # from spectrogram to audio
        aud = self.from_spectrogram(spec, n_fft=n_fft, hop_length=hop_length, win=win, sr=sr)

        # save audio
        aud = self.play_audio(aud)
        if save:
            with open(f'{audio_file[0:-4]}.wav', 'wb') as f:
                    f.write(aud.data)
        return aud

    def play_audio(self, audio_file, sr=22050):
        if type(audio_file) == str:
            return ipd.Audio(audio_file,  rate=sr)
        else:
            return ipd.Audio(audio_file, rate=sr)

    def from_spectrogram(self, spectrogram,  hop_length=150, n_fft=2000, win=50, sr = 22050):
        # undo power_to_db
        S = spectrogram
        S = librosa.db_to_power(S)
        S = librosa.feature.inverse.mel_to_audio(S, sr=sr, n_fft=n_fft, hop_length=hop_length, window=win)
        return S

    
        

Test dict size: 32

In [None]:
model_dict_32 = vq_vae_search(bs=64, dict_size=32, epochs=15, h=64, w=256)


In [None]:

model_dict_32.load_model(1)
model_dict_32.predict_values() 
model_dict_32.top_k(50)

In [None]:
model_dict_32.show_top_k(select_id=599, k=10)

In [None]:
model_dict_32.pca()

In [None]:
model_dict_32.tsne()

del model_dict_32
collect()

Test dict size: 64

In [None]:
model_dict_64 = vq_vae_search(bs=64, dict_size=64, epochs=15, h=64, w=256)

In [None]:
model_dict_64.show_top_k(select_id=599, k=10)

In [None]:
model_dict_64.pca()

In [None]:
model_dict_64.tsne()
del model_dict_64
collect()

Test dict size: 128

In [None]:
model_dict_128 = vq_vae_search(bs=64, dict_size=128, epochs=15, h=64, w=256)

In [None]:
model_dict_128.show_top_k(select_id=599, k=10)

In [None]:
model_dict_128.pca()

In [None]:
model_dict_128.tsne()
del model_dict_128
collect()