# Train the RNN style metric model (NSynth)

Contrastively, using bilinear similarity

## Setup

In [1]:
from ss_vq_vae.models.vqvae_oneshot import Model
import confugue

2024-07-12 19:09:54.751485: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
cfg_path = "/mnt/vdb/model-original-no-style-pretraining-19-11-2023/config.yaml"
cfg = confugue.Configuration.from_yaml_file(cfg_path)

In [3]:
from ss_vq_vae.nn.nn import ResidualWrapper
from ss_vq_vae.nn.bilinear_similarity import BilinearSimilarity
from torch import nn

class StyleEncoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.style_encoder_1d = nn.Sequential(*cfg['model']['style_encoder_1d'].configure_list())
        self.style_encoder_rnn = cfg['model']['style_encoder_rnn'].maybe_configure(nn.GRU, batch_first=True)
        self.style_encoder_0d = nn.Sequential(*cfg['model']['style_encoder_0d'].configure_list())
        
    def forward(self, input, length):
        encoded = self.style_encoder_1d(input)

        # Mask positions corresponding to padding
        length = (length // (input.shape[2] / encoded.shape[2])).to(torch.int)
        mask = (torch.arange(encoded.shape[2], device=encoded.device) < length[:, None])[:, None, :]
        encoded = encoded * mask

        if self.style_encoder_rnn is not None:
            encoded = encoded.transpose(1, 2)
            encoded = nn.utils.rnn.pack_padded_sequence(
                encoded, length.clamp(min=1).to('cpu'),
                batch_first=True, enforce_sorted=False)
            _, encoded = self.style_encoder_rnn(encoded)
            # Get rid of layer dimension
            encoded = encoded.transpose(0, 1).reshape(input.shape[0], -1)
        else:
            # Compute the Gram matrix, normalized by the length squared
            encoded = encoded / mask.sum(dim=2, keepdim=True) + torch.finfo(encoded.dtype).eps
            encoded = torch.matmul(encoded, encoded.transpose(1, 2))
        encoded = encoded.reshape(encoded.shape[0], -1)

        encoded = self.style_encoder_0d(encoded)

        return encoded, {}


In [17]:
import os
import json
import torch
from torch.utils.data import Dataset, DataLoader
import librosa
import numpy as np

class NSynthDataset(Dataset):
    def __init__(self, cfg, base_folder, json_file='examples.json', sampling_rate=16_000):
        self.base_folder = base_folder
        self.json_file = json_file
        self.data = self._load_metadata()
        self.instrument_to_samples = self._group_by_instrument()
        self.sr = sampling_rate
        self.spec_fn = cfg['spectrogram'].bind(librosa.stft)

    def _load_metadata(self):
        with open(os.path.join(self.base_folder, self.json_file), 'r') as f:
            data = json.load(f)
        return data
    
    def preprocess_audio(self, audio_path):
        audio, _ = librosa.load(audio_path, sr=self.sr)
        if len(audio) == 0:
            audio = np.zeros(shape=[1], dtype=audio.dtype)
        return np.log1p(np.abs(self.spec_fn(y=audio)))

    def _group_by_instrument(self):
        instrument_to_samples = {}
        for key, value in self.data.items():
            instrument = value['instrument']
            if instrument not in instrument_to_samples:
                instrument_to_samples[instrument] = []
            instrument_to_samples[instrument].append(key)
        return instrument_to_samples

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        key = list(self.data.keys())[idx]
        audio_path = os.path.join(self.base_folder, 'audio', f"{key}.wav")
        waveform = self.preprocess_audio(audio_path)
        sample = self.data[key]
        return waveform

In [32]:
from torch.utils.data import Sampler
import random

class PairSampler(Sampler):
    def __init__(self, dataset, batch_size):
        self.dataset = dataset
        self.batch_size = batch_size
        self.instrument_to_indices = self._group_indices_by_instrument()

    def _group_indices_by_instrument(self):
        instrument_to_indices = {}
        for idx, sample in enumerate(self.dataset.data.values()):
            instrument = sample['instrument']
            if instrument not in instrument_to_indices:
                instrument_to_indices[instrument] = []
            instrument_to_indices[instrument].append(idx)
        return instrument_to_indices

    def __iter__(self):
        instruments = list(self.instrument_to_indices.keys())
        random.shuffle(instruments)
        batch = []
        no_of_batches = 0
        
        while True:
            for instrument in instruments:
                if no_of_batches == len(self):
                    return
                indices = self.instrument_to_indices[instrument]
                if len(indices) < 2:
                    continue  # Skip instruments that don't have at least 2 samples
                random.shuffle(indices)
                pair = indices[:2]
                batch.extend(pair)
                if len(batch) == 2 * self.batch_size:
                    yield batch
                    no_of_batches += 1
                    batch = []
            if batch:
                yield batch
                no_of_batches += 1
                batch = []

    def __len__(self):
        return len(self.dataset) // (2 * self.batch_size)

### Training loop

In [33]:
style_encoder = StyleEncoder(cfg)
style_encoder

StyleEncoder(
  (style_encoder_1d): Sequential(
    (0): Conv1d(1025, 1024, kernel_size=(4,), stride=(2,))
    (1): ResidualWrapper(
      (module): Sequential(
        (0): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): LeakyReLU(negative_slope=0.1)
        (2): Conv1d(1024, 1024, kernel_size=(1,), stride=(1,))
      )
    )
    (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): LeakyReLU(negative_slope=0.1)
  )
  (style_encoder_rnn): GRU(1024, 1024, batch_first=True)
  (style_encoder_0d): Sequential()
)

In [34]:
train_audios_folder = "/mnt/vdc/nsynth-train/"
valid_audios_folder = "/mnt/vdc/nsynth-valid/"
output_path = "/mnt/vdc/run-contrastive-original-style-metric-nsynth-12-07-2024"

batch_size = 64
learning_rate = 0.001
no_of_epochs = 500
# TODO: dodaj poniżej config device
device = 'cuda'

In [35]:
import wandb
from torch.utils.data import DataLoader
from torch.optim import AdamW
import torch.nn as nn
from tqdm import tqdm


wandb.init(project='original_style_metric_training_nsynth', config={
    "learning_rate": learning_rate,
    "batch_size": batch_size,
    "epochs": no_of_epochs,
    "train_audios_path": train_audios_folder,
    "valid_audios_path": valid_audios_folder,
    "output_path": output_path
})

# make sure the output directory exists
if not os.path.exists(wandb.config.output_path):
    os.makedirs(wandb.config.output_path)

config = wandb.config

train_dataset = NSynthDataset(cfg, base_folder=config.train_audios_path)
valid_dataset = NSynthDataset(cfg, base_folder=config.valid_audios_path)

train_sampler = PairSampler(train_dataset, config.batch_size)
valid_sampler = PairSampler(valid_dataset, config.batch_size)

train_loader = DataLoader(
    train_dataset,
    num_workers=8,
    batch_sampler=train_sampler,
)
valid_loader = DataLoader(
    valid_dataset,
    num_workers=4,
    batch_sampler=valid_sampler,
)

bilinear_similarity = BilinearSimilarity(cfg['model']['style_encoder_rnn']['hidden_size'].get())
bilinear_similarity.cuda()
bilinear_similarity.train()

style_encoder.cuda()
style_encoder.train()

optimizer = AdamW([{'params': style_encoder.parameters()}, {'params': bilinear_similarity.parameters()}], 
                  lr=config.learning_rate)
cross_entropy = nn.CrossEntropyLoss()

step = 0
for epoch in range(config.epochs):
    style_encoder.train()
    bilinear_similarity.train()
    
    # Training loop
    for ix, batch in tqdm(enumerate(train_loader)):
        if ix == 1000:
            break
        
        step += 1
        positives = batch[0::2]
        anchors = batch[1::2]
        
        n_batch = anchors.shape[0]
        anchors = anchors.cuda()
        positives = positives.cuda()
        anchors_lengths = torch.as_tensor([anchor.shape[1] for anchor in anchors], device='cuda')
        positives_lengths = torch.as_tensor([positive.shape[1] for positive in positives], device='cuda')
        optimizer.zero_grad()
        
        y_anchors, _ = style_encoder(anchors, anchors_lengths)
        y_positives, _ = style_encoder(positives, positives_lengths)
        
        similarities = bilinear_similarity(y_anchors, y_positives)
        loss = cross_entropy(similarities, torch.arange(n_batch).cuda())

        loss.backward()
        optimizer.step()
        
        wandb.log({'train_loss': loss.item()}, step=step)
    
    # Validation loop
    style_encoder.eval()
    bilinear_similarity.eval()
    epoch_val_loss = 0.0
    
    with torch.no_grad():
        for ix, batch in tqdm(enumerate(valid_loader)):
            positives = batch[0::2]
            anchors = batch[1::2]
            
            n_batch = anchors.shape[0]
            anchors = anchors.cuda()
            positives = positives.cuda()
            anchors_lengths = torch.as_tensor([anchor.shape[1] for anchor in anchors], device='cuda')
            positives_lengths = torch.as_tensor([positive.shape[1] for positive in positives], device='cuda')
            
            y_anchors, _ = style_encoder(anchors, anchors_lengths)
            y_positives, _ = style_encoder(positives, positives_lengths)

            similarities = bilinear_similarity(y_anchors, y_positives)
            loss = cross_entropy(similarities, torch.arange(n_batch).cuda())

            epoch_val_loss += loss.item()

    wandb.log({'val_loss': epoch_val_loss / len(valid_loader)}, step=step)
    
    latest_checkpoint_path = os.path.join(config.output_path, 'style_encoder_latest.pth')
    torch.save(style_encoder.state_dict(), latest_checkpoint_path)
    torch.save(bilinear_similarity.state_dict(), os.path.join(config.output_path, 'bilinear_similarity_latest.pth'))
    wandb.save(latest_checkpoint_path)

# Save the model checkpoint
torch.save(style_encoder.state_dict(), 'style_encoder.pth')
wandb.save('style_encoder.pth')


VBox(children=(Label(value='0.002 MB of 0.007 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.286193…

0,1
train_loss,▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train_loss,4.84941


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016669640566669842, max=1.0…

1000it [06:02,  2.76it/s]
Exception ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x7f394b22d430>Traceback (most recent call last):
  File "/home/user/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    Exception ignored in: self._shutdown_workers()<function _MultiProcessingDataLoaderIter.__del__ at 0x7f394b22d430>

0it [00:00, ?it/s]Traceback (most recent call last):
  File "/home/user/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
      File "/home/user/miniconda3/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
if w.is_alive():    
  File "/home/user/miniconda3/lib/python3.8/multiprocessing/process.py", line 160, in is_alive
self._shutdown_workers()Exception ignored in:     Exception ignored in: assert self._parent_pid == os.getpid(), 'can only test a child process'

AssertionError<function _MultiProcessingDataLoaderIter.__del__ 

KeyboardInterrupt: 