# Model Trainer
Use this notebook to train the model on a dataset of audio. You'll likely want to have a dataset of **one genre** of audio/music for best results and compatibility with frontend. 

This notebook mounts your Google Drive so you can point it to your dataset.

## Important Note
If the specified genre is "example", the dataset directory path input will be ignored and the notebook will instead download an example dataset for training.
This is for demonstration purposes.

### Setup
Installs dependencies, mounts Google Drive

In [None]:
%cd /content
!git clone https://github.com/marcoppasini/musika
%cd musika
!pip install -r requirements.txt
!pip install --upgrade --no-cache-dir gdown
!apt install unzip

from google.colab import drive
drive.mount("/content/drive")

### Inputs
`genre_name`: Name of audio genre (optionally "example" to download example dataset)

`dataset_path`: Path to folder containing audio dataset

In [None]:
genre_name = "example" 
dataset_path = "/path/to/audio/dir"

### Encode Dataset
Encodes dataset in format usable by generation model

In [None]:
import gdown, os, subprocess
if genre_name == "example":
  dataset_path = "/content/dataset"
  if not os.path.exists(dataset_path + ".zip"):
    dataset_url = "https://drive.google.com/uc?id=15iroZ6Sh89pFuL41gd-Q1rkdhmb7DKzJ"
    gdown.download(dataset_url, dataset_path + ".zip")
  if not os.path.exists(dataset_path):
    subprocess.check_output(["unzip", dataset_path + ".zip", "-d", "/content"])

!python musika_encode.py --files_path $dataset_path --save_path /content/{genre_name}_encodings

### Train model on dataset
Trains the model on the user-provided dataset (or a default downloaded dataset, if genre specified as "example")

Prints path to folder containing model weights after training

In [None]:
!python musika_train.py --train_path /content/{genre_name}_encodings

default_checkpoints = {"ae", "misc", "misc_small", "techno"}
weight_dir = [chk for chk in os.listdir("/content/musika/checkpoints") if chk not in default_checkpoints][0]
print(f"Folder with model checkpoints and respective weights: /content/musika/checkpoints/{weight_dir}")

Folder with model checkpoints and respective weights: /content/musika/checkpoints/MUSIKA_latlen_256_latdepth_64_sr_44100_time_20221208-063552
