In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
os.chdir('../..')
sys.path.insert(1, os.path.join(sys.path[0], '../..'))

In [None]:
MODELS = [
    'models/ssl/voxceleb2/simclr/simclr_proj-none_t-0.03/',
    'models/ssl/voxceleb2/simclr/simclr_proj-none_t-0.03_sup2/',
    'models/ssl/voxceleb2/moco/moco_proj-none_Q-32768_t-0.03_m-0.999/',
    'models/ssl/voxceleb2/moco/moco_proj-none_Q-32768_t-0.03_m-0.999_sup2/',
    'models/ssl/voxceleb2/swav/swav_proj-2048-BN-R-2048-BN-R-512_K-6000_t-0.1/',
    'models/ssl/voxceleb2/swav/swav_proj-2048-BN-R-2048-BN-R-512_K-6000_t-0.1_sup2/',
    'models/ssl/voxceleb2/vicreg/vicreg_proj-2048-BN-R-2048-BN-R-512_inv-1.0_var-1.0_cov-0.1/',
    'models/ssl/voxceleb2/vicreg/vicreg_proj-2048-BN-R-2048-BN-R-512_inv-1.0_var-1.0_cov-0.1_sup2/',
    'models/ssl/voxceleb2/dino/dino_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04/',
    'models/ssl/voxceleb2/dino/dino_proj-2048-BN-G-2048-BN-G-256-L2-65536_G-2x4_L-4x2_t-0.04_sup2/',
]

In [None]:
import torch
import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics import normalized_mutual_info_score
import json
from tqdm import tqdm


def compute_nmi(model, K=1251):
    embeddings_path = f'{model}/embeddings_vox1_avg.pt'

    embeddings = torch.load(embeddings_path)

    if len(embeddings) != 153516:
        print('Invalid embeddings', model, len(embeddings))

    X = np.concatenate(list(embeddings.values()))
    y_speaker = [y.split('/')[-3] for y in embeddings.keys()]
    y_video = [y.split('/')[-2] for y in embeddings.keys()]

    kmeans = KMeans(n_clusters=K, init='random', algorithm='lloyd', random_state=0).fit(X)

    nmi_speaker = normalized_mutual_info_score(y_speaker, kmeans.labels_)
    nmi_video = normalized_mutual_info_score(y_video, kmeans.labels_)

    return nmi_speaker, nmi_video

In [None]:
for model in tqdm(MODELS):
    nmi_speaker, nmi_video = compute_nmi(model)
    print(f'Model: {model} - NMI Speaker: {nmi_speaker} - NMI Video: {nmi_video} - Ratio: {nmi_speaker / nmi_video}')
    with open(f'{model}/nmi.json', 'w') as f:
        json.dump({
            "vox1_nmi_speaker": nmi_speaker,
            "vox1_nmi_video": nmi_video
        }, f, indent=4)