# Training Setup

In [None]:
# Ubuntu environment only
! apt install libasound2-dev portaudio19-dev

### Setup AudioCraft

In [None]:
!git clone https://github.com/facebookresearch/audiocraft.git
%cd audiocraft
!pip install -e .
AUDIOCRAFT_ROOT = "/content/audiocraft"

### Python dependencies

In [None]:
!pip install dora-search numba
!pip install git+https://github.com/tnadav/prompt-synth.git#subdirectory=audiomanip
!pip install torchvision==0.16

### Dataset Generation

In [None]:
import yaml


def make_dataset_yaml(
    name: str, eval_path: str, generate_path: str, train_path: str, valid_path: str
) -> None:
    data = yaml.dump(
        {
            "datasource": {
                "max_channels": 1,
                "max_sample_rate": 32000,
                "evaluate": eval_path,
                "generate": generate_path,
                "train": train_path,
                "valid": valid_path,
            }
        }
    )

    with open(f"/content/audiocraft/config/dset/audio/{name}.yaml", "w") as f:
        _package = "package"
        f.write(f"# @{_package} __global__\n\n")
        f.write(data)

In [None]:
from google.colab import drive

drive.mount("/content/drive/")

nsynth_test = "/content/drive/MyDrive/prompt-synth/musicgen-nsynth-test"
make_dataset_yaml(
    "nsynth-test",
    nsynth_test,
    nsynth_test,
    nsynth_test,
    nsynth_test,
)

!rm -rf /content/audiocraft/dataset/nsynth-test
!cp -r /content/drive/MyDrive/prompt-synth/musicgen-nsynth-test/nsynth-test /content/audiocraft/dataset/nsynth-test

# Train using dora

In [None]:
%env USER=nadav
%env AUDIOCRAFT_TEAM=default

# clear cuda mem
from numba import cuda

device = cuda.get_current_device()
device.reset()

command = (
    "dora run solver=magnet/magnet_32khz"
    " model/lm/model_scale=small"
    " continue_from=//pretrained/facebook/magnet-small-10secs"
    " conditioner=text2music"
    " dset=audio/nsynth-test"
    " dataset.num_workers=1"
    " dataset.valid.num_samples=1"
    " dataset.batch_size=1"  # batch_size 2 with T4 resulted in OOM
    " schedule.cosine.warmup=8"
    " optim.optimizer=adamw"  # uses dadaw by default, which is worse for single-gpu runs
    " optim.lr=1e-4"
    " optim.epochs=5"  # stops training after 5 epochs- change this
    " optim.updates_per_epoch=1000"  # 2000 by default, change this if you want checkpoints quicker ig
    " optim.adam.weight_decay=0.01"
)

!cd /content/audiocraft
!{command}