Correlation plots
=================

In [1]:
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import numpy as np
import deem
import pickle
import pandas as pd
import scipy
from effect_params import effect_params_dict, effect_params_str_dict
import os
plt.rcParams['figure.constrained_layout.use'] = True

**In the cell below, select the instrument by modifying the "inst" variable.**

Instrument list:
- clarinet
- organ
- cello
- violin
- guitar_acc
- voice
- guitar_ele
- saxophone
- trumpet
- piano
- flute

In [2]:
embedding_names = ["openl3", "panns", "clap"]
embedding_human_names = {
    "openl3": "OpenL3",
    "panns": "PANNs",
    "clap": "CLAP"
}

effects = ["bitcrush", "gain", "lowpass_cheby", "reverb"]
effect_human_names = {
    "bitcrush": "Bitcrushing",
    "gain": "Gain",
    "lowpass_cheby": "Low-pass filtering",
    "reverb": "Reverberation"
}

desensitization_method = 'cca'

inst = "cello"

In [3]:
meta_all = pd.read_csv("train_test_split.csv")

In [4]:
def load_results():
    if not os.path.isfile("results/correlations.csv"):
        return dict()
    results_csv = pd.read_csv("results/correlations.csv")
    ret_dict = dict()
    for row in results_csv.itertuples():
        ret_dict[row.embedding, row.effect, row.instrument, row.desensitization_method] = (row.spearman_corr2, row.comment)
    return ret_dict

def save_results(results_dict: dict):
    rows = []
    for (embedding, effect, instrument, desensitization_method), (spearman_corr2, comment) in results_dict.items():
        rows.append({"embedding": embedding, "effect": effect, "instrument": instrument, "desensitization_method": desensitization_method, "spearman_corr2": spearman_corr2, "comment": comment})
    df = pd.DataFrame(rows, columns=["embedding", "effect", "instrument", "desensitization_method", "spearman_corr2", "comment"])
    df.to_csv("results/correlations.csv", index=False)

results_dict = load_results()

In [None]:
fig, axs = plt.subplots(len(effects), len(embedding_names), figsize=(7, 9), dpi=300, layout="constrained", squeeze=False)

for ieff, effect in enumerate(tqdm(effects)):
    for iemb, embedding_name in enumerate(tqdm(embedding_names)):
        with open(f"models/{effect}/deformdir_{embedding_name}.pkl", "rb") as f:
            deformdirs = pickle.load(f)
        deformdir = deformdirs["-cca"][inst]

        effect_params = effect_params_str_dict[effect]
        effect_params_float = effect_params_dict[effect]
        
        X_all = []
        Y_all = []
        color_all = []

        for iparam, (param_str, param_float) in enumerate(zip(tqdm(effect_params), effect_params_float)):
            embeddings_fn = f"embeddings/embeddings_{effect}_{param_str}.h5"
            (X_train, Y_train), _, _ = deem.load_feature(embeddings_fn, embedding_name, meta_all)
            X_train_inst = X_train[Y_train==inst]
            X_all.extend(np.sum(X_train_inst * deformdir, axis=-1))
            Y_all.extend([iparam/len(effect_params)] * len(X_train_inst))
            color_all.extend([param_float] * len(X_train_inst))

        scat = axs[ieff, iemb].scatter(X_all, Y_all, c=color_all, marker='.', cmap='twilight_shifted', s=3)

        if iemb == 0:
            axs[ieff, iemb].set_ylabel(effect_human_names[effect] + "\nParam. rank")
        if iemb == len(embedding_names)-1:
            cbar = fig.colorbar(scat, ax=axs[ieff, iemb])
            # Set the ticks and labels of the colorbar
            if effect == "lowpass_cheby":
                cbar.set_ticks(np.arange(2000, 20000, 4000), labels=[str(k) for k in np.arange(2, 20, 4)])
                cbar.set_label("\"Cutoff\" frequency (kHz)")
            elif effect == "gain":
                cbar.set_ticks([-40, -30, -20, -10, 0], labels=["-40.0", "-30.0", "-20.0", "-10.0", "0.0"])
                cbar.set_label("Gain (dB)")
            elif effect == "bitcrush":
                cbar.set_ticks([5, 8, 11, 14], labels=["5", "8", "11", "14"])
                cbar.set_label("Bit depth")
            elif effect == "reverb":
                cbar.set_ticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0], labels=["0.0", "0.2", "0.4", "0.6", "0.8", "1.0"])
                cbar.set_label("Room size")
            else:
                raise AssertionError("Unknown effect " + effect + "??")
        if ieff == len(effects)-1:
            axs[ieff, iemb].set_xlabel("$\\langle u, \\Xi \\rangle$\n" + embedding_human_names[embedding_name])
        
        r2 = scipy.stats.spearmanr(X_all, color_all).statistic**2
        axs[ieff, iemb].set_title(f"$R^2 \\approx {r2:.4f}$")

        axs[ieff, iemb].ticklabel_format(style="sci", axis="x", scilimits=(-3, 3))

        #print(f"Squared Spearman correlation coefficient: {r2}")
        
        key = (embedding_name, effect, inst, desensitization_method)
        if key not in results_dict:
            comment = ""
        else:
            comment = results_dict[key][1]
        results_dict[key] = (r2, comment)
        save_results(results_dict)

plt.show()
fig.savefig(f"plots/corr/corr_{inst}.png")