# RNN Style Metric Model inference

## Model

In [1]:
COLA is from ss_vq_vae.models.vqvae_oneshot import Model
import confugue

cfg_path = "/mnt/vdb/model-original-no-style-pretraining-19-11-2023/config.yaml"
cfg = confugue.Configuration.from_yaml_file(cfg_path)

2024-09-08 18:48:04.096989: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
from ss_vq_vae.nn.nn import ResidualWrapper
from ss_vq_vae.nn.bilinear_similarity import BilinearSimilarity
from torch import nn

class StyleEncoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.style_encoder_1d = nn.Sequential(*cfg['model']['style_encoder_1d'].configure_list())
        self.style_encoder_rnn = cfg['model']['style_encoder_rnn'].maybe_configure(nn.GRU, batch_first=True)
        self.style_encoder_0d = nn.Sequential(*cfg['model']['style_encoder_0d'].configure_list())
        
    def forward(self, input, length):
        encoded = self.style_encoder_1d(input)

        # Mask positions corresponding to padding
        length = (length // (input.shape[2] / encoded.shape[2])).to(torch.int)
        mask = (torch.arange(encoded.shape[2], device=encoded.device) < length[:, None])[:, None, :]
        encoded = encoded * mask

        if self.style_encoder_rnn is not None:
            encoded = encoded.transpose(1, 2)
            encoded = nn.utils.rnn.pack_padded_sequence(
                encoded, length.clamp(min=1).to('cpu'),
                batch_first=True, enforce_sorted=False)
            _, encoded = self.style_encoder_rnn(encoded)
            # Get rid of layer dimension
            encoded = encoded.transpose(0, 1).reshape(input.shape[0], -1)
        else:
            # Compute the Gram matrix, normalized by the length squared
            encoded = encoded / mask.sum(dim=2, keepdim=True) + torch.finfo(encoded.dtype).eps
            encoded = torch.matmul(encoded, encoded.transpose(1, 2))
        encoded = encoded.reshape(encoded.shape[0], -1)

        encoded = self.style_encoder_0d(encoded)

        return encoded, {}

In [3]:
import os
import argparse
import torch
from torch.utils.data import DataLoader
from ss_vq_vae.nn.bilinear_similarity import BilinearSimilarity


# RNN_PATH = "/mnt/vdb/run-contrastive-original-style-metric-08-07-2024/style_encoder_5508.pth"
# SIMILARITY_PATH = "/mnt/vdb/run-contrastive-original-style-metric-08-07-2024/bilinear_similarity_5508.pth"
RNN_PATH = "/mnt/vdc/run-contrastive-original-style-metric-nsynth-12-07-2024/style_encoder_9000.pth"
SIMILARITY_PATH = "/mnt/vdc/run-contrastive-original-style-metric-nsynth-12-07-2024/bilinear_similarity_9000.pth"


def load_model(rnn_path, similarity_path, cfg):
    style_metric_model = StyleEncoder(cfg)
    bilinear_similarity = BilinearSimilarity(cfg['model']['style_encoder_rnn']['hidden_size'].get())
    
    style_metric_model.load_state_dict(torch.load(rnn_path))
    bilinear_similarity.load_state_dict(torch.load(similarity_path))
    
    style_metric_model.cuda()
    bilinear_similarity.cuda()
    
    return style_metric_model, bilinear_similarity

style_metric_model, bilinear_similarity = load_model(RNN_PATH, SIMILARITY_PATH, cfg)

## Data

In [4]:
import os
from torch.utils.data import Dataset
import numpy as np
import pandas as pd
import librosa


class LocalTestSet(Dataset):
    def __init__(self, input_triples, generated_files, cfg, sampling_rate=16000):
        super(Dataset, self).__init__()
        if len(input_triples) != len(generated_files):
            raise ValueError(f"Input pairs and generated files lengths do not match: ({len(input_triples), len(generated_files)}")
        self.melody_paths = list(input_triples[0])
        self.style_paths = list(input_triples[1])
        self.ground_paths = list(input_triples[2])
        self.generated_paths = list(generated_files[0])
        self.sr = sampling_rate
        self.spec_fn = cfg['spectrogram'].bind(librosa.stft)
        
    def process_audio(self, audio_path):
        audio, _ = librosa.load(audio_path, sr=self.sr)
        if len(audio) == 0:
            audio = np.zeros(shape=[1], dtype=audio.dtype)
        return np.log1p(np.abs(self.spec_fn(y=audio)))

    def __getitem__(self, ix):
        melody_path = self.melody_paths[ix]
        style_path = self.style_paths[ix]
        ground_path = self.ground_paths[ix]
        generated_path = self.generated_paths[ix]
        
        melody_stft = self.process_audio(melody_path)
        style_stft = self.process_audio(style_path)
        ground_stft = self.process_audio(ground_path)
        generated_stft = self.process_audio(generated_path)
        
        return (melody_path, style_path, ground_path, generated_path), (melody_stft, style_stft, ground_stft, generated_stft)

    def __len__(self):
        return len(self.melody_paths)
    
def extract_segment(audio_stft_batch, segment_len=96, start_frame=0):
    assert len(audio_stft_batch) == 1, "This function should only be used for a single audio batch"
    return audio_stft_batch[:, :, start_frame:start_frame+segment_len]

## Calculate similarities

In [5]:
def evaluate_generative_model(segment_M, segment_S, segment_Ground, segment_Gen, style_metric_model, similarity):
    # Each segment_X is actually a batch of bs=1
    M_lengths = torch.as_tensor([segment.shape[1] for segment in segment_M], device='cuda')
    S_lengths = torch.as_tensor([segment.shape[1] for segment in segment_S], device='cuda')
    Ground_lengths = torch.as_tensor([segment.shape[1] for segment in segment_Ground], device='cuda')
    Gen_lengths = torch.as_tensor([segment.shape[1] for segment in segment_Gen], device='cuda')
    
    y_anchors, _ = style_metric_model(segment_M.cuda(), M_lengths)
    y_positives, _ = style_metric_model(segment_Gen.cuda(), Gen_lengths)
    melody_similarity = similarity(y_anchors, y_positives).item()
    
    
    y_anchors, _ = style_metric_model(segment_S.cuda(), S_lengths)
    y_positives, _ = style_metric_model(segment_Gen.cuda(), Gen_lengths)
    style_similarity = similarity(y_anchors, y_positives).item()
    
    y_anchors, _ = style_metric_model(segment_Ground.cuda(), Ground_lengths)
    y_positives, _ = style_metric_model(segment_Gen.cuda(), Gen_lengths)
    ground_similarity = similarity(y_anchors, y_positives).item()
    
    return melody_similarity, style_similarity, ground_similarity

## Val2 Data

In [6]:
import os
import pandas as pd

GENERATED_PREFIX = '/home/user/ss-vq-vae/experiments/outputs/model-original-no-style-pretraining-19-11-2023/val2'

input_triples = pd.read_csv('/mnt/vdb/validation_set_2.csv', header=None, sep='\t')
generated_list = pd.read_csv(f'{GENERATED_PREFIX}/vqvae_list', header=None).applymap(lambda path: f"{GENERATED_PREFIX}/{path}")


assert os.path.exists(input_triples.iloc[0][0]), "The input pairs file contains non-valid paths"
assert os.path.exists(generated_list.iloc[0][0]), "The generated file contains non-valid paths"

In [7]:
from torch.utils.data import DataLoader

val2_dataset = LocalTestSet(input_triples, generated_list, cfg)
val2_loader = DataLoader(val2_dataset, shuffle=False, batch_size=1)

In [8]:
import os
import itertools
from tqdm import tqdm
import numpy as np
from typing import List


DATASETS = {
    'val2': '/mnt/vdb/validation_set_2.csv',
    'test': '/mnt/vdb/test_set.csv',
}

MODELS = {
    model_name: os.path.join("/home/user/ss-vq-vae/experiments/outputs/", model_name)
    for model_name in [
        'model-original-no-style-pretraining-19-11-2023',
        'model-original-frozen-style-pretraining-21-11-2023',
        'model-original-finetuned-style-pretraining-22-11-2023',
        'model-original-no-style-pretraining-with-ssl-dataloader-07-09-2024',
        'model-leaky-relu-no-style-pretraining-30-08-2024',
        'model-leaky-relu-frozen-style-pretraining-01-09-2024',
        'model-leaky-relu-finetuned-style-pretraining-29-08-2024',
    ]
}


def empty_results_df():
    return pd.DataFrame(columns=[
        "model", "dataset", 
        "melody_path", "style_path", "ground_path", "generated_path", 
        "melody_similarity", "style_similarity", "ground_similarity"]
    )


def run_on_dataset_model(dataset_name, model, style_metric_model, similarity):
    input_triples = pd.read_csv(DATASETS[dataset_name], header=None, sep='\t')
    
    generated_prefix = MODELS[model]
    if dataset_name == 'val2':
        generated_prefix = os.path.join(generated_prefix, "val2/")
    generated_list = pd.read_csv(f'{generated_prefix}/vqvae_list', header=None).applymap(lambda path: f"{generated_prefix}/{path}")
    
    # Create dataset
    dataset = LocalTestSet(input_triples, generated_list, cfg)
    dataset_loader = DataLoader(dataset, shuffle=False, batch_size=1)
        
    results = empty_results_df()
    pbar = tqdm(dataset_loader)
    pbar.set_description(f"Processing dataset: {dataset_name}, model: {model}")
    for paths, stfts in pbar:
        m_path, s_path, ground_path, gen_path = paths
        m_stft, s_stft, ground_stft, gen_stft = stfts
        m_segment = extract_segment(m_stft, segment_len=128)
        s_segment = extract_segment(s_stft, segment_len=128)
        ground_segment = extract_segment(ground_stft, segment_len=128)
        gen_segment = extract_segment(gen_stft, segment_len=128)

        with torch.no_grad():
            melody_similarity, style_similarity, ground_similarity = evaluate_generative_model(
                m_segment, 
                s_segment, 
                ground_segment, 
                gen_segment, 
                style_metric_model, 
                similarity
            )

        row = {
            "model": model,
            "dataset": dataset_name,
            "melody_path": m_path[0],
            "style_path": s_path[0],
            "ground_path": ground_path[0],
            "generated_path": gen_path[0],
            "melody_similarity": melody_similarity,
            "style_similarity": style_similarity,
            "ground_similarity": ground_similarity,
        }
        results = pd.concat([results, pd.DataFrame([row])], ignore_index=True)
        
    return results

def run_style_evaluations(datasets: List[str], models: List[str], style_metric_model: StyleEncoder, similarity: BilinearSimilarity) -> pd.DataFrame:
    results = empty_results_df()
    
    for dataset, model in itertools.product(datasets, models):
        try:
            results_dataset_model = run_on_dataset_model(dataset, model, style_metric_model, similarity)
            results = pd.concat([results, results_dataset_model], ignore_index=True)
        except Exception as e:
            print(f"Exception raised in while processing dataset {dataset} and model {model}: {e}")
            print("Moving onto the next one...")
        
    return results


In [9]:
results = run_style_evaluations(DATASETS.keys(), MODELS.keys(), style_metric_model, bilinear_similarity)

Processing dataset: val2, model: model-original-no-style-pretraining-19-11-2023: 100%|██████████| 845/845 [01:07<00:00, 12.48it/s]
Processing dataset: val2, model: model-original-frozen-style-pretraining-21-11-2023: 100%|██████████| 845/845 [00:49<00:00, 16.95it/s]
Processing dataset: val2, model: model-original-finetuned-style-pretraining-22-11-2023: 100%|██████████| 845/845 [01:03<00:00, 13.32it/s]
Processing dataset: val2, model: model-original-no-style-pretraining-with-ssl-dataloader-07-09-2024: 100%|██████████| 845/845 [01:13<00:00, 11.54it/s]
Processing dataset: val2, model: model-leaky-relu-no-style-pretraining-30-08-2024: 100%|██████████| 845/845 [00:50<00:00, 16.72it/s]
Processing dataset: val2, model: model-leaky-relu-frozen-style-pretraining-01-09-2024: 100%|██████████| 845/845 [01:18<00:00, 10.79it/s]
Processing dataset: val2, model: model-leaky-relu-finetuned-style-pretraining-29-08-2024: 100%|██████████| 845/845 [01:11<00:00, 11.85it/s]
Processing dataset: test, model: mo

In [10]:
results.head()

Unnamed: 0,model,dataset,melody_path,style_path,ground_path,generated_path,melody_similarity,style_similarity,ground_similarity
0,model-original-no-style-pretraining-19-11-2023,val2,/mnt/vdb/random_audios_patch_16k//data_719_084...,/mnt/vdb/random_audios_patch_16k//data_850_042...,/mnt/vdb/random_audios_patch_16k//data_719_042...,/home/user/ss-vq-vae/experiments/outputs/model...,10116.464844,10086.172852,9860.506836
1,model-original-no-style-pretraining-19-11-2023,val2,/mnt/vdb/random_audios_patch_16k//data_952_023...,/mnt/vdb/random_audios_patch_16k//data_212_045...,/mnt/vdb/random_audios_patch_16k//data_952_045...,/home/user/ss-vq-vae/experiments/outputs/model...,13336.756836,13292.535156,13336.854492
2,model-original-no-style-pretraining-19-11-2023,val2,/mnt/vdb/random_audios_patch_16k//data_378_040...,/mnt/vdb/random_audios_patch_16k//data_361_042...,/mnt/vdb/random_audios_patch_16k//data_378_042...,/home/user/ss-vq-vae/experiments/outputs/model...,12211.90625,8761.861328,13312.101562
3,model-original-no-style-pretraining-19-11-2023,val2,/mnt/vdb/random_audios_patch_16k//data_984_015...,/mnt/vdb/random_audios_patch_16k//data_712_049...,/mnt/vdb/random_audios_patch_16k//data_984_049...,/home/user/ss-vq-vae/experiments/outputs/model...,10176.09375,9567.772461,9389.15918
4,model-original-no-style-pretraining-19-11-2023,val2,/mnt/vdb/random_audios_patch_16k//data_558_068...,/mnt/vdb/random_audios_patch_16k//data_582_044...,/mnt/vdb/random_audios_patch_16k//data_558_044...,/home/user/ss-vq-vae/experiments/outputs/model...,13309.070312,13302.0,13297.214844


In [11]:
results.to_csv('results-08-09-2024.csv')

In [12]:
results_copy = results.copy()

In [13]:
# Function to handle the log-mean-log transformation
def log_mean_exp(df, groupby_cols, transform_cols):
    # Copy the groupby columns to ensure they are available for grouping after transformation
    exp_df = df[groupby_cols + transform_cols].copy()
    # Exponentiate the log values
    exp_df[transform_cols] = np.exp(exp_df[transform_cols])
    # Compute the mean of the exponentiated values grouped by specified columns
    mean_exp_df = exp_df.groupby(by=groupby_cols).mean()
    # Log the mean values
    log_mean_df = np.log(mean_exp_df)
    return log_mean_df

# Specify the columns to group by and to transform
groupby_cols = ['model', 'dataset']
transform_cols = ['melody_similarity', 'style_similarity', 'ground_similarity']

# Apply the transformation
log_mean_results = log_mean_exp(results, groupby_cols, transform_cols)

# Optional: Display the result
log_mean_results

  result = func(self.values, **kwargs)


Unnamed: 0_level_0,Unnamed: 1_level_0,melody_similarity,style_similarity,ground_similarity
model,dataset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
model-leaky-relu-finetuned-style-pretraining-29-08-2024,test,,,
model-leaky-relu-finetuned-style-pretraining-29-08-2024,val2,,,
model-leaky-relu-frozen-style-pretraining-01-09-2024,test,,,
model-leaky-relu-frozen-style-pretraining-01-09-2024,val2,,,
model-leaky-relu-no-style-pretraining-30-08-2024,test,,,
model-leaky-relu-no-style-pretraining-30-08-2024,val2,,,
model-original-finetuned-style-pretraining-22-11-2023,test,,,
model-original-finetuned-style-pretraining-22-11-2023,val2,,,
model-original-frozen-style-pretraining-21-11-2023,test,,,
model-original-frozen-style-pretraining-21-11-2023,val2,,,


In [14]:
# TODO: wylicz, czy te różnice są statystycznie istotne
results[['model', 'dataset', 'melody_similarity', 'style_similarity', 'ground_similarity']].groupby(by=['model', 'dataset']).mean()

Unnamed: 0_level_0,Unnamed: 1_level_0,melody_similarity,style_similarity,ground_similarity
model,dataset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
model-leaky-relu-finetuned-style-pretraining-29-08-2024,test,11898.55785,11864.868179,11921.362109
model-leaky-relu-finetuned-style-pretraining-29-08-2024,val2,11887.297824,11774.535468,11806.595713
model-leaky-relu-frozen-style-pretraining-01-09-2024,test,11615.551227,11601.777771,11635.896481
model-leaky-relu-frozen-style-pretraining-01-09-2024,val2,11805.541332,11715.731475,11738.418292
model-leaky-relu-no-style-pretraining-30-08-2024,test,11919.650807,11895.598026,11950.788841
model-leaky-relu-no-style-pretraining-30-08-2024,val2,11892.82207,11780.600119,11812.517602
model-original-finetuned-style-pretraining-22-11-2023,test,11693.286532,11666.060908,11829.203245
model-original-finetuned-style-pretraining-22-11-2023,val2,11701.001772,11594.37667,11669.906287
model-original-frozen-style-pretraining-21-11-2023,test,11496.877452,11486.084107,11563.76569
model-original-frozen-style-pretraining-21-11-2023,val2,11365.404685,11272.404566,11326.638875


In [15]:
import pandas as pd

results = pd.read_csv('results-08-09-2024.csv')

In [16]:
print(results[['model', 'dataset', 'melody_similarity', 'style_similarity', 'ground_similarity']].groupby(by=['model', 'dataset']).mean().to_latex())

\begin{tabular}{llrrr}
\toprule
 &  & melody_similarity & style_similarity & ground_similarity \\
model & dataset &  &  &  \\
\midrule
\multirow[t]{2}{*}{model-leaky-relu-finetuned-style-pretraining-29-08-2024} & test & 11898.557850 & 11864.868179 & 11921.362109 \\
 & val2 & 11887.297824 & 11774.535468 & 11806.595713 \\
\cline{1-5}
\multirow[t]{2}{*}{model-leaky-relu-frozen-style-pretraining-01-09-2024} & test & 11615.551227 & 11601.777771 & 11635.896481 \\
 & val2 & 11805.541332 & 11715.731475 & 11738.418292 \\
\cline{1-5}
\multirow[t]{2}{*}{model-leaky-relu-no-style-pretraining-30-08-2024} & test & 11919.650807 & 11895.598026 & 11950.788841 \\
 & val2 & 11892.822070 & 11780.600119 & 11812.517602 \\
\cline{1-5}
\multirow[t]{2}{*}{model-original-finetuned-style-pretraining-22-11-2023} & test & 11693.286532 & 11666.060908 & 11829.203245 \\
 & val2 & 11701.001772 & 11594.376670 & 11669.906287 \\
\cline{1-5}
\multirow[t]{2}{*}{model-original-frozen-style-pretraining-21-11-2023} & test & 1