# Training Setup

In [None]:
# Ubuntu environment only
! apt install libasound2-dev portaudio19-dev

### Setup AudioCraft

In [None]:
!git clone https://github.com/facebookresearch/audiocraft.git
%cd audiocraft
!pip install -e .
AUDIOCRAFT_ROOT = "/content/audiocraft"

### Python dependencies

In [None]:
!pip install dora-search numba
!pip install git+https://github.com/tnadav/prompt-synth.git#subdirectory=audiomanip
!pip install torchvision==0.16

### Dataset Generation

In [None]:
import yaml


def make_dataset_yaml(
    name: str, train_path: str, valid_path: str, eval_path: str, generate_path: str
) -> None:
    data = yaml.dump(
        {
            "datasource": {
                "max_channels": 1,
                "max_sample_rate": 32000,
                "evaluate": eval_path,
                "generate": generate_path,
                "train": train_path,
                "valid": valid_path,
            }
        }
    )

    with open(f"/content/audiocraft/config/dset/audio/{name}.yaml", "w") as f:
        _package = "package"
        f.write(f"# @{_package} __global__\n\n")
        f.write(data)

In [None]:
from google.colab import drive

drive.mount("/content/drive/")

nsynth_train = "/content/drive/MyDrive/prompt-synth/musicgen-nsynth-train-ext"
nsynth_valid = "/content/drive/MyDrive/prompt-synth/musicgen-nsynth-valid-ext"
nsynth_test = "/content/drive/MyDrive/prompt-synth/musicgen-nsynth-test-ext"
make_dataset_yaml(
    "nsynth-full-fixed-ext",
    train_path=nsynth_train,
    valid_path=nsynth_valid,
    eval_path=nsynth_test,
    generate_path=nsynth_test,
)

!rm -rf /content/audiocraft/dataset/nsynth-test
!cp -r /content/drive/MyDrive/prompt-synth/musicgen-nsynth-test/nsynth-test /content/audiocraft/dataset/nsynth-test

In [None]:
# Import NSynth samples to local dir
import shutil

shutil.copytree(nsynth_train, "/content/audiocraft/dataset/nsynth-train-ext")
shutil.copytree(nsynth_valid, "/content/audiocraft/dataset/nsynth-valid-ext")
shutil.copytree(nsynth_test, "/content/audiocraft/dataset/nsynth-test-ext")

# Train using dora

In [None]:
%env USER=nadav
%env AUDIOCRAFT_TEAM=default

# clear cuda mem
from numba import cuda

device = cuda.get_current_device()
device.reset()

command = (
    "dora run solver=magnet/magnet_32khz"
    " model/lm/model_scale=small"
    " continue_from=//pretrained/facebook/magnet-small-10secs"
    " conditioner=text2music"
    " dset=audio/nsynth-test"
    " dataset.num_workers=1"
    " dataset.valid.num_samples=1"
    " dataset.batch_size=1"  # batch_size 2 with T4 resulted in OOM
    " schedule.cosine.warmup=8"
    " optim.optimizer=adamw"  # uses dadaw by default, which is worse for single-gpu runs
    " optim.lr=1e-4"
    " optim.epochs=5"  # stops training after 5 epochs- change this
    " optim.updates_per_epoch=1000"  # 2000 by default, change this if you want checkpoints quicker ig
    " optim.adam.weight_decay=0.01"
)

!cd /content/audiocraft
!{command}

# Export fine-tuned model

In [None]:
import os

from audiocraft import train
from audiocraft.utils import export

def extract_xp_dataset_name(xp):
    for arg in xp.argv:
        if arg.startswith("dset=audio/"):
            return arg.replace("dset=audio/", "")

    raise ValueError("Couldn't extract dataset name")

def get_xp_name(xp):
    train_name = extract_xp_dataset_name(xp)
    return f"{xp.cfg.solver}-{train_name}-{xp.cfg.optim.epochs}-epochs-{xp.sig}"

def export_model(sig, base_dir) -> str:
    xp = train.main.get_xp_from_sig(sig)
    name = get_xp_name(xp)
    export_dir = os.path.join(base_dir, name)
    os.makedirs(export_dir)

    export.export_lm(
        xp.folder / "checkpoint.th", os.path.join(export_dir, "state_dict.bin")
    )
    # Export pre-trained encoded. Modify if self trained encodec
    export.export_pretrained_compression_model(
        "facebook/encodec_32khz", os.path.join(export_dir, "compression_state_dict.bin")
    )

    return export_dir

def list_xps():
    for sig in os.listdir("/content/drive/MyDrive/prompt-synth/dora/xps"):
        try:
            xp = train.main.get_xp_from_sig(sig)
            print(f"{sig}: {get_xp_name(xp)}")
        except Exception as e:
            print(f"Failed to load {sig}: {e}")

In [None]:
list_xps()

In [None]:
sig = "d83d6943"
exported_model_dir = export_model(sig)