In [1]:
import random
import os
import numpy as np
import torch

torch.set_float32_matmul_precision('medium')

SEED: int = 654

def seed_everything(seed: int) -> None:
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(SEED)

from dataclasses import dataclass

@dataclass
class Hyperparameters:
    BATCH_SIZE: int = 8
    IMAGE_SIZE: int = 256
    EPOCHS: int = 100
    LEARNING_RATE: float = 0.001
    LEARNING_RATE_DECAY: int = 0
    TRAIN_SIZE: float = 0.7
    BASE_CHANNEL_SIZE: int = 16
    LATENT_DIM: int = 128
    NUM_INPUT_CHANNELS: int = 2


cfg: Hyperparameters = Hyperparameters()

In [2]:
class StereoToMono(torch.nn.Module):
    """Convert audio from stereo to mono.
    """
    def __init__(self, reduction: str = "avg", *args, **kwargs):
        super().__init__(*args, **kwargs)
        assert isinstance(reduction, str)
        assert reduction in ["avg", "sum"]
        self.reduction = reduction

    def forward(self, sample: torch.Tensor) -> torch.Tensor:
        sample = sample.squeeze()
        if sample.shape[0] == 1 or len(sample.shape) == 1:
            return sample
        return sample.mean(dim=0) if self.reduction == "avg" else sample.sum(dim=0)

class AudioCrop(torch.nn.Module):
    def __init__(self, sample_rate: int, crop_size: int = 60) -> None:
        super().__init__()
        self.crop_size: int = crop_size # in seconds
        self.sample_rate: int = sample_rate

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.shape[1] <= self.crop_size * self.sample_rate:
            return x
        
        start_frame: torch.Tensor = torch.randint(
            low=0,
            high=max(0, x.shape[1] - (self.crop_size*self.sample_rate)),
            size=(1,)
        ).detach()
        return x[:, start_frame:start_frame + (self.crop_size * self.sample_rate)]

class AddGaussianNoise(torch.nn.Module):
    def __init__(self, mean: float = 0., std: float = 1., p: float = 0.5) -> None:
        super().__init__()
        assert 0 <= p <= 1
        self.std: float = std
        self.mean: float = mean
        self.p: float = p

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + torch.randn(x.size()) * self.std + self.mean if random.random() < self.p else x

class Squeeze(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x.squeeze(0)

In [3]:
from safetensors import safe_open
from copy import deepcopy
from torch_audiomentations import Compose, OneOf, SomeOf, Gain, HighPassFilter, LowPassFilter, PeakNormalization, PitchShift
from torchvision.transforms import v2
import torchaudio.transforms as T
import torchaudio
import random

class AudioDataset(torch.utils.data.Dataset):
    def __init__(self, data_path: np.ndarray | list[str], image_size: int, sample_rate: int = 44100, crop_size: int = 60, mode: str = "train") -> None:
        assert mode in ["train", "valid", "test"]
        super().__init__()
        self.data_path: np.ndarray | list[str] = data_path
        self.image_size: int = image_size
        self.sample_rate: int = sample_rate
        self.crop_size: int = crop_size
        self.mode: str = mode
        self._init_transforms()
        
    def _init_transforms(self) -> None:
        self.y_transforms = Compose([
            T.MelSpectrogram(
                sample_rate=self.sample_rate,
                n_fft=512,
                win_length=512,
                hop_length=256,
                n_mels=256
            ),
            v2.Resize(size=(self.image_size, self.image_size)),
            v2.ToDtype(torch.float16, scale=True)
        ])
        
        if self.mode == "train":
            self.x_transforms = Compose([
                AddGaussianNoise(p=0.5),
                T.MelSpectrogram(
                    sample_rate=self.sample_rate,
                    n_fft=512,
                    win_length=512,
                    hop_length=256,
                    n_mels=256
                ),
                OneOf([
                    T.TimeMasking(time_mask_param=100),
                    T.FrequencyMasking(freq_mask_param=100)
                ]),
                v2.Resize(size=(self.image_size, self.image_size)),
                v2.ToDtype(torch.float16, scale=True)
            ])
        else:
            self.x_transforms = deepcopy(self.y_transforms)
    
    def __len__(self) -> int:
        return len(self.data_path)
    
    def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]:
        # print(self.data_path[index])
        with safe_open(self.data_path[index], framework="pt", device="cpu") as f:
            sample_rate = f.get_tensor("sample_rate")
            audio = f.get_tensor("audio")

        num_frames: int = audio.shape[1]
        crop_frames: int = self.crop_size * sample_rate
        # original = T.MelSpectrogram(
        #     sample_rate=self.sample_rate,
        #     n_fft=512,
        #     win_length=512,
        #     hop_length=256,
        #     n_mels=256
        # )(audio)
        
        frame_offset = -1
        if num_frames > crop_frames:
            frame_offset: int = random.randint(0, num_frames-crop_frames)
            audio = audio[:, frame_offset:frame_offset+crop_frames]

        # original_cropped = T.MelSpectrogram(
        #     sample_rate=self.sample_rate,
        #     n_fft=512,
        #     win_length=512,
        #     hop_length=256,
        #     n_mels=256
        # )(audio)

        # print(f"sample_rate: {sample_rate} - num_frames: {num_frames} - frame_offset: {frame_offset} - crop_frames: {crop_frames}")
        
        return self.x_transforms(audio), self.y_transforms(audio)


  torchaudio.set_audio_backend("soundfile")


## Train test split

In [4]:
import pathlib
from sklearn.model_selection import train_test_split
from typing import Any
import pandas as pd

def get_splits(data: pd.DataFrame | np.ndarray | list[...], train_size: float, valid_size: float, test_size: float, stratify_col: str | None = None) -> tuple[Any, Any, Any]:
    assert train_size + valid_size + test_size <= 1.
    
    if stratify_col:
        train_split, valid_test = train_test_split(data, train_size=train_size, stratify=data[stratify_col], random_state=SEED)
        valid_split, test_split = train_test_split(valid_test, train_size=valid_size/(1-train_size), stratify=valid_test[stratify_col], random_state=SEED)
    else:
        train_split, valid_test = train_test_split(data, train_size=train_size, stratify=None, random_state=SEED)
        valid_split, test_split = train_test_split(valid_test, train_size=valid_size/(1-train_size), stratify=None, random_state=SEED)
        
    return train_split, valid_split, test_split

songs_path: list[pathlib.Path] = list(pathlib.Path(os.getcwd()).parent.rglob("*.safetensors"))
train, valid, test = get_splits(songs_path, train_size=0.7, valid_size=0.2, test_size=0.1, stratify_col=None)

## Model

In [5]:
from torch import nn

class Encoder(nn.Module):
    def __init__(self, num_input_channels: int, base_channel_size: int, latent_dim: int, act_fn: nn.Module = nn.Mish):
        """Encoder.

        Args:
           num_input_channels : Number of input channels of the image. For CIFAR, this parameter is 3
           base_channel_size : Number of channels we use in the first convolutional layers. Deeper layers might use a duplicate of it.
           latent_dim : Dimensionality of latent representation z
           act_fn : Activation function used throughout the encoder network
        """
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(num_input_channels, base_channel_size, kernel_size=3, padding=1, stride=2),  # 256 => 128
            act_fn(),
            nn.Conv2d(base_channel_size, 2 * base_channel_size, kernel_size=3, padding=1, stride=2),  # 128 => 64
            act_fn(),
            nn.Conv2d(2 * base_channel_size, 2 * base_channel_size, kernel_size=3, padding=1, stride=2), # 64 => 32,32,32
            act_fn(),
            nn.Conv2d(2 * base_channel_size, 2 * base_channel_size, kernel_size=3, padding=1, stride=2),  # 32 => 32,16,16
            act_fn(),
        )
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(512 * base_channel_size, latent_dim)

    def forward(self, x):
        # print(f"input shape {x.shape}")
        x = self.net(x)
        # print(f"encoder output shape {x.shape}")
        x = self.flatten(x)
        # print(f"flatten output shape {x.shape}")
        x = self.fc(x)
        # print(f"linear output shape {x.shape}")
        return x


class Decoder(nn.Module):
    def __init__(self, num_input_channels: int, base_channel_size: int, latent_dim: int, act_fn: nn.Module = nn.Mish):
        """Decoder.

        Args:
           num_input_channels : Number of channels of the image to reconstruct. For CIFAR, this parameter is 3
           base_channel_size : Number of channels we use in the last convolutional layers. Early layers might use a duplicate of it.
           latent_dim : Dimensionality of latent representation z
           act_fn : Activation function used throughout the decoder network
        """
        super().__init__()
        c_hid = base_channel_size
        # self.linear = nn.Sequential(nn.Linear(latent_dim, 2 * 16 * c_hid), act_fn())
        self.linear = nn.Sequential(nn.Linear(latent_dim,512 * base_channel_size), act_fn())
        self.net = nn.Sequential(
            nn.ConvTranspose2d(2 * base_channel_size, 2 * base_channel_size, kernel_size=3, padding=1, stride=2, output_padding=1),  # 4x4 => 8x8
            act_fn(),
            nn.ConvTranspose2d(2 * base_channel_size, 2 * base_channel_size, kernel_size=3, padding=1, stride=2, output_padding=1),  # 8x8 => 16x16
            act_fn(),
            nn.ConvTranspose2d(2 * base_channel_size, base_channel_size, kernel_size=3, padding=1, stride=2, output_padding=1), # 16x16 => 32x32
            act_fn(),
            nn.ConvTranspose2d(base_channel_size, num_input_channels, kernel_size=3, padding=1, stride=2, output_padding=1), # 16x16 => 32x32
            act_fn(),
            nn.Sigmoid(),  # The input images is scaled between -1 and 1, hence the output has to be bounded as well
        )

    def forward(self, x):
        # print(f"decoder input shape {x.shape}")
        x = self.linear(x)
        # print(f"decoder linear shape {x.shape}")
        x = x.reshape(x.shape[0], -1, 16, 16)
        # print(f"decoder reshape shape {x.shape}")
        x = self.net(x)
        # print(f"decoder output shape {x.shape}")
        return x

In [6]:
from torch import optim
import lightning as L

class Autoencoder(L.LightningModule):
    def __init__(
            self,
            base_channel_size: int,
            latent_dim: int,
            encoder_class: Encoder = Encoder,
            decoder_class: Decoder = Decoder,
            num_input_channels: int = 2,
            width: int = cfg.IMAGE_SIZE,
            height: int = cfg.IMAGE_SIZE,
    ):
        super().__init__()
        # Saving hyperparameters of autoencoder
        self.save_hyperparameters()
        # Creating encoder and decoder
        self.encoder = encoder_class(num_input_channels, base_channel_size, latent_dim)
        self.decoder = decoder_class(num_input_channels, base_channel_size, latent_dim)
        self.loss = nn.MSELoss(reduction="mean")

    def forward(self, x):
        """The forward function takes in an image and returns the reconstructed image."""
        # print(f"input autoencoder shape {x.shape}")
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat

    def _get_reconstruction_loss(self, batch):
        """Given a batch of images, this function returns the reconstruction loss (MSE in our case)."""
        x, _ = batch  # We do not need the labels
        x_hat = self.forward(x)
        return self.loss(x_hat, x)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        # Using a scheduler is optional but can be helpful.
        # The scheduler reduces the LR if the validation performance hasn't improved for the last N epochs
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.2, patience=5, min_lr=5e-5)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "valid_loss"}

    def training_step(self, batch, batch_idx):
        """
        optimizer = self.optimizers()

        # first forward-backward pass
        loss_1 = self.compute_loss(batch)
        self.manual_backward(loss_1, optimizer)
        optimizer.first_step(zero_grad=True)
    
        # second forward-backward pass
        loss_2 = self.compute_loss(batch)
        self.manual_backward(loss_2, optimizer)
        optimizer.second_step(zero_grad=True)
        """
        loss = self._get_reconstruction_loss(batch)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log("test_loss", loss)

In [7]:
from lightning.pytorch.callbacks import TQDMProgressBar
import sys

class MyProgressBar(TQDMProgressBar):
    def init_validation_tqdm(self):
        bar = super().init_validation_tqdm()
        if not sys.stdout.isatty():
            bar.disable = True
        return bar

    def init_predict_tqdm(self):
        bar = super().init_predict_tqdm()
        if not sys.stdout.isatty():
            bar.disable = True
        return bar

    def init_test_tqdm(self):
        bar = super().init_test_tqdm()
        if not sys.stdout.isatty():
            bar.disable = True
        return bar

In [8]:
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import MLFlowLogger
from torch.utils.data import DataLoader

early_stop_callback = EarlyStopping(monitor="valid_loss", min_delta=0.00, patience=3, verbose=True, mode="min")
checkpoint_callback = ModelCheckpoint(dirpath='/home/paolo/git/spotify-playlist-generator/models', filename='{epoch}-{val_loss:.5f}', verbose=True, monitor="valid_loss")


train_dataloader = DataLoader(
    dataset=AudioDataset(data_path=train, image_size=cfg.IMAGE_SIZE, mode="train"),
    batch_size=cfg.BATCH_SIZE,
    num_workers=1,
    shuffle=True,
    pin_memory=True,
    persistent_workers=True
)

valid_dataloader = DataLoader(
    dataset=AudioDataset(data_path=valid, image_size=cfg.IMAGE_SIZE, mode="valid"),
    batch_size=cfg.BATCH_SIZE,
    num_workers=1,
    shuffle=False,
    pin_memory=True,
    persistent_workers=True
)

In [9]:
# import matplotlib.pyplot as plt
# import librosa
# 
# def plot_spectrogram(specgram, title=None, ylabel="freq_bin", ax=None):
#     if ax is None:
#         _, ax = plt.subplots(1, 1)
#     if title is not None:
#         ax.set_title(title)
#     ax.set_ylabel(ylabel)
#     ax.imshow(librosa.power_to_db(specgram[0]), origin="lower", aspect="auto", interpolation="nearest")
# 
# idx = 0
# for x, y, orig, orig_crop in valid_dataloader:
#     x, y, orig, orig_crop = x[idx].numpy(), y[idx].numpy(), orig[idx].numpy(), orig_crop[idx].numpy()
#     print(f"x shape: {x.shape} - y shape: {y.shape} - orig shape: {orig.shape} - orig_crop shape: {orig_crop.shape}")
#     
#     print(f"x min: {x.min()} - x max: {x.max()} - x mean: {x.mean()}")
#     print(f"y min: {y.min()} - y max: {y.max()} - y mean: {y.mean()}")
#     print(f"orig min: {orig.min()} - orig max: {orig.max()} - orig mean: {orig.mean()}")
#     print(f"orig_crop min: {orig_crop.min()} - orig_crop max: {orig_crop.max()} - orig_crop mean: {orig_crop.mean()}")
# 
#     plot_spectrogram(x, "x")
#     plot_spectrogram(y, "y")
#     plot_spectrogram(orig, "orig")
#     plot_spectrogram(orig_crop, "orig cropped")
#     break

In [None]:
from lightning.pytorch.callbacks import RichProgressBar

model = Autoencoder(
    base_channel_size=cfg.BASE_CHANNEL_SIZE,
    latent_dim=cfg.LATENT_DIM,
    encoder_class=Encoder,
    decoder_class=Decoder,
)


trainer: L.Trainer = L.Trainer(
    accelerator = "gpu",
    num_nodes = 1,
    precision = 16,
    logger = None,
    callbacks = [RichProgressBar()],
    fast_dev_run = False,
    max_epochs = cfg.EPOCHS,
    min_epochs = 1,
    overfit_batches = 1,
    log_every_n_steps=100,
    check_val_every_n_epoch = 1,
    enable_checkpointing = False,
    enable_progress_bar = True,
    enable_model_summary = True,
    deterministic = "warn",
    benchmark = True,
    inference_mode = True,
    profiler = None,
    detect_anomaly = True,
    barebones = False,
)


trainer.fit(
    model=model,
    train_dataloaders=train_dataloader,
    val_dataloaders=valid_dataloader,
    ckpt_path=None
)

Using 16bit Automatic Mixed Precision (AMP)
You have turned on `Trainer(detect_anomaly=True)`. This will significantly slow down compute speed and is recommended only for model debugging.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(overfit_batches=1)` was configured so 1 batch will be used.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()

In [None]:


import mlflow
if False:
    mlflow.pytorch.autolog(
        log_every_n_epoch=1,
        log_every_n_step=None,
        log_models=True,
        log_datasets=False,
        disable=False,
        exclusive=False,
        disable_for_unsupported_versions=False,
        silent=False,
        registered_model_name="model",
        extra_tags=None
    )
    
    model = Autoencoder(
        base_channel_size=cfg.BASE_CHANNEL_SIZE,
        latent_dim=cfg.LATENT_DIM,
        encoder_class=Encoder,
        decoder_class=Decoder,
    )
    
    
    trainer: L.Trainer = L.Trainer(
        accelerator = "gpu",
        num_nodes = 1,
        precision = 16,
        logger = MLFlowLogger(experiment_name="lightning_experiment"),
        callbacks = [early_stop_callback, checkpoint_callback, MyProgressBar()],
        fast_dev_run = False,
        max_epochs = cfg.EPOCHS,
        min_epochs = 1,
        overfit_batches = 1,
        log_every_n_steps=50,
        check_val_every_n_epoch = 1,
        enable_checkpointing = True,
        enable_progress_bar = True,
        enable_model_summary = True,
        deterministic = "warn",
        benchmark = True,
        inference_mode = True,
        profiler = None,
        detect_anomaly = True,
        barebones = False,
    )
    
    with mlflow.start_run():
        trainer.fit(
            model=model,
            train_dataloaders=train_dataloader,
            val_dataloaders=valid_dataloader,
            ckpt_path=None
        )