# AudioLM

Implementation of <a href="https://google-research.github.io/seanet/audiolm/examples/">AudioLM</a> in Pytorch Lightning.

This implementation is based on [audiolm-pytorch](https://github.com/lucidrains/audiolm-pytorch). However, here we wrapped their model into a `LightningModule` in order to have ready-to-use object that sets up everything you need: the model but also optimizers, the training loop, etc.

Hopefully, this repo is also easier to read and understand, both for users and developers that wish to contribute.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from sai.utils import nb_init

nb_init()

INFO | nb_init | Set current dir to synthetic-data
INFO | nb_init | You are using Python 3.10.10 (main, Sep 14 2023, 16:59:47) [Clang 14.0.3 (clang-1403.0.22.14.1)]


In [3]:
from loguru import logger
import lightning.pytorch as pl
import torch

from sai.datasets import MusicCaps
from sai.models import AudioLMLightning

## Data

AudioLM can be trained on the MusicCaps dataset. This is a dataset of YouTube audioclips with annotations.

We will use a subset (`samples_to_load`) of total audio files, or the download will take time and disk space.

In [4]:
ROOT = ".data/music_data"

# Load dataset
dm = MusicCaps(
    root=ROOT,
    samples_to_load=32,
    batch_size=1,
)
dm.prepare_data()
dm.setup()

A sample of this dataset comes in the form of a dictionary.

In [5]:
for batch in dm.train_dataloader():
    break

batch

{'ytid': ['-BHPu-dPmWQ'],
 'start_s': tensor([30]),
 'end_s': tensor([40]),
 'audioset_positive_labels': ['/m/04rlf,/t/dd00032'],
 'aspect_list': ["['intimate wide mixed vocals', 'synth lead melody', 'punchy kick', 'noisy snare', 'claps', 'groovy bass guitar', 'tinny wide hi hats', 'short snare roll', 'alternative/indie', 'electric guitar melody', 'easygoing', 'melancholic']"],
 'caption': ['The Alternative/Indie song features an intimate, widely spread, mixed vocals singing over noisy snare, punchy kick, wide tinny hi hats, electric guitar melody, synth lead melody and groovy bass guitar. At the end of the loop there is a short snare roll and some claps. It sounds easygoing and melancholic thanks to those vocals.'],
 'author_id': tensor([4]),
 'is_balanced_subset': tensor([False]),
 'is_audioset_eval': tensor([True]),
 'audio': {'path': ['.data/music_data/-BHPu-dPmWQ.wav'],
  'array': tensor([[ 0.0237,  0.0276,  0.0278,  ..., -0.3301, -0.4594,  0.0000]]),
  'sampling_rate': tensor([44

## Model

`AudioLMLightning` initialization and training. `AudioLMLightning` will look in `data_folder` for audio files. This folder has been populated by the `MusicCaps` datamodule above.

We may also have not initialized `MusicCaps` and just provided a `data_folder` to `AudioLMLightning`, `AudioLMLightning` would have downloaded the dataset for us.

In the original [audiolm repo](https://github.com/lucidrains/audiolm-pytorch), one would need to download checkpoits, then initialize all models and transformers (`SoundStream`, etc.), train them one by one, then combining them together into a `AudioLM` object. You do not need to do this here. As you can see, by default everything is set up automatically.

In [6]:
model = AudioLMLightning(data_folder=ROOT)

We'll be training for a few steps only, for brevity.

In [7]:
# Trainer
trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=False,
    max_steps=4,
    accelerator="cpu",
)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [8]:
# Train
trainer.fit(model)


  | Name                         | Type                       | Params
----------------------------------------------------------------------------
0 | wave2vec                     | HubertWithKmeans           | 94.7 M
1 | soundstream                  | SoundStream                | 48.8 M
2 | semantic_transformer         | SemanticTransformer        | 59.2 M
3 | semantic_transformer_wrapper | SemanticTransformerWrapper | 153 M 
4 | coarse_transformer           | CoarseTransformer          | 18.6 M
5 | coarse_transformer_wrapper   | CoarseTransformerWrapper   | 162 M 
6 | fine_transformer             | FineTransformer            | 22.7 M
7 | fine_transformer_wrapper     | FineTransformerWrapper     | 71.5 M
8 | model                        | AudioLM                    | 243 M 
----------------------------------------------------------------------------
243 M     Trainable params
0         Non-trainable params
243 M     Total params
975.968   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_steps=4` reached.


Now we can generat audio files. We may input some text or not at all.

(This will take time...)

In [11]:
generated_wav: torch.Tensor = model(
    text='chirping of birds and the distant echos of bells',
    max_length=1024,
)

INFO | forward | Generating semantic token...
generating semantic:  61%|██████▏   | 628/1024 [00:07<00:04, 82.05it/s]
INFO | forward | Generating coarse token...
generating coarse: 100%|██████████| 512/512 [00:40<00:00, 12.62it/s]
INFO | forward | Generating wave...
generating fine: 100%|██████████| 512/512 [07:43<00:00,  1.10it/s]


In [12]:
import sounddevice as sd

sd.play(generated_wav.detach().cpu().numpy(), 44100)
sd.wait()

PortAudioError: Error opening OutputStream: Invalid number of channels [PaErrorCode -9998]