# Study of SSPS

In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
os.chdir('../..')
sys.path.insert(1, os.path.join(sys.path[0], '../..'))

In [None]:
from notebooks.ssps.utils import (
    evaluate_sv,
    plot_inter_speaker_center_similarity,
    plot_inter_class_similarity,
    plot_intra_class_similarity,
    plot_intra_class_similarity_by_class
)

from notebooks.evaluation.sv_visualization import (
    det_curve,
    scores_distribution,
    tsne_2D,
)

In [None]:
from dataclasses import dataclass
from typing import List, Dict

import torch


@dataclass
class Model:

    scores: List[float] = None
    targets: List[int] = None
    embeddings: Dict[str, torch.Tensor] = None


def get_models_for_visualization(scores, names=None):
    if names is None:
        names = list(scores.keys())

    models = {
        k:Model(v['scores'], v['targets'])
        for k, v
        in scores.items()
        if k in names
    }

    return models

## Sampling hyper-params

### Stochastic sampling

In [None]:
import torch
import numpy as np
import pandas as pd
from plotnine import (
    ggplot,
    aes,
    geom_line,
    geom_point,
    labs,
    theme_bw,
    scale_x_continuous,
    scale_y_continuous,
    scale_color_discrete,
    geom_text,
    theme,
    element_text,
    element_blank,
    element_rect,
    guides,
    guide_legend
)


# Parameters
N = 10
decays = [0.2, 0.5, 1.0, 2.0]
# decays = [0.1, 0.4, 0.6, 0.8, 1.0, 1.2, 1.7, 2.5, 4.0]

# Generate data for each decay
data = []
for decay in decays:
    if decay == 0.0:
        method = 'uniform'
        probs = torch.ones(N) / N
    else:
        method = f'Î»={decay}'
        probs = decay * torch.exp(-decay * torch.arange(N).float())
    probs = (probs / probs.sum()).numpy()

    # Add the data to the list
    for idx, prob in enumerate(probs):
        data.append({'Index': idx, 'Probability': prob, 'Method': method})

# Create the plot
df = pd.DataFrame(data)
p = (
    ggplot(df, aes(x='Index', y='Probability', color='Method'))
    + geom_line(aes(group='Method'), size=0.75)
    + geom_point(size=1)
    + labs(title=f'Sampling probability distribution (N={N})')
    + scale_x_continuous(breaks=range(N))
    + scale_y_continuous(breaks=np.arange(0, 1.1, 0.1))
    + scale_color_discrete(
        limits=df.Method.unique()
    )
    + guides(color=guide_legend(nrow=1, byrow=True))
    + theme_bw()
    + theme(
        figure_size=(12, 7),
        text=element_text(size=10),
        axis_title_x=element_blank(),
        axis_title_y=element_blank(),
        legend_title=element_blank(),
        legend_position=(0.5, 0.83),
        legend_direction='horizontal',
    )
)

p

### SSPS-NN

In [None]:
import json
from glob import glob
import re

from plotnine import ggplot, aes, geom_line, geom_vline, geom_point, theme, theme_bw, labs, scale_x_continuous, scale_color_discrete, element_text
import patchworklib as pw
import pandas as pd


exps = [
    "models/ssps/voxceleb2/simclr/ssps_knn_uni-1",
    "models/ssps/voxceleb2/simclr/ssps_knn_uni-10",
    "models/ssps/voxceleb2/simclr/ssps_knn_uni-50",
    "models/ssps/voxceleb2/simclr/ssps_knn_uni-100",
    "models/ssps/voxceleb2/simclr/ssps_knn_uni-200",
]


res = []
for exp in exps:
    with open(exp + "/training.json", "r") as f:
        train = json.load(f)

    with open(exp + "/evaluation.json", "r") as f:
        eval = json.load(f)

    sampling = re.search(r'uni-([\w\d.]+)', exp.split('/')[-1])
    if sampling:
        sampling = int(sampling.group(1))
    else:
        sampling = 0

    res.append({
        'sampling': sampling,
        **train["109"],
        **eval
    })

data = pd.DataFrame(res)

def create_plot(y, label):
    p = (
        ggplot(data, aes(x='sampling', y=y))
        + geom_line()
        + geom_point()
        # + geom_vline(xintercept=1, linetype='dashed', color='black')
        + labs(title=label, x='Sampling', y=None)
        + scale_x_continuous(
            breaks=data['sampling'],
            labels=data['sampling']
        )
        # + theme_bw()
        + theme(
            figure_size=(6, 5),
            text=element_text(size=14),
            plot_title=element_text(
                ha="left",
                # x=0.535,
                margin={'b': 90 if y == 'ssps_speaker_acc' else 90}
            ),
            # axis_text_x=element_text(angle=45, ha="right")
        )
    )
    p = pw.load_ggplot(p)
    return p


g_spkacc = create_plot('ssps_speaker_acc', 'Pseudo-Positives Speaker Accuracy (%)')
g_vidacc = create_plot('ssps_video_acc', 'Pseudo-Positives Video Accuracy (%)')
g_eer = create_plot('test/sv_cosine/voxceleb1_test_O/eer', 'EER (%)')
g_mindcf = create_plot('test/sv_cosine/voxceleb1_test_O/mindcf', 'minDCF (p=0.01)')

p = (g_eer|g_spkacc|g_vidacc)

p.set_suptitle(
    "SSPS-NN: Metrics with different sampling hyper-parameters",
    fontsize=20,
    pad=40
)
p.savefig()

### SSPS-Clustering

In [None]:
import json
from glob import glob
import re

from plotnine import ggplot, aes, geom_line, geom_vline, geom_point, theme, theme_bw, labs, scale_x_continuous, scale_color_discrete, element_text
import patchworklib as pw
import pandas as pd


exps = [
    "models/ssps/voxceleb2/simclr/ssps_kmeans_6k",
    "models/ssps/voxceleb2/simclr/ssps_kmeans_6k_uni-1",

    "models/ssps/voxceleb2/simclr/ssps_kmeans_10k",
    "models/ssps/voxceleb2/simclr/ssps_kmeans_10k_uni-1",

    "models/ssps/voxceleb2/simclr/ssps_kmeans_25k",
    "models/ssps/voxceleb2/simclr/ssps_kmeans_25k_uni-1",
    "models/ssps/voxceleb2/simclr/ssps_kmeans_25k_uni-3",
    "models/ssps/voxceleb2/simclr/ssps_kmeans_25k_uni-5",

    # "models/ssps/voxceleb2/simclr/ssps_kmeans-centroid_25k",
    # "models/ssps/voxceleb2/simclr/ssps_kmeans-centroid_25k_uni-1",

    "models/ssps/voxceleb2/simclr/ssps_kmeans_50k",
    "models/ssps/voxceleb2/simclr/ssps_kmeans_50k_uni-1",
    "models/ssps/voxceleb2/simclr/ssps_kmeans_50k_uni-3",
    "models/ssps/voxceleb2/simclr/ssps_kmeans_50k_uni-5",
    "models/ssps/voxceleb2/simclr/ssps_kmeans_50k_uni-10",

    "models/ssps/voxceleb2/simclr/ssps_kmeans_75k",
    "models/ssps/voxceleb2/simclr/ssps_kmeans_75k_uni-1",

    "models/ssps/voxceleb2/simclr/ssps_kmeans_150k",
    "models/ssps/voxceleb2/simclr/ssps_kmeans_150k_uni-1",
]

res = []
for exp in exps:
    with open(exp + "/training.json", "r") as f:
        train = json.load(f)

    with open(exp + "/evaluation.json", "r") as f:
        eval = json.load(f)

    inter_sampling = re.search(r'uni-([\w\d.]+)', exp.split('/')[-1])
    if inter_sampling:
        inter_sampling = int(inter_sampling.group(1))
    else:
        inter_sampling = 0

    K = re.search(r'(\d+)k', exp.split('/')[-1]).group(1) + "k"
    
    res.append({
        'inter_sampling': inter_sampling,
        'K': K,
        **train["109"],
        **eval
    })

data = pd.DataFrame(res)

def create_plot(y, label):
    p = (
        ggplot(data, aes(x='inter_sampling', y=y, color='factor(K)'))
        + geom_line()
        + geom_point()
        # + geom_vline(xintercept=1, linetype='dashed', color='black')
        + labs(title=label, x='Inter-cluster sampling', y=None, color='K')
        + scale_x_continuous(
            breaks=data['inter_sampling'],
            labels=data['inter_sampling']
        )
        + scale_color_discrete(limits=data['K'].unique())
        # + theme_bw()
        + theme(
            figure_size=(6, 5),
            text=element_text(size=14),
            plot_title=element_text(
                ha="left",
                # x=0.535,
                margin={'b': 90 if y == 'ssps_speaker_acc' else 90}
            ),
            # axis_text_x=element_text(angle=45, ha="right")
        )
    )
    p = pw.load_ggplot(p)
    return p


g_spkacc = create_plot('ssps_speaker_acc', 'Pseudo-Positives Speaker Accuracy (%)')
g_vidacc = create_plot('ssps_video_acc', 'Pseudo-Positives Video Accuracy (%)')
g_eer = create_plot('test/sv_cosine/voxceleb1_test_O/eer', 'EER (%)')
g_mindcf = create_plot('test/sv_cosine/voxceleb1_test_O/mindcf', 'minDCF (p=0.01)')

p = (g_eer|g_spkacc|g_vidacc)

p.set_suptitle(
    "SSPS-Clustering: Metrics with different sampling hyper-parameters",
    fontsize=20,
    pad=40
)
p.savefig()

### NMI

In [None]:
import yaml
import subprocess


res = []
for K in [6, 10, 25, 50, 75, 150]:
    # Update config
    with open('models/ssps/voxceleb2/simclr/DEBUG/config.yml') as f:
        data = yaml.safe_load(f)
    data['method']['ssps']['kmeans_nb_prototypes'] = K * 1000
    with open('models/ssps/voxceleb2/simclr/DEBUG/config.yml', 'w') as f:
        yaml.dump(data, f)
    
    # Start training -> capture output
    train = subprocess.run(
        [
            "./train_ddp.sh",
            "2",
            'models/ssps/voxceleb2/simclr/DEBUG/config.yml',
        ],
        capture_output=True,
        text=True,
    )

    out = json.loads(train.stdout.strip().split("\n")[-1])

    res.append({
        'K': K,
        'NMI': out["nmi_video"],
        'Labels': 'Video'
    })

    res.append({
        'K': K,
        'NMI': out["nmi_speaker"],
        'Labels': 'Speaker'
    })

res

In [None]:
from plotnine import scale_x_log10


data = pd.DataFrame(res)

p = (
    ggplot(data, aes(x='K', y='NMI', color='Labels'))
    + geom_line()
    + geom_point()
    + labs(title="NMI for differents values of K", x='K', y='NMI')
    + scale_x_log10(
        breaks=data['K'].unique(),
        labels=[f"{k}k" for k in data['K'].unique()]
    )
    + theme_bw()
    + theme(
        figure_size=(12, 8),
        text=element_text(size=14),
        axis_text_x=element_text(angle=45, ha="right")
    )
)

p

## Results on SV

### Metrics

In [None]:
vox1o_scores = evaluate_sv([
    "models/ssps/voxceleb2/simclr/baseline/config.yml",
    "models/ssps/voxceleb2/simclr/ssps_kmeans_25k_uni-1/config.yml",
    "models/ssps/voxceleb2/simclr/baseline_sup/config.yml",
    "models/ssps/voxceleb2/simclr/baseline_sup_aam/config.yml",
], 'embeddings_vox1.pt', trials=[
    "voxceleb1_test_O",
])

vox1_scores = evaluate_sv([
    "models/ssps/voxceleb2/simclr/baseline/config.yml",
    "models/ssps/voxceleb2/simclr/ssps_kmeans_25k_uni-1/config.yml",
    "models/ssps/voxceleb2/simclr/baseline_sup/config.yml",
    "models/ssps/voxceleb2/simclr/baseline_sup_aam/config.yml",
], 'embeddings_vox1.pt', trials=[
    "voxceleb1_test_O",
    "voxceleb1_test_E",
    "voxceleb1_test_H",
])

### Scores distribution

In [None]:
scores_distribution(get_models_for_visualization(vox1o_scores, [
    "baseline",
    "ssps_kmeans_25k_uni-1",
]), use_angle=False)

In [None]:
scores_distribution(get_models_for_visualization(vox1o_scores, [
    "baseline_sup",
    "baseline_sup_aam",
]), use_angle=False)

### DET

In [None]:
det_curve(get_models_for_visualization(vox1o_scores))

## Convergence

In [None]:
import json
import pandas as pd
from plotnine import ggplot, aes, geom_line, geom_point, theme, labs, scale_x_continuous, element_text
import patchworklib as pw


with open('models/ssps/voxceleb2/simclr/ssps_kmeans_25k_uni-1/training.json', "r") as f:
    train = json.load(f)

res = []
for epoch, metrics in train.items():
    res.append({
        'Epoch': int(epoch),
        'Model': 'SSPS',
        **metrics
    })

with open('models/ssps/voxceleb2/simclr/baseline/training.json', "r") as f:
    train = json.load(f)

for epoch, metrics in train.items():
    if epoch == '110':
        break
    res.append({
        'Epoch': int(epoch),
        'Model': 'Baseline',
        **metrics
    })

data = pd.DataFrame(res)

def create_plot(y, label):
    p = (
        ggplot(data, aes(x='Epoch', y=y, color='factor(Model)'))
        + geom_line()
        + geom_point()
        + labs(title=label, x='Epoch', y=None, color='Model')
        + scale_x_continuous(
            breaks=data['Epoch'],
            labels=data['Epoch']
        )
        # + theme_bw()
        + theme(
            figure_size=(6, 5),
            text=element_text(size=14),
            plot_title=element_text(
                ha='left',
                margin={'b': 90}
            ),
            # axis_text_x=element_text(angle=45, ha="right")
        )
    )
    p = pw.load_ggplot(p)
    return p


g_loss = create_plot('train/loss', 'Train loss')
g_eer = create_plot('val/sv_cosine/voxceleb1_test_O/eer', 'EER (%)')
g_mindcf = create_plot('val/sv_cosine/voxceleb1_test_O/mindcf', 'minDCF (p=0.01)')

g_spkacc = create_plot('ssps_speaker_acc', 'Pseudo-Positives Speaker Accuracy (%)')
g_vidacc = create_plot('ssps_video_acc', 'Pseudo-Positives Video Accuracy (%)')
g_nmi = create_plot('ssps_kmeans_nmi', 'NMI on video labels')

p = (g_loss|g_eer|g_mindcf)/(g_spkacc|g_vidacc|g_nmi)

p.set_suptitle(
    "Convergence of SSPS",
    fontsize=20,
    pad=40
)
p.savefig()

## Intra-speaker similarity

In [None]:
plot_intra_class_similarity('speaker', {
    'SSL': 'models/ssps/voxceleb2/simclr/baseline/embeddings_vox1.pt',
    'SSPS': 'models/ssps/voxceleb2/simclr/ssps_kmeans_25k_uni-1/embeddings_vox1.pt',
    'SSL (supervised)': 'models/ssps/voxceleb2/simclr/baseline_sup/embeddings_vox1.pt',
    'AAM-Softmax': 'models/ssps/voxceleb2/simclr/baseline_sup_aam/embeddings_vox1.pt',
})

In [None]:
plot_intra_class_similarity('speaker', {
    'SSL': 'models/ssps/voxceleb2/simclr/baseline/embeddings_vox2.pt',
    'SSPS': 'models/ssps/voxceleb2/simclr/ssps_kmeans_25k_uni-1/embeddings_vox2.pt',
})

## Inter-speaker similarity

In [None]:
plot_inter_class_similarity('speaker', {
    'SSL': 'models/ssps/voxceleb2/simclr/baseline/embeddings_vox1.pt',
    'SSPS': 'models/ssps/voxceleb2/simclr/ssps_kmeans_25k_uni-1/embeddings_vox1.pt',
    'SSL (supervised)': 'models/ssps/voxceleb2/simclr/baseline_sup/embeddings_vox1.pt',
    'AAM-Softmax': 'models/ssps/voxceleb2/simclr/baseline_sup_aam/embeddings_vox1.pt',
}, nb_samples=1000)

In [None]:
plot_inter_speaker_center_similarity({
    'SSL': 'models/ssps/voxceleb2/simclr/baseline/embeddings_vox1.pt',
    'SSPS': 'models/ssps/voxceleb2/simclr/ssps_kmeans_25k_uni-1/embeddings_vox1.pt',
    'SSL (supervised)': 'models/ssps/voxceleb2/simclr/baseline_sup/embeddings_vox1.pt',
    'AAM-Softmax': 'models/ssps/voxceleb2/simclr/baseline_sup_aam/embeddings_vox1.pt',
})

In [None]:
plot_inter_speaker_center_similarity({
    'SSL': 'models/ssps/voxceleb2/simclr/baseline/embeddings_vox2.pt',
    'SSPS': 'models/ssps/voxceleb2/simclr/ssps_kmeans_25k_uni-1/embeddings_vox2.pt',
})

## Intra-video similarity

In [None]:
plot_intra_class_similarity('video', {
    'SSL': 'models/ssps/voxceleb2/simclr/baseline/embeddings_vox1.pt',
    'SSPS': 'models/ssps/voxceleb2/simclr/ssps_kmeans_25k_uni-1/embeddings_vox1.pt',
    'SSL (supervised)': 'models/ssps/voxceleb2/simclr/baseline_sup/embeddings_vox1.pt',
    'AAM-Softmax': 'models/ssps/voxceleb2/simclr/baseline_sup_aam/embeddings_vox1.pt',
})

## Inter-video similarity

In [None]:
plot_inter_class_similarity('video', {
    'SSL': 'models/ssps/voxceleb2/simclr/baseline/embeddings_vox1.pt',
    'SSPS': 'models/ssps/voxceleb2/simclr/ssps_kmeans_25k_uni-1/embeddings_vox1.pt',
    'SSL (supervised)': 'models/ssps/voxceleb2/simclr/baseline_sup/embeddings_vox1.pt',
    'AAM-Softmax': 'models/ssps/voxceleb2/simclr/baseline_sup_aam/embeddings_vox1.pt',
}, nb_samples=1000)

## t-SNE

In [None]:
baseline_embeddings_vox1 = torch.load("models/ssps/voxceleb2/simclr/baseline/embeddings_vox1.pt")
ssps_embeddings_vox1 = torch.load("models/ssps/voxceleb2/simclr/ssps_kmeans_25k_uni-1/embeddings_vox1.pt")

baseline_embeddings_vox2 = torch.load("models/ssps/voxceleb2/simclr/baseline/embeddings_vox2.pt")
ssps_embeddings_vox2 = torch.load("models/ssps/voxceleb2/simclr/ssps_kmeans_25k_uni-1/embeddings_vox2.pt")

In [None]:
from plotnine import labs, theme, element_text
import patchworklib as pw


def plot_tsne(baseline_embeddings, ssps_embeddings, speakers):    
    p1, tsne_init = tsne_2D(Model(
        embeddings=baseline_embeddings
    ), speakers=speakers)

    p2, _ = tsne_2D(Model(
        embeddings=ssps_embeddings
    ), speakers=speakers, init=tsne_init)


    p1 = pw.load_ggplot(
        p1
        + labs(title="Baseline")
        + theme(plot_title=element_text(
            ha='left',
            margin={'b': 90}
        ))
    )
    p2 = pw.load_ggplot(
        p2
        + labs(title="SSPS")
        + theme(plot_title=element_text(
            ha='left',
            margin={'b': 90}
        ))
    )
    p = (p1|p2)

    p.set_suptitle(
        "t-SNE of speaker representations",
        fontsize=18,
        pad=40
    )
    p.savefig()
    return p

### VoxCeleb1

In [None]:
plot_tsne(
    baseline_embeddings_vox1,
    ssps_embeddings_vox1,
    ['id10200', 'id10564', 'id11129', 'id10983', 'id10270', 'id11086', 'id10356', 'id10218', 'id10757', 'id10140']
)

In [None]:
plot_tsne(
    baseline_embeddings_vox1,
    ssps_embeddings_vox1,
    ['id10505', 'id10209', 'id10762', 'id10059', 'id10020', 'id10113', 'id10709', 'id10443', 'id11169', 'id10309']
)

### VoxCeleb2

In [None]:
plot_tsne(
    baseline_embeddings_vox2,
    ssps_embeddings_vox2,
    ['id00568', 'id00736', 'id00417', 'id00992', 'id00270', 'id00018', 'id00234', 'id00521', 'id00777', 'id00584']
)

### Find speakers for t-SNE

In [None]:
from plotnine import labs, theme, element_text
import patchworklib as pw


for i in range(50):
    speakers = [key.split("/")[-3] for key in baseline_embeddings_vox2.keys()]
    speakers = [s for s in list(set(speakers)) if speakers.count(s) >= 150]
    import random
    speakers = random.sample(speakers, 10)
    print(i, speakers)


    p1, tsne_init = tsne_2D(Model(
        embeddings=baseline_embeddings_vox2
    ), speakers=speakers)

    p2, _ = tsne_2D(Model(
        embeddings=ssps_embeddings_vox2
    ), speakers=speakers, init=tsne_init)


    p1 = pw.load_ggplot(
        p1
        + labs(title="Baseline")
        + theme(plot_title=element_text(
            ha='left',
            margin={'b': 90}
        ))
    )
    p2 = pw.load_ggplot(
        p2
        + labs(title="SSPS")
        + theme(plot_title=element_text(
            ha='left',
            margin={'b': 90}
        ))
    )
    p = (p1|p2)

    p.set_suptitle(
        "t-SNE of speaker representations",
        fontsize=18,
        pad=40
    )
    p.savefig()
    p.savefig(f"output{i}.png")

## Predict Vox1 metadata

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split


def fit_mlp_on_representations(embeddings, y_key_pos, test_size=0.2):
    keys = list(embeddings.keys())
    
    X = [embeddings[key][0].numpy() for key in keys]
    if y_key_pos is None:
        y = keys
    else:
        y = [key.split('/')[y_key_pos] for key in keys]

    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_size, random_state=0
    )
    
    clf = LogisticRegression()
    clf.fit(X_train, y_train)
    
    print(f'Train accuracy: {clf.score(X_train, y_train)}')
    print(f'Test accuracy: {clf.score(X_test, y_test)}')

In [None]:
baseline_embeddings = torch.load("models/ssps/voxceleb2/simclr/baseline/embeddings_vox1o_epoch-109.pt")
ssps_embeddings = torch.load("models/ssps/voxceleb2/simclr/ssps_kmeans_25k_uni-1/embeddings_vox1o_epoch-109.pt")

### Speaker

In [None]:
_ = fit_mlp_on_representations(baseline_embeddings, y_key_pos=-3)

In [None]:
_ = fit_mlp_on_representations(ssps_embeddings, y_key_pos=-3)

### Video

In [None]:
_ = fit_mlp_on_representations(baseline_embeddings, y_key_pos=-2)

In [None]:
_ = fit_mlp_on_representations(ssps_embeddings, y_key_pos=-2)

### Segment

In [None]:
_ = fit_mlp_on_representations(baseline_embeddings, y_key_pos=-1)

In [None]:
_ = fit_mlp_on_representations(ssps_embeddings, y_key_pos=-1)