# Study of SSPS

In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
os.chdir('../..')
sys.path.insert(1, os.path.join(sys.path[0], '../..'))

In [None]:
from notebooks.ssps.utils import (
    evaluate_sv,
    plot_inter_speaker_center_similarity,
    plot_inter_class_similarity,
    plot_intra_class_similarity,
    plot_intra_class_similarity_by_class
)

from notebooks.evaluation.sv_visualization import (
    det_curve,
    scores_distribution,
    tsne_2D,
)

## Evaluation on SV (Vox1-O/E/H)

In [None]:
scores = evaluate_sv([
    "models/ssps/voxceleb2/simclr/baseline/config.yml",
    "models/ssps/voxceleb2/simclr/2-kmeans_exp-10-0.5/config.yml",
    "models/ssps/voxceleb2/simclr/baseline_sup/config.yml",
    "models/ssps/voxceleb2/simclr/baseline_sup_aam/config.yml",
], 'embeddings_vox1_epoch-100.pt', trials=[
    "voxceleb1_test_O",
    "voxceleb1_test_E",
    "voxceleb1_test_H",
])

## Inter-speaker similarity

### VoxCeleb1-O

In [None]:
plot_inter_class_similarity('speaker', {
    'SSL': 'models/ssps/voxceleb2/simclr/baseline/embeddings_vox1o_epoch-100.pt',
    'SSPS': 'models/ssps/voxceleb2/simclr/2-kmeans_exp-10-0.5/embeddings_vox1o_epoch-100.pt',
    # 'SSL (supervised)': 'models/ssps/voxceleb2/simclr/baseline_sup/embeddings_vox1o_epoch-100.pt',
    'AAM-Softmax': 'models/ssps/voxceleb2/simclr/baseline_sup_aam/embeddings_vox1o_epoch-100.pt',
}, nb_samples=1000)

In [None]:
plot_inter_speaker_center_similarity({
    'SSL': 'models/ssps/voxceleb2/simclr/baseline/embeddings_vox1o_epoch-100.pt',
    'SSPS': 'models/ssps/voxceleb2/simclr/2-kmeans_exp-10-0.5/embeddings_vox1o_epoch-100.pt',
    # 'SSL (supervised)': 'models/ssps/voxceleb2/simclr/baseline_sup/embeddings_vox1_epoch-100.pt',
    'AAM-Softmax': 'models/ssps/voxceleb2/simclr/baseline_sup_aam/embeddings_vox1o_epoch-100.pt',
})

### VoxCeleb1

In [None]:
plot_inter_class_similarity('speaker', {
    'SSL': 'models/ssps/voxceleb2/simclr/baseline/embeddings_vox1_epoch-100.pt',
    'SSPS': 'models/ssps/voxceleb2/simclr/2-kmeans_exp-10-0.5/embeddings_vox1_epoch-100.pt',
    # 'SSL (supervised)': 'models/ssps/voxceleb2/simclr/baseline_sup/embeddings_vox1_epoch-100.pt',
    'AAM-Softmax': 'models/ssps/voxceleb2/simclr/baseline_sup_aam/embeddings_vox1_epoch-100.pt',
}, nb_samples=1000)

In [None]:
plot_inter_speaker_center_similarity({
    'SSL': 'models/ssps/voxceleb2/simclr/baseline/embeddings_vox1_epoch-100.pt',
    'SSPS': 'models/ssps/voxceleb2/simclr/2-kmeans_exp-10-0.5/embeddings_vox1_epoch-100.pt',
    # 'SSL (supervised)': 'models/ssps/voxceleb2/simclr/baseline_sup/embeddings_vox1_epoch-100.pt',
    'AAM-Softmax': 'models/ssps/voxceleb2/simclr/baseline_sup_aam/embeddings_vox1_epoch-100.pt',
})

### VoxCeleb2

In [None]:
plot_inter_class_similarity('speaker', {
    'SSL': 'models/ssps/voxceleb2/simclr/baseline/embeddings_vox2_epoch-100.pt',
    'SSPS': 'models/ssps/voxceleb2/simclr/2-kmeans_exp-10-0.5/embeddings_vox2_epoch-100.pt',
    # 'SSL (supervised)': 'models/ssps/voxceleb2/simclr/baseline_sup/embeddings_vox2_epoch-100.pt',
    # 'AAM-Softmax': 'models/ssps/voxceleb2/simclr/baseline_sup_aam/embeddings_vox2_epoch-100.pt',
}, nb_samples=1000)

In [None]:
plot_inter_speaker_center_similarity({
    'SSL': 'models/ssps/voxceleb2/simclr/baseline/embeddings_vox2_epoch-100.pt',
    'SSPS': 'models/ssps/voxceleb2/simclr/2-kmeans_exp-10-0.5/embeddings_vox2_epoch-100.pt',
    # 'SSL (supervised)': 'models/ssps/voxceleb2/simclr/baseline_sup/embeddings_vox1_epoch-100.pt',
    # 'AAM-Softmax': 'models/ssps/voxceleb2/simclr/baseline_sup_aam/embeddings_vox1_epoch-100.pt',
})

## Intra-speaker similarity

In [None]:
plot_intra_class_similarity('speaker', {
    'SSL': 'models/ssps/voxceleb2/simclr/baseline/embeddings_vox1_epoch-100.pt',
    'SSPS': 'models/ssps/voxceleb2/simclr/2-kmeans_exp-10-0.5/embeddings_vox1_epoch-100.pt',
    'SSL (supervised)': 'models/ssps/voxceleb2/simclr/baseline_sup/embeddings_vox1_epoch-100.pt',
    'AAM-Softmax': 'models/ssps/voxceleb2/simclr/baseline_sup_aam/embeddings_vox1_epoch-100.pt',
})

In [None]:
plot_intra_class_similarity_by_class('speaker', {
    'SSL': 'models/ssps/voxceleb2/simclr/baseline/embeddings_vox1_epoch-100.pt',
    'SSPS': 'models/ssps/voxceleb2/simclr/2-kmeans_exp-10-0.5/embeddings_vox1_epoch-100.pt',
}, nb_classes=20)

In [None]:
#

## Inter-video similarity

In [None]:
plot_inter_class_similarity('video', {
    'SSL': 'models/ssps/voxceleb2/simclr/baseline/embeddings_vox1_epoch-100.pt',
    'SSPS': 'models/ssps/voxceleb2/simclr/2-kmeans_exp-10-0.5/embeddings_vox1_epoch-100.pt',
    # 'SSL (supervised)': 'models/ssps/voxceleb2/simclr/baseline_sup/embeddings_vox1_epoch-100.pt',
    'AAM-Softmax': 'models/ssps/voxceleb2/simclr/baseline_sup_aam/embeddings_vox1_epoch-100.pt',
}, nb_samples=1000)

## Intra-video similarity

In [None]:
plot_intra_class_similarity('video', {
    'SSL': 'models/ssps/voxceleb2/simclr/baseline/embeddings_vox1_epoch-100.pt',
    'SSPS': 'models/ssps/voxceleb2/simclr/2-kmeans_exp-10-0.5/embeddings_vox1_epoch-100.pt',
    'SSL (supervised)': 'models/ssps/voxceleb2/simclr/baseline_sup/embeddings_vox1_epoch-100.pt',
    'AAM-Softmax': 'models/ssps/voxceleb2/simclr/baseline_sup_aam/embeddings_vox1_epoch-100.pt',
})

## Scores distribution

In [None]:
from dataclasses import dataclass
from typing import List, Dict

import torch


@dataclass
class Model:

    scores: List[float] = None
    targets: List[int] = None
    embeddings: Dict[str, torch.Tensor] = None


def get_models_for_visualization(scores, names=None):
    if names is None:
        names = list(scores.keys())

    models = {
        k:Model(v['scores'], v['targets'])
        for k, v
        in scores.items()
        if k in names
    }

    return models

In [None]:
scores_distribution(get_models_for_visualization(scores, [
    "baseline",
    "2-kmeans_exp-10-0.5",
]), use_angle=False)

In [None]:
scores_distribution(get_models_for_visualization(scores, [
    "baseline_sup",
    "baseline_sup_aam",
]), use_angle=False)

## DET curves

In [None]:
det_curve(get_models_for_visualization(scores))

## t-SNE

In [None]:
tsne_init = tsne_2D(Model(
    embeddings=torch.load("models/ssps/voxceleb2/simclr/baseline/embeddings_vox1_epoch-100.pt")
))

_ = tsne_2D(Model(
    embeddings=torch.load("models/ssps/voxceleb2/simclr/2-kmeans_exp-10-0.5/embeddings_vox1_epoch-100.pt")
), init=tsne_init)

## Predict Vox1 metadata from representations

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split


def fit_mlp_on_representations(embeddings, y_key_pos, test_size=0.2):
    keys = list(embeddings.keys())
    
    X = [embeddings[key][0].numpy() for key in keys]
    if y_key_pos is None:
        y = keys
    else:
        y = [key.split('/')[y_key_pos] for key in keys]

    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_size, random_state=0
    )
    
    clf = LogisticRegression()
    clf.fit(X_train, y_train)
    
    print(f'Train accuracy: {clf.score(X_train, y_train)}')
    print(f'Test accuracy: {clf.score(X_test, y_test)}')

In [None]:
baseline_embeddings = torch.load("models/ssps/voxceleb2/simclr/baseline/embeddings_vox1o_epoch-100.pt")
ssps_embeddings = torch.load("models/ssps/voxceleb2/simclr/2-kmeans_exp-10-0.5/embeddings_vox1o_epoch-100.pt")

### Speaker

In [None]:
_ = fit_mlp_on_representations(baseline_embeddings, y_key_pos=-3)

In [None]:
_ = fit_mlp_on_representations(ssps_embeddings, y_key_pos=-3)

### Video

In [None]:
_ = fit_mlp_on_representations(baseline_embeddings, y_key_pos=-2)

In [None]:
_ = fit_mlp_on_representations(ssps_embeddings, y_key_pos=-2)

### Segment

In [None]:
_ = fit_mlp_on_representations(baseline_embeddings, y_key_pos=-1)

In [None]:
_ = fit_mlp_on_representations(ssps_embeddings, y_key_pos=-1)

## K-means assignments distribution

In [None]:
import torch
import pandas as pd
from plotnine import ggplot, aes, geom_bar, labs, theme_bw, theme, element_text, coord_cartesian


def plot_kmeans_cluster_distribution(checkpoint):
    assignments = torch.load(checkpoint).cpu()

    cluster_counts = torch.bincount(assignments[assignments > 1])

    data = pd.DataFrame({'elements_in_cluster': cluster_counts.numpy()})
    cluster_histogram = data['elements_in_cluster'].value_counts().reset_index()
    cluster_histogram.columns = ['x_elements', 'count']

    p = (
        ggplot(cluster_histogram, aes(x='x_elements', y='count'))
        + geom_bar(stat='identity')
        + labs(
            x='Number of samples',
            y='Count',
            title='K-means cluster distribution'
        )
        + theme_bw()
        + theme(figure_size=(12, 8), text=element_text(size=14))
    )
    print(p)

    cluster_histogram['type'] = 'K-means'

    stats = {
        'min': data['elements_in_cluster'].min(),
        'max': data['elements_in_cluster'].max(),
        'median': data['elements_in_cluster'].median(),
        'mean': data['elements_in_cluster'].mean(),
        'zero_count': (data['elements_in_cluster'] == 0).sum()
    }

    return stats, cluster_histogram

In [None]:
stats, kmeans_50k_hist = plot_kmeans_cluster_distribution("models/ssps/voxceleb2/simclr/2-kmeans_exp-10-0.5/assignments_epoch-100.pt")
stats

In [None]:
stats, kmeans_150k_hist = plot_kmeans_cluster_distribution("models/ssps/voxceleb2/simclr/2-kmeans-repr_150k/assignments_epoch-100.pt")

In [None]:
from glob import glob
from collections import defaultdict


def plot_vox_video_distribution(path):
    res = defaultdict(int)

    for file in glob(path):
        video = file.split('/')[-2]
        res[video] += 1

    data2 = pd.DataFrame({'elements_in_video': list(res.values())})
    cluster_histogram = data2['elements_in_video'].value_counts().reset_index()
    cluster_histogram.columns = ['x_elements', 'count']

    p = (
        ggplot(cluster_histogram, aes(x='x_elements', y='count'))
        + geom_bar(stat='identity')
        + labs(
            x='Number of samples',
            y='Count',
            title='VoxCeleb2 video distribution'
        )
        + theme_bw()
        + theme(figure_size=(12, 8), text=element_text(size=14))
    )
    print(p)

    cluster_histogram['type'] = 'VoxCeleb2'

    stats = {
        'min': data2['elements_in_video'].min(),
        'max': data2['elements_in_video'].max(),
        'median': data2['elements_in_video'].median(),
        'mean': data2['elements_in_video'].mean()
    }

    return stats, cluster_histogram

In [None]:
stats, vox2_hist = plot_vox_video_distribution('data/voxceleb2/*/*/*.wav')
stats

In [None]:
combined_data = pd.concat([kmeans_150k_hist, vox2_hist], ignore_index=True)

p = (
    ggplot(combined_data, aes(x='x_elements', y='count', fill='type'))
    + geom_bar(stat='identity', position='identity', alpha=0.7)
    + labs(
        x='Number of samples per cluster/video',
        y='Count',
        title='VoxCeleb2 videos and K-means clusters distribution'
    )
    + coord_cartesian(xlim=(0, 50))
    + theme_bw()
    + theme(figure_size=(12, 8), text=element_text(size=14))
)

p

## Dependency on data-aug

In [None]:
#  train SSL without aug
# "val/sv_cosine/voxceleb1_test_O/eer": 15.551537070524413,
# "val/sv_cosine/voxceleb1_test_O/mindcf": 0.7550674893955924

#  train SSPS without aug
# val/sv_cosine/voxceleb1_test_O/eer: 10.388644
# val/sv_cosine/voxceleb1_test_O/mindcf: 0.691761

## Thresholds: SSPS metrics vs SV metrics

In [None]:
import json
from glob import glob
import re

from plotnine import ggplot, aes, geom_line, geom_vline, geom_point, theme, theme_bw, labs, scale_x_continuous, element_text
import patchworklib as pw
import pandas as pd


exps = glob("models/ssps/voxceleb2/simclr/2-kmeans-repr_tau1-*_tau2-*")

res = []
for exp in exps:
    try:
        with open(exp + "/training.json", "r") as f:
            data = json.load(f)
    except:
        continue

    tau1 = float(re.search(r'tau1-([\d.]+)_tau2-([\d.]+)', exp.split('/')[-1]).group(1))
    tau2 = float(re.search(r'tau1-([\d.]+)_tau2-([\d.]+)', exp.split('/')[-1]).group(2))

    if tau2 in [0.95, 0.975]:
        continue

    cost = data["100"]["ssps_speaker_acc"] + (1 - data["100"]["ssps_video_acc"]) + data["100"]["ssps_coverage"]
    cost = 3 - cost

    res.append({
        'tau1': tau1,
        'tau2': tau2,
        'ssps_cost': cost,
        **data["100"]
    })

data = pd.DataFrame(res)

def create_plot(y, label):
    p = (
        ggplot(data, aes(x='tau1', y=y, color='factor(tau2)'))
        + geom_line()
        + geom_point()
        + geom_vline(xintercept=0.835, linetype='dashed', color='black')
        + labs(title=label, x='τ₁', y=None, color='τ₂')
        + scale_x_continuous(breaks=data['tau1'])
        # + theme_bw()
        + theme(
            figure_size=(6, 5),
            text=element_text(size=14),
            plot_title=element_text(
                ha="left",
                # x=0.535,
                margin={'b': 0 if y == 'ssps_speaker_acc' else 90}
            ),
            axis_text_x=element_text(angle=45, ha="right")
        )
    )
    p = pw.load_ggplot(p)
    return p


g_spkacc = create_plot('ssps_speaker_acc', 'Speaker Accuracy (%)')
g_vidacc = create_plot('ssps_video_acc', 'Video Accuracy (%)')
g_coverage = create_plot('ssps_coverage', 'Coverage (%)')
g_cost = create_plot('ssps_cost', 'SSPS cost')
g_interpool = create_plot('ssps_inter_sampling_pool', 'Inter-sampling pool size')
g_intrapool = create_plot('ssps_intra_sampling_pool', 'Intra-sampling pool size')
g_eer = create_plot('val/sv_cosine/voxceleb1_test_O/eer', 'EER (%)')
g_mindcf = create_plot('val/sv_cosine/voxceleb1_test_O/mindcf', 'minDCF (p=0.01)')

# p = (g_eer|g_mindcf)/(g_spkacc|g_vidacc|g_coverage)
# p = (g_eer|g_spkacc|g_vidacc)/(g_coverage|g_interpool|g_intrapool)
# p = (g_eer|g_mindcf)/(g_spkacc|g_vidacc)/(g_coverage|g_interpool|g_intrapool)
p = (g_eer|g_mindcf|g_cost)/(g_spkacc|g_vidacc|g_coverage)

p.set_suptitle(
    "SV and SSPS metrics with different thresholds",
    fontsize=20,
    pad=40
)
p.savefig()

In [None]:
from pathlib import Path
import shutil
import yaml
from tqdm import tqdm

cmds = []

for lamba in tqdm((1.2, 0.8, 0.5, 0.2, 0.0, -0.2, -0.5, -0.8, -1.2)):
    exp = f'models/ssps/voxceleb2/simclr/default_inter-10-{lamba}'

    # Create experiment folder
    Path(exp).mkdir(exist_ok=True)

    # Create config file
    with open('models/ssps/voxceleb2/simclr/default/config.yml') as f:
        data = yaml.safe_load(f)
    data['method']['ssps']['inter_sampling_prob_exp_lambda'] = lamba
    data['trainer']['epochs'] = 101
    with open(exp + '/config.yml', 'w') as f:
        yaml.dump(data, f)

    # Copy latest checkpoint
    (Path(exp) / "checkpoints").mkdir(exist_ok=True)
    # shutil.copy(
    #     "models/ssps/voxceleb2/simclr/model_base.pt",
    #     (Path(exp) / "checkpoints" / "model_latest.pt")
    # )

    cmds.append(f"\"{exp.split('/')[-1]}\"")

' '.join(cmds)