# RNN Style Metric Model inference

## Model

In [50]:
from ss_vq_vae.models.vqvae_oneshot import Model
import confugue

cfg_path = "/mnt/vdb/model-original-no-style-pretraining-19-11-2023/config.yaml"
cfg = confugue.Configuration.from_yaml_file(cfg_path)

In [51]:
from ss_vq_vae.nn.nn import ResidualWrapper
from ss_vq_vae.nn.bilinear_similarity import BilinearSimilarity
from torch import nn

class StyleEncoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.style_encoder_1d = nn.Sequential(*cfg['model']['style_encoder_1d'].configure_list())
        self.style_encoder_rnn = cfg['model']['style_encoder_rnn'].maybe_configure(nn.GRU, batch_first=True)
        self.style_encoder_0d = nn.Sequential(*cfg['model']['style_encoder_0d'].configure_list())
        
    def forward(self, input, length):
        encoded = self.style_encoder_1d(input)

        # Mask positions corresponding to padding
        length = (length // (input.shape[2] / encoded.shape[2])).to(torch.int)
        mask = (torch.arange(encoded.shape[2], device=encoded.device) < length[:, None])[:, None, :]
        encoded = encoded * mask

        if self.style_encoder_rnn is not None:
            encoded = encoded.transpose(1, 2)
            encoded = nn.utils.rnn.pack_padded_sequence(
                encoded, length.clamp(min=1).to('cpu'),
                batch_first=True, enforce_sorted=False)
            _, encoded = self.style_encoder_rnn(encoded)
            # Get rid of layer dimension
            encoded = encoded.transpose(0, 1).reshape(input.shape[0], -1)
        else:
            # Compute the Gram matrix, normalized by the length squared
            encoded = encoded / mask.sum(dim=2, keepdim=True) + torch.finfo(encoded.dtype).eps
            encoded = torch.matmul(encoded, encoded.transpose(1, 2))
        encoded = encoded.reshape(encoded.shape[0], -1)

        encoded = self.style_encoder_0d(encoded)

        return encoded, {}

In [52]:
import os
import argparse
import torch
from torch.utils.data import DataLoader
from ss_vq_vae.nn.bilinear_similarity import BilinearSimilarity


RNN_PATH = "/mnt/vdb/run-contrastive-original-style-metric-08-07-2024/style_encoder_5508.pth"
SIMILARITY_PATH = "/mnt/vdb/run-contrastive-original-style-metric-08-07-2024/bilinear_similarity_5508.pth"


def load_model(rnn_path, similarity_path, cfg):
    style_metric_model = StyleEncoder(cfg)
    bilinear_similarity = BilinearSimilarity(cfg['model']['style_encoder_rnn']['hidden_size'].get())
    
    style_metric_model.load_state_dict(torch.load(rnn_path))
    bilinear_similarity.load_state_dict(torch.load(similarity_path))
    
    style_metric_model.cuda()
    bilinear_similarity.cuda()
    
    return style_metric_model, bilinear_similarity

style_metric_model, bilinear_similarity = load_model(RNN_PATH, SIMILARITY_PATH, cfg)

## Data

In [53]:
import os
from torch.utils.data import Dataset
import numpy as np
import pandas as pd
import librosa


class LocalTestSet(Dataset):
    def __init__(self, input_triples, generated_files, cfg, sampling_rate=16000):
        super(Dataset, self).__init__()
        if len(input_triples) != len(generated_files):
            raise ValueError(f"Input pairs and generated files lengths do not match: ({len(input_triples), len(generated_files)}")
        self.melody_paths = list(input_triples[0])
        self.style_paths = list(input_triples[1])
        self.ground_paths = list(input_triples[2])
        self.generated_paths = list(generated_files[0])
        self.sr = sampling_rate
        self.spec_fn = cfg['spectrogram'].bind(librosa.stft)
        
    def process_audio(self, audio_path):
        audio, _ = librosa.load(audio_path, sr=self.sr)
        if len(audio) == 0:
            audio = np.zeros(shape=[1], dtype=audio.dtype)
        return np.log1p(np.abs(self.spec_fn(y=audio)))

    def __getitem__(self, ix):
        melody_path = self.melody_paths[ix]
        style_path = self.style_paths[ix]
        ground_path = self.ground_paths[ix]
        generated_path = self.generated_paths[ix]
        
        melody_stft = self.process_audio(melody_path)
        style_stft = self.process_audio(style_path)
        ground_stft = self.process_audio(ground_path)
        generated_stft = self.process_audio(generated_path)
        
        return (melody_path, style_path, ground_path, generated_path), (melody_stft, style_stft, ground_stft, generated_stft)

    def __len__(self):
        return len(self.melody_paths)
    
def extract_segment(audio_stft_batch, segment_len=96, start_frame=0):
    assert len(audio_stft_batch) == 1, "This function should only be used for a single audio batch"
    return audio_stft_batch[:, :, start_frame:start_frame+segment_len]

## Calculate similarities

In [54]:
def evaluate_generative_model(segment_M, segment_S, segment_Ground, segment_Gen, style_metric_model, similarity):
    # Each segment_X is actually a batch of bs=1
    M_lengths = torch.as_tensor([segment.shape[1] for segment in segment_M], device='cuda')
    S_lengths = torch.as_tensor([segment.shape[1] for segment in segment_S], device='cuda')
    Ground_lengths = torch.as_tensor([segment.shape[1] for segment in segment_Ground], device='cuda')
    Gen_lengths = torch.as_tensor([segment.shape[1] for segment in segment_Gen], device='cuda')
    
    y_anchors, _ = style_metric_model(segment_M.cuda(), M_lengths)
    y_positives, _ = style_metric_model(segment_Gen.cuda(), Gen_lengths)
    melody_similarity = similarity(y_anchors, y_positives).item()
    
    
    y_anchors, _ = style_metric_model(segment_S.cuda(), S_lengths)
    y_positives, _ = style_metric_model(segment_Gen.cuda(), Gen_lengths)
    style_similarity = similarity(y_anchors, y_positives).item()
    
    y_anchors, _ = style_metric_model(segment_Ground.cuda(), Ground_lengths)
    y_positives, _ = style_metric_model(segment_Gen.cuda(), Gen_lengths)
    ground_similarity = similarity(y_anchors, y_positives).item()
    
    return melody_similarity, style_similarity, ground_similarity

## Val2 Data

In [55]:
import os
import pandas as pd

GENERATED_PREFIX = '/home/user/ss-vq-vae/experiments/outputs/model-original-no-style-pretraining-19-11-2023/val2'

input_triples = pd.read_csv('/mnt/vdb/validation_set_2.csv', header=None, sep='\t')
generated_list = pd.read_csv(f'{GENERATED_PREFIX}/vqvae_list', header=None).applymap(lambda path: f"{GENERATED_PREFIX}/{path}")


assert os.path.exists(input_triples.iloc[0][0]), "The input pairs file contains non-valid paths"
assert os.path.exists(generated_list.iloc[0][0]), "The generated file contains non-valid paths"

In [56]:
from torch.utils.data import DataLoader

val2_dataset = LocalTestSet(input_triples, generated_list, cfg)
val2_loader = DataLoader(val2_dataset, shuffle=False, batch_size=1)

In [57]:
import os
import itertools
from tqdm import tqdm
import numpy as np
from typing import List


DATASETS = {
    'val2': '/mnt/vdb/validation_set_2.csv',
    'test': '/mnt/vdb/test_set.csv',
}

MODELS = {
    model_name: os.path.join("/home/user/ss-vq-vae/experiments/outputs/", model_name)
    for model_name in [
        'model-original-no-style-pretraining-19-11-2023',
        'model-original-frozen-style-pretraining-21-11-2023',
        'model-original-finetuned-style-pretraining-22-11-2023',
        'model-leaky-relu-no-style-pretraining-13-11-2023',
        'model-leaky-relu-frozen-style-pretraining-15-11-2023',
        'model-leaky-relu-finetuned-style-pretraining-15-11-2023'
    ]
}


def empty_results_df():
    return pd.DataFrame(columns=[
        "model", "dataset", 
        "melody_path", "style_path", "ground_path", "generated_path", 
        "melody_similarity", "style_similarity", "ground_similarity"]
    )


def run_on_dataset_model(dataset_name, model, style_metric_model, similarity):
    input_triples = pd.read_csv(DATASETS[dataset_name], header=None, sep='\t')
    
    generated_prefix = MODELS[model]
    if dataset_name == 'val2':
        generated_prefix = os.path.join(generated_prefix, "val2/")
    generated_list = pd.read_csv(f'{generated_prefix}/vqvae_list', header=None).applymap(lambda path: f"{generated_prefix}/{path}")
    
    # Create dataset
    dataset = LocalTestSet(input_triples, generated_list, cfg)
    dataset_loader = DataLoader(dataset, shuffle=False, batch_size=1)
        
    results = empty_results_df()
    pbar = tqdm(dataset_loader)
    pbar.set_description(f"Processing dataset: {dataset_name}, model: {model}")
    for paths, stfts in pbar:
        m_path, s_path, ground_path, gen_path = paths
        m_stft, s_stft, ground_stft, gen_stft = stfts
        m_segment = extract_segment(m_stft, segment_len=128)
        s_segment = extract_segment(s_stft, segment_len=128)
        ground_segment = extract_segment(ground_stft, segment_len=128)
        gen_segment = extract_segment(gen_stft, segment_len=128)

        with torch.no_grad():
            melody_similarity, style_similarity, ground_similarity = evaluate_generative_model(
                m_segment, 
                s_segment, 
                ground_segment, 
                gen_segment, 
                style_metric_model, 
                similarity
            )

        row = {
            "model": model,
            "dataset": dataset_name,
            "melody_path": m_path[0],
            "style_path": s_path[0],
            "ground_path": ground_path[0],
            "generated_path": gen_path[0],
            "melody_similarity": melody_similarity,
            "style_similarity": style_similarity,
            "ground_similarity": ground_similarity,
        }
        results = pd.concat([results, pd.DataFrame([row])], ignore_index=True)
        
    return results

def run_style_evaluations(datasets: List[str], models: List[str], style_metric_model: StyleEncoder, similarity: BilinearSimilarity) -> pd.DataFrame:
    results = empty_results_df()
    
    for dataset, model in itertools.product(datasets, models):
        try:
            results_dataset_model = run_on_dataset_model(dataset, model, style_metric_model, similarity)
            results = pd.concat([results, results_dataset_model], ignore_index=True)
        except Exception as e:
            print(f"Exception raised in while processing dataset {dataset} and model {model}: {e}")
            print("Moving onto the next one...")
        
    return results


In [58]:
results = run_style_evaluations(DATASETS.keys(), MODELS.keys(), style_metric_model, bilinear_similarity)

Processing dataset: val2, model: model-original-no-style-pretraining-19-11-2023: 100%|██████████| 845/845 [03:32<00:00,  3.98it/s]
Processing dataset: val2, model: model-original-frozen-style-pretraining-21-11-2023: 100%|██████████| 845/845 [00:55<00:00, 15.35it/s]
Processing dataset: val2, model: model-original-finetuned-style-pretraining-22-11-2023: 100%|██████████| 845/845 [00:52<00:00, 16.24it/s]
Processing dataset: val2, model: model-leaky-relu-no-style-pretraining-13-11-2023: 100%|██████████| 845/845 [00:49<00:00, 16.93it/s]
Processing dataset: val2, model: model-leaky-relu-frozen-style-pretraining-15-11-2023: 100%|██████████| 845/845 [01:04<00:00, 13.16it/s]
Processing dataset: val2, model: model-leaky-relu-finetuned-style-pretraining-15-11-2023: 100%|██████████| 845/845 [01:05<00:00, 12.91it/s]
Processing dataset: test, model: model-original-no-style-pretraining-19-11-2023: 100%|██████████| 1661/1661 [01:52<00:00, 14.74it/s]
Processing dataset: test, model: model-original-froze

In [59]:
results.head()

Unnamed: 0,model,dataset,melody_path,style_path,ground_path,generated_path,melody_similarity,style_similarity,ground_similarity
0,model-original-no-style-pretraining-19-11-2023,val2,/mnt/vdb/random_audios_patch_16k//data_719_084...,/mnt/vdb/random_audios_patch_16k//data_850_042...,/mnt/vdb/random_audios_patch_16k//data_719_042...,/home/user/ss-vq-vae/experiments/outputs/model...,131.70401,93.903595,146.0271
1,model-original-no-style-pretraining-19-11-2023,val2,/mnt/vdb/random_audios_patch_16k//data_952_023...,/mnt/vdb/random_audios_patch_16k//data_212_045...,/mnt/vdb/random_audios_patch_16k//data_952_045...,/home/user/ss-vq-vae/experiments/outputs/model...,118.698265,129.622452,167.668091
2,model-original-no-style-pretraining-19-11-2023,val2,/mnt/vdb/random_audios_patch_16k//data_378_040...,/mnt/vdb/random_audios_patch_16k//data_361_042...,/mnt/vdb/random_audios_patch_16k//data_378_042...,/home/user/ss-vq-vae/experiments/outputs/model...,201.656265,94.673439,166.175674
3,model-original-no-style-pretraining-19-11-2023,val2,/mnt/vdb/random_audios_patch_16k//data_984_015...,/mnt/vdb/random_audios_patch_16k//data_712_049...,/mnt/vdb/random_audios_patch_16k//data_984_049...,/home/user/ss-vq-vae/experiments/outputs/model...,51.088585,147.91188,69.226852
4,model-original-no-style-pretraining-19-11-2023,val2,/mnt/vdb/random_audios_patch_16k//data_558_068...,/mnt/vdb/random_audios_patch_16k//data_582_044...,/mnt/vdb/random_audios_patch_16k//data_558_044...,/home/user/ss-vq-vae/experiments/outputs/model...,184.517868,136.75882,112.75882


In [60]:
results.to_csv('results-08-07-2024.csv')

In [61]:
results_copy = results.copy()

In [62]:
# Function to handle the log-mean-log transformation
def log_mean_exp(df, groupby_cols, transform_cols):
    # Copy the groupby columns to ensure they are available for grouping after transformation
    exp_df = df[groupby_cols + transform_cols].copy()
    # Exponentiate the log values
    exp_df[transform_cols] = np.exp(exp_df[transform_cols])
    # Compute the mean of the exponentiated values grouped by specified columns
    mean_exp_df = exp_df.groupby(by=groupby_cols).mean()
    # Log the mean values
    log_mean_df = np.log(mean_exp_df)
    return log_mean_df

# Specify the columns to group by and to transform
groupby_cols = ['model', 'dataset']
transform_cols = ['melody_similarity', 'style_similarity', 'ground_similarity']

# Apply the transformation
log_mean_results = log_mean_exp(results, groupby_cols, transform_cols)

# Optional: Display the result
log_mean_results

Unnamed: 0_level_0,Unnamed: 1_level_0,melody_similarity,style_similarity,ground_similarity
model,dataset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
model-leaky-relu-finetuned-style-pretraining-15-11-2023,test,254.733244,259.067104,287.010367
model-leaky-relu-finetuned-style-pretraining-15-11-2023,val2,250.435135,248.748528,290.17729
model-leaky-relu-frozen-style-pretraining-15-11-2023,test,266.298821,282.241685,260.644536
model-leaky-relu-frozen-style-pretraining-15-11-2023,val2,254.986476,274.39067,275.253688
model-leaky-relu-no-style-pretraining-13-11-2023,test,259.369804,254.101406,278.396963
model-leaky-relu-no-style-pretraining-13-11-2023,val2,253.45026,247.877713,263.865531
model-original-finetuned-style-pretraining-22-11-2023,test,267.579919,253.738228,309.467882
model-original-finetuned-style-pretraining-22-11-2023,val2,259.33577,250.423995,275.173839
model-original-frozen-style-pretraining-21-11-2023,test,257.525902,264.106557,313.269639
model-original-frozen-style-pretraining-21-11-2023,val2,272.061628,254.865895,278.967371


In [64]:
results[['model', 'dataset', 'melody_similarity', 'style_similarity', 'ground_similarity']].groupby(by=['model', 'dataset']).mean()

Unnamed: 0_level_0,Unnamed: 1_level_0,melody_similarity,style_similarity,ground_similarity
model,dataset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
model-leaky-relu-finetuned-style-pretraining-15-11-2023,test,118.615144,116.533439,120.66232
model-leaky-relu-finetuned-style-pretraining-15-11-2023,val2,118.752463,113.216349,122.4693
model-leaky-relu-frozen-style-pretraining-15-11-2023,test,119.701317,117.435,121.849097
model-leaky-relu-frozen-style-pretraining-15-11-2023,val2,119.324216,114.483524,124.551299
model-leaky-relu-no-style-pretraining-13-11-2023,test,115.175725,111.356153,117.509949
model-leaky-relu-no-style-pretraining-13-11-2023,val2,114.31962,107.723954,118.396356
model-original-finetuned-style-pretraining-22-11-2023,test,116.384651,116.071041,129.012125
model-original-finetuned-style-pretraining-22-11-2023,val2,116.08159,113.046906,124.435363
model-original-frozen-style-pretraining-21-11-2023,test,114.946917,116.138929,121.964397
model-original-frozen-style-pretraining-21-11-2023,val2,114.271262,113.834104,121.600113
