SVD of sample-wise CCA directions for a fixed instrument
========================================================

## Imports and parameters

In [6]:
import h5py
import matplotlib.pyplot as plt
import matplotlib.axes as maxes
import numpy as np

This notebook displays several plots in one row. The cell below controls which combinations of audio embeddings, audio effects and instruments are displayed in the plots.

The number of plots can be different from 3, but in this case, make sure to update the figure size in the last cell as required.

In [7]:
embeddings = ["clap", "panns", "openl3"]
embedding_human_names = {
    "clap": "CLAP",
    "panns": "PANNs",
    "openl3": "OpenL3"
}
effects = ["lowpass_cheby", "reverb", "gain"]
effect_human_names = {
    "gain": "gain",
    "reverb": "reverb.",
    "lowpass_cheby": "low-pass"
}
instruments = ["cello"] * 3

# Performing SVD and computing the plots

In [9]:
def replot_perinst(ax: maxes.Axes, embedding: str, effect: str, instrument: str, show_legend: bool):

    with h5py.File(f"embeddings/averaged/{embedding}/embeddings.h5", "r") as f:
        X_train = f["X_train"][...]
        Y_train = np.array(f["Y_train"][...], dtype=str)

    ccadirs_h5 = h5py.File(f"embeddings/averaged/{embedding}/ccadirs_{effect}.h5", 'r')

    sample_indices = np.argwhere(Y_train == instrument)[:,0]
    X_train_inst = X_train[sample_indices]
    ccadirs_inst = ccadirs_h5["cca_dirs"][sample_indices]
    _, S, Vh = np.linalg.svd(ccadirs_inst, full_matrices=False)
    _, S_train, Vh_train = np.linalg.svd(X_train_inst-np.mean(X_train_inst, axis=0), full_matrices=False)

    ax.plot(S/S[0], label="CCA directions", color="tab:blue")
    ax.plot(S_train/S_train[0], label="Original data", color="tab:orange")
    if show_legend: ax.legend()
    ax.set_xlabel(f"Singular vector #\n{embedding_human_names[embedding]}, {effect_human_names[effect]}")

# Displaying the plots

In [None]:
fig, axs = plt.subplots(1, len(embeddings), figsize=(6.5, 1.625), dpi=300, sharey=True, layout="constrained")
for ax, embedding, effect, instrument, i in zip(axs, embeddings, effects, instruments, range(len(instruments))):
    replot_perinst(ax, embedding, effect, instrument, i == len(instruments)-1)

axs[0].set_ylabel("Normalized\nsing. val.")

plt.show()
fig.savefig("plots/svd_singvals_ccadirs/svd_singvals_ccadirs_paper.pdf", bbox_inches="tight")