# Globals

In [None]:
import copy
import datetime
from collections import defaultdict

import matplotlib.colors
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import wandb

In [None]:
VALIDATION_DATASETS = ["imagenet", "imagenette", "imagewoof"]
RESNET50_MODELS = [
    "resnet50",
    "mocov3_resnet50",
    "vicreg_resnet50",
    "dino_resnet50",
    "clip_RN50",
]
VITB16_MODELS = [
    "vitb16",
    "mocov3_vit_base",
    "timm_vit_base_patch16_224.mae",
    "dino_vitb16",
    "clip_vitb16",
]
CLUSTERERS = [
    "KMeans",
    "AgglomerativeClustering",
    "AffinityPropagation",
    "SpectralClustering",
    "HDBSCAN",
    "OPTICS",
]
ALL_CLUSTERERS = copy.deepcopy(CLUSTERERS)
DISTANCE_METRICS = [
    "euclidean",
    "l1",
    "chebyshev",
    "cosine",
    "arccos",
    "braycurtis",
    "canberra",
]

In [None]:
DATASET2LS = {
    "imagenet": "-.",
    "imagenette": "--",
    "imagewoof": ":",
}

In [None]:
DEFAULT_PARAMS = {
    "all": {
        "dim_reducer": "None",
        "dim_reducer_man": "None",
        "zscore": False,
        "normalize": False,
        "zscore2": False,
        "ndim_correction": False,
    },
    "KMeans": {"clusterer": "KMeans"},
    "AffinityPropagation": {
        "clusterer": "AffinityPropagation",
        "affinity_damping": 0.5,
        "affinity_conv_iter": 15,
    },
    "SpectralClustering": {
        "clusterer": "SpectralClustering",
        "spectral_assigner": "kmeans",
    },
    "AgglomerativeClustering": {
        "clusterer": "AgglomerativeClustering",
        "distance_metric": "euclidean",
        "aggclust_linkage": "ward",
    },
    "HDBSCAN": {
        "clusterer": "HDBSCAN",
        "hdbscan_method": "eom",
        "min_samples": 5,
        "max_samples": 0.2,
        "distance_metric": "euclidean",
    },
    "OPTICS": {
        "clusterer": "OPTICS",
        "optics_method": "xi",
        "optics_xi": 0.05,
        "distance_metric": "euclidean",
    },
}

## Set best params

These were discovered by the search that follows, but we move the definitions here to make the code more modular, so you don't have to re-run the hyperparameter search analysis code to define these hyperparameters.

### Num dims

In [None]:
models = RESNET50_MODELS + VITB16_MODELS
BEST_PARAMS = {
    clusterer: {model: copy.deepcopy(DEFAULT_PARAMS[clusterer]) for model in models}
    for clusterer in ALL_CLUSTERERS
}

# KMeans
# Use UMAP (num dims unimportant; we select 50d for consistency) for every encoder except
# - clip_RN50 : a little better to use PCA with 500d than UMAP. UMAP beats PCA if you
#   reduce the PCA dims below 500.
# - clip_vitb16 : same behaviour as clip_RN50
# - timm_vit_base_patch16_224.mae : best is PCA 0.85 variance explained. Need at least
#   200 PCA dims, and PCA perf beats UMAP throughout

for model in RESNET50_MODELS + VITB16_MODELS:
    if model.startswith("clip") or model == "timm_vit_base_patch16_224.mae":
        continue
    BEST_PARAMS["KMeans"][model].update(
        {"dim_reducer_man": "UMAP", "ndim_reduced_man": 50}
    )

BEST_PARAMS["KMeans"]["clip_RN50"].update(
    {"dim_reducer": "PCA", "ndim_reduced": 500, "zscore": True, "pca_variance": None}
)
BEST_PARAMS["KMeans"]["clip_vitb16"].update(
    {"dim_reducer": "PCA", "ndim_reduced": 500, "zscore": True, "pca_variance": None}
)
BEST_PARAMS["KMeans"]["timm_vit_base_patch16_224.mae"].update(
    {"dim_reducer": "PCA", "pca_variance": 0.85, "zscore": True, "ndim_reduced": None}
)

# AffinityPropagation
# Use PCA with 10 dims for every encoder except
# - resnet50 (supervised) : original embeddings, no reduction (AMI=0.62);
#   perf gets worse if they are whitened (AMI=0.55) and although the perf increases
#   as num dims are reduced it doesn't quite recover. PCA perf peaks at 10-20 dim (AMI=0.57).
# - dino_resnet50 : does marginally better at UMAP 50 (AMI=0.52495) than PCA 10 (AMI=0.5044)
# - timm_vit_base_patch16_224.mae : PCA 0.95 variance explained (AMI=0.303).
#   Definite improvement from 10 to 20 dims, but not much improvement above that.

for model in models:
    if model in ["resnet50", "dino_resnet50", "timm_vit_base_patch16_224.mae"]:
        continue
    BEST_PARAMS["AffinityPropagation"][model].update(
        {
            "dim_reducer": "PCA",
            "ndim_reduced": 10,
            "zscore": True,
            "pca_variance": None,
            "dim_reducer_man": "None",
        }
    )

BEST_PARAMS["AffinityPropagation"]["resnet50"].update(
    {"dim_reducer": "None", "dim_reducer_man": "None", "zscore": False}
)
BEST_PARAMS["AffinityPropagation"]["dino_resnet50"].update(
    {
        "dim_reducer": "PCA",
        "pca_variance": 0.95,
        "zscore": True,
        "ndim_reduced": None,
        "dim_reducer_man": "None",
    }
)
BEST_PARAMS["AffinityPropagation"]["timm_vit_base_patch16_224.mae"].update(
    {
        "dim_reducer": "PCA",
        "pca_variance": 0.95,
        "zscore": True,
        "ndim_reduced": None,
        "dim_reducer_man": "None",
    }
)

# AgglomerativeClustering
# Use UMAP (num dims unimportant; we select 50d for consistency) for every encoder except
# - timm_vit_base_patch16_224.mae : PCA 0.98 variance explained (i.e. nearly all
#   dimensions kept), which is not noticably better than using 500 dim PCA but there is
#   an increase compared to using less than 500d.

for model in models:
    if model == "timm_vit_base_patch16_224.mae":
        continue
    BEST_PARAMS["AgglomerativeClustering"][model].update(
        {"dim_reducer_man": "UMAP", "ndim_reduced_man": 50, "dim_reducer": "None"}
    )

BEST_PARAMS["AgglomerativeClustering"]["timm_vit_base_patch16_224.mae"].update(
    {
        "dim_reducer": "PCA",
        "pca_variance": 0.98,
        "zscore": True,
        "ndim_reduced": None,
        "dim_reducer_man": "None",
    }
)

# HDBSCAN
# Use UMAP for every encoder except
# - timm_vit_base_patch16_224.mae : PCA 0.95 variance explained (AMI=0.085) which is
#   not noticably better than PCA with 50 dim

for model in models:
    if model in ["timm_vit_base_patch16_224.mae"]:
        continue
    BEST_PARAMS["HDBSCAN"][model].update(
        {"dim_reducer_man": "UMAP", "ndim_reduced_man": 50, "dim_reducer": "None"}
    )

BEST_PARAMS["HDBSCAN"]["timm_vit_base_patch16_224.mae"].update(
    {
        "dim_reducer": "PCA",
        "pca_variance": 0.95,
        "zscore": True,
        "ndim_reduced": None,
        "dim_reducer_man": "None",
    }
)

# OPTICS
# Use UMAP for every encoder, no exceptions necessary
for model in models:
    BEST_PARAMS["OPTICS"][model].update(
        {"dim_reducer_man": "UMAP", "ndim_reduced_man": 50, "dim_reducer": "None"}
    )

In [None]:
BEST_PARAMS_v1 = copy.deepcopy(BEST_PARAMS)
BEST_PARAMS_v2 = BEST_PARAMS

print("Updating dim choices for new method")
# Updated dim choices
# (changed to this when we swapped to using weighted average instead of straight
# average between Imagenet-1k, Imagenette, Imagewoof)

# Changed KMeans clip_RN50 from PCA 500 to UMAP 50, so it uses fewer dimensions
# (probably more stable than using 500-d which is what PCA needs to marginally beat UMAP)
BEST_PARAMS_v2["KMeans"]["clip_RN50"].update(
    {"dim_reducer": None, "ndim_reduced": None, "zscore": False, "pca_variance": None}
)
BEST_PARAMS_v2["KMeans"]["clip_RN50"].update(
    {"dim_reducer_man": "UMAP", "ndim_reduced_man": 50}
)
# Changed KMeans MAE from PCA 85% to PCA 200
# (since we see perf above plateaus at 200-d, there is no point going above that)
BEST_PARAMS_v2["KMeans"]["timm_vit_base_patch16_224.mae"].update(
    {"dim_reducer": "PCA", "zscore": True, "ndim_reduced": 200, "pca_variance": None}
)
# Changed KMeans clip_vitb16 from PCA 500 to PCA 75%
# (gives a notably better train set AMI measurement above)
BEST_PARAMS_v2["KMeans"]["clip_vitb16"].update(
    {"dim_reducer": "PCA", "zscore": True, "pca_variance": 0.75, "ndim_reduced": None}
)

# Changed AffinityPropagation dino_resnet50 from PCA 95% to PCA 10
# (performance is basically equal, so no point using higher-dim space;
# could have done UMAP 50 instead with basically equal train AMI to PCA 10,
# but didn't for consistency with other models)
BEST_PARAMS_v2["AffinityPropagation"]["dino_resnet50"].update(
    {"dim_reducer": "PCA", "zscore": True, "ndim_reduced": 10, "pca_variance": None}
)
# Changed AffinityPropagation MAE from PCA 95% to PCA 100
BEST_PARAMS_v2["AffinityPropagation"]["timm_vit_base_patch16_224.mae"].update(
    {"dim_reducer": "PCA", "zscore": True, "ndim_reduced": 100, "pca_variance": None}
)

### Agglomerative specific settings

In [None]:
for model in [
    "resnet50",
    "mocov3_resnet50",
    "vicreg_resnet50",
    "vitb16",
    "timm_vit_base_patch16_224.mae",
]:
    BEST_PARAMS_v1["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "euclidean",
            "aggclust_linkage": "ward",
        }
    )
for model in ["dino_resnet50", "clip_RN50", "dino_vitb16"]:
    BEST_PARAMS_v1["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "euclidean",
            "aggclust_linkage": "average",
        }
    )
for model in ["mocov3_vit_base", "clip_vitb16"]:
    BEST_PARAMS_v1["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "chebyshev",
            "aggclust_linkage": "average",
        }
    )

In [None]:
# vicreg_resnet50 is the only change from v1 to v2
for model in ["resnet50", "mocov3_resnet50", "vitb16", "timm_vit_base_patch16_224.mae"]:
    BEST_PARAMS_v2["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "euclidean",
            "aggclust_linkage": "ward",
        }
    )
for model in ["vicreg_resnet50", "dino_resnet50", "clip_RN50", "dino_vitb16"]:
    BEST_PARAMS_v2["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "euclidean",
            "aggclust_linkage": "average",
        }
    )
for model in ["mocov3_vit_base", "clip_vitb16"]:
    BEST_PARAMS_v2["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "chebyshev",
            "aggclust_linkage": "average",
        }
    )

In [None]:
BEST_PARAMS_v1["AC w/ C"] = copy.deepcopy(BEST_PARAMS_v1["AgglomerativeClustering"])
BEST_PARAMS_v1["AC w/o C"] = copy.deepcopy(BEST_PARAMS_v1["AgglomerativeClustering"])
BEST_PARAMS_v2["AC w/ C"] = copy.deepcopy(BEST_PARAMS_v2["AgglomerativeClustering"])
BEST_PARAMS_v2["AC w/o C"] = copy.deepcopy(BEST_PARAMS_v2["AgglomerativeClustering"])

In [None]:
for model in RESNET50_MODELS + VITB16_MODELS:
    BEST_PARAMS_v1["AC w/ C"][model].update({"aggclust_dist_thresh": None})
    BEST_PARAMS_v2["AC w/ C"][model].update({"aggclust_dist_thresh": None})

In [None]:
for model in RESNET50_MODELS + VITB16_MODELS:
    BEST_PARAMS_v2["AC w/o C"][model].update(
        {"zscore2": "average", "ndim_correction": True}
    )

In [None]:
# Run AgglomerativeClustering experiments with number of clusters unknown
# 	resnet50        	20.0
# 	mocov3_resnet50 	20.0
# 	vicreg_resnet50 	20.0
# 	vitb16 	            20.0
# 	dino_resnet50     	 1.0
# 	clip_RN50 	         1.0
# 	dino_vitb16 	     2.0
# 	mocov3_vit_base 	 1.0
# 	clip_vitb16 	     0.5
# 	timm_vit_base_patch16_224.mae 	200.0

for model in ["resnet50", "mocov3_resnet50", "vicreg_resnet50", "vitb16"]:
    BEST_PARAMS_v1["AC w/o C"][model].update({"aggclust_dist_thresh": 20.0})
for model in ["dino_resnet50", "clip_RN50", "mocov3_vit_base"]:
    BEST_PARAMS_v1["AC w/o C"][model].update({"aggclust_dist_thresh": 1.0})
BEST_PARAMS_v1["AC w/o C"]["dino_vitb16"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v1["AC w/o C"]["clip_vitb16"]["aggclust_dist_thresh"] = 0.5
BEST_PARAMS_v1["AC w/o C"]["timm_vit_base_patch16_224.mae"][
    "aggclust_dist_thresh"
] = 200.0

In [None]:
BEST_PARAMS_v2["AC w/o C"]["resnet50"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v2["AC w/o C"]["mocov3_resnet50"]["aggclust_dist_thresh"] = 10.0
BEST_PARAMS_v2["AC w/o C"]["vicreg_resnet50"]["aggclust_dist_thresh"] = 0.5
BEST_PARAMS_v2["AC w/o C"]["dino_resnet50"]["aggclust_dist_thresh"] = 0.5
BEST_PARAMS_v2["AC w/o C"]["clip_RN50"]["aggclust_dist_thresh"] = 0.5
BEST_PARAMS_v2["AC w/o C"]["vitb16"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v2["AC w/o C"]["mocov3_vit_base"]["aggclust_dist_thresh"] = 1.0
BEST_PARAMS_v2["AC w/o C"]["timm_vit_base_patch16_224.mae"][
    "aggclust_dist_thresh"
] = 5.0
BEST_PARAMS_v2["AC w/o C"]["dino_vitb16"]["aggclust_dist_thresh"] = 0.2
BEST_PARAMS_v2["AC w/o C"]["clip_vitb16"]["aggclust_dist_thresh"] = 1.0

### HDBSCAN

In [None]:
for model in RESNET50_MODELS + VITB16_MODELS:
    BEST_PARAMS_v1["HDBSCAN"][model].update(
        {
            "distance_metric": "euclidean",
            "hdbscan_method": "eom",
        }
    )

v2 selection

|    | model                         | distance_metric   | hdbscan_method   |      AMI |
|---:|:------------------------------|:------------------|:-----------------|---------:|
|  0 | resnet50                      | euclidean         | eom              | 0.828368 |
|  1 | mocov3_resnet50               | euclidean         | eom              | 0.531644 |
|  2 | vicreg_resnet50               | l1                | eom              | 0.472324 |
|  3 | dino_resnet50                 | l1                | eom              | 0.503147 |
|  4 | clip_RN50                     | l1                | eom              | 0.461363 |
|  5 | vitb16                        | chebyshev         | eom              | 0.906110 |
|  6 | mocov3_vit_base               | euclidean         | eom              | 0.629966 |
|  7 | timm_vit_base_patch16_224.mae | euclidean         | eom              | 0.070495 |
|  8 | dino_vitb16                   | l1                | eom              | 0.691547 |
|  9 | clip_vitb16                   | l1                | eom              | 0.592489 |

In [None]:
for model in RESNET50_MODELS + VITB16_MODELS:
    BEST_PARAMS_v2["HDBSCAN"][model].update(
        {
            "distance_metric": "euclidean",
            "hdbscan_method": "eom",
        }
    )
for model in [
    "vicreg_resnet50",
    "dino_resnet50",
    "clip_RN50",
    "dino_vitb16",
    "clip_vitb16",
]:
    BEST_PARAMS_v2["HDBSCAN"][model].update(
        {
            "distance_metric": "l1",
        }
    )
BEST_PARAMS_v2["HDBSCAN"]["vitb16"]["distance_metric"] = "chebyshev"

### Define placeholder for interactive hp search

In [None]:
BEST_PARAMS_vX = {
    clusterer: {
        model: copy.deepcopy(DEFAULT_PARAMS[clusterer])
        for model in RESNET50_MODELS + VITB16_MODELS
    }
    for clusterer in ALL_CLUSTERERS
}
BEST_PARAMS = BEST_PARAMS_vX

## Utility functions

In [None]:
def categorical_cmap(nc, nsc, cmap="tab10", continuous=False):
    """
    Create a colormap with a certain number of shades of colours.

    https://stackoverflow.com/a/47232942/1960959
    """
    if nc > plt.get_cmap(cmap).N:
        raise ValueError("Too many categories for colormap.")
    if continuous:
        ccolors = plt.get_cmap(cmap)(np.linspace(0, 1, nc))
    else:
        ccolors = plt.get_cmap(cmap)(np.arange(nc, dtype=int))
    cols = np.zeros((nc * nsc, 3))
    for i, c in enumerate(ccolors):
        chsv = matplotlib.colors.rgb_to_hsv(c[:3])
        arhsv = np.tile(chsv, nsc).reshape(nsc, 3)
        arhsv[:, 1] = np.linspace(chsv[1], 0.25, nsc)
        arhsv[:, 2] = np.linspace(chsv[2], 1, nsc)
        rgb = matplotlib.colors.hsv_to_rgb(arhsv)
        cols[i * nsc : (i + 1) * nsc, :] = rgb
    cmap = matplotlib.colors.ListedColormap(cols)
    return cmap

In [None]:
categorical_cmap(len(RESNET50_MODELS), len(VALIDATION_DATASETS))

In [None]:
def select_rows(df, filters, allow_missing=True):
    select = np.ones(len(df), dtype=bool)
    for col, val in filters.items():
        if col == "dataset":
            col = "dataset_name"
        if col == "clusterer":
            col = "clusterer_name"
        if val is None or val == "None" or val == "none":
            select_i = pd.isna(df[col])
            select_i |= df[col] == "None"
            select_i |= df[col] == "none"
        else:
            select_i = df[col] == val
            select_i |= df[col] == str(val)
            if allow_missing or val == "None" or val == "none":
                select_i |= pd.isna(df[col])
        select &= select_i
    return df[select]

In [None]:
def find_differing_columns(df, cols=None):
    if cols is None:
        cols = df.columns
    my_cols = []
    for col in cols:
        if col not in df.columns:
            continue
        if df[col].nunique(dropna=False) > 1:
            my_cols.append(col)
    return my_cols

In [None]:
def filter2command(*filters, partition="val"):
    f = {}
    for filter in filters:
        for k, v in filter.items():
            f[k] = v
    dataset = f.get("dataset", "")
    clusterer = f.get("clusterer", "")
    mem = 4
    if dataset != "imagenet":
        pass
    elif clusterer == "AgglomerativeClustering":
        mem = 20
    if partition == "val":
        seed = 100
    elif partition == "test":
        seed = 1
    else:
        seed = 0
    s = (
        f"sbatch --array={seed} --mem={mem}G"
        f' --job-name="zsc-{f.get("model", "")}-{dataset}-{clusterer}"'
        f" slurm/cluster.slrm --partition={partition}"
    )
    for k, v in f.items():
        if v is None:
            continue
        if k == "zscore":
            if v == "False" or not v:
                s += " --no-zscore"
            elif v == "True" or v:
                s += " --zscore"
            continue
        if k == "normalize":
            if v == "False" or not v:
                pass
            elif v == "True" or v:
                s += " --normalize"
            continue
        if k == "zscore2":
            if v == "False" or not v:
                s += " --no-zscore2"
            elif v == "average":
                s += " --azscore2"
            elif v == "standard" or v:
                s += " --zscore2"
            continue
        if k == "ndim_correction":
            if v == "False" or not v:
                s += " --no-ndim-correction"
            elif v == "True" or v:
                s += " --ndim-correction"
            continue
        s += f" --{k.replace('_', '-')}={v}"
    return s

# Final results

In [None]:
TEST_DATASETS = [
    "imagenet",
    "cifar10",
    "cifar100",
    "mnist",
    "fashionmnist",
    "svhn",
    "flowers102",
    "aircraft",
    "nabirds",
    "inaturalist",
]
DATASET2SH = {
    "aircraft": "Aircraft",
    "cifar10": "C10",
    "cifar100": "C100",
    "flowers102": "Flowers",
    "fashionmnist": "fMNIST",
    "imagenet": "IN1k",
    "imagenette": "IN10",
    "imagewoof": "INwf",
    "inaturalist": "iNat21",
    "mnist": "MNIST",
    "nabirds": "NABirds",
    "svhn": "SVHN",
}
MODEL_GROUPS = {
    "ResNet-50": RESNET50_MODELS,
    "ViT-B": VITB16_MODELS,
}
MODEL2SH = {
    "resnet50": "Supervised",
    "mocov3_resnet50": "MoCo-v3",
    "vicreg_resnet50": "VICReg",
    "dino_resnet50": "DINO",
    "clip_RN50": "CLIP",
    "vitb16": "Supervised",
    "mocov3_vit_base": "MoCo-v3",
    "timm_vit_base_patch16_224.mae": "MAE",
    "dino_vitb16": "DINO",
    "clip_vitb16": "CLIP",
}
CLUSTERER2SH = {
    "KMeans": "K-Means",
    "AffinityPropagation": "Affinity Prop",
    "AgglomerativeClustering": "AC",
    "AC w/ C": "AC w/  C",
}

## Fetch results

In [None]:
# Project is specified by <entity/project-name>
api = wandb.Api()
runs_test = api.runs(
    "uoguelph_mlrg/zs-ssl-clustering",
    filters={"state": "Finished", "config.partition": "test"},
)
len(runs_test)

In [None]:
summary_list, config_list, name_list = [], [], []
for run in runs_test:
    # .summary contains the output keys/values for metrics like accuracy.
    #  We call ._json_dict to omit large files
    summary_list.append(run.summary._json_dict)
    # .config contains the hyperparameters.
    #  We remove special values that start with _.
    config_list.append({k: v for k, v in run.config.items() if not k.startswith("_")})
    # .name is the human-readable name of the run.
    name_list.append(run.name)

rows = []
config_keys = set()
summary_keys = set()
for summary, config, name in zip(summary_list, config_list, name_list):
    row = {"name": name}
    row.update({k: v for k, v in config.items() if not k.startswith("_")})
    row.update({k: v for k, v in summary.items() if not k.startswith("_")})
    row["_timestamp"] = summary["_timestamp"]
    rows.append(row)
    config_keys = config_keys.union(config.keys())
    summary_keys = summary_keys.union(summary.keys())

test_runs_df = pd.DataFrame.from_records(rows)

# Handle changed default value for spectral_assigner after config arg was introduced
if "spectral_assigner" not in test_runs_df.columns:
    test_runs_df["spectral_assigner"] = None
select = test_runs_df["clusterer_name"] != "SpectralClustering"
test_runs_df.loc[select, "spectral_assigner"] = None
select = (test_runs_df["clusterer_name"] == "SpectralClustering") & pd.isna(
    test_runs_df["spectral_assigner"]
)
test_runs_df.loc[select, "spectral_assigner"] = "kmeans"

if "zscore2" not in test_runs_df.columns:
    test_runs_df["zscore2"] = False
test_runs_df.loc[pd.isna(test_runs_df["zscore2"]), "zscore2"] = False

if "ndim_correction" not in test_runs_df.columns:
    test_runs_df["ndim_correction"] = False
test_runs_df.loc[pd.isna(test_runs_df["ndim_correction"]), "ndim_correction"] = False

In [None]:
config.keys()

In [None]:
config_keys = config_keys.difference(
    {"workers", "memory_avail_GB", "memory_total_GB", "memory_slurm"}
)

In [None]:
test_runs_df

In [None]:
list(test_runs_df["dataset_name"].unique())

In [None]:
metric_key = "AMI"
show_pc = True
show_fmt = "{:5.1f}"
eps = 0.001
override_fields = {
    # "aggclust_dist_thresh": None,  # to flip between unknown/known n clusters for AC
    "predictions_dir": "y_pred",
}
BEST_PARAMS = BEST_PARAMS_v2

# KMeans  AffinityPropagation  AgglomerativeClustering  HDBSCAN
clusterer = "AgglomerativeClustering"

best_results = {k: [] for k in TEST_DATASETS}
for dummy in [True, False]:
    cmds = []
    latex_table = r"% Results for " + f"{metric_key}, {clusterer}" + "\n"
    now_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    latex_table += r"% Generated " + now_str + "\n"
    latex_table += r"\label{tab:" + clusterer + r"}" + "\n"
    latex_table += r"\resizebox{\textwidth}{!}{%" + "\n"
    latex_table += r"\begin{tabular}{ll" + r"r" * len(TEST_DATASETS) + r"}" + "\n"
    latex_table += r"\toprule" + "\n"
    latex_table += r"& " + f"{'Encoder':<11s}"
    for dataset in TEST_DATASETS:
        latex_table += r"&" + "{:^15s}".format(DATASET2SH.get(dataset, dataset))
    latex_table += r"\\" + "\n"
    latex_table += r"\toprule" + "\n"
    for i_group, model_group_name in enumerate(list(MODEL_GROUPS.keys())):
        if i_group > 0:
            latex_table += r"\midrule" + "\n"
        for i_model, model in enumerate(MODEL_GROUPS[model_group_name]):
            if i_model == 0:
                latex_table += (
                    r"\parbox[t]{2mm}{\multirow{5}{*}{\rotatebox[origin=c]{90}{"
                    + model_group_name
                    + "}}}"
                )
                latex_table += "\n"
            latex_table += f"& {MODEL2SH.get(model, model):<10s}"
            for i_dataset, dataset in enumerate(TEST_DATASETS):
                latex_table += " &"
                filter = {
                    "model": model,
                    "dataset": dataset,
                    "clusterer": clusterer,
                }
                sdf = select_rows(test_runs_df, filter, allow_missing=False)
                filter2 = dict(DEFAULT_PARAMS["all"], **BEST_PARAMS[clusterer][model])
                filter2 = {k: v for k, v in filter2.items() if k not in filter}
                filter2.update(override_fields)
                sdf = select_rows(sdf, filter2, allow_missing=False)
                if len(sdf) < 1:
                    print(f"No data for {filter} {filter2}")
                    if clusterer == "AffinityPropagation" and dataset in [
                        "imagenet",
                        "inaturalist",
                    ]:
                        continue
                        pass
                    cmds.append(filter2command(filter, filter2, partition="test"))
                    continue
                if len(sdf) > 1:
                    perf = sdf.iloc[0]["AMI"]
                    if sum(sdf["AMI"] != perf) > 0:
                        print()
                        print("More than one result with AMIs:", list(sdf["AMI"]))
                        print(f"for search {filter}\nand {filter2}")
                        dif_cols = find_differing_columns(sdf, config_keys)
                        print(f"columns which differ: {dif_cols}")
                        if dif_cols:
                            for col in dif_cols:
                                print(f"  {col}: {list(sdf[col])}")
                my_val = np.median(sdf[metric_key])
                if dummy:
                    best_results[dataset].append(my_val)
                    continue
                is_best = my_val + eps >= np.max(best_results[dataset])
                if len(best_results[dataset]) > 1:
                    is_secd = my_val + eps >= np.sort(best_results[dataset])[-2]
                else:
                    is_secd = False
                if show_pc:
                    my_val = my_val * 100
                latex_table += " $"
                if is_best:
                    latex_table += r"\tcf{"
                elif is_secd:
                    latex_table += r"\tcs{"
                else:
                    latex_table += "     "
                latex_table += show_fmt.format(my_val)
                latex_table += r"}" if is_best or is_secd else " "
                latex_table += "$"
            latex_table += r" \\" + "\n"
    latex_table += r"\bottomrule" + "\n"
    latex_table += r"\end{tabular}" + "\n"
    latex_table += r"}" + "\n"


if len(cmds) > 0:
    print()
for cmd in cmds:
    print(cmd)

print()
print("Done!")
print()
print(f"Here is your results table for {clusterer}:")
print()
print()
print(latex_table)

## Grouping by encoder

In [None]:
metric_key = "AMI"
show_pc = True
show_fmt = "{:4.0f}"
eps = 0.001
override_fields = {
    # "aggclust_dist_thresh": None,  # to flip between unknown/known n clusters for AC
    "predictions_dir": "y_pred",
}
BEST_PARAMS = BEST_PARAMS_v2

backbone = "ResNet-50"

CLUSTERERS = [
    "KMeans",
    "AgglomerativeClustering",
    "AgglomerativeClustering",
    "AffinityPropagation",
    "HDBSCAN",
]
print(MODEL2SH)

best_results = {k: [] for k in TEST_DATASETS}
for dummy in [True, False]:
    cmds = []
    latex_table = r"% Results for " + f"{metric_key}, {backbone}" + "\n"
    now_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    latex_table += r"% Generated " + now_str + "\n"
    latex_table += r"\label{tab:" + backbone + r"}" + "\n"
    latex_table += r"\resizebox{\textwidth}{!}{%" + "\n"
    latex_table += r"\begin{tabular}{ll" + r"r" * len(TEST_DATASETS) + r"}" + "\n"
    latex_table += r"\toprule" + "\n"
    latex_table += r"& " + f"{'Clusterer':<11s}"
    for dataset in TEST_DATASETS:
        latex_table += r"&" + "{:^15s}".format(DATASET2SH.get(dataset, dataset))
    latex_table += r"\\" + "\n"
    latex_table += r"\toprule" + "\n"
    print(MODEL_GROUPS[backbone])
    for i_group, model in enumerate(list(MODEL_GROUPS[backbone])):
        print(model)
        if i_group > 0:
            latex_table += r"\midrule" + "\n"

        first_agg = True
        for i_clusters, clusterer in enumerate(CLUSTERERS):
            if i_clusters == 0:
                latex_table += (
                    r"\parbox[t]{2mm}{\multirow{5}{*}{\rotatebox[origin=c]{90}{"
                    + MODEL2SH[model]
                    + "}}}"
                )
                latex_table += "\n"
            override_fields = {}
            clusterername = CLUSTERER2SH.get(clusterer, clusterer)
            if first_agg and clusterer == "AgglomerativeClustering":
                first_agg = False
                override_fields = {"aggclust_dist_thresh": None}
                clusterername = "AC  w/ C"
            elif clusterer == "AgglomerativeClustering":
                clusterername = "AC w/o C"
            latex_table += f"& {clusterername:<10s}"
            for i_dataset, dataset in enumerate(TEST_DATASETS):
                latex_table += " &"
                filter = {
                    "model": model,
                    "dataset": dataset,
                    "clusterer": clusterer,
                }
                sdf = select_rows(test_runs_df, filter, allow_missing=False)
                filter2 = dict(DEFAULT_PARAMS["all"], **BEST_PARAMS[clusterer][model])
                filter2 = {k: v for k, v in filter2.items() if k not in filter}
                filter2.update(override_fields)
                sdf = select_rows(sdf, filter2, allow_missing=False)
                if len(sdf) < 1:
                    print(f"No data for {filter} {filter2}")
                    cmds.append(filter2command(filter, filter2, partition="test"))
                    continue
                if len(sdf) > 1:
                    perf = sdf.iloc[0]["AMI"]
                    if sum(sdf["AMI"] != perf) > 0:
                        print()
                        print("More than one result with AMIs:", list(sdf["AMI"]))
                        print(f"for search {filter}\nand {filter2}")
                        dif_cols = find_differing_columns(sdf, config_keys)
                        print(f"columns which differ: {dif_cols}")
                        if dif_cols:
                            for col in dif_cols:
                                print(f"  {col}: {list(sdf[col])}")
                my_val = np.median(sdf[metric_key])
                if dummy:
                    best_results[dataset].append(my_val)
                    continue
                is_best = my_val + eps >= np.max(best_results[dataset])
                if len(best_results[dataset]) > 1:
                    is_secd = my_val + eps >= np.sort(best_results[dataset])[-2]
                else:
                    is_secd = False
                if show_pc:
                    my_val = my_val * 100
                latex_table += " $"
                if is_best:
                    latex_table += r"\tcf{"
                elif is_secd:
                    latex_table += r"\tcs{"
                else:
                    latex_table += "     "
                latex_table += show_fmt.format(my_val)
                latex_table += r"}" if is_best or is_secd else " "
                latex_table += "$"
            latex_table += r" \\" + "\n"
    latex_table += r"\bottomrule" + "\n"
    latex_table += r"\end{tabular}" + "\n"
    latex_table += r"}" + "\n"


if len(cmds) > 0:
    print()
for cmd in cmds:
    print(cmd)

print()
print("Done!")
print()
print(f"Here is your results table for {clusterer}:")
print()
print()
print(latex_table)

## Grouping by clusterer

In [None]:
CLUSTERERS

In [None]:
metric_key = "AMI"  # AMI  num_cluster_pred  silhouette-euclidean_pred  silhouette-og-euclidean_pred
show_pc = True
show_fmt = "{:4.0f}"
highlight_best = True
use_si_num = False
eps = 0.005
override_fields = {
    "predictions_dir": "y_pred",
    # "aggclust_dist_thresh": None,  # to flip between unknown/known n clusters for AC
}
BEST_PARAMS = BEST_PARAMS_v2

backbone = "ResNet-50"  # "ResNet-50" or "ViT-B"

if metric_key == "num_cluster_pred":
    CLUSTERERS = ["AC w/o C", "AffinityPropagation", "HDBSCAN"]
    show_pc = False
    show_fmt = "{:4.0f}"
    highlight_best = False
    use_si_num = True
else:
    CLUSTERERS = ["KMeans", "AC w/ C", "AC w/o C", "AffinityPropagation", "HDBSCAN"]
if metric_key.startswith("silhouette"):
    show_pc = False
    show_fmt = "{:5.2f}"

print(MODEL2SH)

best_results = {k: [] for k in TEST_DATASETS}
best_results_grouped = {k: defaultdict(list) for k in TEST_DATASETS}

for dummy in [True, False]:
    cmds = []
    latex_table = r"% Results for " + f"{metric_key}, {backbone}" + "\n"
    now_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    latex_table += r"% Generated " + now_str + "\n"
    label = backbone
    if metric_key == "AMI":
        latex_table += r"\label{tab:" + label + r"}" + "\n"
    label = metric_key.replace("_", "-") + ":" + label
    latex_table += r"\label{tab:" + label + r"}" + "\n"
    latex_table += r"\resizebox{\textwidth}{!}{%" + "\n"
    latex_table += r"\begin{tabular}{ll" + r"r" * len(TEST_DATASETS) + r"}" + "\n"
    latex_table += r"\toprule" + "\n"
    latex_table += r"& " + f"{'Encoder':<11s}"
    for dataset in TEST_DATASETS:
        latex_table += r"&" + "{:^15s}".format(DATASET2SH.get(dataset, dataset))
    latex_table += r"\\" + "\n"
    latex_table += r"\toprule" + "\n"
    print(MODEL_GROUPS[backbone])
    if metric_key == "num_cluster_pred":
        latex_table += r"& Num targets"
        for i_dataset, dataset in enumerate(TEST_DATASETS):
            sdf = select_rows(test_runs_df, {"dataset": dataset}, allow_missing=False)
            sdf = sdf[~pd.isna(sdf["num_cluster_true"])]
            latex_table += r"& "
            latex_table += r"\num{" if use_si_num else r"$"
            latex_table += f"{sdf.iloc[0]['num_cluster_true'].item()}"
            latex_table += r"}" if use_si_num else r"$"
        latex_table += r"\\" + "\n"
        latex_table += r"\toprule" + "\n"
    elif metric_key.endswith("_pred"):
        metric_key2 = metric_key.replace("_pred", "_true")
        clusterername = "G.T."
        latex_table += (
            r"\parbox[t]{2mm}{\multirow{5}{*}{\rotatebox[origin=c]{90}{"
            + clusterername
            + "}}}"
        )
        latex_table += "\n"
        for i_group, model in enumerate(list(MODEL_GROUPS[backbone])):
            latex_table += f"& {MODEL2SH[model]:<10s}"
            for i_dataset, dataset in enumerate(TEST_DATASETS):
                latex_table += " &"
                filter = {"model": model, "dataset": dataset}
                if model == "timm_vit_base_patch16_224.mae":
                    filter["dim_reducer"] = "PCA"
                    filter["pca_variance"] = 0.95
                else:
                    filter["dim_reducer_man"] = "UMAP"
                    filter["ndim_reduced_man"] = 50
                    filter["dim_reducer_man_metric"] = "euclidean"
                sdf = select_rows(test_runs_df, filter, allow_missing=False)
                sdf = sdf[~pd.isna(sdf[metric_key2])]
                my_val = np.nanmedian(sdf[metric_key])
                if sum(sdf[metric_key2] != my_val) > 0:
                    pass
                if dummy:
                    best_results_grouped[dataset][clusterername].append(my_val)
                    continue
                is_best_grp = my_val + eps >= np.max(
                    best_results_grouped[dataset][clusterername]
                )
                latex_table += r"\num{" if use_si_num else r"$"
                latex_table += "     "
                if not highlight_best:
                    pass
                elif is_best_grp:
                    latex_table += r"\tcg{"
                else:
                    latex_table += "     "
                latex_table += show_fmt.format(my_val)
                if highlight_best:
                    latex_table += r"}" if is_best_grp else " "
                latex_table += r"}" if use_si_num else r"$"

            latex_table += r" \\" + "\n"
        latex_table += r"\toprule" + "\n"

    first_agg = True
    for i_clusters, clusterer in enumerate(CLUSTERERS):
        clusterername = CLUSTERER2SH.get(clusterer, clusterer)
        if (
            first_agg
            and clusterer == "AgglomerativeClustering"
            and metric_key != "num_cluster_pred"
        ):
            first_agg = False
            override_fields = {"aggclust_dist_thresh": None}
            clusterername = "AC  w/ C"
        elif clusterer == "AgglomerativeClustering":
            clusterername = "AC w/o C"
            if "aggclust_dist_thresh" in override_fields:
                del override_fields["aggclust_dist_thresh"]

        if i_clusters > 0:
            latex_table += r"\midrule" + "\n"

        latex_table += (
            r"\parbox[t]{2mm}{\multirow{5}{*}{\rotatebox[origin=c]{90}{"
            + clusterername
            + "}}}"
        )
        latex_table += "\n"

        for i_group, model in enumerate(list(MODEL_GROUPS[backbone])):
            print(model)

            latex_table += f"& {MODEL2SH[model]:<10s}"
            for i_dataset, dataset in enumerate(TEST_DATASETS):
                latex_table += " &"
                filter = {
                    "model": model,
                    "dataset": dataset,
                }
                sdf = select_rows(test_runs_df, filter, allow_missing=False)
                filter2 = dict(DEFAULT_PARAMS["all"], **BEST_PARAMS[clusterer][model])
                filter2 = {k: v for k, v in filter2.items() if k not in filter}
                filter2.update(override_fields)
                sdf = select_rows(sdf, filter2, allow_missing=False)
                if len(sdf) < 1:
                    print(f"No data for {filter} {filter2}")
                    cmds.append(filter2command(filter, filter2, partition="test"))
                    if not dummy:
                        # latex_table += r"\multicolumn{1}{c}{--}"
                        latex_table += r"   --  "
                    continue
                if len(sdf) > 1:
                    perf = sdf.iloc[0]["AMI"]
                    if sum(sdf["AMI"] != perf) > 0:
                        print()
                        print("More than one result with AMIs:", list(sdf["AMI"]))
                        print(f"for search {filter}\nand {filter2}")
                        dif_cols = find_differing_columns(sdf, config_keys)
                        print(f"columns which differ: {dif_cols}")
                        if dif_cols:
                            for col in dif_cols:
                                print(f"  {col}: {list(sdf[col])}")
                my_val = np.nanmedian(sdf[metric_key])
                if dummy:
                    best_results[dataset].append(my_val)
                    best_results_grouped[dataset][clusterername].append(my_val)
                    continue
                if np.isnan(my_val):
                    latex_table += r"   --  "
                    continue
                is_best = my_val + eps >= np.max(best_results[dataset])
                if len(best_results[dataset]) > 1:
                    is_secd = my_val + eps >= np.sort(best_results[dataset])[-2]
                else:
                    is_secd = False
                is_best_grp = my_val + eps >= np.max(
                    best_results_grouped[dataset][clusterername]
                )
                if show_pc:
                    my_val = my_val * 100
                latex_table += r"\num{" if use_si_num else r"$"
                if not highlight_best:
                    pass
                elif is_best:
                    latex_table += r"\tcf{"
                elif is_secd:
                    latex_table += r"\tcs{"
                else:
                    latex_table += "     "
                if not highlight_best:
                    pass
                elif is_best_grp:
                    latex_table += r"\tcg{"
                else:
                    latex_table += "     "
                latex_table += show_fmt.format(my_val)
                if highlight_best:
                    latex_table += r"}" if is_best or is_secd else " "
                    latex_table += r"}" if is_best_grp else " "
                latex_table += r"}" if use_si_num else r"$"
            latex_table += r" \\" + "\n"
    latex_table += r"\bottomrule" + "\n"
    latex_table += r"\end{tabular}" + "\n"
    latex_table += r"}" + "\n"


if len(cmds) > 0:
    print()
for cmd in cmds:
    print(cmd)

print()
print("Done!")
print()
print(f"Here is your results table for {metric_key}, {backbone}:")
print()
print()
print(latex_table)

In [None]:
override_fields

In [None]:
BEST_PARAMS["AC w/o C"]

In [None]:
BEST_PARAMS["AgglomerativeClustering"]

In [None]:
model = "resnet50"
clusterer = "AgglomerativeClustering"
filter = {
    "model": model,
    "dataset": "imagenet",
    "clusterer": clusterer,
}
sdf = select_rows(test_runs_df, filter, allow_missing=False)

In [None]:
sdf

In [None]:
filter2 = dict(DEFAULT_PARAMS["all"], **BEST_PARAMS[clusterer][model])

In [None]:
filter2

In [None]:
filter2 = {k: v for k, v in filter2.items() if k not in filter}

In [None]:
filter2

In [None]:
override_fields

In [None]:
sdf = select_rows(sdf, filter2, allow_missing=False)
sdf

In [None]:
sdf["AMI"]

In [None]:
model = "resnet50"
clusterer = "AC w/o C"
filter = {
    "model": model,
    "dataset": "imagenet",
}
sdf = select_rows(test_runs_df, filter, allow_missing=False)
sdf

In [None]:
BEST_PARAMS[clusterer][model]

In [None]:
filter2 = dict(DEFAULT_PARAMS["all"], **BEST_PARAMS[clusterer][model])
filter2

In [None]:
override_fields

In [None]:
sdf = select_rows(sdf, filter2, allow_missing=False)
sdf

In [None]:
sdf["AMI"]

In [None]:
BEST_PARAMS_v2["AgglomerativeClustering"]

## Correlation between AMI and SIlhouette

In [None]:
best_results_grouped

In [None]:
metric_key1 = "AMI"
metric_key2 = "silhouette-euclidean_pred"
BEST_PARAMS = BEST_PARAMS_v2


CLUSTERERS = [
    "KMeans",
    "AffinityPropagation",
    "AgglomerativeClustering",
    "AgglomerativeClustering",
    "HDBSCAN",
]
print(MODEL2SH)

fig, ax = plt.subplots(1, 2, sharey=True, figsize=(5, 3))


colors = [
    "tab:blue",
    "tab:orange",
    "tab:green",
    "tab:red",
    "tab:purple",
    "tab:brown",
    "tab:pink",
    "tab:gray",
    "tab:olive",
    "tab:cyan",
]

correlations = {"ResNet-50": [], "ViT-B": []}
for i_backbone, backbone in enumerate(["ResNet-50", "ViT-B"]):
    my_valx_overall = []
    my_valy_overall = []

    my_valx_method = {clusterer: [] for clusterer in CLUSTERERS}
    my_valy_method = {clusterer: [] for clusterer in CLUSTERERS}
    best_results = {k: [] for k in TEST_DATASETS}

    for i_dataset, dataset in enumerate(TEST_DATASETS):
        my_valx = []
        my_valy = []
        first_agg = True
        for i_clusters, clusterer in enumerate(CLUSTERERS):
            clusterername = clusterer
            if first_agg and clusterer == "AgglomerativeClustering":
                first_agg = False
                override_fields = {"aggclust_dist_thresh": None}
                clusterername = "AC  w/ C"
            elif clusterer == "AgglomerativeClustering":
                override_fields = {}
                clusterername = "AC w/o C"

            for i_group, model in enumerate(list(MODEL_GROUPS[backbone])):
                if i_group == 0:
                    latex_table += (
                        r"\parbox[t]{2mm}{\multirow{5}{*}{\rotatebox[origin=c]{90}{"
                        + clusterername
                        + "}}}"
                    )
                    latex_table += "\n"

                latex_table += f"& {MODEL2SH[model]:<10s}"
                latex_table += " &"
                filter = {
                    "model": model,
                    "dataset": dataset,
                    "clusterer": clusterer,
                }
                sdf = select_rows(test_runs_df, filter, allow_missing=False)
                filter2 = dict(DEFAULT_PARAMS["all"], **BEST_PARAMS[clusterer][model])
                filter2 = {k: v for k, v in filter2.items() if k not in filter}
                filter2.update(override_fields)
                sdf = select_rows(sdf, filter2, allow_missing=False)
                if len(sdf) < 1:
                    cmds.append(filter2command(filter, filter2, partition="test"))
                    continue
                my_valx.append(np.nanmedian(sdf[metric_key1]))
                my_valy.append(np.nanmedian(sdf[metric_key2]))

                my_valx_method[clusterer].append(np.nanmedian(sdf[metric_key1]))
                my_valy_method[clusterer].append(np.nanmedian(sdf[metric_key2]))

        correlations[backbone].append(np.corrcoef(my_valx, my_valy)[0, 1])

        ax[i_backbone].scatter(
            my_valy,
            my_valx,
            color=colors[i_dataset],
            alpha=0.5,
            label=TEST_DATASETS[i_dataset],
        )
        my_valx_overall.extend(my_valx)
        my_valy_overall.extend(my_valy)
        ax[i_backbone].set_xlabel(r"$S$")
        if i_backbone == 0:
            ax[i_backbone].set_ylabel(metric_key1)
        ax[i_backbone].set_ylim(-0.05, 1.05)
        ax[i_backbone].set_xlim(-1.05, 1.05)
        ax[i_backbone].set_title(
            f"{backbone}\nPCC: {np.corrcoef(my_valx_overall, my_valy_overall)[0,1]:.2f}"
        )


label_fn = lambda c, marker: plt.plot(  # noqa:E731
    [], [], color=c, ls="None", marker=marker, linewidth=6
)[0]
handles = [label_fn(colors[idx], "o") for idx in range(len(TEST_DATASETS))]
data_labels = [DATASET2SH.get(dataset, dataset) for dataset in TEST_DATASETS]

ax[1].legend(handles, data_labels, loc="center left", bbox_to_anchor=(1, 0.5))

print(data_labels)
print(correlations["ResNet-50"], len(correlations["ResNet-50"]))
print(correlations["ViT-B"], len(correlations["ViT-B"]))

fig.savefig("ami_silhouette.pdf", bbox_inches="tight")

In [None]:
metric_key1 = "AMI"
metric_key2 = "silhouette-og-euclidean_pred"
BEST_PARAMS = BEST_PARAMS_v2

CLUSTERERS = [
    "KMeans",
    "AffinityPropagation",
    "AgglomerativeClustering",
    "AgglomerativeClustering",
    "HDBSCAN",
]
print(MODEL2SH)

fig, ax = plt.subplots(1, 2, sharey=True, figsize=(5.5, 3))

colors = [
    "tab:red",
    "tab:blue",
    "tab:orange",
    "tab:green",
    "tab:purple",
    "tab:brown",
    "tab:pink",
    "tab:gray",
    "tab:olive",
    "tab:cyan",
]

correlations = {"ResNet-50": [], "ViT-B": []}
for i_backbone, backbone in enumerate(["ResNet-50", "ViT-B"]):
    my_valx_overall = []
    my_valy_overall = []
    best_results = {k: [] for k in TEST_DATASETS}

    for i_dataset, dataset in enumerate(TEST_DATASETS):
        my_valx = []
        my_valy = []
        first_agg = True
        for i_clusters, clusterer in enumerate(CLUSTERERS):
            clusterername = clusterer
            if first_agg and clusterer == "AgglomerativeClustering":
                first_agg = False
                override_fields = {"aggclust_dist_thresh": None}
                clusterername = "AC  w/ C"
            elif clusterer == "AgglomerativeClustering":
                override_fields = {}
                clusterername = "AC w/o C"

            for i_group, model in enumerate(list(MODEL_GROUPS[backbone])):
                if i_group == 0:
                    latex_table += (
                        r"\parbox[t]{2mm}{\multirow{5}{*}{\rotatebox[origin=c]{90}{"
                        + clusterername
                        + "}}}"
                    )
                    latex_table += "\n"

                latex_table += f"& {MODEL2SH[model]:<10s}"
                latex_table += " &"
                filter = {
                    "model": model,
                    "dataset": dataset,
                    "clusterer": clusterer,
                }
                sdf = select_rows(test_runs_df, filter, allow_missing=False)
                filter2 = dict(DEFAULT_PARAMS["all"], **BEST_PARAMS[clusterer][model])
                filter2 = {k: v for k, v in filter2.items() if k not in filter}
                filter2.update(override_fields)
                sdf = select_rows(sdf, filter2, allow_missing=False)
                if len(sdf) < 1:
                    cmds.append(filter2command(filter, filter2, partition="test"))
                    continue
                my_valx.append(np.nanmedian(sdf[metric_key1]))
                my_valy.append(np.nanmedian(sdf[metric_key2]))

        correlations[backbone].append(np.corrcoef(my_valx, my_valy)[0, 1])

        ax[i_backbone].scatter(
            my_valy,
            my_valx,
            color=colors[i_dataset],
            alpha=0.5,
            s=8,
            label=TEST_DATASETS[i_dataset],
        )
        my_valx_overall.extend(my_valx)
        my_valy_overall.extend(my_valy)

    ax[i_backbone].set_xlabel(r"$S$")
    if i_backbone == 0:
        ax[i_backbone].set_ylabel(metric_key1)
    ax[i_backbone].set_ylim(-0.05, 1.05)
    ax[i_backbone].set_xlim(-1.05, 1.05)
    ax[i_backbone].set_title(backbone)
    my_valx_overall = np.array(my_valx_overall)
    my_valy_overall = np.array(my_valy_overall)
    select = ~(np.isnan(my_valx_overall) | np.isnan(my_valy_overall))
    cor = np.corrcoef(my_valx_overall[select], my_valy_overall[select])
    ax[i_backbone].text(-0.85, 0.95, f"$r={cor[0,1]:.2f}$")


label_fn = lambda c, marker: plt.plot(  # noqa:E731
    [], [], color=c, ls="None", marker=marker, linewidth=6
)[0]
handles = [label_fn(colors[idx], "o") for idx in range(len(TEST_DATASETS))]
data_labels = [DATASET2SH.get(dataset, dataset) for dataset in TEST_DATASETS]

ax[1].legend(handles, data_labels, loc="center left", bbox_to_anchor=(1, 0.5))

print(data_labels)
print(correlations["ResNet-50"], len(correlations["ResNet-50"]))
print(correlations["ViT-B"], len(correlations["ViT-B"]))

fig.savefig(
    f"{metric_key1}_{metric_key2.replace('-euclidean', '')}.pdf", bbox_inches="tight"
)

In [None]:
cor

In [None]:
metric_key1 = "AMI"
metric_key2 = "silhouette-euclidean_pred"
BEST_PARAMS = BEST_PARAMS_v2

CLUSTERERS = [
    "KMeans",
    "AffinityPropagation",
    "AgglomerativeClustering",
    "AgglomerativeClustering",
    "HDBSCAN",
]
print(MODEL2SH)

figenc, axenc = plt.subplots(1, 2, figsize=(6, 2))
figclus, axclus = plt.subplots(1, 2, figsize=(6, 2))

for i_backbone, backbone in enumerate(["ResNet-50", "ViT-B"]):
    result_table = np.zeros(
        (5, len(CLUSTERERS), len(TEST_DATASETS))
    )  # Encoders, clusteres, dataset
    for dummy in [True, False]:
        cmds = []

        for i_group, model in enumerate(list(MODEL_GROUPS[backbone])):
            first_agg = True
            for i_clusters, clusterer in enumerate(CLUSTERERS):
                clusterername = clusterer
                if first_agg and clusterer == "AgglomerativeClustering":
                    first_agg = False
                    override_fields = {"aggclust_dist_thresh": None}
                    clusterername = "Agg  w/ C"
                elif clusterer == "AgglomerativeClustering":
                    override_fields = {}
                    clusterername = "Agg w/o C"

                for i_dataset, dataset in enumerate(TEST_DATASETS):
                    latex_table += " &"
                    filter = {
                        "model": model,
                        "dataset": dataset,
                        "clusterer": clusterer,
                    }
                    sdf = select_rows(test_runs_df, filter, allow_missing=False)
                    filter2 = dict(
                        DEFAULT_PARAMS["all"], **BEST_PARAMS[clusterer][model]
                    )
                    filter2 = {k: v for k, v in filter2.items() if k not in filter}
                    filter2.update(override_fields)
                    sdf = select_rows(sdf, filter2, allow_missing=False)
                    if len(sdf) < 1:
                        cmds.append(filter2command(filter, filter2, partition="test"))
                        result_table[i_group, i_clusters, i_dataset] = -100.0
                        continue
                    result_table[i_group, i_clusters, i_dataset] = np.median(
                        sdf[metric_key1]
                    )

    print(result_table[0])

    print(backbone)
    print(MODEL_GROUPS[backbone])
    CLUSTERERS2 = ["K-Means", "Affinity Prop", "Agg w/ C", "Agg w/o C", "HDBSCAN"]
    colors = ["tab:blue", "tab:orange", "tab:red", "tab:green", "tab:olive", "tab:cyan"]

    encoder_to_color = {}
    cluster_to_color = {
        CLUSTERERS2[idx]: colors[idx] for idx in range(len(CLUSTERERS2))
    }

    for model in list(MODEL_GROUPS[backbone]):
        if model == "resnet50" or model == "vitb16":
            encoder_to_color[model] = colors[0]
        if "mae" in model:
            encoder_to_color[model] = colors[1]
        if "vicreg" in model:
            encoder_to_color[model] = colors[2]
        if "clip" in model:
            encoder_to_color[model] = colors[3]
        if "moco" in model:
            encoder_to_color[model] = colors[4]
        if "dino" in model:
            encoder_to_color[model] = colors[5]

    print(encoder_to_color)
    rank_tmp = np.asarray([1, 2, 3, 4, 5])
    # RANK PER ENCODER - go through each dataset, look at each clusterer,
    # and determine the rank of each encoder in that setting
    print(list(MODEL_GROUPS[backbone]))
    ranks_encoders = np.zeros((5, len(CLUSTERERS), len(TEST_DATASETS)))
    for i_dataset in range(len(TEST_DATASETS)):
        for i_clusters in range(len(CLUSTERERS)):
            cluster_data = result_table[:, i_clusters, i_dataset]
            rank = np.argsort(cluster_data)[::-1]
            ranks_encoders[:, i_clusters, i_dataset] = rank_tmp[rank.argsort()]
    mean_rank_encoders = np.mean(ranks_encoders, axis=(1, 2))
    std_rank_encoders = np.std(ranks_encoders, axis=(1, 2))
    order = [
        (
            list(MODEL_GROUPS[backbone])[idx],
            mean_rank_encoders[idx],
            std_rank_encoders[idx],
        )
        for idx in np.argsort(mean_rank_encoders)
    ]

    for idx, model in enumerate(order[::-1]):
        axenc[i_backbone].barh(
            idx,
            model[1],
            xerr=model[2],
            align="center",
            alpha=0.6,
            ecolor="black",
            color=encoder_to_color[model[0]],
            capsize=2,
            zorder=10,
        )

    axenc[i_backbone].set_yticks([])
    axenc[i_backbone].set_yticklabels([])
    axenc[i_backbone].set_xticks([1, 2, 3, 4, 5])
    axenc[i_backbone].set_xticklabels([1, 2, 3, 4, 5])
    axenc[i_backbone].xaxis.grid(True, zorder=1, alpha=0.5)
    axenc[i_backbone].set_title(f"{backbone}")

    # RANK PER CLUSTERER - go through each dataset, look at each encoder,
    # and determine the rank of each clusterer in that setting

    print(CLUSTERERS2)
    ranks_clusterers = np.zeros((5, len(CLUSTERERS2), len(TEST_DATASETS)))
    for i_dataset in range(len(TEST_DATASETS)):
        for i_encoder in range(len(list(MODEL_GROUPS[backbone]))):
            encoder_data = result_table[i_encoder, :, i_dataset]
            rank = np.argsort(encoder_data)[::-1]
            ranks_clusterers[i_encoder, :, i_dataset] = rank_tmp[rank.argsort()]
    mean_rank_clusters = np.mean(ranks_clusterers, axis=(0, 2))
    std_rank_clusters = np.std(ranks_clusterers, axis=(0, 2))
    order = [
        (CLUSTERERS2[idx], mean_rank_clusters[idx], std_rank_clusters[idx])
        for idx in np.argsort(mean_rank_clusters)
    ]

    for idx, model in enumerate(order[::-1]):
        axclus[i_backbone].barh(
            idx,
            model[1],
            xerr=model[2],
            align="center",
            alpha=0.6,
            ecolor="black",
            color=cluster_to_color[model[0]],
            capsize=2,
            zorder=10,
        )

    axclus[i_backbone].set_yticks([])
    axclus[i_backbone].set_yticklabels([])
    axclus[i_backbone].set_xticks([1, 2, 3, 4, 5])
    axclus[i_backbone].set_xticklabels([1, 2, 3, 4, 5])
    axclus[i_backbone].xaxis.grid(True, zorder=1, alpha=0.5)
    axclus[i_backbone].set_title(f"{backbone}")

    axclus[i_backbone].set_xlabel("Rank")
    axenc[i_backbone].set_xlabel("Rank")

    print(order)


encoder_to_color["vicreg_resnet50"] = colors[2]

label_fn = lambda c, ls: plt.plot([], [], color=c, ls=ls, linewidth=3)[0]  # noqa:E731
handles_clus = [label_fn(cluster_to_color[idx], "-") for idx in CLUSTERERS2]
handles_enc = [
    label_fn(encoder_to_color[idx], "-")
    for idx in list(MODEL_GROUPS[backbone]) + ["vicreg_resnet50"]
]

axenc[1].legend(
    handles_enc,
    [MODEL2SH[x] for x in list(MODEL_GROUPS[backbone]) + ["vicreg_resnet50"]],
    loc="center left",
    bbox_to_anchor=(1, 0.5),
)
axclus[1].legend(handles_clus, CLUSTERERS2, loc="center left", bbox_to_anchor=(1, 0.5))

figenc.savefig("ranking_enc.pdf", bbox_inches="tight")
figclus.savefig("ranking_clus.pdf", bbox_inches="tight")