# Training MusicLM model
### Robert Chen, Ahmadsho Akdodshoev, Philip Timofeev

## 0. Imports

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchaudio
from musiclm_pytorch import (
    MuLaN, MuLaNEmbedQuantizer, MuLaNTrainer, 
    AudioSpectrogramTransformer, TextTransformer,
    MusicLM
)
from audiolm_pytorch import (
    SemanticTransformer, SemanticTransformerTrainer, 
    CoarseTransformer, CoarseTransformerTrainer, 
    FineTransformer, FineTransformerTrainer, 
    AudioLM, HubertWithKmeans, MusicLMSoundStream,
    SoundStreamTrainer, SoundStream
)
import os
import wave
import urllib.request

## 1. Creating dataloaders and downloading Hubert K-means checkpoints

Creating the dataset

In [None]:
class MusicLMDataset(Dataset):
    def __init__(self, path: str) -> None:
        super().__init__()
    def __getitem__(self, index) -> Any:
        ...


train_dataset = MusicLMDataset('path/to/files')

Downloading Hubert checkpoints

In [None]:
hubert_ckpt = './hubert/hubert_base_ls960.pt'
hubert_quantizer = './hubert/hubert_base_ls960_L9_km500.bin'
soundstream_ckpt = './results/soundstream.pt'

if not os.path.isdir("hubert"):
  os.makedirs("hubert")
if not os.path.isfile(hubert_ckpt):
  hubert_ckpt_download = f"https://dl.fbaipublicfiles.com/{hubert_ckpt}"
  urllib.request.urlretrieve(hubert_ckpt_download, f"./{hubert_ckpt}")
if not os.path.isfile(hubert_quantizer):
  hubert_quantizer_download = f"https://dl.fbaipublicfiles.com/{hubert_quantizer}"
  urllib.request.urlretrieve(hubert_quantizer_download, f"./{hubert_quantizer}")

## 2. Training MuLaN

Arguments for every module are defined in the respective dictionaries to make fine-tuning easier

In [1]:
AUDIO_KWARGS = {
    'dim': 512,
    'depth': 6,
    'heads': 8,
    'dim_head': 64,
    'spec_n_fft': 128,
    'spec_win_length': 24,
    'spec_aug_stretch_factor': 0.8
}

TEXT_KWARGS = {
    'dim': 512,
    'depth': 6,
    'heads': 8,
    'dim_head': 64
}

MULAN_KWARGS = {
    'num_train_steps': 10,
    'batch_size': 4,
    'force_clear_prev_results': True,
    'save_model_every': 5
}

MULAN_QUANTIZER_KWARGS = {
    'conditioning_dims': (1024, 1024, 1024),
    'namespaces': ('semantic', 'coarse', 'fine')
}

HUBERT_KWARGS = {
    'checkpoint_path': hubert_ckpt,
    'kmeans_path': hubert_quantizer
}

SOUNDSTREAM_KWARGS = {
    
}

SOUNDSTREAM_TRAINER_KWARGS = {
    
}
    
SEMANTIC_KWARGS = {
    
}

COARSE_KWARGS = {
    
}

FINE_KWARGS = {
    
}

TRANSFORMER_TRAINER_KWARGS = {
    'dataset': train_dataset,
    
}

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

NameError: name 'dim' is not defined

Training MuLaN

In [None]:
audio_transformer = AudioSpectrogramTransformer(**AUDIO_KWARGS)

text_transformer = TextTransformer(**TEXT_KWARGS)

mulan = MuLaN(audio_transformer, text_transformer)

mulan_trainer = MuLaNTrainer(mulan, train_dataset, **MULAN_KWARGS)

mulan_trainer.train()

mulan_trainer.save('./models/MuLaN.pt')

## 3. Training SoundStream

In [None]:
soundstream = MusicLMSoundStream()
soundstream_trainer = SoundStreamTrainer()

## 4. Training conditioning embeddings

Defining the MuLaN Embed Quantizer and Hubert K-means Embedder

In [None]:
quantizer = MuLaNEmbedQuantizer(
    mulan=mulan,                         
    **MULAN_QUANTIZER_KWARGS
)

wav2vec = HubertWithKmeans(
    **HUBERT_KWARGS
)

Training Semantic Transformer

In [None]:
soundstream = MusicLMSoundStream()

soundstream.load(soundstream_ckpt)

semantic_transformer = SemanticTransformer(
   num_semantic_tokens=wav2vec.codebook_size,
   **SEMANTIC_KWARGS 
).to(DEVICE)

semantic_trainer = SemanticTransformerTrainer(
    wav2vec,
    semantic_transformer,
    audio_conditioner=quantizer,
    **TRANSFORMER_TRAINER_KWARGS
)

semantic_trainer.train()

Training Coarse Transformer

In [None]:
soundstream = MusicLMSoundStream()

soundstream.load(soundstream_ckpt)

coarse_transformer = CoarseTransformer(
    num_semantic_tokens=wav2vec.codebook_size,
    **COARSE_KWARGS
).to(DEVICE)

coarse_trainer = CoarseTransformerTrainer(
    wav2vec,
    semantic_transformer,
    codec=soundstream,
    audio_conditioner=quantizer,
    **TRANSFORMER_TRAINER_KWARGS
)

coarse_trainer.train()

Training Fine Transformer

In [None]:
soundstream = MusicLMSoundStream()

soundstream.load(soundstream_ckpt)

fine_transformer = FineTransformer(
    codebook_size=wav2vec.codebook_size,
    **FINE_KWARGS
).to(DEVICE)

fine_trainer = FineTransformerTrainer(
    wav2vec,
    semantic_transformer,
    codec=soundstream
    audio_conditioner=quantizer,
    **TRANSFORMER_TRAINER_KWARGS
)

fine_trainer.train()

## 5. Combining AudioLM and MusicLM

In [None]:
audio_lm = AudioLM(
    wav2vec=wav2vec,
    codec=soundstream,
    semantic_transformer=semantic_transformer,
    coarse_transformer=coarse_transformer,
    fine_transformer=fine_transformer   
)

In [None]:
music_lm = MusicLM(
    audio_lm=audio_lm,
    mulan_embed_quantizer=quantizer
)
