UMAP visualization of embeddings
================================

In [1]:
import h5py
import matplotlib.pyplot as plt
import deem
import numpy as np
import pandas as pd
from umap import UMAP
from tqdm.notebook import tqdm
from IPython.display import display
import ipywidgets as widgets
from ipywidgets import interactive
from effect_params import effect_params_dict
plt.rcParams['figure.constrained_layout.use'] = True

Load embeddings

In [2]:
meta_all = pd.read_csv("train_test_split.csv")

(_, Y_train), _, _ = deem.load_feature("embeddings/embeddings.h5", "openl3", meta_all)


effects = ["lowpass_cheby", "gain", "reverb", "bitcrush"]
effect_human_names = {
    "lowpass_cheby": "Low-pass filtering",
    "gain": "Gain",
    "reverb": "Reverberation",
    "bitcrush": "Bitcrushing"
}


##### SELECT INSTRUMENT #####
instrument_list = list(deem.instrument_map)
inst = "flute"



embedding_list = ["openl3", "panns", "clap"]
embedding_human_names = {
    "openl3": "OpenL3",
    "panns": "PANNs",
    "clap": "CLAP"
}

In [3]:
output = widgets.Output()
indices = [0]
dropdowns: "list[interactive]" = []

In [4]:
def replot(*args):
    global indices
    for dropdown in dropdowns:
        dropdown.out.clear_output()
    with output:
        output.clear_output()

        fig, axs = plt.subplots(len(embedding_list), len(effects), figsize=(10, 7), dpi=300, layout="constrained")

        for iemb, emb in enumerate(tqdm(embedding_list)):
            image_per_eff = [None] * len(effects)
            for ieff, eff in enumerate(tqdm(effects)):
                orig_h5 = h5py.File(f"embeddings/grouped/{emb}/{eff}.h5", "r")
                X_train_A_orig = orig_h5["train_A"]
                X_train_A_inst_orig = X_train_A_orig[Y_train==inst]
                X_train_Bs_orig = orig_h5["train_Bs"]
                paths_B_orig = X_train_Bs_orig[:].swapaxes(0, 1)
                paths_B_inst_orig = paths_B_orig[Y_train==inst]
                orig_h5.close()

                reducer = UMAP(n_neighbors=3)
                paths = paths_B_inst_orig[indices]
                paths_flattened = paths.reshape((paths.shape[0] * paths.shape[1], paths.shape[2]))
                reducer.fit(paths_flattened)
                Y_orig = reducer.transform(X_train_A_inst_orig[indices])
                axs[iemb, ieff].scatter(Y_orig[:, 0], Y_orig[:, 1], marker='x', color='k', zorder=10)
                for n in indices:
                    Y = reducer.transform(paths_B_inst_orig[n])
                    image_per_eff[ieff] = axs[iemb, ieff].scatter(Y[:, 0], Y[:, 1], c=effect_params_dict[eff], marker='.', cmap='coolwarm')
                
                if iemb == 0:
                    axs[iemb, ieff].set_title(effect_human_names[eff])
                if ieff == 0:
                    axs[iemb, ieff].set_ylabel(embedding_human_names[emb])

        for ieff, eff in enumerate(effects):
            # Create a colorbar with custom ticks and labels
            cbar = fig.colorbar(image_per_eff[ieff], ax=axs[-1,ieff], location="bottom")

            # Set the ticks and labels of the colorbar
            if eff == "lowpass_cheby":
                cbar.set_ticks(np.arange(2000, 20000, 4000), labels=[str(k) for k in np.arange(2, 20, 4)])
                cbar.set_label("\"Cutoff\" frequency (kHz)")
            elif eff == "gain":
                cbar.set_ticks([-40, -30, -20, -10, 0], labels=["-40.0", "-30.0", "-20.0", "-10.0", "0.0"])
                cbar.set_label("Gain (dB)")
            elif eff == "bitcrush":
                cbar.set_ticks([5, 8, 11, 14], labels=["5", "8", "11", "14"])
                cbar.set_label("Bit depth")
            elif eff == "reverb":
                cbar.set_ticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0], labels=["0.0", "0.2", "0.4", "0.6", "0.8", "1.0"])
                cbar.set_label("Room size")
            else:
                raise AssertionError("Unknown effect " + eff + "??")

        plt.show()
        fig.savefig(f"plots/umap/umap_{inst}_{'_'.join(map(str, indices))}.pdf")


replot_button = widgets.Button(description="Replot")
replot_button.on_click(replot)

num_samples = 1

def set_num_samples(n):
    global num_samples
    num_samples = n

num_samples_slider = interactive(set_num_samples, n=widgets.IntSlider(min=1, max=20, step=1, value=num_samples, description="Number of samples", style={"description_width": "150px"}))

def shuffle_indices(*args):
    global indices
    indices = np.random.choice(np.sum(Y_train == inst), size=num_samples, replace=False)

shuffle_button = widgets.Button(description="Shuffle indices")
shuffle_button.on_click(shuffle_indices)

Usage:  
In the cell below:
1. Select the instrument you are interested in
2. Select the number of samples to plot using the slider
3. Click "Shuffle indices"
4. Click "Replot"
5. Wait for the plots to display; this takes some time

Repeat steps 3-6 to display the path of randomly selected other samples.

If you change instrument, rerun steps 1 and 3 before replotting.

In [5]:
def set_instrument(i):
    global inst
    inst = i

inst_dropdown = interactive(set_instrument, i=instrument_list)
inst_dropdown.children[0].description = "Instrument"
inst_dropdown.children[0].style = {"description_width": "150px"}
dropdowns.append(inst_dropdown)

display(inst_dropdown)
button_row = widgets.HBox((replot_button, shuffle_button))
display(button_row)
display(num_samples_slider)
display(output)

interactive(children=(Dropdown(description='Instrument', options=('clarinet', 'organ', 'cello', 'violin', 'gui…

HBox(children=(Button(description='Replot', style=ButtonStyle()), Button(description='Shuffle indices', style=…

interactive(children=(IntSlider(value=1, description='Number of samples', max=20, min=1, style=SliderStyle(des…

Output()