<a href="https://colab.research.google.com/github/rajakumaran/musiclm-pytorch/blob/main/musiclm-pytorch-demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install Requirements

In [None]:
!nvidia-smi

# If this doesn't work, there's no GPU available or detected

Thu Feb  9 02:50:50 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.47.03    Driver Version: 510.47.03    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   51C    P0    25W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
pip install musiclm-pytorch

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting musiclm-pytorch
  Downloading musiclm_pytorch-0.0.17-py3-none-any.whl (10 kB)
Collecting vector-quantize-pytorch>=1.0.0
  Downloading vector_quantize_pytorch-1.0.0-py3-none-any.whl (8.7 kB)
Collecting x-clip
  Downloading x_clip-0.12.0-py3-none-any.whl (1.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m19.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting audiolm-pytorch>=0.10.4
  Downloading audiolm_pytorch-0.11.15-py3-none-any.whl (28 kB)
Collecting einops>=0.6
  Downloading einops-0.6.0-py3-none-any.whl (41 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.6/41.6 KB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting beartype
  Downloading beartype-0.12.0-py3-none-any.whl (754 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m754.5/754.5 KB[0m [31m65.0 MB/s[0m eta [36m0:00:00[0m
Collect

In [1]:
%%capture
! pip install datasets[audio] yt-dlp

# Imports

In [2]:
import subprocess
import os
import requests
from pathlib import Path

from datasets import load_dataset, Audio

import torch
import torchaudio
from torch.utils.data import Dataset

from musiclm_pytorch import MuLaN, MuLaNTrainer, AudioSpectrogramTransformer, TextTransformer, MuLaNEmbedQuantizer, MusicLM
from audiolm_pytorch import SoundStream, SoundStreamTrainer, HubertWithKmeans, SemanticTransformer, SemanticTransformerTrainer, HubertWithKmeans, CoarseTransformer, CoarseTransformerWrapper, CoarseTransformerTrainer, FineTransformer, FineTransformerWrapper, FineTransformerTrainer, AudioLM



ModuleNotFoundError: ignored

# Data

In [None]:
def download_clip(
    video_identifier,
    output_filename,
    start_time,
    end_time,
    tmp_dir='/tmp/musiccaps',
    num_attempts=5,
    url_base='https://www.youtube.com/watch?v='
):
    status = False

    command = f"""
        yt-dlp --quiet --no-warnings -x --audio-format wav -f bestaudio -o "{output_filename}" --download-sections "*{start_time}-{end_time}" {url_base}{video_identifier}
    """.strip()

    attempts = 0
    while True:
        try:
            output = subprocess.check_output(command, shell=True,
                                                stderr=subprocess.STDOUT)
        except subprocess.CalledProcessError as err:
            attempts += 1
            if attempts == num_attempts:
                return status, err.output
        else:
            break

    # Check if the video was successfully saved.
    status = os.path.exists(output_filename)
    return status, 'Downloaded'


def main(
    data_dir: str,
    sampling_rate: int = 44100,
    limit: int = None,
    num_proc: int = 1,
    writer_batch_size: int = 1000,
):
    """
    Download the clips within the MusicCaps dataset from YouTube.
    Args:
        data_dir: Directory to save the clips to.
        sampling_rate: Sampling rate of the audio clips.
        limit: Limit the number of examples to download.
        num_proc: Number of processes to use for downloading.
        writer_batch_size: Batch size for writing the dataset. This is per process.
    """

    ds = load_dataset('google/MusicCaps', split='train')
    if limit is not None:
        print(f"Limiting to {limit} examples")
        ds = ds.select(range(limit))

    data_dir = Path(data_dir)
    data_dir.mkdir(exist_ok=True, parents=True)

    def process(example):
        outfile_path = str(data_dir / f"{example['ytid']}.wav")
        status = True
        if not os.path.exists(outfile_path):
            status = False
            status, log = download_clip(
                example['ytid'],
                outfile_path,
                example['start_s'],
                example['end_s'],
            )

        example['audio'] = outfile_path
        example['download_status'] = status
        return example

    return ds.map(
        process,
        num_proc=num_proc,
        writer_batch_size=writer_batch_size,
        keep_in_memory=False
    ).cast_column('audio', Audio(sampling_rate=sampling_rate))

In [None]:
ds = main('./music_data', num_proc=2, limit=32, writer_batch_size=4) # change limit for larger dataset



Limiting to 32 examples
    

#0:   0%|          | 0/16 [00:00<?, ?ex/s]

#1:   0%|          | 0/16 [00:00<?, ?ex/s]

In [None]:
class TextAudioDataset(Dataset):
    def __init__(self, dset):
        super().__init__()
        self.dset = dset

    def __len__(self):
        return len(self.dset)

    def __getitem__(self, idx):
        ex = self.dset[idx]
        caption = ex['caption']
        wav, samplerate = torchaudio.load(ex['audio']['path'])
        return caption, wav

# Train MuLaN

In [1]:
audio_transformer = AudioSpectrogramTransformer(
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64,
    spec_n_fft = 128,
    spec_win_length = 24,
    spec_aug_stretch_factor = 0.8
)

text_transformer = TextTransformer(
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64
)

mulan = MuLaN(
    audio_transformer = audio_transformer,
    text_transformer = text_transformer
)

trainer = MuLaNTrainer(
    mulan = mulan,
    dataset = TextAudioDataset(ds),
    batch_size = 4
)

trainer.train()

NameError: ignored

In [None]:
# setup the quantizer with the namespaced conditioning embeddings, unique per quantizer as well as namespace (per transformer)

quantizer = MuLaNEmbedQuantizer(
    mulan = mulan,                          # pass in trained mulan from above
    conditioning_dims = (1024, 1024, 1024), # say all three transformers have model dimensions of 1024
    namespaces = ('semantic', 'coarse', 'fine')
)

# Download Hubert

In [None]:
# Create a script that downloads the following files: https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt, https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960_L9_km500.bin
# into a new folder in the current directory called hubert.

def download_file(url, file_name):
    with open(file_name, "wb") as file:
        response = requests.get(url)
        file.write(response.content)

def get_hubert():
    # Create a folder called hubert
    os.mkdir("hubert")

    # Download the files
    download_file("https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt", "hubert/hubert_base_ls960.pt")
    download_file("https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960_L9_km500.bin", "hubert/hubert_base_ls960_L9_km500.bin")


In [None]:
get_hubert()

# Semantic Transformer Training

In [None]:
wav2vec = HubertWithKmeans(
    checkpoint_path = './hubert/hubert_base_ls960.pt',
    kmeans_path = './hubert/hubert_base_ls960_L9_km500.bin'
)

semantic_transformer = SemanticTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    dim = 1024,
    depth = 6,
    audio_text_condition = True      # this must be set to True (same for CoarseTransformer and FineTransformers)
).cuda()

trainer = SemanticTransformerTrainer(
    transformer = semantic_transformer,
    wav2vec = wav2vec,
    audio_conditioner = quantizer,   # pass in the MulanEmbedQuantizer instance above
    folder ='./music_data',
    batch_size = 1,
    data_max_length = 320 * 32,
    num_train_steps = 1
)

trainer.train()

# Soundstream Training

In [None]:
soundstream = SoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 8,
)

trainer = SoundStreamTrainer(
    soundstream,
    folder = './music_data',
    batch_size = 4,
    grad_accum_every = 8,         # effective batch size of 32
    data_max_length = 320 * 32,
    save_results_every = 2,
    save_model_every = 4,
    num_train_steps = 9
).cuda()

trainer.train()

# Coarse Transformer Training

In [None]:
wav2vec = HubertWithKmeans(
    checkpoint_path = './hubert/hubert_base_ls960.pt',
    kmeans_path = './hubert/hubert_base_ls960_L9_km500.bin'
)

soundstream = SoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 8,
)

soundstream.load("./results/soundstream.8.pt")

coarse_transformer = CoarseTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    codebook_size = 1024,
    num_coarse_quantizers = 3,
    dim = 1024,
    depth = 6,
    audio_text_condition = True
)

trainer = CoarseTransformerTrainer(
    transformer = coarse_transformer,
    soundstream = soundstream,
    audio_conditioner = quantizer,
    wav2vec = wav2vec,
    folder = './music_data',
    batch_size = 1,
    data_max_length = 320 * 32,
    save_results_every = 2,
    save_model_every = 4,
    num_train_steps = 9
)
# NOTE: I changed num_train_steps to 9 (aka 8 + 1) from 10000 to make things go faster for demo purposes
# adjusting save_*_every variables for the same reason

trainer.train()

# Fine Transformer Training

In [None]:
soundstream = SoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 8,
)

soundstream.load("./results/soundstream.8.pt")

fine_transformer = FineTransformer(
    num_coarse_quantizers = 3,
    num_fine_quantizers = 5,
    codebook_size = 1024,
    dim = 1024,
    depth = 6,
    audio_text_condition = True
)

trainer = FineTransformerTrainer(
    transformer = fine_transformer,
    soundstream = soundstream,
    audio_conditioner = quantizer,
    folder = './music_data',
    batch_size = 1,
    data_max_length = 320 * 32,
    num_train_steps = 9
)
# NOTE: I changed num_train_steps to 9 (aka 8 + 1) from 10000 to make things go faster for demo purposes
# adjusting save_*_every variables for the same reason

trainer.train()

# Generate

In [None]:
from musiclm_pytorch import MusicLM

audiolm = AudioLM(
    wav2vec = wav2vec,
    soundstream = soundstream,
    semantic_transformer = semantic_transformer,
    coarse_transformer = coarse_transformer,
    fine_transformer = fine_transformer
)
musiclm = MusicLM(
    audio_lm = audiolm,
    mulan_embed_quantizer = quantizer
)

music = musiclm(['the crystalline sounds of the piano in a ballroom']) # torch.Tensor

In [None]:
output_path = "out.wav"
sample_rate = 44100
torchaudio.save(output_path, music.cpu(), sample_rate)