# COLA encoder inference

In [1]:
import sys

sys.path.append('/home/user/COLA-PyTorch')

## Model

In [2]:
import os
import argparse
import torch
from torch.utils.data import DataLoader
from models.cola import COLA
from models.similarity import BilinearSimilarity
from data.audioset import Audioset, LocalAudioset


COLA_PATH = "/mnt/vdb/run-style-metric-model-03-07-2024/cola_30_305.pth"
SIMILARITY_PATH = "/mnt/vdb/run-style-metric-model-03-07-2024/similarity_30_305.pth"


def load_model(cola_path, similarity_path, hidden_size, output_size):
    cola = COLA(hidden_size, output_size)
    similarity = BilinearSimilarity(output_size)
    
    cola.load_state_dict(torch.load(cola_path))
    similarity.load_state_dict(torch.load(similarity_path))
    
    cola.cuda()
    similarity.cuda()
    return cola, similarity

cola, similarity = load_model(COLA_PATH, SIMILARITY_PATH, hidden_size=1280, output_size=1024)



## Data

In [65]:
import os
from torch.utils.data import Dataset
import numpy as np
import pandas as pd
import librosa
from data.utils import collate_audio_data
from data.audioset import process_audio


class TestSet(Dataset):
    def __init__(self, input_triples, generated_files):
        super(TestSet, self).__init__()
        if len(input_triples) != len(generated_files):
            raise ValueError(f"Input pairs and generated files lengths do not match: ({len(input_triples), len(generated_files)}")
        self.melody_paths = list(input_triples[0])
        self.style_paths = list(input_triples[1])
        self.ground_paths = list(input_triples[2])
        self.generated_paths = list(generated_files[0])
            
    def __len__(self):
        return len(self.melody_paths)
    
    def __getitem__(self, ix):
        melody_path = self.melody_paths[ix]
        style_path = self.style_paths[ix]
        ground_path = self.ground_paths[ix]
        generated_path = self.generated_paths[ix]
        
        melody_stft = process_audio(melody_path)
        style_stft = process_audio(style_path)
        ground_stft = process_audio(ground_path)
        generated_stft = process_audio(generated_path)
        
        return (melody_path, style_path, ground_path, generated_path), (melody_stft, style_stft, ground_stft, generated_stft)
    
def extract_segment(audio_stft_batch, segment_len=96, start_frame=0):
    assert len(audio_stft_batch) == 1, "This function should only be used for a single audio batch"
    return audio_stft_batch[:, :, start_frame:start_frame+segment_len]

## Calculate similarities

In [66]:
def evaluate_generative_model(segment_M, segment_S, segment_Ground, segment_Gen, cola, similarity):
    y_anchors = cola(segment_M.cuda())
    y_positives = cola(segment_Gen.cuda())
    melody_similarity = similarity(y_anchors, y_positives).item()
    
    
    y_anchors = cola(segment_S.cuda())
    y_positives = cola(segment_Gen.cuda())
    style_similarity = similarity(y_anchors, y_positives).item()
    
    y_anchors = cola(segment_Ground.cuda())
    y_positives = cola(segment_Gen.cuda())
    ground_similarity = similarity(y_anchors, y_positives).item()
    
    return melody_similarity, style_similarity, ground_similarity

## Val2 Data

In [67]:
import os
import pandas as pd

GENERATED_PREFIX = '/home/user/ss-vq-vae/experiments/outputs/model-original-no-style-pretraining-19-11-2023/val2'

input_triples = pd.read_csv('/mnt/vdb/validation_set_2.csv', header=None, sep='\t')
generated_list = pd.read_csv(f'{GENERATED_PREFIX}/vqvae_list', header=None).applymap(lambda path: f"{GENERATED_PREFIX}/{path}")


assert os.path.exists(input_triples.iloc[0][0]), "The input pairs file contains non-valid paths"
assert os.path.exists(generated_list.iloc[0][0]), "The generated file contains non-valid paths"

In [68]:
from torch.utils.data import DataLoader

val2_dataset = TestSet(input_triples, generated_list)
val2_loader = DataLoader(val2_dataset, shuffle=False, batch_size=1)

In [80]:
import os
import itertools
from tqdm import tqdm
import numpy as np
from typing import List


DATASETS = {
    'val2': '/mnt/vdb/validation_set_2.csv',
    'test': '/mnt/vdb/test_set.csv',
}

MODELS = {
    # TODO
    'model-original-no-style-pretraining-19-11-2023': '/home/user/ss-vq-vae/experiments/outputs/model-original-no-style-pretraining-19-11-2023/',
    'model-original-frozen-style-pretraining-21-11-2023': '/home/user/ss-vq-vae/experiments/outputs/model-original-frozen-style-pretraining-21-11-2023/'
    
}

MODELS = {
    model_name: os.path.join("/home/user/ss-vq-vae/experiments/outputs/", model_name)
    for model_name in [
        'model-original-no-style-pretraining-19-11-2023',
        'model-original-frozen-style-pretraining-21-11-2023',
        'model-original-finetuned-style-pretraining-22-11-2023',
        'model-leaky-relu-no-style-pretraining-13-11-2023',
        'model-leaky-relu-frozen-style-pretraining-15-11-2023',
        'model-leaky-relu-finetuned-style-pretraining-15-11-2023'
    ]
}


def empty_results_df():
    return pd.DataFrame(columns=[
        "model", "dataset", 
        "melody_path", "style_path", "ground_path", "generated_path", 
        "melody_similarity", "style_similarity", "ground_similarity"]
    )


def run_on_dataset_model(dataset_name, model, cola, similarity):
    input_triples = pd.read_csv(DATASETS[dataset_name], header=None, sep='\t')
    
    generated_prefix = MODELS[model]
    if dataset_name == 'val2':
        generated_prefix = os.path.join(generated_prefix, "val2/")
    generated_list = pd.read_csv(f'{generated_prefix}/vqvae_list', header=None).applymap(lambda path: f"{generated_prefix}/{path}")
    
    # Create dataset
    dataset = TestSet(input_triples, generated_list)
    dataset_loader = DataLoader(dataset, shuffle=False, batch_size=1)
        
    results = empty_results_df()
    pbar = tqdm(dataset_loader)
    pbar.set_description(f"Processing dataset: {dataset_name}, model: {model}")
    for paths, stfts in pbar:
        m_path, s_path, ground_path, gen_path = paths
        m_stft, s_stft, ground_stft, gen_stft = stfts
        m_segment = extract_segment(m_stft, segment_len=128)
        s_segment = extract_segment(s_stft, segment_len=128)
        ground_segment = extract_segment(ground_stft, segment_len=128)
        gen_segment = extract_segment(gen_stft, segment_len=128)

        with torch.no_grad():
            melody_similarity, style_similarity, ground_similarity = evaluate_generative_model(m_segment, s_segment, ground_segment, gen_segment, cola, similarity)

        row = {
            "model": model,
            "dataset": dataset_name,
            "melody_path": m_path[0],
            "style_path": s_path[0],
            "ground_path": ground_path[0],
            "generated_path": gen_path[0],
            "melody_similarity": melody_similarity,
            "style_similarity": style_similarity,
            "ground_similarity": ground_similarity,
        }
        results = pd.concat([results, pd.DataFrame([row])], ignore_index=True)
        
    return results

def run_style_evaluations(datasets: List[str], models: List[str], cola: COLA, similarity: BilinearSimilarity) -> pd.DataFrame:
    results = empty_results_df()
    
    for dataset, model in itertools.product(datasets, models):
        try:
            results_dataset_model = run_on_dataset_model(dataset, model, cola, similarity)
            results = pd.concat([results, results_dataset_model], ignore_index=True)
        except Exception as e:
            print(f"Exception raised in while processing dataset {dataset} and model {model}: {e}")
            print("Moving onto the next one...")
        
    return results


In [81]:
results = run_style_evaluations(DATASETS.keys(), MODELS.keys(), cola, similarity)

Processing dataset: val2, model: model-original-no-style-pretraining-19-11-2023: 100%|██████████| 845/845 [01:13<00:00, 11.45it/s]
Processing dataset: val2, model: model-original-frozen-style-pretraining-21-11-2023: 100%|██████████| 845/845 [01:13<00:00, 11.47it/s]
Processing dataset: val2, model: model-original-finetuned-style-pretraining-22-11-2023: 100%|██████████| 845/845 [01:21<00:00, 10.41it/s]
Processing dataset: val2, model: model-leaky-relu-no-style-pretraining-13-11-2023: 100%|██████████| 845/845 [01:26<00:00,  9.83it/s]
Processing dataset: val2, model: model-leaky-relu-frozen-style-pretraining-15-11-2023: 100%|██████████| 845/845 [01:13<00:00, 11.48it/s]
Processing dataset: val2, model: model-leaky-relu-finetuned-style-pretraining-15-11-2023: 100%|██████████| 845/845 [01:19<00:00, 10.66it/s]
Processing dataset: test, model: model-original-no-style-pretraining-19-11-2023: 100%|██████████| 1661/1661 [02:28<00:00, 11.21it/s]
Processing dataset: test, model: model-original-froze

In [82]:
results.head()

Unnamed: 0,model,dataset,melody_path,style_path,ground_path,generated_path,melody_similarity,style_similarity,ground_similarity
0,model-original-no-style-pretraining-19-11-2023,val2,/mnt/vdb/random_audios_patch_16k//data_719_084...,/mnt/vdb/random_audios_patch_16k//data_850_042...,/mnt/vdb/random_audios_patch_16k//data_719_042...,/home/user/ss-vq-vae/experiments/outputs/model...,-520.356628,-591.278198,-328.321503
1,model-original-no-style-pretraining-19-11-2023,val2,/mnt/vdb/random_audios_patch_16k//data_952_023...,/mnt/vdb/random_audios_patch_16k//data_212_045...,/mnt/vdb/random_audios_patch_16k//data_952_045...,/home/user/ss-vq-vae/experiments/outputs/model...,-370.708008,-480.811462,-303.603668
2,model-original-no-style-pretraining-19-11-2023,val2,/mnt/vdb/random_audios_patch_16k//data_378_040...,/mnt/vdb/random_audios_patch_16k//data_361_042...,/mnt/vdb/random_audios_patch_16k//data_378_042...,/home/user/ss-vq-vae/experiments/outputs/model...,-339.902649,-352.415527,-387.449768
3,model-original-no-style-pretraining-19-11-2023,val2,/mnt/vdb/random_audios_patch_16k//data_984_015...,/mnt/vdb/random_audios_patch_16k//data_712_049...,/mnt/vdb/random_audios_patch_16k//data_984_049...,/home/user/ss-vq-vae/experiments/outputs/model...,-442.523499,-235.312454,-224.609222
4,model-original-no-style-pretraining-19-11-2023,val2,/mnt/vdb/random_audios_patch_16k//data_558_068...,/mnt/vdb/random_audios_patch_16k//data_582_044...,/mnt/vdb/random_audios_patch_16k//data_558_044...,/home/user/ss-vq-vae/experiments/outputs/model...,-397.747162,-348.875916,-267.944885


In [83]:
results.to_csv('results-07-07-2024.csv')

In [85]:
results_copy = results.copy()

In [90]:
# Function to handle the log-mean-log transformation
def log_mean_exp(df, groupby_cols, transform_cols):
    # Copy the groupby columns to ensure they are available for grouping after transformation
    exp_df = df[groupby_cols + transform_cols].copy()
    # Exponentiate the log values
    exp_df[transform_cols] = np.exp(exp_df[transform_cols])
    # Compute the mean of the exponentiated values grouped by specified columns
    mean_exp_df = exp_df.groupby(by=groupby_cols).mean()
    # Log the mean values
    log_mean_df = np.log(mean_exp_df)
    return log_mean_df

# Specify the columns to group by and to transform
groupby_cols = ['model', 'dataset']
transform_cols = ['melody_similarity', 'style_similarity', 'ground_similarity']

# Apply the transformation
log_mean_results = log_mean_exp(results, groupby_cols, transform_cols)

# Optional: Display the result
log_mean_results

Unnamed: 0_level_0,Unnamed: 1_level_0,melody_similarity,style_similarity,ground_similarity
model,dataset,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
model-leaky-relu-finetuned-style-pretraining-15-11-2023,test,-72.438705,-50.494162,-42.949407
model-leaky-relu-finetuned-style-pretraining-15-11-2023,val2,-74.72542,-69.786165,-80.643193
model-leaky-relu-frozen-style-pretraining-15-11-2023,test,-30.979037,-51.386755,-50.912211
model-leaky-relu-frozen-style-pretraining-15-11-2023,val2,-47.402041,-61.44234,-63.051549
model-leaky-relu-no-style-pretraining-13-11-2023,test,-31.128218,-64.07383,-54.77398
model-leaky-relu-no-style-pretraining-13-11-2023,val2,-77.771801,-67.675484,-40.595294
model-original-finetuned-style-pretraining-22-11-2023,test,-63.622674,-48.106265,-54.325623
model-original-finetuned-style-pretraining-22-11-2023,val2,-80.095247,-45.504458,-79.357689
model-original-frozen-style-pretraining-21-11-2023,test,-47.775841,-53.066908,-59.478079
model-original-frozen-style-pretraining-21-11-2023,val2,-70.322937,-42.699336,-77.535152
