# Plot t-SNE of speaker embeddings over SSL training

In [None]:
# Auto-reload imported modules from sslsv
%load_ext autoreload
%autoreload 2

# Load sslsv as a package from the parent folder
import os
import sys
os.chdir('../..')
sys.path.insert(1, os.path.join(sys.path[0], '../..'))

# Embed fonts when saving figures as PDF
import matplotlib
matplotlib.rc('pdf', fonttype=42)

In [None]:
from notebooks.notebooks_utils import load_models, evaluate_models

from sv_visualization import _filter_embeddings

from sslsv.evaluations.CosineSVEvaluation import CosineSVEvaluation, CosineSVEvaluationTaskConfig

In [None]:
import numpy as np
import scipy
import pandas as pd

from sklearn.manifold import TSNE

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import seaborn as sns


from glob import glob
from tqdm import tqdm
from pathlib import Path

from IPython.display import Image, display

In [None]:
def generate_tsne_embeddings(config, checkpoints, nb_speakers=10, nb_samples=150):
    embeddings = []
    labels = []

    init = np.random.randn(nb_speakers*nb_samples, 2)

    for checkpoint_path in tqdm(checkpoints, desc='Generating t-SNE of embeddings'):
        checkpoint_name = Path(checkpoint_path).name

        models = load_models(
            [config],
            checkpoint_name=checkpoint_name,
        )
        
        evaluate_models(
            models,
            CosineSVEvaluation,
            CosineSVEvaluationTaskConfig(__type__='sv_cosine'),
            verbose=False
        )

        model = list(models.values())[0]

        Z, y = _filter_embeddings(model.embeddings, nb_speakers=nb_speakers, nb_samples=nb_samples)
        Z_2d = TSNE(
            n_components=2,
            init=init,
            random_state=0
        ).fit_transform(Z)

        init = Z_2d

        embeddings.append(Z_2d)
        labels.append(y)

    embeddings = np.array(embeddings).transpose(1, 0, 2) # (S, N, D) -> (N, S, D)
    labels = np.array(labels)[0]
    
    return embeddings, labels

In [None]:
def generate_interpolated_embeddings(embeddings, nb_frames):
    F = nb_frames
    N, S, D = embeddings.shape

    res = np.zeros((F, N, D))

    for n in range(N):
        fx = scipy.interpolate.interp1d(np.arange(0, S), embeddings[n, :, 0])
        fy = scipy.interpolate.interp1d(np.arange(0, S), embeddings[n, :, 1])

        for f, t in enumerate(np.linspace(0, S - 1, F)):
            res[f, n, 0] = fx(t)
            res[f, n, 1] = fy(t)
            
    return res

In [None]:
def create_animation(interpolated_embeddings, labels, frame_interval, nb_frames, checkpoints):
    def update(i, ax):
        ax.cla()
        sns.scatterplot(
            x="x",
            y="y",
            hue="Speaker",
            palette=sns.color_palette("hls", len(np.unique(labels))),
            data=df_list[i],
            legend="full",
            alpha=0.6,
            ax=ax
        )
        
        ckpt_i = int(i // (nb_frames / len(checkpoints)))
        step = int(checkpoints[ckpt_i].split('__step_')[1].split('.')[0])
        ax.set_title(f'Training iteration: {step}', fontsize=12)
        
        ax.axis('off')
        
        legend = ax.legend(loc='lower left')
        legend.set_title("Speaker", prop={'size': 11})
        legend.get_frame().set_linewidth(0)
    
    df_list = []
    for Z in interpolated_embeddings:
        df_list.append(pd.DataFrame({
            "Speaker": labels,
            "x": Z[:, 0],
            "y": Z[:, 1],
        }))
    
    dpi = 100
    width = 1920 / dpi
    height = 1080 / dpi
    fig = plt.figure(figsize=(width, height))

    ax = fig.gca()
    anim = FuncAnimation(fig, update, frames=len(df_list), fargs=(ax,), interval=frame_interval)

    plt.tight_layout()
    padding = 0.01
    fig.subplots_adjust(left=padding, bottom=padding, right=1-padding, top=0.95, wspace=None, hspace=None)
    
    plt.style.use("dark_background")
    return anim

In [None]:
checkpoints = glob('models/tests/simclr/model__step*')
checkpoints = sorted(checkpoints, key=lambda f: int(f.split("__step_")[1].split(".")[0]))

config='./models/tests/simclr/config.yml'
output="output.gif"

checkpoints_ = []
checkpoints_ += checkpoints[  0:80:5]
checkpoints_ += checkpoints[80:80*6:50]
checkpoints_ += checkpoints[80*6::100]
checkpoints_ += [checkpoints[-1]]
checkpoints = checkpoints_

nb_speakers = 10
nb_samples = 150
nb_frames = 30 * (len(checkpoints) - 1)
frame_interval = 30

print("Number of checkpoints:", len(checkpoints))
print("Number of frames:", nb_frames)
print("GIF duration:", nb_frames*frame_interval/1000)
print("FPS:", 1000*1/frame_interval)

In [None]:
# checkpoints = [
#     'model__step_780.pt',
#     'model__step_790.pt',
# ]
# nb_frames = 2
# frame_interval = 1000

In [None]:
# embeddings, labels = generate_tsne_embeddings(
#     config,
#     checkpoints,
#     nb_speakers,
#     nb_samples
# )

# nb_frames = 1
# interpolated_embeddings = generate_interpolated_embeddings(embeddings, nb_frames)
# interpolated_embeddings = embeddings.transpose(1, 0, 2)

anim = create_animation(
    interpolated_embeddings,
    labels,
    frame_interval,
    nb_frames,
    checkpoints
)
anim.save(output, writer="pillow")

display(Image(output))