# Plots for SSLSV

In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
os.chdir('../..')
sys.path.insert(1, os.path.join(sys.path[0], '../..'))

In [None]:
import logging
logging.getLogger('fontTools').setLevel(logging.ERROR)
logging.getLogger('matplotlib').setLevel(logging.ERROR)

In [None]:
from notebooks.articles.utils import (
    evaluate_sv,
    plot_inter_class_similarity,
    plot_intra_class_similarity,
)

from notebooks.evaluation.sv_visualization import (
    det_curve,
)

from plotnine import *
import patchworklib as pw

import numpy as np
import pandas as pd

import json

In [None]:
from dataclasses import dataclass
from typing import List, Dict

import torch


@dataclass
class Model:

    scores: List[float] = None
    targets: List[int] = None
    embeddings: Dict[str, torch.Tensor] = None


def get_models_for_visualization(scores, names=None):
    if names is None:
        names = list(scores.keys())

    models = {
        k:Model(v['scores'], v['targets'])
        for k, v
        in scores.items()
        if k in names
    }

    return models

In [None]:
MODELS_RESNET = {
    "SimCLR":       "models/ssl/voxceleb2/simclr/simclr_proj-none_t-0.03/",
    "MoCo":         "models/ssl/voxceleb2/moco/moco_proj-none_Q-32768_t-0.03_m-0.999/",
    "SwAV":         "models/ssl/voxceleb2/swav/swav_proj-2048-BN-R-2048-BN-R-512_K-6000_t-0.1/",
    "VICReg":       "models/ssl/voxceleb2/vicreg/vicreg_proj-2048-BN-R-2048-BN-R-512_inv-1.0_var-1.0_cov-0.1/",
    "DINO":         "models/ssl/voxceleb2/dino/dino_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04/",
    "Supervised":   "models/ssl/voxceleb2/supervised/supervised_loss-AAM_s-30_m-0.2/",
}

MODELS_ECAPA = {
    "SimCLR":       "models/ssl/voxceleb2/simclr/simclr_enc-ECAPATDNN-1024_proj-none_t-0.03/",
    "MoCo":         "models/ssl/voxceleb2/moco/moco_enc-ECAPATDNN-1024_proj-none_Q-32768_t-0.03_m-0.999/",
    "SwAV":         "models/ssl/voxceleb2/swav/swav_enc-ECAPATDNN-1024_proj-2048-BN-R-2048-BN-R-512_K-6000_t-0.1/",
    "VICReg":       "models/ssl/voxceleb2/vicreg/vicreg_enc-ECAPATDNN-1024_proj-2048-BN-R-2048-BN-R-512_inv-1.0_var-1.0_cov-0.1/",
    "DINO":         "models/ssl/voxceleb2/dino/dino_enc-ECAPATDNN-1024_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04/",
    "Supervised":   "models/ssl/voxceleb2/supervised/supervised_enc-ECAPATDNN-1024_loss-AAM_s-30_m-0.2/",
}

## Palette

In [None]:
from mizani.palettes import hue_pal
import matplotlib.pyplot as plt

palette = hue_pal(h=0.01, l=0.6, s=0.65, color_space="hls")(len(MODELS_ECAPA.keys()))

plt.figure(figsize=(8, 1))
for i, color in enumerate(palette):
    plt.bar(i, 1, color=color)
plt.xticks(range(len(MODELS_ECAPA.keys())), MODELS_ECAPA.keys())
plt.show()

palette = dict(zip(MODELS_ECAPA.keys(), palette))
print(palette)

In [None]:
MODELS_ECAPA_ORDER = list(MODELS_ECAPA.keys())
MODELS_ECAPA_PALETTE = palette

MODELS_ECAPA_ORDER, MODELS_ECAPA_PALETTE

In [None]:
MODELS_ECAPA_PALETTE = {
    'SimCLR': '#db5f57',
    'MoCo': '#57db5f',
    'SwAV': '#d3db57',
    'VICReg': '#57d3db',
    'DINO': '#5f57db',
    'Supervised': '#01041a'
}

## Metrics

In [None]:
# vox1o_scores = evaluate_sv(MODELS_RESNET, 'embeddings_vox1_avg.pt', trials=["voxceleb1_test_O"])

In [None]:
vox1o_scores_ecapa = evaluate_sv(MODELS_ECAPA, 'embeddings_vox1_avg.pt', trials=["voxceleb1_test_O"])

In [None]:
vox1h_scores_ecapa = evaluate_sv(MODELS_ECAPA, 'embeddings_vox1_avg.pt', trials=["voxceleb1_test_H"])

## Complementarity (Correlation)

In [None]:
df = pd.DataFrame({model:vox1o_scores_ecapa[model]["scores"] for model in MODELS_ECAPA.keys()})
# df = pd.DataFrame({model:vox1h_scores_ecapa[model]["scores"] for model in MODELS_ECAPA.keys()})

correlation_matrix = df.corr()

corr_long = correlation_matrix.reset_index().melt(id_vars="index")
corr_long.columns = ["x", "y", "correlation"]
corr_long["x"] = pd.Categorical(corr_long["x"], categories=MODELS_ECAPA_ORDER, ordered=True)
corr_long["y"] = pd.Categorical(corr_long["y"], categories=MODELS_ECAPA_ORDER[::-1], ordered=True)
corr_long["label"] = corr_long["correlation"].map("{:.2f}".format)

p = (
    ggplot(corr_long, aes(x='x', y='y', fill='correlation'))
    + geom_tile(color='white')
    + geom_text(aes(label='label'), size=10)
    + scale_fill_gradient(
        low='#c2d1ff', high='#4a78ff',
        limits=(0.9, 1.0)
    )
    + labs(x="", y="", fill="Correlation")
    + theme_bw()
    + theme(
        figure_size=(5, 4.9),
        text=element_text(size=14),
        legend_title=element_blank(),
        legend_position="none",
        panel_border=element_blank(),
        axis_text_x=element_text(rotation=45, hjust=1)
    )
)

# p.save('correlation.pdf')

p

## Fusion (score-level)

In [None]:
from sslsv.evaluations.CosineSVEvaluation import SpeakerVerificationEvaluation, SpeakerVerificationEvaluationTaskConfig
from notebooks.evaluation.ScoreCalibration import ScoreCalibration


class FusedAndCalibratedSVEvaluation(SpeakerVerificationEvaluation):
    
    def __init__(self, train_evaluations, test_evaluations, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        self.evaluations = test_evaluations
        self.sc = ScoreCalibration(train_evaluations)

    def _prepare_evaluation(self):
        self.sc.train()
    
    def _get_sv_score(self, a, b):
        scores = [evaluation._get_sv_score(a, b) for evaluation in self.evaluations]
        score = self.sc.predict(torch.tensor(scores).unsqueeze(0))
        return score.detach().item()

In [None]:
MODELS_ECAPA_SSL = {k:v for k, v in MODELS_ECAPA.items() if k != "Supervised"}

In [None]:
vox2fusion_scores_ecapa = evaluate_sv(MODELS_ECAPA_SSL, 'embeddings_vox2f_avg.pt', trials=["voxceleb2_test_fusion"])

In [None]:
train_evals = [vox2fusion_scores_ecapa[model]["evaluation"] for model in MODELS_ECAPA_SSL]
test_evals = [vox1o_scores_ecapa[model]["evaluation"] for model in MODELS_ECAPA_SSL]

evaluation = FusedAndCalibratedSVEvaluation(
    train_evaluations=train_evals,
    test_evaluations=test_evals,
    model=None,
    config=test_evals[0].config,
    task_config=SpeakerVerificationEvaluationTaskConfig(
        trials=['voxceleb1_test_O', 'voxceleb1_test_E', 'voxceleb1_test_H']
        # trials=['voxceleb1_test_O']
    ),
    device='cpu'
)

evaluation.evaluate()

In [None]:
for i, model in enumerate(MODELS_ECAPA_SSL):
    print(model, evaluation.sc.model.W.weight[0, i].item())

In [None]:
evaluation.sc.model.W.weight, evaluation.sc.model.W.bias

## Fusion (representations-level)

In [None]:
from sslsv.evaluations.CosineSVEvaluation import CosineSVEvaluation, CosineSVEvaluationTaskConfig
import torch.nn.functional as F


class ReprConcatenationSVEvaluation(CosineSVEvaluation):
    
    def __init__(self, evaluations, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        self.evaluations = evaluations

    def _prepare_evaluation(self):
        self.test_embeddings = {}

        for k in self.evaluations[0].test_embeddings.keys():
            self.test_embeddings[k] = torch.cat([
                self.evaluations[i].test_embeddings[k]
                for i in range(len(self.evaluations))
            ], dim=-1)

In [None]:
test_evals = [vox1o_scores_ecapa[model]["evaluation"] for model in MODELS_ECAPA_SSL]

evaluation = ReprConcatenationSVEvaluation(
    evaluations=test_evals,
    model=None,
    config=test_evals[0].config,
    task_config=CosineSVEvaluationTaskConfig(
        # trials=['voxceleb1_test_O', 'voxceleb1_test_E', 'voxceleb1_test_H']
        trials=['voxceleb1_test_O']
    ),
    device='cpu'
)

evaluation.evaluate()

## DET

In [None]:
p = det_curve(get_models_for_visualization(vox1o_scores_ecapa))
p += scale_color_manual(values=MODELS_ECAPA_PALETTE, limits=MODELS_ECAPA_ORDER)
# p.save('det.pdf')
p

## Convergence

In [None]:
res = []
for name, path in MODELS_ECAPA.items():
    try:
        with open(f'{path}/training.json', "r") as f:
            train = json.load(f)
    except:
        continue
    for epoch, metrics in train.items():
        res.append({'Epoch': int(epoch), 'Model': name, **metrics})

data = pd.DataFrame(res)

def create_plot(y, label):
    p = (
        ggplot(data, aes(x='Epoch', y=y, color='factor(Model)'))
        + geom_line(size=1)
        # + geom_point()
        + scale_color_manual(values=MODELS_ECAPA_PALETTE, limits=MODELS_ECAPA_ORDER)
        + labs(x='Epoch', y=label, color='Models')
        + theme_bw()
        + theme(
            figure_size=(8, 4.75),
            text=element_text(size=14),
            legend_position='top',
            legend_title=element_blank(),
            legend_key_spacing_x=15
        )
        + guides(color=guide_legend(nrow=1))
    )
    return p


g_loss = create_plot('train/loss', 'Train loss')
g_eer = create_plot('val/sv_cosine/voxceleb1_test_O/eer', 'EER (%)')
g_mindcf = create_plot('val/sv_cosine/voxceleb1_test_O/mindcf', 'minDCF (p=0.01)')

p = g_eer

# p.save("convergence.pdf")
p

## Label-efficient

In [None]:
MODELS = {
    "SSL (DINO)": {
        100: "models/ssl/voxceleb2/dino/dino+_e-ecapa-1024_label-efficient-1.0/",
         50: "models/ssl/voxceleb2/dino/dino+_e-ecapa-1024_label-efficient-0.5/",
         20: "models/ssl/voxceleb2/dino/dino+_e-ecapa-1024_label-efficient-0.2/",
         10: "models/ssl/voxceleb2/dino/dino+_e-ecapa-1024_label-efficient-0.1/",
          5: "models/ssl/voxceleb2/dino/dino+_e-ecapa-1024_label-efficient-0.05/",
          2: "models/ssl/voxceleb2/dino/dino+_e-ecapa-1024_label-efficient-0.02/",
          1: "models/ssl/voxceleb2/dino/dino+_e-ecapa-1024_label-efficient-0.01/",
    },
    "Supervised": {
        100: "models/ssl/voxceleb2/supervised/supervised_enc-ECAPATDNN-1024_loss-AAM_s-30_m-0.2/",
         50: "models/ssl/voxceleb2/supervised/supervised_enc-ECAPATDNN-1024_loss-AAM_s-30_m-0.2_label-efficient-0.5/",
         20: "models/ssl/voxceleb2/supervised/supervised_enc-ECAPATDNN-1024_loss-AAM_s-30_m-0.2_label-efficient-0.2/",
         10: "models/ssl/voxceleb2/supervised/supervised_enc-ECAPATDNN-1024_loss-AAM_s-30_m-0.2_label-efficient-0.1/",
          5: "models/ssl/voxceleb2/supervised/supervised_enc-ECAPATDNN-1024_loss-AAM_s-30_m-0.2_label-efficient-0.05/",
          2: "models/ssl/voxceleb2/supervised/supervised_enc-ECAPATDNN-1024_loss-AAM_s-30_m-0.2_label-efficient-0.02/",
          1: "models/ssl/voxceleb2/supervised/supervised_enc-ECAPATDNN-1024_loss-AAM_s-30_m-0.2_label-efficient-0.01/",
    }
}

res = []
for name, entry in MODELS.items():
    for x, path in entry.items():
        with open(f'{path}/training.json', "r") as f:
            train = json.load(f)
        res.append({'x': x, 'Model': name, **train["99"]})

data = pd.DataFrame(res)

p = (
    ggplot(data, aes(x='x', y='val/sv_cosine/voxceleb1_test_O/eer', color='Model', group='Model'))
    + geom_line(size=1.25)
    + geom_point(size=3)
    
    + geom_segment(aes(x=100, y=2.5, xend=50, yend=2.5), size=0.75, color='#2b2b2b', arrow=arrow(type='closed', ends='last', length=0.1))
    + annotate("text", x=70, y=3.25, label='2x fewer\nlabels', color='#2b2b2b', size=12)
    
    + geom_segment(aes(x=5, y=4, xend=1, yend=4), size=0.75, color='#2b2b2b', arrow=arrow(type='closed', ends='last', length=0.1))
    + annotate("text", x=2.25, y=4.75, label='5x fewer\nlabels', color='#2b2b2b', size=12)

    + scale_colour_manual(values=["#4a78ff", "#01041a"])
    # + scale_colour_manual(values=["#57d3db", "#2db4bd", "#db5f57"])
    + scale_x_log10(breaks=[1, 2, 5, 10, 20, 50, 100])
    + scale_y_continuous(breaks=[1, 2, 4, 6, 8, 10])
    + xlab("% of labeled data")
    + ylab("EER (%)")
    + theme_bw()
    + theme(
        figure_size=(5, 4.75),
        text=element_text(size=14),
        legend_title=element_blank(),
        legend_position="top",
        legend_key_spacing_x=20
        # legend_background=element_rect(fill='white', alpha=1.0, linetype='solid', color='#ebebeb')
    )
)

# p.save('label_efficient.pdf')

p

## Data-augmentation

In [None]:
MODELS = {
    "MoCo": {
        100: "models/ssl/voxceleb2/moco/moco_proj-none_Q-32768_t-0.03_m-0.999/",
         75: "models/ssl/voxceleb2/moco/moco_proj-none_Q-32768_t-0.03_m-0.999_aug-75/",
         50: "models/ssl/voxceleb2/moco/moco_proj-none_Q-32768_t-0.03_m-0.999_aug-50/",
         25: "models/ssl/voxceleb2/moco/moco_proj-none_Q-32768_t-0.03_m-0.999_aug-25/",
          0: "models/ssl/voxceleb2/moco/moco_proj-none_Q-32768_t-0.03_m-0.999_aug-none/",
    },
    "DINO": {
        100: "models/ssl/voxceleb2/dino/dino_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04/",
         75: "models/ssl/voxceleb2/dino/dino_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04_aug-75/",
         50: "models/ssl/voxceleb2/dino/dino_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04_aug-50/",
         25: "models/ssl/voxceleb2/dino/dino_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04_aug-25/",
          0: "models/ssl/voxceleb2/dino/dino_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04_aug-none/",
    },
    "Supervised": {
        100: "models/ssl/voxceleb2/supervised/supervised_loss-AAM_s-30_m-0.2/",
         75: "models/ssl/voxceleb2/supervised/supervised_loss-AAM_s-30_m-0.2_aug-75/",
         50: "models/ssl/voxceleb2/supervised/supervised_loss-AAM_s-30_m-0.2_aug-50/",
         25: "models/ssl/voxceleb2/supervised/supervised_loss-AAM_s-30_m-0.2_aug-25/",
          0: "models/ssl/voxceleb2/supervised/supervised_loss-AAM_s-30_m-0.2_aug-none/",
    }
}

res = []
for name, entry in MODELS.items():
    for augprob, path in entry.items():
        with open(f'{path}/evaluation.json', "r") as f:
            eval = json.load(f)
        res.append({'AugProb': augprob, 'Model': name, **eval})

data = pd.DataFrame(res)

p = (
    ggplot(data, aes(x='AugProb', y='test/sv_cosine/voxceleb1_test_O/eer', color='factor(Model)'))
    + geom_line(size=1)
    + geom_point(size=2)
    + scale_color_manual(values=MODELS_ECAPA_PALETTE, limits=["MoCo", "DINO", "Supervised"])
    + labs(x='% of data-augmentation', y='EER (%)', color='Models')
    + theme_bw()
    + theme(
        figure_size=(8, 4.75),
        text=element_text(size=14),
        legend_position="top",
        legend_title=element_blank(),
        legend_key_spacing_x=20
    )
)
# p.save("data-aug.pdf")
p

## Intra/inter-speaker similarity

In [None]:
MODELS_ECAPA_PALETTE_ALPHA = {k:v + "B3" for k, v in MODELS_ECAPA_PALETTE.items()}
MODELS_ECAPA_PALETTE_ALPHA

In [None]:
MODELS = {k:f'{v}/embeddings_vox1_avg.pt' for k, v in MODELS_ECAPA.items()}
# MODELS.update({f'{k}-sup':f'{v}/embeddings_vox1_avg.pt' for k, v in MODELS_ECAPA.items() if k == 'Supervised'})

In [None]:
p, stats = plot_intra_class_similarity('speaker', MODELS)
p += scale_fill_manual(values=MODELS_ECAPA_PALETTE_ALPHA, limits=MODELS_ECAPA_ORDER)
# p.save("intra-speaker.pdf")
p

In [None]:
p, stats = plot_inter_class_similarity('speaker', MODELS, nb_samples=100)
p += scale_fill_manual(values=MODELS_ECAPA_PALETTE_ALPHA, limits=MODELS_ECAPA_ORDER)
# p.save("inter-speaker.pdf")
p

## Training distribution

In [None]:
MODELS = {
    "MoCo": {
        "Full":     "models/ssl/voxceleb2/moco/moco_proj-none_Q-32768_t-0.03_m-0.999/",
        "50% spk.": "models/ssl/voxceleb2/moco/moco_proj-none_Q-32768_t-0.03_m-0.999_train-half-spk/",
        "50% utt.": "models/ssl/voxceleb2/moco/moco_proj-none_Q-32768_t-0.03_m-0.999_train-half-utt/",
        "25% spk.": "models/ssl/voxceleb2/moco/moco_proj-none_Q-32768_t-0.03_m-0.999_train-quarter-spk/",
        "25% utt.": "models/ssl/voxceleb2/moco/moco_proj-none_Q-32768_t-0.03_m-0.999_train-quarter-utt/",
    },
    "DINO": {
        "Full":     "models/ssl/voxceleb2/dino/dino_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04/",
        "50% spk.": "models/ssl/voxceleb2/dino/dino_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04_train-half-spk/",
        "50% utt.": "models/ssl/voxceleb2/dino/dino_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04_train-half-utt/",
        "25% spk.": "models/ssl/voxceleb2/dino/dino_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04_train-quarter-spk/",
        "25% utt.": "models/ssl/voxceleb2/dino/dino_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04_train-quarter-utt/",
    },
    "Supervised": {
        "Full":     "models/ssl/voxceleb2/supervised/supervised_loss-AAM_s-30_m-0.2/",
        "50% spk.": "models/ssl/voxceleb2/supervised/supervised_loss-AAM_s-30_m-0.2_train-half-spk/",
        "50% utt.": "models/ssl/voxceleb2/supervised/supervised_loss-AAM_s-30_m-0.2_train-half-utt/",
        "25% spk.": "models/ssl/voxceleb2/supervised/supervised_loss-AAM_s-30_m-0.2_train-quarter-spk/",
        "25% utt.": "models/ssl/voxceleb2/supervised/supervised_loss-AAM_s-30_m-0.2_train-quarter-utt/",
    }
}

OPTIONS = {
    "Full": "#001233",
    "50% spk.": "#023e7d",
    "50% utt.": "#0466c8",
    "25% spk.": "#76c893",
    "25% utt.": "#d9ed92"
}

res = []
for name, entry in MODELS.items():
    for option, path in entry.items():
        with open(f'{path}/evaluation.json', "r") as f:
            eval = json.load(f)
        res.append({'Option': option, 'Model': name, **eval})

data = pd.DataFrame(res)
data['Option'] = pd.Categorical(data['Option'], categories=OPTIONS, ordered=True)

p = (
    ggplot(data, aes(x='factor(Model)', y='test/sv_cosine/voxceleb1_test_O/eer', fill='factor(Option)'))
    + geom_bar(stat='identity', position='dodge', width=0.7)
    + scale_fill_manual(values=OPTIONS, limits=list(OPTIONS.keys()))
    # + scale_fill_brewer(type="seq", palette="Blues", direction=-1)
    + scale_x_discrete(limits=["MoCo", "DINO", "Supervised"])
    + coord_cartesian(ylim=(2.0, 10.5))
    + labs(x='', y='EER (%)', fill='Training distribution')
    + theme_bw()
    + theme(
        figure_size=(8, 4.75),
        text=element_text(size=14),
        legend_title=element_blank(),
        legend_position="top",
        legend_key_spacing_x=20
    )
)
# p.save("training_distribution.pdf")
p

## NMI

In [None]:
MODELS = {
    "SimCLR": {
        "SSL pos. sampling":  "models/ssl/voxceleb2/simclr/simclr_proj-none_t-0.03/",
        "Supervised pos. sampling": "models/ssl/voxceleb2/simclr/simclr_proj-none_t-0.03_sup2/",
    },
    "MoCo": {
        "SSL pos. sampling":  "models/ssl/voxceleb2/moco/moco_proj-none_Q-32768_t-0.03_m-0.999/",
        "Supervised pos. sampling": "models/ssl/voxceleb2/moco/moco_proj-none_Q-32768_t-0.03_m-0.999_sup2/",
    },
    "SwAV": {
        "SSL pos. sampling":  "models/ssl/voxceleb2/swav/swav_proj-2048-BN-R-2048-BN-R-512_K-6000_t-0.1/",
        "Supervised pos. sampling": "models/ssl/voxceleb2/swav/swav_proj-2048-BN-R-2048-BN-R-512_K-6000_t-0.1_sup2/",
    },
    "VICReg": {
        "SSL pos. sampling":  "models/ssl/voxceleb2/vicreg/vicreg_proj-2048-BN-R-2048-BN-R-512_inv-1.0_var-1.0_cov-0.1/",
        "Supervised pos. sampling": "models/ssl/voxceleb2/vicreg/vicreg_proj-2048-BN-R-2048-BN-R-512_inv-1.0_var-1.0_cov-0.1_sup2/",
    },
    "DINO": {
        "SSL pos. sampling":  "models/ssl/voxceleb2/dino/dino_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04/",
        "Supervised pos. sampling": "models/ssl/voxceleb2/dino/dino_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04_sup2/",
    },
}

OPTIONS = {
    "SSL pos. sampling": "#01041a",
    "Supervised pos. sampling": "#4a78ff",
}

res = []
for name, entry in MODELS.items():
    for option, path in entry.items():
        with open(f'{path}/nmi.json', "r") as f:
            eval = json.load(f)
        res.append({
            'Option': option,
            'Model': name,
            'nmi_ratio': eval['vox1_nmi_speaker'] / eval['vox1_nmi_video']
        })

data = pd.DataFrame(res)
data['Option'] = pd.Categorical(data['Option'], categories=OPTIONS, ordered=True)

p = (
    ggplot(data, aes(x='factor(Model)', y='nmi_ratio', fill='factor(Option)'))
    + geom_bar(stat='identity', position='dodge', width=0.7)
    + scale_fill_manual(values=OPTIONS, limits=["SSL pos. sampling", "Supervised pos. sampling"])
    + scale_x_discrete(limits=["SimCLR", "MoCo", "SwAV", "VICReg", "DINO"])
    + coord_cartesian(ylim=(0.92, 1.12))
    + labs(x='', y='Speaker-to-Recording NMI Ratio', fill='Pos. sampling.')
    + theme_bw()
    + theme(
        figure_size=(8, 4.75),
        text=element_text(size=14),
        legend_title=element_blank(),
        legend_position="top",
        legend_key_spacing_x=20
    )
)
# p.save("nmi.pdf")
p

In [None]:
pivoted = data.pivot(index="Model", columns="Option", values="nmi_ratio")
pivoted["relative_improvement"] = (
    (pivoted["Supervised pos. sampling"] - pivoted["SSL pos. sampling"])
    / pivoted["SSL pos. sampling"]
)
pivoted["absolute_improvement"] = (
    (pivoted["Supervised pos. sampling"] - pivoted["SSL pos. sampling"])
)
pivoted["relative_improvement"].mean(), pivoted["absolute_improvement"].mean()

## Collapse

In [None]:
MODELS = {
    "Baseline": "models/ssl/voxceleb2/moco/moco_proj-none_Q-32768_t-0.03_m-0.999_collapse-default",
    "Without negs.": "models/ssl/voxceleb2/moco/moco_proj-none_Q-32768_t-0.03_m-0.999_collapse-nonegs",
    "High temp.": "models/ssl/voxceleb2/moco/moco_proj-none_Q-32768_t-0.03_m-0.999_collapse-hightemp",
    "Low temp.": "models/ssl/voxceleb2/moco/moco_proj-none_Q-32768_t-0.03_m-0.999_collapse-lowtemp",
}

res = []
for model, path in MODELS.items():
    with open(f'{path}/debug.json', "r") as f:
        debug = json.load(f)
    for step, metrics in debug.items():
        step = int(step)
        if step > 12000:
            break
        if step % 50 != 0:
            continue
        res.append({'Step': step, 'Model': model, **metrics})

data = pd.DataFrame(res)

def create_plot(y, label):
    p = (
        ggplot(data, aes(x='Step', y=y, color='factor(Model)'))
        + geom_line(size=1)
        # + geom_point()
        + labs(x='Step', y=label, color='Models')
        + theme_bw()
        + theme(
            figure_size=(6, 3.5),
            text=element_text(size=14),
            legend_title=element_blank(),
            legend_position="top",
            legend_key_spacing_x=7
        )
        + guides(color=guide_legend(nrow=1))
    )
    return p


g_loss = create_plot('train/loss', 'Loss')
g_h = create_plot('train/h', 'Contrastive Entropy')
g_std = create_plot('train/std', 'Embeddings Std.')


g_h.save("collapse_moco_h.pdf")
g_std.save("collapse_moco_std.pdf")

g_h, g_std

In [None]:
data[data["Model"] == "Without negs."]["train/h"].max()

In [None]:
MODELS = {
    "Baseline": "models/ssl/voxceleb2/dino/dino_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04_collapse-default",
    "Without centering": "models/ssl/voxceleb2/dino/dino_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04_collapse-nocentering",
    "Without sharpening": "models/ssl/voxceleb2/dino/dino_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04_collapse-nosharpening",
}

res = []
for model, path in MODELS.items():
    with open(f'{path}/debug.json', "r") as f:
        debug = json.load(f)
    for step, metrics in debug.items():
        step = int(step)
        if step > 25000:
            break
        if step % 100 != 0:
            continue
        res.append({'Step': step, 'Model': model, **metrics})

data = pd.DataFrame(res)

def create_plot(y, label):
    p = (
        ggplot(data, aes(x='Step', y=y, color='factor(Model)'))
        + geom_line(size=1)
        # + geom_point()
        + labs(x='Step', y=label, color='Models')
        # + coord_cartesian(xlim=(0, 5000))
        + theme_bw()
        + theme(
            figure_size=(6, 3.5),
            text=element_text(size=14),
            legend_title=element_blank(),
            legend_position="top",
            legend_key_spacing_x=10
        )
        + guides(color=guide_legend(nrow=1))
    )
    return p


g_h = create_plot('train/teacher_h', 'Teacher Entropy')
g_kl = create_plot('train/kl_div', 'Teacher-Student KL div.')
g_std = create_plot('train/teacher_std', 'Embeddings Std.')

g_h.save("collapse_dino_h.pdf")
g_kl.save("collapse_dino_kl.pdf")
g_std.save("collapse_dino_std.pdf")

g_h, g_kl, g_std

In [None]:
data[data["Model"] == "Without sharpening"]["train/teacher_h"].max()