<a href="https://colab.research.google.com/github/aaronannecchiarico/musiclm-pytorch/blob/main/musiclm_pytorch_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install Requirements

In [None]:
!nvidia-smi

# If this doesn't work, there's no GPU available or detected

In [1]:
pip install musiclm-pytorch

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting musiclm-pytorch
  Downloading musiclm_pytorch-0.0.12-py3-none-any.whl (10 kB)
Collecting accelerate
  Downloading accelerate-0.16.0-py3-none-any.whl (199 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.7/199.7 KB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting vector-quantize-pytorch>=1.0.0
  Downloading vector_quantize_pytorch-1.0.0-py3-none-any.whl (8.7 kB)
Collecting einops>=0.6
  Downloading einops-0.6.0-py3-none-any.whl (41 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.6/41.6 KB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting audiolm-pytorch>=0.10.4
  Downloading audiolm_pytorch-0.10.4-py3-none-any.whl (27 kB)
Collecting x-clip
  Downloading x_clip-0.12.0-py3-none-any.whl (1.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m29.5 MB/s[0m eta [36m0:00:00[0m
[?2

In [4]:
%%capture
! pip install datasets[audio] yt-dlp

# Imports

In [None]:
import subprocess
import os
import requests
from pathlib import Path

from datasets import load_dataset, Audio

import torch
import torchaudio
from musiclm_pytorch import MuLaN, AudioSpectrogramTransformer, TextTransformer, MuLaNEmbedQuantizer, MusicLM
from audiolm_pytorch import SoundStream, SoundStreamTrainer, HubertWithKmeans, SemanticTransformer, SemanticTransformerTrainer, HubertWithKmeans, CoarseTransformer, CoarseTransformerWrapper, CoarseTransformerTrainer, FineTransformer, FineTransformerWrapper, FineTransformerTrainer, AudioLM



# Train MuLaN

In [30]:
audio_transformer = AudioSpectrogramTransformer(
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 128,
    spec_n_fft = 128,
    spec_win_length = 24,
    spec_aug_stretch_factor = 0.8
)

text_transformer = TextTransformer(
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 128
)

mulan = MuLaN(
    audio_transformer = audio_transformer,
    text_transformer = text_transformer
)

# get a ton of <sound, text> pairs and train

wavs = torch.randn(2, 512)
texts = torch.randint(0, 20000, (2, 256))

loss = mulan(wavs, texts)
loss.backward()

# after much training, you can embed sounds and text into a joint embedding space
# for conditioning the audio LM

embeds = mulan.get_audio_latents(wavs)  # during training

embeds = mulan.get_text_latents(texts)  # during inference

spectrogram yielded shape of (65, 43), but had to be cropped to (64, 32) to be patchified for transformer
spectrogram yielded shape of (65, 43), but had to be cropped to (64, 32) to be patchified for transformer


In [31]:
# setup the quantizer with the namespaced conditioning embeddings, unique per quantizer as well as namespace (per transformer)

quantizer = MuLaNEmbedQuantizer(
    mulan = mulan,                          # pass in trained mulan from above
    conditioning_dims = (512, 512, 512), # say all three transformers have model dimensions of 1024
    namespaces = ('semantic', 'coarse', 'fine')
)

# now say you want the conditioning embeddings for semantic transformer

wavs = torch.randn(2, 1024)
conds = quantizer(wavs = wavs, namespace = 'semantic') # (2, 8, 1024) - 8 is number of quantizers

spectrogram yielded shape of (65, 86), but had to be cropped to (64, 80) to be patchified for transformer


# Data

In [5]:
def download_clip(
    video_identifier,
    output_filename,
    start_time,
    end_time,
    tmp_dir='/tmp/musiccaps',
    num_attempts=5,
    url_base='https://www.youtube.com/watch?v='
):
    status = False

    command = f"""
        yt-dlp --quiet --no-warnings -x --audio-format wav -f bestaudio -o "{output_filename}" --download-sections "*{start_time}-{end_time}" {url_base}{video_identifier}
    """.strip()

    attempts = 0
    while True:
        try:
            output = subprocess.check_output(command, shell=True,
                                                stderr=subprocess.STDOUT)
        except subprocess.CalledProcessError as err:
            attempts += 1
            if attempts == num_attempts:
                return status, err.output
        else:
            break

    # Check if the video was successfully saved.
    status = os.path.exists(output_filename)
    return status, 'Downloaded'


def main(
    data_dir: str,
    sampling_rate: int = 44100,
    limit: int = None,
    num_proc: int = 1,
    writer_batch_size: int = 1000,
):
    """
    Download the clips within the MusicCaps dataset from YouTube.
    Args:
        data_dir: Directory to save the clips to.
        sampling_rate: Sampling rate of the audio clips.
        limit: Limit the number of examples to download.
        num_proc: Number of processes to use for downloading.
        writer_batch_size: Batch size for writing the dataset. This is per process.
    """

    ds = load_dataset('google/MusicCaps', split='train')
    if limit is not None:
        print(f"Limiting to {limit} examples")
        ds = ds.select(range(limit))

    data_dir = Path(data_dir)
    data_dir.mkdir(exist_ok=True, parents=True)

    def process(example):
        outfile_path = str(data_dir / f"{example['ytid']}.wav")
        status = True
        if not os.path.exists(outfile_path):
            status = False
            status, log = download_clip(
                example['ytid'],
                outfile_path,
                example['start_s'],
                example['end_s'],
            )

        example['audio'] = outfile_path
        example['download_status'] = status
        return example

    return ds.map(
        process,
        num_proc=num_proc,
        writer_batch_size=writer_batch_size,
        keep_in_memory=False
    ).cast_column('audio', Audio(sampling_rate=sampling_rate))

In [29]:
main('./music_data', num_proc=2, limit=2) # change limit for larger dataset



Limiting to 2 examples
    

#0:   0%|          | 0/1 [00:00<?, ?ex/s]

#1:   0%|          | 0/1 [00:00<?, ?ex/s]

Dataset({
    features: ['ytid', 'start_s', 'end_s', 'audioset_positive_labels', 'aspect_list', 'caption', 'author_id', 'is_balanced_subset', 'is_audioset_eval', 'audio', 'download_status'],
    num_rows: 2
})

# Download Hubert

In [7]:
# Create a script that downloads the following files: https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt, https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960_L9_km500.bin
# into a new folder in the current directory called hubert.

def download_file(url, file_name):
    with open(file_name, "wb") as file:
        response = requests.get(url)
        file.write(response.content)

def get_hubert():
    # Create a folder called hubert
    os.mkdir("hubert")

    # Download the files
    download_file("https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt", "hubert/hubert_base_ls960.pt")
    download_file("https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960_L9_km500.bin", "hubert/hubert_base_ls960_L9_km500.bin")


In [8]:
get_hubert()

# Semantic Transformer Training

In [35]:
wav2vec = HubertWithKmeans(
    checkpoint_path = './hubert/hubert_base_ls960.pt',
    kmeans_path = './hubert/hubert_base_ls960_L9_km500.bin'
)

semantic_transformer = SemanticTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    dim = 512,
    depth = 6,
    audio_text_condition = True      # this must be set to True (same for CoarseTransformer and FineTransformers)
).cuda()

trainer = SemanticTransformerTrainer(
    transformer = semantic_transformer,
    wav2vec = wav2vec,
    audio_conditioner = quantizer,   # pass in the MulanEmbedQuantizer instance above
    folder ='./music_data',
    batch_size = 1,
    data_max_length = 320 * 32,
    num_train_steps = 1
)

trainer.train()

https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations


training with dataset of 1 samples and validating with randomly splitted 1 samples
do you want to clear previous experiment checkpoints and results? (y/n) n
spectrogram yielded shape of (65, 854), but had to be cropped to (64, 848) to be patchified for transformer
0: loss: 6.232154369354248
spectrogram yielded shape of (65, 854), but had to be cropped to (64, 848) to be patchified for transformer
0: valid loss 5.5342841148376465
0: saving model to results
training complete


# Soundstream Training

In [33]:
soundstream = SoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 8,
)

trainer = SoundStreamTrainer(
    soundstream,
    folder = './music_data',
    batch_size = 4,
    grad_accum_every = 8,         # effective batch size of 32
    data_max_length = 320 * 32,
    save_results_every = 2,
    save_model_every = 4,
    num_train_steps = 9
).cuda()

trainer.train()

training with dataset of 1 samples and validating with randomly splitted 1 samples
do you want to clear previous experiment checkpoints and results? (y/n) y
0: soundstream total loss: 169.771, soundstream recon loss: 0.730 | discr (scale 1) loss: 2.001 | discr (scale 0.5) loss: 2.000 | discr (scale 0.25) loss: 2.000
0: saving to results
0: saving model to results
1: soundstream total loss: 153.024, soundstream recon loss: 0.388 | discr (scale 1) loss: 1.972 | discr (scale 0.5) loss: 1.967 | discr (scale 0.25) loss: 1.976
2: soundstream total loss: 149.553, soundstream recon loss: 0.364 | discr (scale 1) loss: 1.959 | discr (scale 0.5) loss: 1.934 | discr (scale 0.25) loss: 1.952
2: saving to results
3: soundstream total loss: 150.173, soundstream recon loss: 0.386 | discr (scale 1) loss: 1.959 | discr (scale 0.5) loss: 1.923 | discr (scale 0.25) loss: 1.923
4: soundstream total loss: 117.401, soundstream recon loss: 0.156 | discr (scale 1) loss: 1.968 | discr (scale 0.5) loss: 1.961 | 

# Coarse Transformer Training

In [34]:
wav2vec = HubertWithKmeans(
    checkpoint_path = './hubert/hubert_base_ls960.pt',
    kmeans_path = './hubert/hubert_base_ls960_L9_km500.bin'
)

soundstream = SoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 8,
)

soundstream.load("./results/soundstream.8.pt")

coarse_transformer = CoarseTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    codebook_size = 1024,
    num_coarse_quantizers = 3,
    dim = 512,
    depth = 6,
    audio_text_condition = True
)

trainer = CoarseTransformerTrainer(
    transformer = coarse_transformer,
    soundstream = soundstream,
    audio_conditioner = quantizer,
    wav2vec = wav2vec,
    folder = './music_data',
    batch_size = 1,
    data_max_length = 320 * 32,
    save_results_every = 2,
    save_model_every = 4,
    num_train_steps = 9
)
# NOTE: I changed num_train_steps to 9 (aka 8 + 1) from 10000 to make things go faster for demo purposes
# adjusting save_*_every variables for the same reason

trainer.train()

https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations


training with dataset of 1 samples and validating with randomly splitted 1 samples
do you want to clear previous experiment checkpoints and results? (y/n) n
spectrogram yielded shape of (65, 854), but had to be cropped to (64, 848) to be patchified for transformer
0: loss: 56.20853805541992
spectrogram yielded shape of (65, 854), but had to be cropped to (64, 848) to be patchified for transformer
0: valid loss 29.35236930847168
0: saving model to results
spectrogram yielded shape of (65, 854), but had to be cropped to (64, 848) to be patchified for transformer
1: loss: 25.14291763305664
spectrogram yielded shape of (65, 854), but had to be cropped to (64, 848) to be patchified for transformer
2: loss: 21.656782150268555
spectrogram yielded shape of (65, 854), but had to be cropped to (64, 848) to be patchified for transformer
2: valid loss 34.56642532348633
spectrogram yielded shape of (65, 854), but had to be cropped to (64, 848) to be patchified for transformer
3: loss: 19.1150474548

# Fine Transformer Training

In [36]:
soundstream = SoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 8,
)

soundstream.load("./results/soundstream.8.pt")

fine_transformer = FineTransformer(
    num_coarse_quantizers = 3,
    num_fine_quantizers = 5,
    codebook_size = 1024,
    dim = 512,
    depth = 6,
    audio_text_condition = True
)

trainer = FineTransformerTrainer(
    transformer = fine_transformer,
    soundstream = soundstream,
    audio_conditioner = quantizer,
    folder = './music_data',
    batch_size = 1,
    data_max_length = 320 * 32,
    num_train_steps = 9
)
# NOTE: I changed num_train_steps to 9 (aka 8 + 1) from 10000 to make things go faster for demo purposes
# adjusting save_*_every variables for the same reason

trainer.train()

training with dataset of 1 samples and validating with randomly splitted 1 samples
do you want to clear previous experiment checkpoints and results? (y/n) n
spectrogram yielded shape of (65, 854), but had to be cropped to (64, 848) to be patchified for transformer
0: loss: 64.39331817626953
spectrogram yielded shape of (65, 854), but had to be cropped to (64, 848) to be patchified for transformer
0: valid loss 40.131996154785156
0: saving model to results
spectrogram yielded shape of (65, 854), but had to be cropped to (64, 848) to be patchified for transformer
1: loss: 38.22500991821289
spectrogram yielded shape of (65, 854), but had to be cropped to (64, 848) to be patchified for transformer
2: loss: 18.497600555419922
spectrogram yielded shape of (65, 854), but had to be cropped to (64, 848) to be patchified for transformer
3: loss: 16.688297271728516
spectrogram yielded shape of (65, 854), but had to be cropped to (64, 848) to be patchified for transformer
4: loss: 14.2027502059936

# Generate

In [37]:
from musiclm_pytorch import MusicLM

audiolm = AudioLM(
    wav2vec = wav2vec,
    soundstream = soundstream,
    semantic_transformer = semantic_transformer,
    coarse_transformer = coarse_transformer,
    fine_transformer = fine_transformer
)
musiclm = MusicLM(
    audio_lm = audiolm,
    mulan_embed_quantizer = quantizer
)

music = musiclm(['the crystalline sounds of the piano in a ballroom']) # torch.Tensor

generating semantic:   0%|          | 3/2048 [00:00<01:09, 29.31it/s]
generating coarse: 100%|██████████| 512/512 [01:38<00:00,  5.20it/s]
generating fine: 100%|██████████| 512/512 [17:50<00:00,  2.09s/it]


In [38]:
import torchaudio
output_path = "out1.wav"
sample_rate = 44100
torchaudio.save(output_path, music.cpu(), sample_rate)