Computation of the sample-wise CCA correlation coefficients
===========================================================

In [1]:
import h5py
import matplotlib.pyplot as plt
from IPython.display import display
import ipywidgets as widgets
from ipywidgets import interactive
import numpy as np
import scipy
from tqdm import tqdm
from effect_params import effect_params_str_dict

In [2]:
effects = ["bitcrush", "gain", "lowpass_cheby", "reverb"]
effect = effects[0]

def update_params():
    global params
    params = effect_params_str_dict[effect]

embeddings = ["openl3", "panns", "clap"]
embedding = embeddings[0]

ccadirs_h5 = h5py.File(f"embeddings/averaged/{embedding}/ccadirs_{effect}.h5", "r")
num_samples = ccadirs_h5["cca_dirs"].shape[0]

isample = 0

In [3]:
output = widgets.Output()

def replot(plot=True, isample_override=None):
    X_all = []
    Y_all = []
    color_all = []

    if isample_override is None:
        actual_isample = isample
    else:
        actual_isample = isample_override
    ccadir = ccadirs_h5["cca_dirs"][actual_isample]

    embeddings_sample = []

    for iparam, param in enumerate(params):
        with h5py.File(f"embeddings/averaged/{embedding}/embeddings_{effect}_{param}.h5", "r") as embeddings_h5:
            embedding_sample = embeddings_h5["X_train"][actual_isample]
            X_all.append(np.sum(embedding_sample * ccadir))
            Y_all.append(iparam/len(params))
            color_all.append(float(param))
            embeddings_sample.append(embedding_sample)
    if plot:
        with output:
            output.clear_output(wait=True)

            plt.scatter(X_all, Y_all, c=color_all, marker='.', cmap='twilight_shifted', s=3)
            plt.colorbar()
            plt.title(f"Correlation of embeddings of {isample}-th sample with corresponding CCA direction ({effect})")
            plt.show()

    r2 = scipy.stats.spearmanr(X_all, color_all).statistic**2
    if plot:
        with output: print(r2)
    else:
        return r2

In [4]:
def set_effect(e):
    global effect, ccadirs_h5
    effect = e
    update_params()
    ccadirs_h5.close()
    ccadirs_h5 = h5py.File(f"embeddings/averaged/{embedding}/ccadirs_{effect}.h5", "r")
    replot()

effect_dropdown = interactive(set_effect, e=widgets.Dropdown(options=effects, description="Effect"))

def set_embedding(e):
    global embedding, ccadirs_h5
    embedding = e
    ccadirs_h5.close()
    ccadirs_h5 = h5py.File(f"embeddings/averaged/{embedding}/ccadirs_{effect}.h5", "r")
    replot()

embedding_dropdown = interactive(set_embedding, e=widgets.Dropdown(options=embeddings, description="Embedding"))

def set_sample(i):
    global isample
    isample = i
    replot()

sample_slider = interactive(set_sample, i=widgets.IntSlider(value=isample, min=0, max=num_samples-1, step=1, description="Sample #"))

In [None]:
display(effect_dropdown)
display(embedding_dropdown)
display(sample_slider)
display(output)

In [None]:
# Run this cell to compute all the R2's for the selected audio embedding and audio effect and store them in a text file
r2s = []
for isample_override in tqdm(range(num_samples)):
    r2s.append(replot(plot=False, isample_override=isample_override))
with open(f"results/{effect}/cca_r2s_{embedding}.txt", "w") as f:
    for r2 in r2s:
        print(r2, file=f)