# Model Tester
Use this notebook to test/demonstrate model results.
Since the model generates new audio on each run, it can't exactly be "tested". However, it can be qualitatively (subjectively) evaluated by the human tester.

See: https://stackoverflow.com/a/69897034

> There is no testing phase in GANS as we normally have in other neural networks like CNN etc. GAN generator models are evaluated based on the quality of the images generated, often in the context of the target problem domain.

This notebook mounts your Google Drive so you can point it to a custom dataset.


### Setup
Installs dependencies, downloads pretrained models and weights

In [None]:
#@title Setup
%cd /content
!git clone https://github.com/marcoppasini/musika
%cd musika
!pip install -r requirements.txt
!pip install scikit-learn
!pip install gensim
!pip install nltk
!pip install transformers
!pip install --upgrade --no-cache-dir gdown
!apt install ffmpeg unzip

import sys
import os
import gdown
import subprocess
import torch
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity

weights_link = "https://drive.google.com/uc?id=1RDNjpcAH10JQ87HAvYtYCaIfnO_ZLf8T"
if not os.path.exists("./weights.zip"):
  gdown.download(weights_link, "./weights.zip")
if not os.path.exists("./weights"):
  subprocess.check_output(["unzip", "./weights.zip"])

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased",)
model = AutoModel.from_pretrained("bert-base-uncased",output_hidden_states=True)

def get_embeddings(text,token_length):
  tokens=tokenizer(text,max_length=token_length,padding='max_length',truncation=True)
  output=model(torch.tensor(tokens.input_ids).unsqueeze(0),
               attention_mask=torch.tensor(tokens.attention_mask).unsqueeze(0)).hidden_states[-1]
  return torch.mean(output,axis=1).detach().numpy()
  
genre_list = [genre.replace("/", "") for genre in os.listdir("./weights")]
genre_embeddings = [get_embeddings(genre, 20) for genre in genre_list]

def get_closest(query, token_length=20):
    query_embeddings = get_embeddings(query, token_length)
    sims = [cosine_similarity(embed, query_embeddings)[0][0] for embed in genre_embeddings]
    closest = max(zip(sims, genre_list))[1]
    return closest

### Generate Audio
This cell generates audio based on the genre from user input. The user may subjectively evaluate audio on whatever criteria they deem fit for testing.

In [None]:

from IPython.display import Audio, display
from IPython.utils import io
import subprocess
import glob

genre = "pop" #@param {type:"string"}
coerced_genre = get_closest(genre)

output_dir = "/content/musika/output"
already_output_files = glob.glob(output_dir + "/*")
for f in already_output_files:
  os.remove(f)

weights_path = f"./weights/{coerced_genre}"

!python ./musika_generate.py --load_path $weights_path --num_samples 1 --seconds 30 --save_path ./output

output_file = os.path.join(output_dir, os.listdir(output_dir)[0])
display(Audio(output_file))