# Training MusicLM model
### Robert Chen, Ahmadsho Akdodshoev, Philip Timofeev

## 0. Imports

In [3]:
!pip install musiclm-pytorch

Collecting musiclm-pytorch
  Downloading musiclm_pytorch-0.2.8-py3-none-any.whl (14 kB)
Collecting audiolm-pytorch>=0.17.0 (from musiclm-pytorch)
  Downloading audiolm_pytorch-1.6.3-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.1/42.1 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting beartype (from musiclm-pytorch)
  Downloading beartype-0.16.4-py3-none-any.whl (819 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m819.1/819.1 kB[0m [31m25.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting einops>=0.6 (from musiclm-pytorch)
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting lion-pytorch (from musiclm-pytorch)
  Downloading lion_pytorch-0.1.2-py3-none-any.whl (4.4 kB)
Collecting vector-quantize-pytorch>=1.0.0 (from musiclm-pytorch)
  Downloading vector_quantize_pytorch-1.10.4-p

In [4]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchaudio
from musiclm_pytorch import MuLaN, MuLaNEmbedQuantizer, MuLaNTrainer, \
                            AudioSpectrogramTransformer, TextTransformer, MusicLM
from audiolm_pytorch import SemanticTransformer, SemanticTransformerTrainer, \
                            CoarseTransformer, CoarseTransformerTrainer, \
                            FineTransformer, FineTransformerTrainer, \
                            AudioLM, HubertWithKmeans, MusicLMSoundStream, \
                            SoundStreamTrainer, SoundStream 
import os
import wave
import urllib.request

## 1. Creating dataloaders and downloading Hubert K-means checkpoints

Creating the dataset

In [5]:
dataset_path = '/kaggle/input/soul-mzk/'

class MusicLMDataset(Dataset):
    def __init__(self, path: str):
        pass
    def __getitem__(self, idx):
        pass
    
train_dataset = MusicLMDataset(dataset_path)

Downloading Hubert checkpoints

In [6]:
hubert_ckpt = 'hubert/hubert_base_ls960.pt'
hubert_quantizer = 'hubert/hubert_base_ls960_L9_km500.bin'
soundstream_ckpt = './results/soundstream.pt'

if not os.path.isdir("hubert"):
  os.makedirs("hubert")
if not os.path.isfile(hubert_ckpt):
  hubert_ckpt_download = f"https://dl.fbaipublicfiles.com/{hubert_ckpt}"
  urllib.request.urlretrieve(hubert_ckpt_download, f"./{hubert_ckpt}")
if not os.path.isfile(hubert_quantizer):
  hubert_quantizer_download = f"https://dl.fbaipublicfiles.com/{hubert_quantizer}"
  urllib.request.urlretrieve(hubert_quantizer_download, f"./{hubert_quantizer}")

## 2. Training MuLaN

Arguments for every module are defined in the respective dictionaries to make fine-tuning easier

In [7]:
AUDIO_KWARGS = {
    'dim': 512,
    'depth': 6,
    'heads': 8,
    'dim_head': 64,
    'spec_n_fft': 128,
    'spec_win_length': 24,
    'spec_aug_stretch_factor': 0.8
}

TEXT_KWARGS = {
    'dim': 512,
    'depth': 6,
    'heads': 8,
    'dim_head': 64
}

MULAN_KWARGS = {
    'dataset': train_dataset
    'num_train_steps': 10,
    'batch_size': 4,
    'force_clear_prev_results': True,
    'save_model_every': 5
}

MULAN_QUANTIZER_KWARGS = {
    'conditioning_dims': (1024, 1024, 1024),
    'namespaces': ('semantic', 'coarse', 'fine')
}

HUBERT_KWARGS = {
    'checkpoint_path': hubert_ckpt,
    'kmeans_path': hubert_quantizer
}

SOUNDSTREAM_TRAINER_KWARGS = {
    'folder': dataset_path,
    'num_train_steps': 20,
    'save_model_every': 2,
    'batch_size': 4
}
    
SEMANTIC_KWARGS = {
    'dim': 1024,
    'depth': 6,
    'audio_text_condition': True 
}

COARSE_KWARGS = {
    'codebook_size': 1024,
    'num_coarse_quantizers': 4,
    'dim': 1024,
    'depth': 6,
    'audio_text_condition': True 
}

FINE_KWARGS = {
    'codebook_size': 1024,
    'num_coarse_quantizers': 4,
    'num_fine_quantizers': 8,
    'dim': 1024,
    'depth': 6,
    'audio_text_condition': True 
}

TRANSFORMER_TRAINER_KWARGS = {
    'folder': dataset_path,
    'num_train_steps': 10,
    'save_model_every': 2,
    'batch_size': 4,
    'data_max_length': 320 * 32
}

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

Training MuLaN

In [8]:
audio_transformer = AudioSpectrogramTransformer(**AUDIO_KWARGS)

text_transformer = TextTransformer(**TEXT_KWARGS)

mulan = MuLaN(audio_transformer, text_transformer)

mulan_trainer = MuLaNTrainer(mulan, **MULAN_KWARGS)

mulan_trainer.train()

mulan_trainer.save()

TypeError: MuLaNTrainer.__init__() missing 1 required positional argument: 'dataset'

## 3. Training SoundStream

In [None]:
soundstream = MusicLMSoundStream()
soundstream_trainer = SoundStreamTrainer(
    soundstream,
    **SOUNDSTREAM_TRAINER_KWARGS
)

soundstream_trainer.train()

soundstream_trainer.save(soundstream_ckpt)

## 4. Training conditioning embeddings

Defining the MuLaN Embed Quantizer and Hubert K-means Embedder

In [None]:
quantizer = MuLaNEmbedQuantizer(
    mulan=mulan,                         
    **MULAN_QUANTIZER_KWARGS
)

wav2vec = HubertWithKmeans(
    **HUBERT_KWARGS
)

Training Semantic Transformer

In [None]:
semantic_transformer = SemanticTransformer(
   num_semantic_tokens=wav2vec.codebook_size,
   **SEMANTIC_KWARGS 
).to(DEVICE)

semantic_trainer = SemanticTransformerTrainer(
    wav2vec,
    semantic_transformer,
    audio_conditioner=quantizer,
    **TRANSFORMER_TRAINER_KWARGS
)

semantic_trainer.train()

Training Coarse Transformer

In [None]:
soundstream = MusicLMSoundStream()

soundstream.load(soundstream_ckpt)

coarse_transformer = CoarseTransformer(
    num_semantic_tokens=wav2vec.codebook_size,
    **COARSE_KWARGS
).to(DEVICE)

coarse_trainer = CoarseTransformerTrainer(
    wav2vec,
    semantic_transformer,
    codec=soundstream,
    audio_conditioner=quantizer,
    **TRANSFORMER_TRAINER_KWARGS
)

coarse_trainer.train()

Training Fine Transformer

In [None]:
soundstream = MusicLMSoundStream()

soundstream.load(soundstream_ckpt)

fine_transformer = FineTransformer(
    codebook_size=wav2vec.codebook_size,
    **FINE_KWARGS
).to(DEVICE)

fine_trainer = FineTransformerTrainer(
    wav2vec,
    semantic_transformer,
    codec=soundstream
    audio_conditioner=quantizer,
    **TRANSFORMER_TRAINER_KWARGS
)

fine_trainer.train()

## 5. Combining AudioLM and MusicLM

In [None]:
audio_lm = AudioLM(
    wav2vec=wav2vec,
    codec=soundstream,
    semantic_transformer=semantic_transformer,
    coarse_transformer=coarse_transformer,
    fine_transformer=fine_transformer   
)

In [None]:
music_lm = MusicLM(
    audio_lm=audio_lm,
    mulan_embed_quantizer=quantizer
)

music = music_lm('soul', num_samples=1)

In [None]:
torch.save(music, 'generated_music.pt')

In [None]:
output_path = "out.wav"
sample_rate = 44100
torchaudio.save(output_path, music.cpu(), sample_rate)