# Train the RNN style metric model (Audioset)

Contrastively, using bilinear similarity

## Setup

In [1]:
from ss_vq_vae.models.vqvae_oneshot import Model
import confugue

2024-07-08 16:59:00.625323: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
cfg_path = "/mnt/vdb/model-original-no-style-pretraining-19-11-2023/config.yaml"
cfg = confugue.Configuration.from_yaml_file(cfg_path)

In [3]:
from ss_vq_vae.nn.nn import ResidualWrapper
from ss_vq_vae.nn.bilinear_similarity import BilinearSimilarity
from torch import nn

class StyleEncoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.style_encoder_1d = nn.Sequential(*cfg['model']['style_encoder_1d'].configure_list())
        self.style_encoder_rnn = cfg['model']['style_encoder_rnn'].maybe_configure(nn.GRU, batch_first=True)
        self.style_encoder_0d = nn.Sequential(*cfg['model']['style_encoder_0d'].configure_list())
        
    def forward(self, input, length):
        encoded = self.style_encoder_1d(input)

        # Mask positions corresponding to padding
        length = (length // (input.shape[2] / encoded.shape[2])).to(torch.int)
        mask = (torch.arange(encoded.shape[2], device=encoded.device) < length[:, None])[:, None, :]
        encoded = encoded * mask

        if self.style_encoder_rnn is not None:
            encoded = encoded.transpose(1, 2)
            encoded = nn.utils.rnn.pack_padded_sequence(
                encoded, length.clamp(min=1).to('cpu'),
                batch_first=True, enforce_sorted=False)
            _, encoded = self.style_encoder_rnn(encoded)
            # Get rid of layer dimension
            encoded = encoded.transpose(0, 1).reshape(input.shape[0], -1)
        else:
            # Compute the Gram matrix, normalized by the length squared
            encoded = encoded / mask.sum(dim=2, keepdim=True) + torch.finfo(encoded.dtype).eps
            encoded = torch.matmul(encoded, encoded.transpose(1, 2))
        encoded = encoded.reshape(encoded.shape[0], -1)

        encoded = self.style_encoder_0d(encoded)

        return encoded, {}


In [4]:
import os
from torch.utils.data import Dataset
import numpy as np
import pandas as pd
import librosa
import torch


def collate_audio_data(samples):
    filtered_batch = [sample for sample in samples if sample[0] is not None]
    audio_names, anchors, positives = zip(*filtered_batch)

    anchors = torch.stack([torch.tensor(x) for x in anchors])
    positives = torch.stack([torch.tensor(x) for x in positives])

    return audio_names, anchors, positives


class LocalAudioset(Dataset):
    def __init__(self, cfg, audio_folder=None, audio_paths=None, sample_len=96, sampling_rate=16000):
        super(LocalAudioset, self).__init__()
        if (audio_folder and audio_paths) or (not audio_folder and not audio_paths):
            raise ValueError("You must set only one of the audio_folder/audio_paths")
        if audio_folder:
            self.audio_paths = [os.path.join(audio_folder, filename) for filename in os.listdir(audio_folder) if filename.endswith('.wav')]
        elif audio_paths:
            with open(audio_paths, 'r') as f:
                self.audio_paths = f.read().split()
        self.sample_len = sample_len
        self.sr = sampling_rate
        self.spec_fn = cfg['spectrogram'].bind(librosa.stft)
        
    def preprocess_audio(self, audio_path):
        audio, _ = librosa.load(audio_path, sr=self.sr)
        if len(audio) == 0:
            audio = np.zeros(shape=[1], dtype=audio.dtype)
        return np.log1p(np.abs(self.spec_fn(y=audio)))

    def __getitem__(self, ix):
        audio_path = self.audio_paths[ix]

        audio = self.preprocess_audio(audio_path)
        # If the audio clip is too short, pad it with zeros
        if audio.shape[1] < self.sample_len:
            padding = self.sample_len - audio.shape[1] + 50
            audio = np.pad(audio, ((0, 0), (0, padding)), mode='constant')

        try:
            anchor_begin, positive_begin = np.random.randint(0, audio.shape[1] - self.sample_len, size=2)
            anchor = audio[:, anchor_begin:anchor_begin + self.sample_len]
            positive = audio[:, positive_begin:positive_begin + self.sample_len]
        except Exception as e:
            print(audio.shape[1], self.sample_len, audio_path)
            raise e
        return audio_path, anchor, positive

    def __len__(self):
        return len(self.audio_paths)

### Training loop

In [5]:
style_encoder = StyleEncoder(cfg)
style_encoder

StyleEncoder(
  (style_encoder_1d): Sequential(
    (0): Conv1d(1025, 1024, kernel_size=(4,), stride=(2,))
    (1): ResidualWrapper(
      (module): Sequential(
        (0): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): LeakyReLU(negative_slope=0.1)
        (2): Conv1d(1024, 1024, kernel_size=(1,), stride=(1,))
      )
    )
    (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): LeakyReLU(negative_slope=0.1)
  )
  (style_encoder_rnn): GRU(1024, 1024, batch_first=True)
  (style_encoder_0d): Sequential()
)

In [6]:
train_audios_path = "/mnt/vdb/audioset-large/train_list.txt"
valid_audios_folder = "/mnt/vdb/audioset-large/valid_wav_16k"
output_path = "/mnt/vdb/run-contrastive-original-style-metric-08-07-2024"

batch_size = 64
learning_rate = 0.001
no_of_epochs = 500
# TODO: dodaj poniżej config device
device = 'cuda'

In [7]:
import wandb
from torch.utils.data import DataLoader
from torch.optim import AdamW
import torch.nn as nn
from tqdm import tqdm


wandb.init(project='original_style_metric_training', config={
    "learning_rate": learning_rate,
    "batch_size": batch_size,
    "epochs": no_of_epochs,
    "train_audios_path": train_audios_path,
    "valid_audios_path": valid_audios_folder,
    "output_path": output_path
})

# make sure the output directory exists
if not os.path.exists(wandb.config.output_path):
    os.makedirs(wandb.config.output_path)

config = wandb.config

train_dataset = LocalAudioset(cfg, audio_paths=train_audios_path)
valid_dataset = LocalAudioset(cfg, audio_folder=valid_audios_folder)

train_loader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,  # Accessing values via config
    num_workers=8,
    collate_fn=collate_audio_data
)
valid_loader = DataLoader(
    valid_dataset,
    batch_size=config.batch_size,
    num_workers=4,
    collate_fn=collate_audio_data
)

bilinear_similarity = BilinearSimilarity(cfg['model']['style_encoder_rnn']['hidden_size'].get())
bilinear_similarity.cuda()
bilinear_similarity.train()

style_encoder.cuda()
style_encoder.train()

optimizer = AdamW([{'params': style_encoder.parameters()}, {'params': bilinear_similarity.parameters()}], 
                  lr=config.learning_rate)
cross_entropy = nn.CrossEntropyLoss()

step = 0
for epoch in range(config.epochs):
    style_encoder.train()
    bilinear_similarity.train()
    
    # Training loop
    for ix, batch in tqdm(enumerate(train_loader)):
        step += 1
        audio_names, anchors, positives = batch
        n_batch = anchors.shape[0]
        anchors = anchors.cuda()
        positives = positives.cuda()
        anchors_lengths = torch.as_tensor([anchor.shape[1] for anchor in anchors], device='cuda')
        positives_lengths = torch.as_tensor([positive.shape[1] for positive in positives], device='cuda')
        optimizer.zero_grad()
        
        y_anchors, _ = style_encoder(anchors, anchors_lengths)
        y_positives, _ = style_encoder(positives, positives_lengths)
        
        similarities = bilinear_similarity(y_anchors, y_positives)
        loss = cross_entropy(similarities, torch.arange(n_batch).cuda())

        loss.backward()
        optimizer.step()
        
        wandb.log({'train_loss': loss.item()}, step=step)
    
    # Validation loop
    style_encoder.eval()
    bilinear_similarity.eval()
    epoch_val_loss = 0.0
    
    with torch.no_grad():
        for ix, batch in tqdm(enumerate(valid_loader)):
            audio_names, anchors, positives = batch
            n_batch = anchors.shape[0]
            anchors = anchors.cuda()
            positives = positives.cuda()
            anchors_lengths = torch.as_tensor([anchor.shape[1] for anchor in anchors], device='cuda')
            positives_lengths = torch.as_tensor([positive.shape[1] for positive in positives], device='cuda')
            
            y_anchors, _ = style_encoder(anchors, anchors_lengths)
            y_positives, _ = style_encoder(positives, positives_lengths)

            similarities = bilinear_similarity(y_anchors, y_positives)
            loss = cross_entropy(similarities, torch.arange(n_batch).cuda())

            epoch_val_loss += loss.item()

    wandb.log({'val_loss': epoch_val_loss / len(valid_loader)}, step=step)
    
    latest_checkpoint_path = os.path.join(config.output_path, 'style_encoder_latest.pth')
    torch.save(style_encoder.state_dict(), latest_checkpoint_path)
    torch.save(bilinear_similarity.state_dict(), os.path.join(config.output_path, 'bilinear_similarity_latest.pth'))
    wandb.save(latest_checkpoint_path)

# Save the model checkpoint
torch.save(style_encoder.state_dict(), 'style_encoder.pth')
wandb.save('style_encoder.pth')


[34m[1mwandb[0m: Currently logged in as: [33mwojtekk23[0m. Use [1m`wandb login --relogin`[0m to force relogin


306it [00:56,  5.39it/s]
281it [01:42,  2.73it/s]
306it [00:58,  5.24it/s]
281it [01:44,  2.69it/s]
306it [00:55,  5.49it/s]
281it [01:39,  2.83it/s]
306it [00:55,  5.53it/s]
281it [01:43,  2.70it/s]
306it [00:52,  5.82it/s]
281it [01:42,  2.74it/s]
306it [00:42,  7.19it/s]
281it [01:42,  2.73it/s]
306it [00:42,  7.14it/s]
281it [01:13,  3.82it/s]
306it [00:56,  5.38it/s]
281it [01:42,  2.73it/s]
306it [00:56,  5.39it/s]
281it [01:42,  2.73it/s]
306it [00:55,  5.51it/s]
281it [00:48,  5.82it/s]
306it [00:56,  5.38it/s]
281it [01:42,  2.74it/s]
306it [00:56,  5.37it/s]
281it [00:57,  4.89it/s]
306it [00:56,  5.40it/s]
281it [01:43,  2.71it/s]
306it [00:55,  5.56it/s]
281it [01:43,  2.72it/s]
306it [00:57,  5.37it/s]
281it [00:51,  5.43it/s]
306it [00:58,  5.24it/s]
281it [01:42,  2.75it/s]
306it [00:57,  5.36it/s]
281it [01:42,  2.74it/s]
306it [00:58,  5.26it/s]
281it [00:57,  4.86it/s]
306it [00:56,  5.38it/s]
281it [01:42,  2.74it/s]
306it [00:52,  5.88it/s]
281it [01:04,  4.33it/s]


KeyboardInterrupt: 