<a href="https://colab.research.google.com/github/softmurata/colab_notebooks/blob/main/audio/audiolm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title AudioLm Test
#Github(https://github.com/lucidrains/audiolm-pytorch)

In [None]:
# LibriSpeech Dataset(https://www.openslr.org/12/)
!wget https://www.openslr.org/resources/12/test-clean.tar.gz

In [None]:
# https://qiita.com/supersaiakujin/items/c6b54e9add21d375161f
!tar -zxvf test-clean.tar.gz

In [None]:
!pip install audiolm-pytorch

In [None]:
#@title Train SoundStream
data_folder = "/content/LibriSpeech/test-clean"  #@param

In [None]:
from audiolm_pytorch import SoundStream, SoundStreamTrainer
soundstream = SoundStream(
    codebook_size=1024,
    rq_num_quantizers=8,
)

trainer = SoundStreamTrainer(
    soundstream,
    folder="/content/LibriSpeech/test-clean",
    batch_size=4,
    grad_accum_every=8,
    data_max_length=320 * 32,
    num_train_steps=1500
).cuda()

trainer.train()



In [None]:
#@title Train 3 transformers
# SemanticTransformer, CoarseTransformer, FineTransformer(SoundStream weights are needed for training them)

In [None]:
# download hubert weights
# hubert checkpoints can be downloaded at
# https://github.com/facebookresearch/fairseq/tree/main/examples/hubert
!mkdir hubert
!wget -O ./hubert/hubert_base_ls960.pt https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt
!wget -O ./hubert/hubert_base_ls960_L9_km500.bin https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960_L9_km500.bin

In [None]:
import torch
from audiolm_pytorch import HubertWithKmeans, SemanticTransformer, SemanticTransformerTrainer

In [None]:
#@title params
hubert_weight_path = '/content/hubert/hubert_base_ls960.pt' #@param
kmeans_weight_path = './hubert/hubert_base_ls960_L9_km500.bin' #@param
soundstream_weight_path = '/content/results/soundstream.0.pt'  #@param
data_folder = "/content/LibriSpeech/test-clean"  #@param

In [None]:
wav2vec = HubertWithKmeans(
    checkpoint_path = hubert_weight_path,
    kmeans_path = kmeans_weight_path,
)

semantic_transformer = SemanticTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    dim = 1024,
    depth = 6
).cuda()


trainer = SemanticTransformerTrainer(
    transformer = semantic_transformer,
    wav2vec = wav2vec,
    folder = data_folder,
    batch_size = 1,
    data_max_length = 320 * 32,
    num_train_steps = 1
)

trainer.train()

In [None]:
wav2vec = HubertWithKmeans(
    checkpoint_path = hubert_weight_path,
    kmeans_path = kmeans_weight_path,
)

soundstream = SoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 8,
)

# soundstream.load('/path/to/trained/soundstream.pt')
soundstream.load(soundstream_weight_path)

coarse_transformer = CoarseTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    codebook_size = 1024,
    num_coarse_quantizers = 3,
    dim = 512,
    depth = 6
)

trainer = CoarseTransformerTrainer(
    transformer = coarse_transformer,
    soundstream = soundstream,
    wav2vec = wav2vec,
    folder = data_folder,
    batch_size = 1,
    data_max_length = 320 * 32,
    num_train_steps = 1500
)

trainer.train()

In [None]:
soundstream = SoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 8,
)

soundstream.load(soundstream_weight_path)

fine_transformer = FineTransformer(
    num_coarse_quantizers = 3,
    num_fine_quantizers = 5,
    codebook_size = 1024,
    dim = 512,
    depth = 6
)

trainer = FineTransformerTrainer(
    transformer = fine_transformer,
    soundstream = soundstream,
    folder = data_folder,
    batch_size = 1,
    data_max_length = 320 * 32,
    num_train_steps = 1500
)

trainer.train()

In [None]:
#@title inference
from audiolm_pytorch import AudioLM

audiolm = AudioLM(
    wav2vec = wav2vec,
    soundstream = soundstream,
    semantic_transformer = semantic_transformer,
    coarse_transformer = coarse_transformer,
    fine_transformer = fine_transformer
)

generated_wav = audiolm(batch_size = 1)

# or with priming

generated_wav_with_prime = audiolm(prime_wave = torch.randn(1, 320 * 8))

# or with text condition, if given

generated_wav_with_text_condition = audiolm(text = ['chirping of birds and the distant echos of bells'])