In [64]:
from audiolm_pytorch.data import exists, cast_tuple, collate_one_or_multiple_tensors, curtail_to_multiple, curtail_to_shortest_collate
from functools import partial, wraps

from beartype.typing import Tuple
from beartype.door import is_bearable

import torchaudio
from torchaudio.functional import resample

import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader

from audiolm_pytorch.utils import curtail_to_multiple

from einops import rearrange

class MusicDataset(Dataset):
    def __init__(
        self,
        folder,
        exts = ['mp3', 'wav'],
        max_length = None,
        target_sample_hz = None,
        seq_len_multiple_of = None
    ):
        super().__init__()
        path = Path(folder)
        assert path.exists(), 'folder does not exist'

        files = [os.path.join(root, name)
             for root, dirs, files in os.walk(path)
             for name in files
             if name.endswith(tuple(exts))]
        assert len(files) > 0, 'no sound files found'

        self.files = files
        self.max_length = max_length

        self.target_sample_hz = cast_tuple(target_sample_hz)
        self.seq_len_multiple_of = seq_len_multiple_of

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file = self.files[idx]
        
        try:
            data, sample_hz = torchaudio.load(file) # FMA dataset is a mess, so just jumping over bad loads for now
        except e:
            return
        
        if exists(self.target_sample_hz) and self.target_sample_hz[0] != sample_hz:
            data = resample(data, orig_freq=sample_hz, new_freq=self.target_sample_hz[0])
        
        if data.shape[0] != 1:
            data = data.mean(dim=0, keepdims=True)
        
        if data.size(1) > self.max_length:
            max_start = data.size(1) - self.max_length
            start = torch.randint(0, max_start, (1, ))
            data = data[:, start:start + self.max_length]

        else:
            data = torch.nn.functional.pad(data, (0, self.max_length - data.size(1)), 'constant')
        
        data = rearrange(data, '1 ... -> ...')

        num_outputs = len(self.target_sample_hz)
        data = cast_tuple(data, num_outputs)

        if exists(self.max_length):
            data = tuple(d[:self.max_length] for d in data)

        if exists(self.seq_len_multiple_of):
            data = tuple(curtail_to_multiple(d, self.seq_len_multiple_of) for d in data)

        data = tuple(d.float() for d in data)

        if num_outputs == 1:
            return data[0]

        return
    

In [60]:
from math import sqrt
import copy
from random import choice
from pathlib import Path
from shutil import rmtree

from beartype.typing import Union, List, Optional, Tuple
from typing_extensions import Annotated

from beartype import beartype
from beartype.door import is_bearable
from beartype.vale import Is

import torch
import torchaudio
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split

from einops import rearrange

from audiolm_pytorch.optimizer import get_optimizer

from ema_pytorch import EMA

from audiolm_pytorch.soundstream import SoundStream

from audiolm_pytorch.audiolm_pytorch import (
    SemanticTransformer,
    SemanticTransformerWrapper,
    CoarseTransformer,
    CoarseTransformerWrapper,
    FineTransformer,
    FineTransformerWrapper,
    FairseqVQWav2Vec,
    HubertWithKmeans
)

from audiolm_pytorch.data import SoundDataset, get_dataloader

from accelerate import Accelerator

from audiolm_pytorch.trainer import (
    noop,
    cycle,
    yes_or_no,
    accum_log,
    has_duplicates,
    determine_types,
    DEFAULT_SAMPLE_RATE
)

class CustomSoundStreamTrainer(nn.Module):
    def __init__(
        self,
        soundstream: SoundStream,
        *,
        num_train_steps,
        batch_size,
        data_max_length = None,
        folder='',
        lr = 3e-4,
        grad_accum_every = 4,
        wd = 0.,
        max_grad_norm = 0.5,
        discr_max_grad_norm = None,
        save_results_every = 100,
        save_model_every = 1000,
        results_folder = './results',
        valid_frac = 0.05,
        random_split_seed = 42,
        ema_beta = 0.995,
        ema_update_after_step = 500,
        ema_update_every = 10,
        apply_grad_penalty_every = 4,
        accelerate_kwargs: dict = dict(),
        dataset:Dataset = None
    ):
        super().__init__()
        self.accelerator = Accelerator(**accelerate_kwargs)

        self.soundstream = soundstream
        self.ema_soundstream = EMA(soundstream, beta = ema_beta, update_after_step = ema_update_after_step, update_every = ema_update_every)

        self.register_buffer('steps', torch.Tensor([0]))

        self.num_train_steps = num_train_steps
        self.batch_size = batch_size
        self.grad_accum_every = grad_accum_every

        # optimizers

        self.optim = get_optimizer(soundstream.non_discr_parameters(), lr = lr, wd = wd)

        for discr_optimizer_key, discr in self.multiscale_discriminator_iter():
            one_multiscale_discr_optimizer = get_optimizer(discr.parameters(), lr = lr, wd = wd)
            setattr(self, discr_optimizer_key, one_multiscale_discr_optimizer)

        self.discr_optim = get_optimizer(soundstream.stft_discriminator.parameters(), lr = lr, wd = wd)

        # max grad norm

        self.max_grad_norm = max_grad_norm
        self.discr_max_grad_norm = discr_max_grad_norm

        # create dataset
        
        if dataset is not None:
            self.ds = dataset
        else:
            self.ds = SoundDataset(
                folder,
                max_length = data_max_length,
                target_sample_hz = soundstream.target_sample_hz,
                seq_len_multiple_of = soundstream.seq_len_multiple_of
            )

        # split for validation

        if valid_frac > 0:
            train_size = int((1 - valid_frac) * len(self.ds))
            valid_size = len(self.ds) - train_size
            self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
            self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
        else:
            self.valid_ds = self.ds
            self.print(f'training with shared training and valid dataset of {len(self.ds)} samples')

        # dataloader

        self.dl = get_dataloader(self.ds, batch_size = batch_size, shuffle = True)

        self.valid_dl = get_dataloader(self.valid_ds, batch_size = batch_size, shuffle = True)

        # prepare with accelerator

        (
            self.soundstream,
            self.optim,
            self.discr_optim,
            self.dl,
            self.valid_dl
        ) = self.accelerator.prepare(
            self.soundstream,
            self.optim,
            self.discr_optim,
            self.dl,
            self.valid_dl
        )

        # prepare the multiscale discriminators with accelerator

        for name, _ in self.multiscale_discriminator_iter():
            optimizer = getattr(self, name)
            optimizer = self.accelerator.prepare(optimizer)
            setattr(self, name, optimizer)

        # dataloader iterators

        self.dl_iter = cycle(self.dl)
        self.valid_dl_iter = cycle(self.valid_dl)

        self.save_model_every = save_model_every
        self.save_results_every = save_results_every

        self.apply_grad_penalty_every = apply_grad_penalty_every

        self.results_folder = Path(results_folder)

        if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'):
            rmtree(str(self.results_folder))

        self.results_folder.mkdir(parents = True, exist_ok = True)

    def save(self, path):
        pkg = dict(
            model = self.accelerator.get_state_dict(self.soundstream),
            ema_model = self.ema_soundstream.state_dict(),
            optim = self.optim.state_dict(),
            discr_optim = self.discr_optim.state_dict()
        )

        for key, _ in self.multiscale_discriminator_iter():
            discr_optim = getattr(self, key)
            pkg[key] = discr_optim.state_dict()

        torch.save(pkg, path)

    def load(self, path):
        path = Path(path)
        assert path.exists()
        pkg = torch.load(str(path))

        soundstream = self.accelerator.unwrap_model(self.soundstream)
        soundstream.load_state_dict(pkg['model'])

        self.ema_soundstream.load_state_dict(pkg['ema_model'])
        self.optim.load_state_dict(pkg['optim'])
        self.discr_optim.load_state_dict(pkg['discr_optim'])

        for key, _ in self.multiscale_discriminator_iter():
            discr_optim = getattr(self, key)
            discr_optim.load_state_dict(pkg[key])

    def multiscale_discriminator_iter(self):
        for ind, discr in enumerate(self.soundstream.discriminators):
            yield f'multiscale_discr_optimizer_{ind}', discr

    def print(self, msg):
        self.accelerator.print(msg)

    @property
    def device(self):
        return self.accelerator.device

    @property
    def is_distributed(self):
        return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)

    @property
    def is_main(self):
        return self.accelerator.is_main_process

    @property
    def is_local_main(self):
        return self.accelerator.is_local_main_process

    def train_step(self):
        device = self.device

        steps = int(self.steps.item())
        apply_grad_penalty = not (steps % self.apply_grad_penalty_every)

        self.soundstream.train()

        # logs

        logs = {}

        # update vae (generator)

        for _ in range(self.grad_accum_every):
            wave, = next(self.dl_iter)
            wave = wave.to(device)

            loss, (recon_loss, *_) = self.soundstream(wave, return_loss_breakdown = True)

            self.accelerator.backward(loss / self.grad_accum_every)

            accum_log(logs, dict(
                loss = loss.item() / self.grad_accum_every,
                recon_loss = recon_loss / self.grad_accum_every
            ))

        if exists(self.max_grad_norm):
            self.accelerator.clip_grad_norm_(self.soundstream.parameters(), self.max_grad_norm)

        self.optim.step()
        self.optim.zero_grad()

        # update discriminator

        for _ in range(self.grad_accum_every):
            wave, = next(self.dl_iter)
            wave = wave.to(device)

            discr_losses = self.soundstream(
                wave,
                apply_grad_penalty = apply_grad_penalty,
                return_discr_loss = True,
                return_discr_losses_separately = True
            )

            for name, discr_loss in discr_losses:
                self.accelerator.backward(discr_loss / self.grad_accum_every, retain_graph = True)
                accum_log(logs, {name: discr_loss.item() / self.grad_accum_every})

        if exists(self.discr_max_grad_norm):
            self.accelerator.clip_grad_norm_(self.soundstream.stft_discriminator.parameters(), self.discr_max_grad_norm)

        # gradient step for all discriminators

        self.discr_optim.step()
        self.discr_optim.zero_grad()

        for ind in range(len(self.soundstream.discriminators)):
            discr_optimizer = getattr(self, f'multiscale_discr_optimizer_{ind}')
            discr_optimizer.step()
            discr_optimizer.zero_grad()

        # build pretty printed losses

        losses_str = f"{steps}: soundstream total loss: {logs['loss']:.3f}, soundstream recon loss: {logs['recon_loss']:.3f}"

        for key, loss in logs.items():
            if not key.startswith('scale:'):
                continue
            _, scale_factor = key.split(':')

            losses_str += f" | discr (scale {scale_factor}) loss: {loss:.3f}"

        # log

        self.print(losses_str)

        # update exponential moving averaged generator

        if self.is_main:
            self.ema_soundstream.update()

        # sample results every so often

        if self.is_main and not (steps % self.save_results_every):
            for model, filename in ((self.ema_soundstream.ema_model, f'{steps}.ema'), (self.soundstream, str(steps))):
                model.eval()

                wave, = next(self.valid_dl_iter)
                wave = wave.to(device)

                recons = model(wave, return_recons_only = True)

                milestone = steps // self.save_results_every

                for ind, recon in enumerate(recons.unbind(dim = 0)):
                    filename = str(self.results_folder / f'sample_{steps}.flac')
                    torchaudio.save(filename, recon.cpu().detach(), DEFAULT_SAMPLE_RATE)

            self.print(f'{steps}: saving to {str(self.results_folder)}')

        # save model every so often

        if self.is_main and not (steps % self.save_model_every):
            state_dict = self.soundstream.state_dict()
            model_path = str(self.results_folder / f'soundstream.{steps}.pt')
            torch.save(state_dict, model_path)

            ema_state_dict = self.ema_soundstream.state_dict()
            model_path = str(self.results_folder / f'soundstream.{steps}.ema.pt')
            torch.save(ema_state_dict, model_path)

            self.print(f'{steps}: saving model to {str(self.results_folder)}')

        self.steps += 1
        return logs

    def train(self, log_fn = noop):

        while self.steps < self.num_train_steps:
            logs = self.train_step()
            log_fn(logs)

        self.print('training complete')


In [None]:
from audiolm_pytorch import SoundStream, SoundStreamTrainer
from audiolm_pytorch.data import SoundDataset

soundstream = SoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 8,
)
DATA_MAX_LENGTH = 320 * 32

dataset = MusicDataset(
    '/media/philip/ferous/FMA/fma_small/', 
    max_length=DATA_MAX_LENGTH,
    target_sample_hz = soundstream.target_sample_hz,
    seq_len_multiple_of = soundstream.seq_len_multiple_of
)

trainer = CustomSoundStreamTrainer(
    soundstream,
    batch_size = 4,
    grad_accum_every = 8,         # effective batch size of 32
    data_max_length = DATA_MAX_LENGTH,
    num_train_steps = 1000,
    dataset=dataset
).cuda()

trainer.train()

training with dataset of 7597 samples and validating with randomly splitted 400 samples


do you want to clear previous experiment checkpoints and results? (y/n)  y


0: soundstream total loss: 24.334, soundstream recon loss: 0.042 | discr (scale 1) loss: 2.000 | discr (scale 0.5) loss: 2.000 | discr (scale 0.25) loss: 2.000
0: saving to results
0: saving model to results
1: soundstream total loss: 26.065, soundstream recon loss: 0.045 | discr (scale 1) loss: 1.995 | discr (scale 0.5) loss: 1.995 | discr (scale 0.25) loss: 1.996
2: soundstream total loss: 31.013, soundstream recon loss: 0.052 | discr (scale 1) loss: 1.984 | discr (scale 0.5) loss: 1.984 | discr (scale 0.25) loss: 1.989
3: soundstream total loss: 30.583, soundstream recon loss: 0.049 | discr (scale 1) loss: 1.971 | discr (scale 0.5) loss: 1.985 | discr (scale 0.25) loss: 1.983
4: soundstream total loss: 24.303, soundstream recon loss: 0.033 | discr (scale 1) loss: 1.959 | discr (scale 0.5) loss: 2.000 | discr (scale 0.25) loss: 1.981
5: soundstream total loss: 21.495, soundstream recon loss: 0.034 | discr (scale 1) loss: 1.950 | discr (scale 0.5) loss: 2.011 | discr (scale 0.25) loss