# Globals

In [None]:
import copy
import datetime
import os
from collections import defaultdict

import matplotlib.colors
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import wandb
from IPython.display import display
from tqdm.autonotebook import tqdm

In [None]:
FIGS_DIR = "figs"
os.makedirs(FIGS_DIR, exist_ok=True)

In [None]:
VALIDATION_DATASETS = ["imagenet", "imagenette", "imagewoof"]
RESNET50_MODELS = [
    "random_resnet50",
    "resnet50",
    "mocov3_resnet50",
    "dino_resnet50",
    "vicreg_resnet50",
    "clip_RN50",
]
VITB16_MODELS = [
    "random_vitb16",
    "vitb16",
    "mocov3_vit_base",
    "dino_vitb16",
    "timm_vit_base_patch16_224.mae",
    "mae_pretrain_vit_base_global",
    "clip_vitb16",
]
FT_RESNET50_MODELS = [
    "ft_mocov3_resnet50",
    "ft_dino_resnet50",
    "ft_vicreg_resnet50",
]
FT_VITB16_MODELS = [
    "ft_mocov3_vit_base",
    "ft_dino_vitb16",
    "mae_finetuned_vit_base_global",
]
FT_MODELS = FT_RESNET50_MODELS + FT_VITB16_MODELS
ALL_MODELS = ["none"] + RESNET50_MODELS + VITB16_MODELS + FT_RESNET50_MODELS + FT_VITB16_MODELS
CLUSTERERS = [
    "KMeans",
    "AgglomerativeClustering",
    "AffinityPropagation",
    "SpectralClustering",
    "HDBSCAN",
    "OPTICS",
]
ALL_CLUSTERERS = copy.deepcopy(CLUSTERERS)
DISTANCE_METRICS = [
    "euclidean",
    "l1",
    "chebyshev",
    "cosine",
    "arccos",
    "braycurtis",
    "canberra",
]

In [None]:
DATASET2LS = {
    "imagenet": "-.",
    "imagenette": "--",
    "imagewoof": ":",
}

In [None]:
DEFAULT_PARAMS = {
    "all": {
        "dim_reducer": "None",
        "dim_reducer_man": "None",
        "zscore": False,
        "normalize": False,
        "zscore2": False,
        "ndim_correction": False,
    },
    "KMeans": {"clusterer": "KMeans"},
    "AffinityPropagation": {
        "clusterer": "AffinityPropagation",
        "affinity_damping": 0.9,
        "affinity_conv_iter": 15,
    },
    "SpectralClustering": {
        "clusterer": "SpectralClustering",
        "spectral_assigner": "cluster_qr",
    },
    "AgglomerativeClustering": {
        "clusterer": "AgglomerativeClustering",
        "distance_metric": "euclidean",
        "aggclust_linkage": "ward",
    },
    "HDBSCAN": {
        "clusterer": "HDBSCAN",
        "hdbscan_method": "eom",
        "min_samples": 5,
        "max_samples": 0.2,
        "distance_metric": "euclidean",
    },
    "OPTICS": {
        "clusterer": "OPTICS",
        "optics_method": "xi",
        "optics_xi": 0.05,
        "distance_metric": "euclidean",
    },
}

## Set best params

These were discovered by the search in hpsearch.ipynb.

### Num dims

In [None]:
models = RESNET50_MODELS + VITB16_MODELS
BEST_PARAMS = {
    clusterer: {model: copy.deepcopy(DEFAULT_PARAMS[clusterer]) for model in models}
    for clusterer in ALL_CLUSTERERS
}

# KMeans
# Use UMAP (num dims unimportant; we select 50d for consistency) for every encoder except
# - clip_RN50 : a little better to use PCA with 500d than UMAP. UMAP beats PCA if you
#   reduce the PCA dims below 500.
# - clip_vitb16 : same behaviour as clip_RN50
# - timm_vit_base_patch16_224.mae : best is PCA 0.85 variance explained. Need at least
#   200 PCA dims, and PCA perf beats UMAP throughout

for model in RESNET50_MODELS + VITB16_MODELS:
    if model.startswith("clip") or model == "timm_vit_base_patch16_224.mae":
        continue
    BEST_PARAMS["KMeans"][model].update(
        {"dim_reducer_man": "UMAP", "ndim_reduced_man": 50}
    )

BEST_PARAMS["KMeans"]["clip_RN50"].update(
    {"dim_reducer": "PCA", "ndim_reduced": 500, "zscore": True, "pca_variance": None}
)
BEST_PARAMS["KMeans"]["clip_vitb16"].update(
    {"dim_reducer": "PCA", "ndim_reduced": 500, "zscore": True, "pca_variance": None}
)
BEST_PARAMS["KMeans"]["timm_vit_base_patch16_224.mae"].update(
    {"dim_reducer": "PCA", "pca_variance": 0.85, "zscore": True, "ndim_reduced": None}
)

# AffinityPropagation
# Use PCA with 10 dims for every encoder except
# - resnet50 (supervised) : original embeddings, no reduction (AMI=0.62);
#   perf gets worse if they are whitened (AMI=0.55) and although the perf increases
#   as num dims are reduced it doesn't quite recover. PCA perf peaks at 10-20 dim (AMI=0.57).
# - dino_resnet50 : does marginally better at UMAP 50 (AMI=0.52495) than PCA 10 (AMI=0.5044)
# - timm_vit_base_patch16_224.mae : PCA 0.95 variance explained (AMI=0.303).
#   Definite improvement from 10 to 20 dims, but not much improvement above that.

for model in models:
    if model in ["resnet50", "dino_resnet50", "timm_vit_base_patch16_224.mae"]:
        continue
    BEST_PARAMS["AffinityPropagation"][model].update(
        {
            "dim_reducer": "PCA",
            "ndim_reduced": 10,
            "zscore": True,
            "pca_variance": None,
            "dim_reducer_man": "None",
        }
    )

BEST_PARAMS["AffinityPropagation"]["resnet50"].update(
    {"dim_reducer": "None", "dim_reducer_man": "None", "zscore": False}
)
BEST_PARAMS["AffinityPropagation"]["dino_resnet50"].update(
    {
        "dim_reducer": "PCA",
        "pca_variance": 0.95,
        "zscore": True,
        "ndim_reduced": None,
        "dim_reducer_man": "None",
    }
)
BEST_PARAMS["AffinityPropagation"]["timm_vit_base_patch16_224.mae"].update(
    {
        "dim_reducer": "PCA",
        "pca_variance": 0.95,
        "zscore": True,
        "ndim_reduced": None,
        "dim_reducer_man": "None",
    }
)

# AgglomerativeClustering
# Use UMAP (num dims unimportant; we select 50d for consistency) for every encoder except
# - timm_vit_base_patch16_224.mae : PCA 0.98 variance explained (i.e. nearly all
#   dimensions kept), which is not noticably better than using 500 dim PCA but there is
#   an increase compared to using less than 500d.

for model in models:
    if model == "timm_vit_base_patch16_224.mae":
        continue
    BEST_PARAMS["AgglomerativeClustering"][model].update(
        {"dim_reducer_man": "UMAP", "ndim_reduced_man": 50, "dim_reducer": "None"}
    )

BEST_PARAMS["AgglomerativeClustering"]["timm_vit_base_patch16_224.mae"].update(
    {
        "dim_reducer": "PCA",
        "pca_variance": 0.98,
        "zscore": True,
        "ndim_reduced": None,
        "dim_reducer_man": "None",
    }
)

# HDBSCAN
# Use UMAP for every encoder except
# - timm_vit_base_patch16_224.mae : PCA 0.95 variance explained (AMI=0.085) which is
#   not noticably better than PCA with 50 dim

for model in models:
    if model in ["timm_vit_base_patch16_224.mae"]:
        continue
    BEST_PARAMS["HDBSCAN"][model].update(
        {"dim_reducer_man": "UMAP", "ndim_reduced_man": 50, "dim_reducer": "None"}
    )

BEST_PARAMS["HDBSCAN"]["timm_vit_base_patch16_224.mae"].update(
    {
        "dim_reducer": "PCA",
        "pca_variance": 0.95,
        "zscore": True,
        "ndim_reduced": None,
        "dim_reducer_man": "None",
    }
)

# OPTICS
# Use UMAP for every encoder, no exceptions necessary
for model in models:
    BEST_PARAMS["OPTICS"][model].update(
        {"dim_reducer_man": "UMAP", "ndim_reduced_man": 50, "dim_reducer": "None"}
    )

In [None]:
BEST_PARAMS_v1 = copy.deepcopy(BEST_PARAMS)
BEST_PARAMS_v1["_version"] = "v1.0"

In [None]:
BEST_PARAMS_v2 = copy.deepcopy(BEST_PARAMS)
BEST_PARAMS_v2["_version"] = "v2.0"

print("Updating dim choices for new method")
# Updated dim choices
# (changed to this when we swapped to using weighted average instead of straight
# average between Imagenet-1k, Imagenette, Imagewoof)

# Changed KMeans clip_RN50 from PCA 500 to UMAP 50, so it uses fewer dimensions
# (probably more stable than using 500-d which is what PCA needs to marginally beat UMAP)
BEST_PARAMS_v2["KMeans"]["clip_RN50"].update(
    {"dim_reducer": None, "ndim_reduced": None, "zscore": False, "pca_variance": None}
)
BEST_PARAMS_v2["KMeans"]["clip_RN50"].update(
    {"dim_reducer_man": "UMAP", "ndim_reduced_man": 50}
)
# Changed KMeans MAE from PCA 85% to PCA 200
# (since we see perf above plateaus at 200-d, there is no point going above that)
BEST_PARAMS_v2["KMeans"]["timm_vit_base_patch16_224.mae"].update(
    {"dim_reducer": "PCA", "zscore": True, "ndim_reduced": 200, "pca_variance": None}
)
# Changed KMeans clip_vitb16 from PCA 500 to PCA 75%
# (gives a notably better train set AMI measurement above)
BEST_PARAMS_v2["KMeans"]["clip_vitb16"].update(
    {"dim_reducer": "PCA", "zscore": True, "pca_variance": 0.75, "ndim_reduced": None}
)

# Changed AffinityPropagation dino_resnet50 from PCA 95% to PCA 10
# (performance is basically equal, so no point using higher-dim space;
# could have done UMAP 50 instead with basically equal train AMI to PCA 10,
# but didn't for consistency with other models)
BEST_PARAMS_v2["AffinityPropagation"]["dino_resnet50"].update(
    {"dim_reducer": "PCA", "zscore": True, "ndim_reduced": 10, "pca_variance": None}
)
# Changed AffinityPropagation MAE from PCA 95% to PCA 100
BEST_PARAMS_v2["AffinityPropagation"]["timm_vit_base_patch16_224.mae"].update(
    {"dim_reducer": "PCA", "zscore": True, "ndim_reduced": 100, "pca_variance": None}
)

In [None]:
print(
    "Updating dim choices to use Affinity Prop dim results found with 0.9 damping,"
    " prefering PCA reduction by percentage variance explained"
)
BEST_PARAMS_v3 = {
    clusterer: {model: copy.deepcopy(DEFAULT_PARAMS[clusterer]) for model in ALL_MODELS}
    for clusterer in ALL_CLUSTERERS
}
BEST_PARAMS_v3["_version"] = "v3.0"

# KMeans
for model in RESNET50_MODELS + VITB16_MODELS + FT_MODELS:
    if (
        model == "none"
        or model.startswith("random")
        or model.startswith("clip")
        or model == "timm_vit_base_patch16_224.mae"
    ):
        continue
    BEST_PARAMS_v3["KMeans"][model].update(
        {"dim_reducer_man": "UMAP", "ndim_reduced_man": 50}
    )

BEST_PARAMS_v3["KMeans"]["none"].update(
    {"image_size": 32, "dim_reducer": "PCA", "pca_variance": 0.98, "zscore": True}
)
BEST_PARAMS_v3["KMeans"]["random_resnet50"].update(
    {"dim_reducer": "PCA", "pca_variance": 0.95, "zscore": True}
)
BEST_PARAMS_v3["KMeans"]["random_vitb16"].update(
    {"dim_reducer": "PCA", "ndim_reduced": 100, "zscore": True}
)

BEST_PARAMS_v3["KMeans"]["clip_RN50"].update(
    {"dim_reducer": "PCA", "pca_variance": 0.85, "zscore": True}
)
BEST_PARAMS_v3["KMeans"]["clip_vitb16"].update(
    {"dim_reducer": "PCA", "pca_variance": 0.75, "zscore": True}
)
BEST_PARAMS_v3["KMeans"]["timm_vit_base_patch16_224.mae"].update(
    {"dim_reducer": "PCA", "pca_variance": 0.85, "zscore": True}
)

# AffinityPropagation
for model in ALL_MODELS:
    BEST_PARAMS_v3["AffinityPropagation"][model].update({"affinity_damping": 0.9})

for model in (
    [
        "resnet50",
        "clip_RN50",
        "vitb16",
        "mocov3_vit_base",
        "mae_pretrain_vit_base_global",
        "dino_vitb16",
        "clip_vitb16",
    ] + FT_MODELS
):
    BEST_PARAMS_v3["AffinityPropagation"][model].update(
        {"dim_reducer_man": "UMAP", "ndim_reduced_man": 50}
    )
for model in ["mocov3_resnet50", "vicreg_resnet50", "dino_resnet50"]:
    BEST_PARAMS_v3["AffinityPropagation"][model].update(
        {"dim_reducer_man": "PaCMAP", "ndim_reduced_man": 50, "dim_reducer_man_nn": None}
    )

BEST_PARAMS_v3["AffinityPropagation"]["none"].update(
    {"image_size": 32, "dim_reducer": "PCA", "pca_variance": 0.8, "zscore": True}
)
BEST_PARAMS_v3["AffinityPropagation"]["random_resnet50"].update(
    {"dim_reducer": "PCA", "pca_variance": 0.99, "zscore": True}
)
BEST_PARAMS_v3["AffinityPropagation"]["random_vitb16"].update(
    {"dim_reducer": "PCA", "pca_variance": 0.98, "zscore": True}
)

BEST_PARAMS_v3["KMeans"]["timm_vit_base_patch16_224.mae"].update(
    {"dim_reducer": "PCA", "pca_variance": 0.99, "zscore": True}
)

# AgglomerativeClustering
for model in ALL_MODELS:
    if (
        model == "none"
        or model.startswith("random")
        or model == "timm_vit_base_patch16_224.mae"
    ):
        continue
    BEST_PARAMS_v3["AgglomerativeClustering"][model].update(
        {"dim_reducer_man": "UMAP", "ndim_reduced_man": 50, "dim_reducer": "None"}
    )

BEST_PARAMS_v3["AgglomerativeClustering"]["none"].update(
    {"image_size": 32, "dim_reducer": "PCA", "pca_variance": 0.75, "zscore": True}
)
BEST_PARAMS_v3["AgglomerativeClustering"]["random_resnet50"].update(
    {"dim_reducer": "PCA", "pca_variance": 0.98, "zscore": True}
)
BEST_PARAMS_v3["AgglomerativeClustering"]["random_vitb16"].update(
    {"dim_reducer": "PCA", "pca_variance": 0.85, "zscore": True}
)
BEST_PARAMS_v3["AgglomerativeClustering"]["timm_vit_base_patch16_224.mae"].update(
    {"dim_reducer": "PCA", "pca_variance": 0.98, "zscore": True}
)

# HDBSCAN
for model in ALL_MODELS:
    if model in ["timm_vit_base_patch16_224.mae"]:
        continue
    BEST_PARAMS_v3["HDBSCAN"][model].update(
        {"dim_reducer_man": "UMAP", "ndim_reduced_man": 50, "dim_reducer": "None"}
    )

BEST_PARAMS_v3["HDBSCAN"]["none"].update(
    {"image_size": 32}
)
BEST_PARAMS_v3["HDBSCAN"]["timm_vit_base_patch16_224.mae"].update(
    {"dim_reducer": "PCA", "pca_variance": 0.95, "zscore": True}
)

# OPTICS - TODO
# Use UMAP for every encoder, no exceptions necessary (not checked raw or random)
for model in ALL_MODELS:
    BEST_PARAMS_v3["OPTICS"][model].update(
        {"dim_reducer_man": "UMAP", "ndim_reduced_man": 50, "dim_reducer": "None"}
    )

In [None]:
print(
    "Updating dim choices to use Affinity Prop dim results found with 0.9 damping,"
    " stop PCA at 95%"
)
BEST_PARAMS_v4 = {
    clusterer: {model: copy.deepcopy(DEFAULT_PARAMS[clusterer]) for model in ALL_MODELS}
    for clusterer in ALL_CLUSTERERS
}
BEST_PARAMS_v4["_version"] = "v4.0"
for clusterer in BEST_PARAMS_v4:
    if clusterer.startswith("_"):
        continue
    BEST_PARAMS_v4[clusterer]["none"].update({"image_size": 32})

# KMeans
for model in RESNET50_MODELS + VITB16_MODELS + FT_MODELS:
    if (
        model == "none"
        or model.startswith("random")
        or model.startswith("clip")
        or model == "timm_vit_base_patch16_224.mae"
        or model == "mae_pretrain_vit_base_global"
    ):
        continue
    BEST_PARAMS_v4["KMeans"][model].update(
        {"dim_reducer_man": "UMAP", "ndim_reduced_man": 50}
    )

BEST_PARAMS_v4["KMeans"]["none"].update(
    {"dim_reducer": "PCA", "pca_variance": 0.90, "zscore": True}
)
BEST_PARAMS_v4["KMeans"]["random_resnet50"].update(
    {"dim_reducer": "PCA", "pca_variance": 0.95, "zscore": True}
)
BEST_PARAMS_v4["KMeans"]["random_vitb16"].update(
    {"dim_reducer": "PCA", "ndim_reduced": 100, "zscore": True}
)

BEST_PARAMS_v4["KMeans"]["clip_RN50"].update(
    {"dim_reducer": "PCA", "pca_variance": 0.85, "zscore": True}
)
BEST_PARAMS_v4["KMeans"]["clip_vitb16"].update(
    {"dim_reducer": "PCA", "pca_variance": 0.75, "zscore": True}
)
BEST_PARAMS_v4["KMeans"]["timm_vit_base_patch16_224.mae"].update(
    {"dim_reducer": "PCA", "pca_variance": 0.95, "zscore": True}
)
BEST_PARAMS_v4["KMeans"]["mae_pretrain_vit_base_global"].update(
    {"dim_reducer": "PCA", "pca_variance": 0.9, "zscore": True}
)

# AffinityPropagation
for model in ALL_MODELS:
    BEST_PARAMS_v4["AffinityPropagation"][model].update({"affinity_damping": 0.9})

for model in (
    [
        "resnet50",
        "clip_RN50",
        "vitb16",
        "mocov3_vit_base",
        "mae_pretrain_vit_base_global",
        "dino_vitb16",
        "clip_vitb16",
    ] + FT_MODELS
):
    BEST_PARAMS_v4["AffinityPropagation"][model].update(
        {"dim_reducer_man": "UMAP", "ndim_reduced_man": 50}
    )
for model in ["mocov3_resnet50", "vicreg_resnet50", "dino_resnet50"]:
    # tbc
    BEST_PARAMS_v4["AffinityPropagation"][model].update(
        {"dim_reducer_man": "UMAP", "ndim_reduced_man": 50, "dim_reducer_man_nn": None}
    )

BEST_PARAMS_v4["AffinityPropagation"]["none"].update(
    {"dim_reducer": "PCA", "pca_variance": 0.8, "zscore": True}
)
BEST_PARAMS_v4["AffinityPropagation"]["random_resnet50"].update(
    {"dim_reducer": "PCA", "pca_variance": 0.9, "zscore": True}
)
BEST_PARAMS_v4["AffinityPropagation"]["random_vitb16"].update(
    {"dim_reducer": "PCA", "pca_variance": 0.9, "zscore": True}
)
BEST_PARAMS_v4["AffinityPropagation"]["timm_vit_base_patch16_224.mae"].update(
    {"dim_reducer": "PCA", "ndim_reduced": 200, "zscore": True}
)

# AgglomerativeClustering
for model in ALL_MODELS:
    if (
        model == "none"
        or model.startswith("random")
        or model == "timm_vit_base_patch16_224.mae"
        or model == "mae_pretrain_vit_base_global"
    ):
        continue
    BEST_PARAMS_v4["AgglomerativeClustering"][model].update(
        {"dim_reducer_man": "UMAP", "ndim_reduced_man": 50, "dim_reducer": "None"}
    )

BEST_PARAMS_v4["AgglomerativeClustering"]["none"].update(
    {"dim_reducer": "PCA", "ndim_reduced": 200, "zscore": True}
)
BEST_PARAMS_v4["AgglomerativeClustering"]["random_resnet50"].update(
    {"dim_reducer": "PCA", "ndim_reduced": 200, "zscore": True}
)
BEST_PARAMS_v4["AgglomerativeClustering"]["random_vitb16"].update(
    {"dim_reducer": "PCA", "pca_variance": 0.85, "zscore": True}
)
BEST_PARAMS_v4["AgglomerativeClustering"]["timm_vit_base_patch16_224.mae"].update(
    {"dim_reducer": "PCA", "pca_variance": 0.90, "zscore": True}
)
BEST_PARAMS_v4["AgglomerativeClustering"]["mae_pretrain_vit_base_global"].update(
    {"dim_reducer": "PCA", "pca_variance": 0.85, "zscore": True}
)

# HDBSCAN
for model in ALL_MODELS:
    if model in ["timm_vit_base_patch16_224.mae"]:
        continue
    BEST_PARAMS_v4["HDBSCAN"][model].update(
        {"dim_reducer_man": "UMAP", "ndim_reduced_man": 50, "dim_reducer": "None"}
    )

BEST_PARAMS_v4["HDBSCAN"]["timm_vit_base_patch16_224.mae"].update(
    {"dim_reducer": "PCA", "pca_variance": 0.95, "zscore": True}
)

# OPTICS - TODO
# Use UMAP for every encoder, no exceptions necessary (not checked raw or random)
for model in ALL_MODELS:
    BEST_PARAMS_v4["OPTICS"][model].update(
        {"dim_reducer_man": "UMAP", "ndim_reduced_man": 50, "dim_reducer": "None"}
    )

### Agglomerative specific settings

In [None]:
for model in [
    "resnet50",
    "mocov3_resnet50",
    "vicreg_resnet50",
    "vitb16",
    "timm_vit_base_patch16_224.mae",
]:
    BEST_PARAMS_v1["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "euclidean",
            "aggclust_linkage": "ward",
        }
    )
for model in ["dino_resnet50", "clip_RN50", "dino_vitb16"]:
    BEST_PARAMS_v1["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "euclidean",
            "aggclust_linkage": "average",
        }
    )
for model in ["mocov3_vit_base", "clip_vitb16"]:
    BEST_PARAMS_v1["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "chebyshev",
            "aggclust_linkage": "average",
        }
    )

In [None]:
# vicreg_resnet50 is the only change from v1 to v2
for model in ["resnet50", "mocov3_resnet50", "vitb16", "timm_vit_base_patch16_224.mae"]:
    BEST_PARAMS_v2["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "euclidean",
            "aggclust_linkage": "ward",
        }
    )
for model in ["vicreg_resnet50", "dino_resnet50", "clip_RN50", "dino_vitb16"]:
    BEST_PARAMS_v2["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "euclidean",
            "aggclust_linkage": "average",
        }
    )
for model in ["mocov3_vit_base", "clip_vitb16"]:
    BEST_PARAMS_v2["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "chebyshev",
            "aggclust_linkage": "average",
        }
    )

In [None]:
for model in ["none", "resnet50", "mocov3_resnet50", "vitb16"] + FT_MODELS:
    BEST_PARAMS_v3["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "euclidean",
            "aggclust_linkage": "ward",
        }
    )
for model in ["vicreg_resnet50", "dino_resnet50", "clip_RN50", "dino_vitb16"]:
    BEST_PARAMS_v3["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "euclidean",
            "aggclust_linkage": "average",
        }
    )
for model in ["mocov3_vit_base", "clip_vitb16", "random_resnet50", "random_vitb16"]:
    BEST_PARAMS_v3["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "chebyshev",
            "aggclust_linkage": "average",
        }
    )
for model in ["timm_vit_base_patch16_224.mae"]:
    BEST_PARAMS_v3["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "cosine",
            "aggclust_linkage": "average",
        }
    )

In [None]:
# TODO:
# - mae_pretrain_vit_base_global
# - clip_vitb16 (leaving as-is for now)
for model in ALL_MODELS:
    BEST_PARAMS_v4["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "tbd",
            "aggclust_linkage": "tbd",
        }
    )

for model in ["resnet50", "mocov3_resnet50", "vitb16"] + FT_MODELS:
    BEST_PARAMS_v4["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "euclidean",
            "aggclust_linkage": "ward",
        }
    )
for model in ["vicreg_resnet50", "dino_resnet50", "clip_RN50", "dino_vitb16"]:
    BEST_PARAMS_v4["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "euclidean",
            "aggclust_linkage": "average",
        }
    )
for model in ["mocov3_vit_base", "clip_vitb16", "random_resnet50", "random_vitb16"]:
    BEST_PARAMS_v4["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "chebyshev",
            "aggclust_linkage": "average",
        }
    )
for model in ["none", "timm_vit_base_patch16_224.mae", "mae_pretrain_vit_base_global"]:
    BEST_PARAMS_v4["AgglomerativeClustering"][model].update(
        {
            "distance_metric": "cosine",
            "aggclust_linkage": "average",
        }
    )

In [None]:
BEST_PARAMS_v1["AC w/ C"] = copy.deepcopy(BEST_PARAMS_v1["AgglomerativeClustering"])
BEST_PARAMS_v1["AC w/o C"] = copy.deepcopy(BEST_PARAMS_v1["AgglomerativeClustering"])
BEST_PARAMS_v2["AC w/ C"] = copy.deepcopy(BEST_PARAMS_v2["AgglomerativeClustering"])
BEST_PARAMS_v2["AC w/o C"] = copy.deepcopy(BEST_PARAMS_v2["AgglomerativeClustering"])
BEST_PARAMS_v3["AC w/ C"] = copy.deepcopy(BEST_PARAMS_v3["AgglomerativeClustering"])
BEST_PARAMS_v3["AC w/o C"] = copy.deepcopy(BEST_PARAMS_v3["AgglomerativeClustering"])
BEST_PARAMS_v4["AC w/ C"] = copy.deepcopy(BEST_PARAMS_v4["AgglomerativeClustering"])
BEST_PARAMS_v4["AC w/o C"] = copy.deepcopy(BEST_PARAMS_v4["AgglomerativeClustering"])

In [None]:
for model in BEST_PARAMS_v1["AC w/ C"]:
    BEST_PARAMS_v1["AC w/ C"][model].update({"aggclust_dist_thresh": None})
for model in BEST_PARAMS_v2["AC w/ C"]:
    BEST_PARAMS_v2["AC w/ C"][model].update({"aggclust_dist_thresh": None})
for model in BEST_PARAMS_v3["AC w/ C"]:
    BEST_PARAMS_v3["AC w/ C"][model].update({"aggclust_dist_thresh": None})
for model in BEST_PARAMS_v4["AC w/ C"]:
    BEST_PARAMS_v4["AC w/ C"][model].update({"aggclust_dist_thresh": None})

In [None]:
for model in BEST_PARAMS_v2["AC w/o C"]:
    BEST_PARAMS_v2["AC w/o C"][model].update(
        {"zscore2": "average", "ndim_correction": True}
    )
for model in BEST_PARAMS_v3["AC w/o C"]:
    BEST_PARAMS_v3["AC w/o C"][model].update(
        {"zscore2": "average", "ndim_correction": True}
    )
for model in BEST_PARAMS_v4["AC w/o C"]:
    BEST_PARAMS_v4["AC w/o C"][model].update(
        {"zscore2": "average", "ndim_correction": True}
    )

In [None]:
# Run AgglomerativeClustering experiments with number of clusters unknown
# 	resnet50        	20.0
# 	mocov3_resnet50 	20.0
# 	vicreg_resnet50 	20.0
# 	vitb16 	            20.0
# 	dino_resnet50     	 1.0
# 	clip_RN50 	         1.0
# 	dino_vitb16 	     2.0
# 	mocov3_vit_base 	 1.0
# 	clip_vitb16 	     0.5
# 	timm_vit_base_patch16_224.mae 	200.0

for model in ["resnet50", "mocov3_resnet50", "vicreg_resnet50", "vitb16"]:
    BEST_PARAMS_v1["AC w/o C"][model].update({"aggclust_dist_thresh": 20.0})
for model in ["dino_resnet50", "clip_RN50", "mocov3_vit_base"]:
    BEST_PARAMS_v1["AC w/o C"][model].update({"aggclust_dist_thresh": 1.0})
BEST_PARAMS_v1["AC w/o C"]["dino_vitb16"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v1["AC w/o C"]["clip_vitb16"]["aggclust_dist_thresh"] = 0.5
BEST_PARAMS_v1["AC w/o C"]["timm_vit_base_patch16_224.mae"][
    "aggclust_dist_thresh"
] = 200.0

In [None]:
BEST_PARAMS_v2["AC w/o C"]["resnet50"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v2["AC w/o C"]["mocov3_resnet50"]["aggclust_dist_thresh"] = 10.0
BEST_PARAMS_v2["AC w/o C"]["vicreg_resnet50"]["aggclust_dist_thresh"] = 0.5
BEST_PARAMS_v2["AC w/o C"]["dino_resnet50"]["aggclust_dist_thresh"] = 0.5
BEST_PARAMS_v2["AC w/o C"]["clip_RN50"]["aggclust_dist_thresh"] = 0.5
BEST_PARAMS_v2["AC w/o C"]["vitb16"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v2["AC w/o C"]["mocov3_vit_base"]["aggclust_dist_thresh"] = 1.0
BEST_PARAMS_v2["AC w/o C"]["timm_vit_base_patch16_224.mae"][
    "aggclust_dist_thresh"
] = 5.0
BEST_PARAMS_v2["AC w/o C"]["dino_vitb16"]["aggclust_dist_thresh"] = 0.2
BEST_PARAMS_v2["AC w/o C"]["clip_vitb16"]["aggclust_dist_thresh"] = 1.0

In [None]:
BEST_PARAMS_v3["AC w/o C"]["none"]["aggclust_dist_thresh"] = 10.0
BEST_PARAMS_v3["AC w/o C"]["random_resnet50"]["aggclust_dist_thresh"] = 10.0
BEST_PARAMS_v3["AC w/o C"]["resnet50"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v3["AC w/o C"]["mocov3_resnet50"]["aggclust_dist_thresh"] = 10.0
BEST_PARAMS_v3["AC w/o C"]["dino_resnet50"]["aggclust_dist_thresh"] = 0.5
BEST_PARAMS_v3["AC w/o C"]["vicreg_resnet50"]["aggclust_dist_thresh"] = 0.5
BEST_PARAMS_v3["AC w/o C"]["clip_RN50"]["aggclust_dist_thresh"] = 0.5
BEST_PARAMS_v3["AC w/o C"]["random_vitb16"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v3["AC w/o C"]["vitb16"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v3["AC w/o C"]["mocov3_vit_base"]["aggclust_dist_thresh"] = 1.0
BEST_PARAMS_v3["AC w/o C"]["dino_vitb16"]["aggclust_dist_thresh"] = 0.2
BEST_PARAMS_v3["AC w/o C"]["timm_vit_base_patch16_224.mae"]["aggclust_dist_thresh"] = 0.5
BEST_PARAMS_v3["AC w/o C"]["clip_vitb16"]["aggclust_dist_thresh"] = 1.0
BEST_PARAMS_v3["AC w/o C"]["ft_mocov3_resnet50"]["aggclust_dist_thresh"] = 1.0
BEST_PARAMS_v3["AC w/o C"]["ft_dino_resnet50"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v3["AC w/o C"]["ft_vicreg_resnet50"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v3["AC w/o C"]["ft_mocov3_vit_base"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v3["AC w/o C"]["ft_dino_vitb16"]["aggclust_dist_thresh"] = 2.0

In [None]:
# TODO:
# - none
# - random_resnet50
# - timm_vit_base_patch16_224.mae
# - mae_pretrain_vit_base_global
# - clip_vitb16 (leave as-is)
# - ft_mocov3_resnet50 (tbc)
# - mae_finetuned_vit_base_global
BEST_PARAMS_v4["AC w/o C"]["resnet50"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["mocov3_resnet50"]["aggclust_dist_thresh"] = 10.0
BEST_PARAMS_v4["AC w/o C"]["dino_resnet50"]["aggclust_dist_thresh"] = 0.5
BEST_PARAMS_v4["AC w/o C"]["vicreg_resnet50"]["aggclust_dist_thresh"] = 0.5
BEST_PARAMS_v4["AC w/o C"]["random_vitb16"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["vitb16"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["mocov3_vit_base"]["aggclust_dist_thresh"] = 1.0
BEST_PARAMS_v4["AC w/o C"]["dino_vitb16"]["aggclust_dist_thresh"] = 0.2
BEST_PARAMS_v4["AC w/o C"]["ft_mocov3_resnet50"]["aggclust_dist_thresh"] = 2.0  # tbc
BEST_PARAMS_v4["AC w/o C"]["ft_dino_resnet50"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["ft_vicreg_resnet50"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["ft_mocov3_vit_base"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["ft_dino_vitb16"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["_version"] = "v4.0"

In [None]:
# TODO:
# - none
# - timm_vit_base_patch16_224.mae (tbc)
# - mae_pretrain_vit_base_global
# - clip_vitb16 (leave as-is)
BEST_PARAMS_v4["AC w/o C"]["random_resnet50"]["aggclust_dist_thresh"] = 10.0
BEST_PARAMS_v4["AC w/o C"]["resnet50"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["mocov3_resnet50"]["aggclust_dist_thresh"] = 10.0
BEST_PARAMS_v4["AC w/o C"]["dino_resnet50"]["aggclust_dist_thresh"] = 0.5
BEST_PARAMS_v4["AC w/o C"]["vicreg_resnet50"]["aggclust_dist_thresh"] = 0.5
BEST_PARAMS_v4["AC w/o C"]["clip_RN50"]["aggclust_dist_thresh"] = 0.5
BEST_PARAMS_v4["AC w/o C"]["random_vitb16"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["vitb16"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["mocov3_vit_base"]["aggclust_dist_thresh"] = 1.0
BEST_PARAMS_v4["AC w/o C"]["dino_vitb16"]["aggclust_dist_thresh"] = 0.2
BEST_PARAMS_v4["AC w/o C"]["timm_vit_base_patch16_224.mae"]["aggclust_dist_thresh"] = 0.5  # tbc
BEST_PARAMS_v4["AC w/o C"]["clip_vitb16"]["aggclust_dist_thresh"] = 1.0
BEST_PARAMS_v4["AC w/o C"]["ft_mocov3_resnet50"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["ft_dino_resnet50"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["ft_vicreg_resnet50"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["ft_mocov3_vit_base"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["ft_dino_vitb16"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["mae_finetuned_vit_base_global"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["_version"] = "v4.1"

In [None]:
# v4.4
BEST_PARAMS_v4["AC w/o C"]["none"]["aggclust_dist_thresh"] = 0.71
BEST_PARAMS_v4["AC w/o C"]["random_resnet50"]["aggclust_dist_thresh"] = 10.0
BEST_PARAMS_v4["AC w/o C"]["resnet50"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["mocov3_resnet50"]["aggclust_dist_thresh"] = 10.0
BEST_PARAMS_v4["AC w/o C"]["dino_resnet50"]["aggclust_dist_thresh"] = 0.5
BEST_PARAMS_v4["AC w/o C"]["vicreg_resnet50"]["aggclust_dist_thresh"] = 0.5
BEST_PARAMS_v4["AC w/o C"]["clip_RN50"]["aggclust_dist_thresh"] = 0.5
BEST_PARAMS_v4["AC w/o C"]["random_vitb16"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["vitb16"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["mocov3_vit_base"]["aggclust_dist_thresh"] = 1.0
BEST_PARAMS_v4["AC w/o C"]["dino_vitb16"]["aggclust_dist_thresh"] = 0.2
BEST_PARAMS_v4["AC w/o C"]["timm_vit_base_patch16_224.mae"]["aggclust_dist_thresh"] = 0.71
BEST_PARAMS_v4["AC w/o C"]["mae_pretrain_vit_base_global"]["aggclust_dist_thresh"] = 0.71
BEST_PARAMS_v4["AC w/o C"]["clip_vitb16"]["aggclust_dist_thresh"] = 1.0
BEST_PARAMS_v4["AC w/o C"]["ft_mocov3_resnet50"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["ft_dino_resnet50"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["ft_vicreg_resnet50"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["ft_mocov3_vit_base"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["ft_dino_vitb16"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["AC w/o C"]["mae_finetuned_vit_base_global"]["aggclust_dist_thresh"] = 2.0
BEST_PARAMS_v4["_version"] = "v4.4"

### Affinity Prop

In [None]:
for model in BEST_PARAMS_v1["AffinityPropagation"]:
    BEST_PARAMS_v1["AffinityPropagation"][model]["affinity_damping"] = 0.5
for model in BEST_PARAMS_v2["AffinityPropagation"]:
    BEST_PARAMS_v2["AffinityPropagation"][model]["affinity_damping"] = 0.5
for model in BEST_PARAMS_v3["AffinityPropagation"]:
    BEST_PARAMS_v3["AffinityPropagation"][model]["affinity_damping"] = 0.9
for model in BEST_PARAMS_v4["AffinityPropagation"]:
    BEST_PARAMS_v4["AffinityPropagation"][model]["affinity_damping"] = 0.9

In [None]:
BEST_PARAMS_v3["AffinityPropagation"]["none"]["affinity_damping"] = 0.85
BEST_PARAMS_v3["AffinityPropagation"]["random_resnet50"]["affinity_damping"] = 0.5
BEST_PARAMS_v3["AffinityPropagation"]["resnet50"]["affinity_damping"] = 0.9
BEST_PARAMS_v3["AffinityPropagation"]["mocov3_resnet50"]["affinity_damping"] = 0.8
BEST_PARAMS_v3["AffinityPropagation"]["dino_resnet50"]["affinity_damping"] = 0.8
BEST_PARAMS_v3["AffinityPropagation"]["vicreg_resnet50"]["affinity_damping"] = 0.75
BEST_PARAMS_v3["AffinityPropagation"]["clip_RN50"]["affinity_damping"] = 0.85
BEST_PARAMS_v3["AffinityPropagation"]["random_vitb16"]["affinity_damping"] = 0.7
BEST_PARAMS_v3["AffinityPropagation"]["vitb16"]["affinity_damping"] = 0.9
BEST_PARAMS_v3["AffinityPropagation"]["mocov3_vit_base"]["affinity_damping"] = 0.75
BEST_PARAMS_v3["AffinityPropagation"]["dino_vitb16"]["affinity_damping"] = 0.85
BEST_PARAMS_v3["AffinityPropagation"]["timm_vit_base_patch16_224.mae"]["affinity_damping"] = 0.5
BEST_PARAMS_v3["AffinityPropagation"]["mae_pretrain_vit_base_global"]["affinity_damping"] = 0.9  # To match
BEST_PARAMS_v3["AffinityPropagation"]["clip_vitb16"]["affinity_damping"] = 0.95
BEST_PARAMS_v3["AffinityPropagation"]["ft_mocov3_resnet50"]["affinity_damping"] = 0.9  # Match supervised/ft resnet50
BEST_PARAMS_v3["AffinityPropagation"]["ft_vicreg_resnet50"]["affinity_damping"] = 0.9
BEST_PARAMS_v3["AffinityPropagation"]["ft_dino_vitb16"]["affinity_damping"] = 0.9
BEST_PARAMS_v3["AffinityPropagation"]["ft_mocov3_vit_base"]["affinity_damping"] = 0.9  # Match supervised/ft resnet50
BEST_PARAMS_v3["AffinityPropagation"]["mae_finetuned_vit_base_global"]["affinity_damping"] = 0.9

In [None]:
BEST_PARAMS_v4["AffinityPropagation"]["none"]["affinity_damping"] = 0.85
BEST_PARAMS_v4["AffinityPropagation"]["random_resnet50"]["affinity_damping"] = 0.9
BEST_PARAMS_v4["AffinityPropagation"]["resnet50"]["affinity_damping"] = 0.9
BEST_PARAMS_v4["AffinityPropagation"]["mocov3_resnet50"]["affinity_damping"] = 0.75
BEST_PARAMS_v4["AffinityPropagation"]["dino_resnet50"]["affinity_damping"] = 0.9
BEST_PARAMS_v4["AffinityPropagation"]["vicreg_resnet50"]["affinity_damping"] = 0.8
BEST_PARAMS_v4["AffinityPropagation"]["clip_RN50"]["affinity_damping"] = 0.85
BEST_PARAMS_v4["AffinityPropagation"]["random_vitb16"]["affinity_damping"] = 0.95
BEST_PARAMS_v4["AffinityPropagation"]["vitb16"]["affinity_damping"] = 0.9
BEST_PARAMS_v4["AffinityPropagation"]["mocov3_vit_base"]["affinity_damping"] = 0.75
BEST_PARAMS_v4["AffinityPropagation"]["dino_vitb16"]["affinity_damping"] = 0.85
BEST_PARAMS_v4["AffinityPropagation"]["timm_vit_base_patch16_224.mae"]["affinity_damping"] = 0.6
BEST_PARAMS_v4["AffinityPropagation"]["mae_pretrain_vit_base_global"]["affinity_damping"] = 0.6
BEST_PARAMS_v4["AffinityPropagation"]["clip_vitb16"]["affinity_damping"] = 0.95
BEST_PARAMS_v4["AffinityPropagation"]["ft_mocov3_resnet50"]["affinity_damping"] = 0.95
BEST_PARAMS_v4["AffinityPropagation"]["ft_dino_resnet50"]["affinity_damping"] = 0.9
BEST_PARAMS_v4["AffinityPropagation"]["ft_vicreg_resnet50"]["affinity_damping"] = 0.9
BEST_PARAMS_v4["AffinityPropagation"]["ft_mocov3_vit_base"]["affinity_damping"] = 0.95
BEST_PARAMS_v4["AffinityPropagation"]["ft_dino_vitb16"]["affinity_damping"] = 0.9
BEST_PARAMS_v4["AffinityPropagation"]["mae_finetuned_vit_base_global"]["affinity_damping"] = 0.9

### HDBSCAN

In [None]:
for model in RESNET50_MODELS + VITB16_MODELS:
    BEST_PARAMS_v1["HDBSCAN"][model].update(
        {
            "distance_metric": "euclidean",
            "hdbscan_method": "eom",
        }
    )

v2 selection

|    | model                         | distance_metric   | hdbscan_method   |      AMI |
|---:|:------------------------------|:------------------|:-----------------|---------:|
|  0 | resnet50                      | euclidean         | eom              | 0.828368 |
|  1 | mocov3_resnet50               | euclidean         | eom              | 0.531644 |
|  2 | vicreg_resnet50               | l1                | eom              | 0.472324 |
|  3 | dino_resnet50                 | l1                | eom              | 0.503147 |
|  4 | clip_RN50                     | l1                | eom              | 0.461363 |
|  5 | vitb16                        | chebyshev         | eom              | 0.906110 |
|  6 | mocov3_vit_base               | euclidean         | eom              | 0.629966 |
|  7 | timm_vit_base_patch16_224.mae | euclidean         | eom              | 0.070495 |
|  8 | dino_vitb16                   | l1                | eom              | 0.691547 |
|  9 | clip_vitb16                   | l1                | eom              | 0.592489 |

In [None]:
for model in RESNET50_MODELS + VITB16_MODELS:
    BEST_PARAMS_v2["HDBSCAN"][model].update(
        {
            "distance_metric": "euclidean",
            "hdbscan_method": "eom",
        }
    )
for model in [
    "vicreg_resnet50",
    "dino_resnet50",
    "clip_RN50",
    "dino_vitb16",
    "clip_vitb16",
]:
    BEST_PARAMS_v2["HDBSCAN"][model].update(
        {
            "distance_metric": "l1",
        }
    )
BEST_PARAMS_v2["HDBSCAN"]["vitb16"]["distance_metric"] = "chebyshev"

In [None]:
for model in [
    "resnet50",
    "mocov3_resnet50",
    "mocov3_vit_base",
    "timm_vit_base_patch16_224.mae",
]:
    BEST_PARAMS_v3["HDBSCAN"][model].update(
        {
            "distance_metric": "euclidean",
            "hdbscan_method": "eom",
        }
    )

for model in [
    "random_resnet50",
    "vicreg_resnet50",
    "dino_resnet50",
    "clip_RN50",
    "random_vitb16",
    "dino_vitb16",
    "clip_vitb16",
]:
    BEST_PARAMS_v3["HDBSCAN"][model].update(
        {
            "distance_metric": "l1",
            "hdbscan_method": "eom",
        }
    )

for model in ["vitb16"]:
    BEST_PARAMS_v3["HDBSCAN"][model].update(
        {
            "distance_metric": "chebyshev",
            "hdbscan_method": "eom",
        }
    )

In [None]:
BEST_PARAMS_v4["HDBSCAN"]["none"]["distance_metric"] = "euclidean"
BEST_PARAMS_v4["HDBSCAN"]["none"]["hdbscan_method"] = "eom"
BEST_PARAMS_v4["HDBSCAN"]["random_resnet50"]["distance_metric"] = "l1"
BEST_PARAMS_v4["HDBSCAN"]["random_resnet50"]["hdbscan_method"] = "eom"
BEST_PARAMS_v4["HDBSCAN"]["resnet50"]["distance_metric"] = "euclidean"
BEST_PARAMS_v4["HDBSCAN"]["resnet50"]["hdbscan_method"] = "eom"
BEST_PARAMS_v4["HDBSCAN"]["mocov3_resnet50"]["distance_metric"] = "euclidean"
BEST_PARAMS_v4["HDBSCAN"]["mocov3_resnet50"]["hdbscan_method"] = "eom"
BEST_PARAMS_v4["HDBSCAN"]["dino_resnet50"]["distance_metric"] = "l1"
BEST_PARAMS_v4["HDBSCAN"]["dino_resnet50"]["hdbscan_method"] = "eom"
BEST_PARAMS_v4["HDBSCAN"]["vicreg_resnet50"]["distance_metric"] = "l1"
BEST_PARAMS_v4["HDBSCAN"]["vicreg_resnet50"]["hdbscan_method"] = "eom"
BEST_PARAMS_v4["HDBSCAN"]["clip_RN50"]["distance_metric"] = "l1"
BEST_PARAMS_v4["HDBSCAN"]["clip_RN50"]["hdbscan_method"] = "eom"
BEST_PARAMS_v4["HDBSCAN"]["random_vitb16"]["distance_metric"] = "l1"
BEST_PARAMS_v4["HDBSCAN"]["random_vitb16"]["hdbscan_method"] = "eom"
BEST_PARAMS_v4["HDBSCAN"]["vitb16"]["distance_metric"] = "chebyshev"
BEST_PARAMS_v4["HDBSCAN"]["vitb16"]["hdbscan_method"] = "eom"
BEST_PARAMS_v4["HDBSCAN"]["mocov3_vit_base"]["distance_metric"] = "euclidean"
BEST_PARAMS_v4["HDBSCAN"]["mocov3_vit_base"]["hdbscan_method"] = "eom"
BEST_PARAMS_v4["HDBSCAN"]["dino_vitb16"]["distance_metric"] = "l1"
BEST_PARAMS_v4["HDBSCAN"]["dino_vitb16"]["hdbscan_method"] = "eom"
BEST_PARAMS_v4["HDBSCAN"]["timm_vit_base_patch16_224.mae"]["distance_metric"] = "euclidean"
BEST_PARAMS_v4["HDBSCAN"]["timm_vit_base_patch16_224.mae"]["hdbscan_method"] = "eom"
BEST_PARAMS_v4["HDBSCAN"]["mae_pretrain_vit_base_global"]["distance_metric"] = "l1"
BEST_PARAMS_v4["HDBSCAN"]["mae_pretrain_vit_base_global"]["hdbscan_method"] = "eom"
BEST_PARAMS_v4["HDBSCAN"]["clip_vitb16"]["distance_metric"] = "l1"
BEST_PARAMS_v4["HDBSCAN"]["clip_vitb16"]["hdbscan_method"] = "eom"
BEST_PARAMS_v4["HDBSCAN"]["ft_mocov3_resnet50"]["distance_metric"] = "chebyshev"
BEST_PARAMS_v4["HDBSCAN"]["ft_mocov3_resnet50"]["hdbscan_method"] = "eom"
BEST_PARAMS_v4["HDBSCAN"]["ft_dino_resnet50"]["distance_metric"] = "l1"
BEST_PARAMS_v4["HDBSCAN"]["ft_dino_resnet50"]["hdbscan_method"] = "eom"
BEST_PARAMS_v4["HDBSCAN"]["ft_vicreg_resnet50"]["distance_metric"] = "euclidean"
BEST_PARAMS_v4["HDBSCAN"]["ft_vicreg_resnet50"]["hdbscan_method"] = "eom"
BEST_PARAMS_v4["HDBSCAN"]["ft_mocov3_vit_base"]["distance_metric"] = "chebyshev"
BEST_PARAMS_v4["HDBSCAN"]["ft_mocov3_vit_base"]["hdbscan_method"] = "eom"
BEST_PARAMS_v4["HDBSCAN"]["ft_dino_vitb16"]["distance_metric"] = "chebyshev"
BEST_PARAMS_v4["HDBSCAN"]["ft_dino_vitb16"]["hdbscan_method"] = "eom"
BEST_PARAMS_v4["HDBSCAN"]["mae_finetuned_vit_base_global"]["distance_metric"] = "chebyshev"
BEST_PARAMS_v4["HDBSCAN"]["mae_finetuned_vit_base_global"]["hdbscan_method"] = "eom"

### Finally, set overall hparams

In [None]:
BEST_PARAMS = BEST_PARAMS_v4

## Utility functions

In [None]:
def categorical_cmap(nc, nsc, cmap="tab10", continuous=False):
    """
    Create a colormap with a certain number of shades of colours.

    https://stackoverflow.com/a/47232942/1960959
    """
    if nc > plt.get_cmap(cmap).N:
        raise ValueError("Too many categories for colormap.")
    if continuous:
        ccolors = plt.get_cmap(cmap)(np.linspace(0, 1, nc))
    else:
        ccolors = plt.get_cmap(cmap)(np.arange(nc, dtype=int))
    cols = np.zeros((nc * nsc, 3))
    for i, c in enumerate(ccolors):
        chsv = matplotlib.colors.rgb_to_hsv(c[:3])
        arhsv = np.tile(chsv, nsc).reshape(nsc, 3)
        arhsv[:, 1] = np.linspace(chsv[1], 0.25, nsc)
        arhsv[:, 2] = np.linspace(chsv[2], 1, nsc)
        rgb = matplotlib.colors.hsv_to_rgb(arhsv)
        cols[i * nsc : (i + 1) * nsc, :] = rgb
    cmap = matplotlib.colors.ListedColormap(cols)
    return cmap

In [None]:
categorical_cmap(len(RESNET50_MODELS), len(VALIDATION_DATASETS))

In [None]:
from zs_ssl_clustering.datasets import image_dataset_sizes


def clip_imgsize(dataset, target_image_size):
    if target_image_size is None:
        return target_image_size
    dataset_imsize = image_dataset_sizes(dataset)[1]
    if dataset_imsize is None:
        return target_image_size
    return min(target_image_size, dataset_imsize)

In [None]:
def fixup_filter(filters):
    dataset = filters.get("dataset_name", filters.get("dataset", None))
    if dataset and "image_size" in filters:
        filters["image_size"] = clip_imgsize(dataset, filters["image_size"])
    if dataset and "min_samples" in filters:
        if dataset.lower() in ["celeba", "utkface"]:
            filters["min_samples"] = 2
    return filters

In [None]:
def select_rows(df, filters, allow_missing=True, fixup=True):
    if fixup:
        filters = fixup_filter(filters)
    select = np.ones(len(df), dtype=bool)
    for col, val in filters.items():
        if col == "dataset":
            col = "dataset_name"
        if col == "clusterer":
            col = "clusterer_name"
        if val is None or val == "None" or val == "none":
            select_i = pd.isna(df[col])
            select_i |= df[col] == "None"
            select_i |= df[col] == "none"
        else:
            select_i = df[col] == val
            select_i |= df[col] == str(val)
            if allow_missing or val == "None" or val == "none":
                select_i |= pd.isna(df[col])
        select &= select_i
    return df[select]

In [None]:
def find_differing_columns(df, cols=None):
    if cols is None:
        cols = df.columns
    my_cols = []
    for col in cols:
        if col not in df.columns:
            continue
        if df[col].nunique(dropna=False) > 1:
            my_cols.append(col)
    return my_cols

In [None]:
def filter2command(*filters, partition="val"):
    f = {}
    for filter in filters:
        for k, v in filter.items():
            f[k] = v
    dataset = f.get("dataset", "")
    clusterer = f.get("clusterer", "")

    MEM = "2G"

    if clusterer == "AffinityPropagation":
        if dataset in ["inaturalist"]:
            MEM = "292G"
        elif dataset in ["imagenet-sketch", "imagenet"]:
            MEM = "96G"
        elif dataset in ["places365", "imagenet-r", "svhn", "nabirds"]:
            MEM = "64G"
        elif dataset in ["celeba"]:
            MEM = "12G"
        elif dataset in ["imagenetv2", "cifar10", "cifar100", "lsun", "mnist", "fashionmnist", "stanfordcars"]:
            MEM = "6G"
        elif (
            dataset.startswith("in9")
            or dataset in ["flowers102", "utkface", "eurosat", "aircraft", "imagenet-o", "dtd"]
        ):
            MEM = "2G"
        elif dataset in ["imagenette", "imagewoof"]:
            MEM = "1G"
        else:
            MEM = "8G"

    if clusterer == "AgglomerativeClustering":
        if dataset in ["inaturalist"]:
            MEM = "72G"
        elif dataset in ["imagenet-sketch", "imagenet"]:
            MEM = "20G"
        elif dataset in ["places365", "imagenet-r", "svhn", "nabirds"]:
            MEM = "16G"
        elif dataset in ["celeba"]:
            MEM = "12G"
        elif dataset in ["imagenetv2", "cifar10", "cifar100", "lsun", "mnist", "fashionmnist", "stanfordcars"]:
            MEM = "6G"
        elif (
            dataset.startswith("in9")
            or dataset in ["flowers102", "utkface", "eurosat", "aircraft", "imagenet-o", "dtd"]
        ):
            MEM = "4G"
        elif dataset in ["imagenette", "imagewoof"]:
            MEM = "2G"
        else:
            MEM = "8G"

    if clusterer in ["HDBSCAN", "KMeans"]:
        if dataset in ["inaturalist"]:
            MEM = "6G"
        elif dataset in ["imagenet-sketch", "imagenet"]:
            MEM = "4G"
        elif dataset in ["places365", "imagenet-r", "svhn", "nabirds"]:
            MEM = "4G"
        elif dataset in ["celeba"]:
            MEM = "4G"
        elif dataset in ["imagenetv2", "cifar10", "cifar100", "lsun", "mnist", "fashionmnist", "stanfordcars"]:
            MEM = "2G"
        elif (
            dataset.startswith("in9")
            or dataset in ["flowers102", "utkface", "eurosat", "aircraft", "imagenet-o", "dtd"]
        ):
            MEM = "2G"
        elif dataset in ["imagenette", "imagewoof"]:
            MEM = "1G"
        else:
            MEM = "4G"

    if partition == "val":
        seed = 100
    elif partition == "test":
        seed = 1
    else:
        seed = 0
    s = (
        f"sbatch --array={seed} --mem={MEM}"
        f' --job-name="zsc-{f.get("model", "")}-{dataset}-{clusterer}"'
        f" slurm/cluster.slrm --partition={partition}"
    )
    for k, v in f.items():
        if v is None:
            continue
        if k == "zscore":
            if v == "False" or not v:
                s += " --no-zscore"
            elif v == "True" or v:
                s += " --zscore"
            continue
        if k == "normalize":
            if v == "False" or not v:
                pass
            elif v == "True" or v:
                s += " --normalize"
            continue
        if k == "zscore2":
            if v == "False" or not v:
                s += " --no-zscore2"
            elif v == "average":
                s += " --azscore2"
            elif v == "standard" or v:
                s += " --zscore2"
            continue
        if k == "ndim_correction":
            if v == "False" or not v:
                s += " --no-ndim-correction"
            elif v == "True" or v:
                s += " --ndim-correction"
            continue
        s += f" --{k.replace('_', '-')}={v}"
    return s

# Final results

In [None]:
TEST_DATASETS = [
    "imagenet",
    "imagenetv2",
    "imagenet-o",
    "cifar10",
    "cifar100",
    "in9original",
    "in9mixednext",
    "in9onlybgt",
    "in9onlyfg",
    "imagenet-r",
    "imagenet-sketch",
    "aircraft",
    "stanfordcars",
    "flowers102",
    "nabirds",
    "inaturalist",
    "celeba",
    "utkface",
    "dtd",
    "eurosat",
    "lsun",
    "places365",
    "mnist",
    "fashionmnist",
    "svhn",
]
DATASET2SH = {
    "aircraft": "Air",
    "celeba": "CelA",
    "cifar10": "C10",
    "cifar100": "C100",
    "dtd": "DTD",
    "eurosat": "ESAT",
    "flowers102": "F102",
    "fashionmnist": "fMNIST",
    "imagenet": "IN1k",
    "imagenet-o": "IN-O",
    "imagenet-r": "IN-R",
    "imagenet-sketch": "IN-S",
    "imagenetv2": "INv2",
    "imagenette": "IN10",
    "imagewoof": "INwf",
    "in9original": "IN9",
    "in9mixednext": "IN9-MN",
    "in9mixedrand": "IN9-MR",
    "in9mixedsame": "IN9-MS",
    "in9nofg": "IN9-NoFG",
    "in9onlybgb": "IN9-BGB",
    "in9onlybgt": "IN9-BGT",
    "in9onlyfg": "IN9-FG",
    "inaturalist": "iNat21",
    "lsun": "LSUN",
    "mnist": "MNIST",
    "nabirds": "Birds",
    "places365": "P365",
    "stanfordcars": "Cars",
    "svhn": "SVHN",
    "utkface": "UTKF",
}
MODEL_GROUPS = {
    "ResNet-50": RESNET50_MODELS,
    "ViT-B": VITB16_MODELS,
    "ResNet-50 [FT]": FT_RESNET50_MODELS,
    "ViT-B [FT]": FT_VITB16_MODELS,
    "all": ALL_MODELS,
}
MODEL2SH = {
    "none": "Raw image",
    "random_resnet50": "Random",
    "random_vitb16": "Random",
    "resnet50": "Supervised",
    "mocov3_resnet50": "MoCo-v3",
    "dino_resnet50": "DINO",
    "vicreg_resnet50": "VICReg",
    "clip_RN50": "CLIP",
    "vitb16": "Supervised",
    "mocov3_vit_base": "MoCo-v3",
    "dino_vitb16": "DINO",
    "timm_vit_base_patch16_224.mae": "MAE (CLS)",
    "mae_pretrain_vit_base_global": "MAE (avg)",
    "clip_vitb16": "CLIP",
    "ft_mocov3_resnet50": "MoCo-v3 [FT]",
    "ft_dino_resnet50": "DINO [FT]",
    "ft_vicreg_resnet50": "VICReg [FT]",
    "ft_mocov3_vit_base": "MoCo-v3 [FT]",
    "ft_dino_vitb16": "DINO [FT]",
    "mae_finetuned_vit_base_global": "MAE (avg) [FT]",
}
CLUSTERER2SH = {
    "KMeans": "K-Means",
    "AffinityPropagation": "Affinity Prop",
    "AgglomerativeClustering": "AC",
    "AC w/ C": "AC w/  C",
}

In [None]:
TEST_DATASETS_GROUPED = {
    "In-domain": [
        "imagenet",
        "imagenetv2",
        "imagenet-o",
        "in9original",
        "cifar10",
        "cifar100",
    ],
    "Domain-shift": [
        "imagenet-r",
        "imagenet-sketch",
        "in9mixednext",
        "in9onlybgt",
        "in9onlyfg",
    ],
    "Fine-grained": [
        "aircraft",
        "stanfordcars",
        "flowers102",
        "nabirds",
        "inaturalist",
    ],
    "Near-OOD": [
        "lsun",
        "places365",
    ],
    "Far-OOD": [
        "celeba",
        "utkface",
        "dtd",
        "eurosat",
        "mnist",
        "fashionmnist",
        "svhn",
    ],
}
DATASETGROUP2TITLE = {
    "Domain-shift": "Domain-shifted",
    "Out-of-distribution": "OOD",
}

In [None]:
CLUSTERER2COLORSTR = {
    "KMeans": "tab:purple",
    "AC w/ C": "tab:red",
    "AC w/o C": "tab:orange",
    "AffinityPropagation": "tab:green",
    "HDBSCAN": "tab:blue",
}
CLUSTERER2COLORRGB = {k: matplotlib.colors.to_rgb(v) for k, v in CLUSTERER2COLORSTR.items()}

In [None]:
MODEL2COLORSTR = {
    "none": "black",
    "random_resnet50": "tab:grey",
    "random_vitb16": "tab:grey",
    "resnet50": "tab:red",
    "mocov3_resnet50": "tab:cyan",
    "dino_resnet50": "tab:green",
    "vicreg_resnet50": "tab:purple",
    "clip_RN50": "tab:blue",
    "vitb16": "tab:red",
    "mocov3_vit_base": "tab:cyan",
    "dino_vitb16": "tab:green",
    "timm_vit_base_patch16_224.mae": "tab:olive",
    "mae_pretrain_vit_base_global": "tab:brown",
    "clip_vitb16": "tab:blue",
    "ft_mocov3_resnet50": "tab:cyan",
    "ft_dino_resnet50": "tab:green",
    "ft_vicreg_resnet50": "tab:purple",
    "ft_mocov3_vit_base": "tab:cyan",
    "ft_dino_vitb16": "tab:green",
    "mae_finetuned_vit_base_global": "tab:brown",
}
MODEL2COLORRGB = {k: matplotlib.colors.to_rgb(v) for k, v in MODEL2COLORSTR.items()}

## Fetch results

In [None]:
runs_df_long = pd.DataFrame({"id": []})
config_keys = set()
summary_keys = set()

In [None]:
# Load previous results from CSV file
CSV_FNAME = "test_runs_df.csv"
if os.path.isfile(CSV_FNAME):
    pass
    # runs_df_long = test_runs_df = pd.read_csv(CSV_FNAME)

In [None]:
# Project is specified by <entity/project-name>
api = wandb.Api(timeout=720)
runs = api.runs(
    "uoguelph_mlrg/zs-ssl-clustering",
    filters={"state": "Finished", "config.partition": "test"},  # "config.predictions_dir": "y_pred"},
    per_page=10_000,
)
len(runs)

In [None]:
print(f"{len(runs_df_long)} runs currently in dataframe")
rows_to_add = []
existing_ids = set(runs_df_long["id"].values)
for run in tqdm(runs):
    if run.id in existing_ids:
        continue
    # .summary contains the output keys/values for metrics like accuracy.
    #  We call ._json_dict to omit large files
    summary = run.summary._json_dict
    # .config contains the hyperparameters.
    #  We remove special values that start with _.
    config = {k: v for k, v in run.config.items() if not k.startswith("_")}
    # .name is the human-readable name of the run.
    row = {"id": run.id, "name": run.name}
    row.update({k: v for k, v in config.items() if not k.startswith("_")})
    row.update({k: v for k, v in summary.items() if not k.startswith("_")})
    if "_timestamp" in summary:
        row["_timestamp"] = summary["_timestamp"]
    rows_to_add.append(row)
    config_keys = config_keys.union(config.keys())
    summary_keys = summary_keys.union(summary.keys())

if not len(rows_to_add):
    print("No new runs to add")
else:
    print(f"Adding {len(rows_to_add)} runs")
    runs_df_long = pd.concat([runs_df_long, pd.DataFrame.from_records(rows_to_add)])
print(f"{len(runs_df_long)} runs")

In [None]:
# Remove entries without an AMI metric
test_runs_df = runs_df_long[~runs_df_long["AMI"].isna()]
len(test_runs_df)

In [None]:
# Handle changed default value for spectral_assigner after config arg was introduced
if "spectral_assigner" not in test_runs_df.columns:
    test_runs_df["spectral_assigner"] = None
select = test_runs_df["clusterer_name"] != "SpectralClustering"
test_runs_df.loc[select, "spectral_assigner"] = None
select = (test_runs_df["clusterer_name"] == "SpectralClustering") & pd.isna(
    test_runs_df["spectral_assigner"]
)
test_runs_df.loc[select, "spectral_assigner"] = "kmeans"

# Accidentally wasn't clearing this hparam when it was unused
if "spectral_affinity" not in test_runs_df.columns:
    test_runs_df["spectral_affinity"] = None
select = test_runs_df["clusterer_name"] != "SpectralClustering"
test_runs_df.loc[select, "spectral_affinity"] = None

if "zscore2" not in test_runs_df.columns:
    test_runs_df["zscore2"] = False
test_runs_df.loc[pd.isna(test_runs_df["zscore2"]), "zscore2"] = False

if "ndim_correction" not in test_runs_df.columns:
    test_runs_df["ndim_correction"] = False
test_runs_df.loc[pd.isna(test_runs_df["ndim_correction"]), "ndim_correction"] = False

if "dim_reducer_man_nn" not in test_runs_df.columns:
    test_runs_df["dim_reducer_man_nn"] = None

if "image_size" not in test_runs_df.columns:
    test_runs_df["image_size"] = None

In [None]:
# Save results to CSV file, so we can optionally skip downloading them
test_runs_df.to_csv(CSV_FNAME, index=False)

In [None]:
config_keys = config_keys.difference(
    {"workers", "memory_avail_GB", "memory_total_GB", "memory_slurm"}
)

In [None]:
test_runs_df

#### Draft table

In [None]:
metric_key = "AMI"
show_pc = True
show_fmt = "{:5.1f}"
show_commands = False
eps = 0.001
override_fields = {
    # "aggclust_dist_thresh": None,  # to flip between unknown/known n clusters for AC
    # "predictions_dir": "y_pred",
}

# KMeans  AffinityPropagation  AgglomerativeClustering  HDBSCAN
backbones = MODEL_GROUPS.keys()
clusterer = "AgglomerativeClustering"

best_results = {k: [] for k in TEST_DATASETS}
for dummy in [True, False]:
    cmds = []
    latex_table = r"% Results for " + f"{metric_key}, {clusterer}" + "\n"
    now_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    latex_table += r"% Generated " + now_str + "\n"
    latex_table += r"% Using hparams " + BEST_PARAMS["_version"] + "\n"
    latex_table += r"\label{tab:" + clusterer + r"}" + "\n"
    latex_table += r"\resizebox{\textwidth}{!}{%" + "\n"
    latex_table += r"\begin{tabular}{ll" + r"r" * len(TEST_DATASETS) + r"}" + "\n"
    latex_table += r"\toprule" + "\n"
    latex_table += r"& " + f"{'Encoder':<11s}"
    for dataset in TEST_DATASETS:
        latex_table += r"&" + "{:^15s}".format(DATASET2SH.get(dataset, dataset))
    latex_table += r"\\" + "\n"
    latex_table += r"\toprule" + "\n"
    for i_group, model_group_name in enumerate(list(backbones)):
        if i_group > 0:
            latex_table += r"\midrule" + "\n"
        for i_model, model in enumerate(MODEL_GROUPS[model_group_name]):
            if i_model == 0:
                latex_table += (
                    r"\parbox[t]{2mm}{\multirow{5}{*}{\rotatebox[origin=c]{90}{"
                    + model_group_name
                    + "}}}"
                )
                latex_table += "\n"
            latex_table += f"& {MODEL2SH.get(model, model):<10s}"
            for i_dataset, dataset in enumerate(TEST_DATASETS):
                latex_table += " &"
                filter1 = {
                    "model": model,
                    "dataset": dataset,
                    "clusterer": clusterer,
                }
                filter2 = dict(DEFAULT_PARAMS["all"], **BEST_PARAMS[clusterer][model])
                filter2.update(filter1)
                filter2.update(override_fields)
                filter2 = fixup_filter(filter2)
                sdf = select_rows(test_runs_df, filter2, allow_missing=False)
                if len(sdf) < 1:
                    # print(f"No data for {filter2}")
                    if clusterer == "AffinityPropagation" and dataset in [
                        "imagenet",
                        "inaturalist",
                    ]:
                        continue
                        pass
                    cmds.append(filter2command(filter2, partition="test"))
                    continue
                if len(sdf) > 1:
                    if sum(np.abs(sdf[metric_key] - sdf.iloc[0][metric_key]) > 1e-6) > 0:
                        print()
                        print(f"More than one result with {metric_key} values", list(sdf[metric_key]))
                        print(f"for search {filter2}")
                        dif_cols = find_differing_columns(sdf, config_keys)
                        print(f"columns which differ: {dif_cols}")
                        if dif_cols:
                            for col in dif_cols:
                                print(f"  {col}: {list(sdf[col])}")
                my_val = np.median(sdf[metric_key])
                if dummy:
                    best_results[dataset].append(my_val)
                    continue
                is_best = my_val + eps >= np.max(best_results[dataset])
                if len(best_results[dataset]) > 1:
                    is_secd = my_val + eps >= np.sort(best_results[dataset])[-2]
                else:
                    is_secd = False
                if show_pc:
                    my_val = my_val * 100
                latex_table += " $"
                if is_best:
                    latex_table += r"\tcf{"
                elif is_secd:
                    latex_table += r"\tcs{"
                else:
                    latex_table += "     "
                latex_table += show_fmt.format(my_val)
                latex_table += r"}" if is_best or is_secd else " "
                latex_table += "$"
            latex_table += r" \\" + "\n"
    latex_table += r"\bottomrule" + "\n"
    latex_table += r"\end{tabular}" + "\n"
    latex_table += r"}" + "\n"

print()
print(f"There are {len(cmds)} commands to execute to generate missing datapoints")
if show_commands:
    for cmd in cmds:
        print(cmd)

print()
print("Done!")
print()
print(f"Here is your results table for {clusterer}:")
print()
print()
print(latex_table)

## Grouping by encoder

In [None]:
metric_key = "AMI"
show_pc = True
show_fmt = "{:4.0f}"
show_commands = False
eps = 0.001
override_fields = {
    # "predictions_dir": "y_pred",
}

backbone = "ResNet-50"

CLUSTERERS = [
    "KMeans",
    "AgglomerativeClustering",
    "AgglomerativeClustering",
    "AffinityPropagation",
    "HDBSCAN",
]
print(MODEL2SH)

best_results = {k: [] for k in TEST_DATASETS}
for dummy in [True, False]:
    cmds = []
    latex_table = r"% Results for " + f"{metric_key}, {backbone}" + "\n"
    now_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    latex_table += r"% Generated " + now_str + "\n"
    latex_table += r"% Using hparams " + BEST_PARAMS["_version"] + "\n"
    latex_table += r"\label{tab:" + backbone + r"}" + "\n"
    latex_table += r"\resizebox{\textwidth}{!}{%" + "\n"
    latex_table += r"\begin{tabular}{ll" + r"r" * len(TEST_DATASETS) + r"}" + "\n"
    latex_table += r"\toprule" + "\n"
    latex_table += r"& " + f"{'Clusterer':<11s}"
    for dataset in TEST_DATASETS:
        latex_table += r"&" + "{:^15s}".format(DATASET2SH.get(dataset, dataset))
    latex_table += r"\\" + "\n"
    latex_table += r"\toprule" + "\n"
    print(MODEL_GROUPS[backbone])
    for i_group, model in enumerate(list(MODEL_GROUPS[backbone])):
        print(model)
        if i_group > 0:
            latex_table += r"\midrule" + "\n"

        first_agg = True
        for i_clusterer, clusterer in enumerate(CLUSTERERS):
            if i_clusterer == 0:
                latex_table += (
                    r"\parbox[t]{2mm}{\multirow{5}{*}{\rotatebox[origin=c]{90}{"
                    + MODEL2SH[model]
                    + "}}}"
                )
                latex_table += "\n"
            clusterername = CLUSTERER2SH.get(clusterer, clusterer)

            my_override_fields = override_fields.copy()
            if (
                first_agg
                and clusterer == "AgglomerativeClustering"
                and metric_key != "num_cluster_pred"
            ):
                first_agg = False
                my_override_fields["aggclust_dist_thresh"] = None
                clusterername = "AC  w/ C"
            elif clusterer == "AgglomerativeClustering":
                clusterername = "AC w/o C"
                if "aggclust_dist_thresh" in my_override_fields:
                    del my_override_fields["aggclust_dist_thresh"]

            if clusterer == "HDBSCAN" and dataset in ["celeba", "utkface"]:
                my_override_fields["min_samples"] = 2
            elif "min_samples" in my_override_fields:
                del my_override_fields["min_samples"]

            latex_table += f"& {clusterername:<10s}"
            for i_dataset, dataset in enumerate(TEST_DATASETS):
                latex_table += " &"
                filter1 = {
                    "model": model,
                    "dataset": dataset,
                    "clusterer": clusterer,
                }
                filter2 = dict(DEFAULT_PARAMS["all"], **BEST_PARAMS[clusterer][model])
                filter2.update(filter1)
                filter2.update(my_override_fields)
                filter2 = fixup_filter(filter2)
                sdf = select_rows(test_runs_df, filter2, allow_missing=False)
                if len(sdf) < 1:
                    # print(f"No data for {filter2}")
                    cmds.append(filter2command(filter2, partition="test"))
                    continue
                if len(sdf) > 1:
                    if sum(np.abs(sdf[metric_key] - sdf.iloc[0][metric_key]) > 1e-6) > 0:
                        print()
                        print(f"More than one result with {metric_key} values", list(sdf[metric_key]))
                        print(f"for search {filter2}")
                        dif_cols = find_differing_columns(sdf, config_keys)
                        print(f"columns which differ: {dif_cols}")
                        if dif_cols:
                            for col in dif_cols:
                                print(f"  {col}: {list(sdf[col])}")
                my_val = np.median(sdf[metric_key])
                if dummy:
                    best_results[dataset].append(my_val)
                    continue
                is_best = my_val + eps >= np.max(best_results[dataset])
                if len(best_results[dataset]) > 1:
                    is_secd = my_val + eps >= np.sort(best_results[dataset])[-2]
                else:
                    is_secd = False
                if show_pc:
                    my_val = my_val * 100
                latex_table += " $"
                if is_best:
                    latex_table += r"\tcf{"
                elif is_secd:
                    latex_table += r"\tcs{"
                else:
                    latex_table += "     "
                latex_table += show_fmt.format(my_val)
                latex_table += r"}" if is_best or is_secd else " "
                latex_table += "$"
            latex_table += r" \\" + "\n"
    latex_table += r"\bottomrule" + "\n"
    latex_table += r"\end{tabular}" + "\n"
    latex_table += r"}" + "\n"

print()
print(f"There are {len(cmds)} commands to execute to generate missing datapoints")
if show_commands:
    for cmd in cmds:
        print(cmd)

print()
print("Done!")
print()
print(f"Here is your results table for {clusterer}:")
print()
print()
print(latex_table)

## Grouping by clusterer

In [None]:
metric_key = "AMI"  # AMI  num_cluster_pred  silhouette-euclidean_pred  silhouette-og-euclidean_pred
show_pc = True
show_fmt = "{:4.0f}"
show_commands = False
highlight_best = True
use_si_num = False
eps = 0.005
override_fields = {
    # "predictions_dir": "y_pred",
}

backbone = "ResNet-50"  # "ResNet-50" or "ViT-B"

if metric_key == "num_cluster_pred":
    CLUSTERERS = ["AC w/o C", "AffinityPropagation", "HDBSCAN"]
    show_pc = False
    show_fmt = "{:4.0f}"
    highlight_best = False
    use_si_num = True
else:
    CLUSTERERS = ["KMeans", "AC w/ C", "AC w/o C", "AffinityPropagation", "HDBSCAN"]
if metric_key.startswith("silhouette"):
    show_pc = False
    show_fmt = "{:5.2f}"

print(MODEL_GROUPS)

best_results = {k: [] for k in TEST_DATASETS}
best_results_grouped = {k: defaultdict(list) for k in TEST_DATASETS}

for dummy in [True, False]:
    cmds = []
    latex_table = r"% Results for " + f"{metric_key}, {backbone}" + "\n"
    now_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    latex_table += r"% Generated " + now_str + "\n"
    latex_table += r"% Using hparams " + BEST_PARAMS["_version"] + "\n"
    label = backbone
    if metric_key == "AMI":
        latex_table += r"\label{tab:" + label + r"}" + "\n"
    label = metric_key.replace("_", "-") + ":" + label
    latex_table += r"\label{tab:" + label + r"}" + "\n"
    latex_table += r"\resizebox{\textwidth}{!}{%" + "\n"
    latex_table += r"\begin{tabular}{ll" + r"r" * len(TEST_DATASETS) + r"}" + "\n"
    latex_table += r"\toprule" + "\n"
    latex_table += r"& " + f"{'Encoder':<11s}"
    for dataset in TEST_DATASETS:
        latex_table += r"&" + "{:^15s}".format(DATASET2SH.get(dataset, dataset))
    latex_table += r"\\" + "\n"
    latex_table += r"\toprule" + "\n"
    print(MODEL_GROUPS[backbone])
    if metric_key == "num_cluster_pred":
        latex_table += r"& Num targets"
        for i_dataset, dataset in enumerate(TEST_DATASETS):
            sdf = select_rows(test_runs_df, {"dataset": dataset}, allow_missing=False)
            sdf = sdf[~pd.isna(sdf["num_cluster_true"])]
            latex_table += r"& "
            latex_table += r"\num{" if use_si_num else r"$"
            latex_table += f"{sdf.iloc[0]['num_cluster_true'].item()}"
            latex_table += r"}" if use_si_num else r"$"
        latex_table += r"\\" + "\n"
        latex_table += r"\toprule" + "\n"
    elif metric_key.endswith("_pred"):
        metric_key2 = metric_key.replace("_pred", "_true")
        clusterername = "G.T."
        latex_table += (
            r"\parbox[t]{2mm}{\multirow{"
            + str(len(MODEL_GROUPS[backbone]))
            + r"}{*}{"
            + r"\scalebox{0.9}{"
            + r"\rotatebox[origin=c]{90}{"
            + clusterername
            + r"}}}"
            + r"}"
        )
        latex_table += "\n"
        for i_group, model in enumerate(list(MODEL_GROUPS[backbone])):
            latex_table += f"& {MODEL2SH[model]:<10s}"
            for i_dataset, dataset in enumerate(TEST_DATASETS):
                latex_table += " &"
                filter1 = {"model": model, "dataset": dataset}
                if model == "timm_vit_base_patch16_224.mae":
                    filter1["dim_reducer"] = "PCA"
                    filter1["pca_variance"] = 0.95
                else:
                    filter1["dim_reducer_man"] = "UMAP"
                    filter1["ndim_reduced_man"] = 50
                    filter1["dim_reducer_man_metric"] = "euclidean"
                sdf = select_rows(test_runs_df, filter1, allow_missing=False)
                sdf = sdf[~pd.isna(sdf[metric_key2])]
                my_val = np.nanmedian(sdf[metric_key])
                if sum(sdf[metric_key2] != my_val) > 0:
                    pass
                if dummy:
                    best_results_grouped[dataset][clusterername].append(my_val)
                    continue
                is_best_grp = my_val + eps >= np.max(
                    best_results_grouped[dataset][clusterername]
                )
                latex_table += r"\num{" if use_si_num else r"$"
                latex_table += "     "
                if not highlight_best:
                    pass
                elif is_best_grp:
                    latex_table += r"\tcg{"
                else:
                    latex_table += "     "
                latex_table += show_fmt.format(my_val)
                if highlight_best:
                    latex_table += r"}" if is_best_grp else " "
                latex_table += r"}" if use_si_num else r"$"

            latex_table += r" \\" + "\n"
        latex_table += r"\toprule" + "\n"

    first_agg = True
    for i_clusterer, clusterer in enumerate(CLUSTERERS):
        clusterername = CLUSTERER2SH.get(clusterer, clusterer)
        my_override_fields = override_fields.copy()
        if (
            first_agg
            and clusterer == "AgglomerativeClustering"
            and metric_key != "num_cluster_pred"
        ):
            first_agg = False
            my_override_fields["aggclust_dist_thresh"] = None
            clusterername = "AC  w/ C"
        elif clusterer == "AgglomerativeClustering":
            clusterername = "AC w/o C"
            if "aggclust_dist_thresh" in my_override_fields:
                del my_override_fields["aggclust_dist_thresh"]

        if i_clusterer > 0:
            latex_table += r"\midrule" + "\n"

        latex_table += (
            r"\parbox[t]{2mm}{\multirow{"
            + str(len(MODEL_GROUPS[backbone]))
            + r"}{*}{"
            + r"\scalebox{0.9}{"
            + r"\rotatebox[origin=c]{90}{"
            + clusterername
            + r"}}}"
            + r"}"
        )
        latex_table += "\n"

        for i_group, model in enumerate(list(MODEL_GROUPS[backbone])):
            latex_table += f"& {MODEL2SH[model]:<10s}"
            for i_dataset, dataset in enumerate(TEST_DATASETS):
                latex_table += " &"
                filter1 = {"model": model, "dataset": dataset}
                filter2 = dict(DEFAULT_PARAMS["all"], **BEST_PARAMS[clusterer][model])
                filter2.update(filter1)
                filter2.update(my_override_fields)
                filter2 = fixup_filter(filter2)
                sdf = select_rows(test_runs_df, filter2, allow_missing=False)
                if len(sdf) < 1:
                    # print(f"No data for {model}-{dataset}-{clusterer}\n{filter2}")
                    cmds.append(filter2command(filter2, partition="test"))
                    if not dummy:
                        # latex_table += r"\multicolumn{1}{c}{--}"
                        latex_table += r"   --  "
                    continue
                if len(sdf) > 1:
                    if sum(np.abs(sdf[metric_key] - sdf.iloc[0][metric_key]) > 1e-6) > 0:
                        print()
                        print(f"More than one result with {metric_key} values", list(sdf[metric_key]))
                        print(f"for search {filter2}")
                        dif_cols = find_differing_columns(sdf, config_keys)
                        print(f"columns which differ: {dif_cols}")
                        if dif_cols:
                            for col in dif_cols:
                                print(f"  {col}: {list(sdf[col])}")
                my_val = np.nanmedian(sdf[metric_key])
                if dummy:
                    best_results[dataset].append(my_val)
                    best_results_grouped[dataset][clusterername].append(my_val)
                    continue
                if np.isnan(my_val):
                    latex_table += r"   --  "
                    continue
                is_best = my_val + eps >= np.max(best_results[dataset])
                if len(best_results[dataset]) > 1:
                    is_secd = my_val + eps >= np.sort(best_results[dataset])[-2]
                else:
                    is_secd = False
                is_best_grp = my_val + eps >= np.max(
                    best_results_grouped[dataset][clusterername]
                )
                sc_base = np.nanmedian(best_results[dataset])
                sc_top = np.max(best_results[dataset])
                sc = 100 * max(0, (my_val - sc_base) / (sc_top - sc_base))
                latex_table += r"\cellcolor{cbg!" + f"{sc:.0f}" + "}"
                if show_pc:
                    my_val = my_val * 100
                latex_table += r"\num{" if use_si_num else r"$"
                if not highlight_best:
                    pass
                elif is_best:
                    latex_table += r"\tcf{"
                elif is_secd:
                    latex_table += r"\tcs{"
                else:
                    latex_table += "     "
                if not highlight_best:
                    pass
                elif is_best_grp:
                    latex_table += r"\tcg{"
                else:
                    latex_table += "     "
                latex_table += show_fmt.format(my_val)
                if highlight_best:
                    latex_table += r"}" if is_best or is_secd else " "
                    latex_table += r"}" if is_best_grp else " "
                latex_table += r"}" if use_si_num else r"$"
            latex_table += r" \\" + "\n"
    latex_table += r"\bottomrule" + "\n"
    latex_table += r"\end{tabular}" + "\n"
    latex_table += r"}" + "\n"

print()
print(f"There are {len(cmds)} commands to execute to generate missing datapoints")
if show_commands:
    for cmd in cmds:
        print(cmd)

print()
print("Done!")
print()
print(f"Here is your results table for {metric_key}, {backbone}:")
print()
print()
print(latex_table)

In [None]:
for cmd in cmds:
    print(cmd)

### With grouped datasets

In [None]:
metric_key = "AMI"  # AMI  num_cluster_pred  silhouette-euclidean_pred  silhouette-og-euclidean_pred
show_pc = True
show_fmt = "{:4.0f}"
show_commands = False
highlight_best = True
use_si_num = False
eps = 0.005
override_fields = {
    # "predictions_dir": "y_pred",
}

backbone = "ViT-B"  # "ResNet-50" or "ViT-B"

if metric_key == "num_cluster_pred":
    CLUSTERERS = ["AC w/o C", "AffinityPropagation", "HDBSCAN"]
    show_pc = False
    show_fmt = "{:4.0f}"
    highlight_best = False
    use_si_num = True
else:
    CLUSTERERS = ["KMeans", "AC w/ C", "AC w/o C", "AffinityPropagation", "HDBSCAN"]
if metric_key.startswith("silhouette"):
    show_pc = False
    show_fmt = "{:5.2f}"

print(MODEL_GROUPS)

test_datasets = []
for datagroupname, datagroupset in TEST_DATASETS_GROUPED.items():
    test_datasets.extend(datagroupset)

best_results = {k: [] for k in test_datasets}
best_results_grouped = {k: defaultdict(list) for k in test_datasets}

for dummy in [True, False]:
    cmds = []
    latex_table = r"% Results for " + f"{metric_key}, {backbone}" + "\n"
    now_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    latex_table += r"% Generated " + now_str + "\n"
    latex_table += r"% Using hparams " + BEST_PARAMS["_version"] + "\n"
    label = backbone
    if metric_key == "AMI":
        latex_table += r"\label{tab:" + label + r"}" + "\n"
    label = metric_key.replace("_", "-") + ":" + label
    latex_table += r"\label{tab:" + label + r"}" + "\n"
    latex_table += r"\resizebox{\textwidth}{!}{%" + "\n"
    latex_table += r"\begin{tabular}{ll" + r"r" * len(test_datasets) + r"}" + "\n"
    latex_table += r"\toprule" + "\n"
    latex_table += r"& " + f"{'':<11s}"
    for datagroupname, datagroupset in TEST_DATASETS_GROUPED.items():
        latex_table += r" & \multicolumn{" + str(len(datagroupset)) + r"}{c}{" + datagroupname + r"}"
    latex_table += r"\\" + "\n"
    icol = 3
    for datagroupname, datagroupset in TEST_DATASETS_GROUPED.items():
        latex_table += r"\cmidrule(l){" + f"{icol}-{icol + len(datagroupset) - 1}" + r"}"
        icol += len(datagroupset)
    latex_table += "\n"
    latex_table += r"& " + f"{'Encoder':<11s}"
    for dataset in test_datasets:
        latex_table += r"&" + "{:^15s}".format(DATASET2SH.get(dataset, dataset))
    latex_table += r"\\" + "\n"
    latex_table += r"\toprule" + "\n"
    print(MODEL_GROUPS[backbone])
    if metric_key == "num_cluster_pred":
        latex_table += r"& Num targets"
        for i_dataset, dataset in enumerate(test_datasets):
            sdf = select_rows(test_runs_df, {"dataset": dataset}, allow_missing=False)
            sdf = sdf[~pd.isna(sdf["num_cluster_true"])]
            latex_table += r"& "
            latex_table += r"\num{" if use_si_num else r"$"
            latex_table += f"{sdf.iloc[0]['num_cluster_true'].item()}"
            latex_table += r"}" if use_si_num else r"$"
        latex_table += r"\\" + "\n"
        latex_table += r"\toprule" + "\n"
    elif metric_key.endswith("_pred"):
        metric_key2 = metric_key.replace("_pred", "_true")
        clusterername = "G.T."
        latex_table += (
            r"\parbox[t]{2mm}{\multirow{"
            + str(len(MODEL_GROUPS[backbone]))
            + r"}{*}{"
            + r"\scalebox{0.9}{"
            + r"\rotatebox[origin=c]{90}{"
            + clusterername
            + r"}}}"
            + r"}"
        )
        latex_table += "\n"
        for i_group, model in enumerate(list(MODEL_GROUPS[backbone])):
            latex_table += f"& {MODEL2SH[model]:<10s}"
            for i_dataset, dataset in enumerate(test_datasets):
                latex_table += " &"
                filter1 = {"model": model, "dataset": dataset}
                if model == "timm_vit_base_patch16_224.mae":
                    filter1["dim_reducer"] = "PCA"
                    filter1["pca_variance"] = 0.95
                else:
                    filter1["dim_reducer_man"] = "UMAP"
                    filter1["ndim_reduced_man"] = 50
                    filter1["dim_reducer_man_metric"] = "euclidean"
                sdf = select_rows(test_runs_df, filter1, allow_missing=False)
                sdf = sdf[~pd.isna(sdf[metric_key2])]
                my_val = np.nanmedian(sdf[metric_key])
                if sum(sdf[metric_key2] != my_val) > 0:
                    pass
                if dummy:
                    best_results_grouped[dataset][clusterername].append(my_val)
                    continue
                is_best_grp = my_val + eps >= np.max(
                    best_results_grouped[dataset][clusterername]
                )
                latex_table += r"\num{" if use_si_num else r"$"
                latex_table += "     "
                if not highlight_best:
                    pass
                elif is_best_grp:
                    latex_table += r"\tcg{"
                else:
                    latex_table += "     "
                latex_table += show_fmt.format(my_val)
                if highlight_best:
                    latex_table += r"}" if is_best_grp else " "
                latex_table += r"}" if use_si_num else r"$"

            latex_table += r" \\" + "\n"
        latex_table += r"\toprule" + "\n"

    first_agg = True
    for i_clusterer, clusterer in enumerate(CLUSTERERS):
        clusterername = CLUSTERER2SH.get(clusterer, clusterer)
        my_override_fields = override_fields.copy()
        if (
            first_agg
            and clusterer == "AgglomerativeClustering"
            and metric_key != "num_cluster_pred"
        ):
            first_agg = False
            my_override_fields["aggclust_dist_thresh"] = None
            clusterername = "AC  w/ C"
        elif clusterer == "AgglomerativeClustering":
            clusterername = "AC w/o C"
            if "aggclust_dist_thresh" in my_override_fields:
                del my_override_fields["aggclust_dist_thresh"]

        if i_clusterer > 0:
            latex_table += r"\midrule" + "\n"

        latex_table += (
            r"\parbox[t]{2mm}{\multirow{"
            + str(len(MODEL_GROUPS[backbone]))
            + r"}{*}{"
            + r"\scalebox{0.9}{"
            + r"\rotatebox[origin=c]{90}{"
            + clusterername
            + r"}}}"
            + r"}"
        )
        latex_table += "\n"

        for i_group, model in enumerate(list(MODEL_GROUPS[backbone])):
            latex_table += f"& {MODEL2SH[model]:<10s}"
            for i_dataset, dataset in enumerate(test_datasets):
                latex_table += " &"
                filter1 = {"model": model, "dataset": dataset}
                filter2 = dict(DEFAULT_PARAMS["all"], **BEST_PARAMS[clusterer][model])
                filter2.update(filter1)
                filter2.update(my_override_fields)
                filter2 = fixup_filter(filter2)
                sdf = select_rows(test_runs_df, filter2, allow_missing=False)
                if len(sdf) < 1:
                    # print(f"No data for {model}-{dataset}-{clusterer}\n{filter2}")
                    cmds.append(filter2command(filter2, partition="test"))
                    if not dummy:
                        # latex_table += r"\multicolumn{1}{c}{--}"
                        latex_table += r"   --  "
                    continue
                if len(sdf) > 1:
                    if sum(np.abs(sdf[metric_key] - sdf.iloc[0][metric_key]) > 1e-6) > 0:
                        print()
                        print(f"More than one result with {metric_key} values", list(sdf[metric_key]))
                        print(f"for search {filter2}")
                        dif_cols = find_differing_columns(sdf, config_keys)
                        print(f"columns which differ: {dif_cols}")
                        if dif_cols:
                            for col in dif_cols:
                                print(f"  {col}: {list(sdf[col])}")
                my_val = np.nanmedian(sdf[metric_key])
                if dummy:
                    best_results[dataset].append(my_val)
                    best_results_grouped[dataset][clusterername].append(my_val)
                    continue
                if np.isnan(my_val):
                    latex_table += r"   --  "
                    continue
                is_best = my_val + eps >= np.max(best_results[dataset])
                if len(best_results[dataset]) > 1:
                    is_secd = my_val + eps >= np.sort(best_results[dataset])[-2]
                else:
                    is_secd = False
                is_best_grp = my_val + eps >= np.max(
                    best_results_grouped[dataset][clusterername]
                )
                sc_base = np.nanmedian(best_results[dataset])
                sc_top = np.max(best_results[dataset])
                sc = 100 * max(0, (my_val - sc_base) / (sc_top - sc_base))
                latex_table += r"\cellcolor{cbg!" + f"{sc:.0f}" + "}"
                if show_pc:
                    my_val = my_val * 100
                latex_table += r"\num{" if use_si_num else r"$"
                if not highlight_best:
                    pass
                elif is_best:
                    latex_table += r"\tcf{"
                elif is_secd:
                    latex_table += r"\tcs{"
                else:
                    latex_table += "     "
                if not highlight_best:
                    pass
                elif is_best_grp:
                    latex_table += r"\tcg{"
                else:
                    latex_table += "     "
                latex_table += show_fmt.format(my_val)
                if highlight_best:
                    latex_table += r"}" if is_best or is_secd else " "
                    latex_table += r"}" if is_best_grp else " "
                latex_table += r"}" if use_si_num else r"$"
            latex_table += r" \\" + "\n"
    latex_table += r"\bottomrule" + "\n"
    latex_table += r"\end{tabular}" + "\n"
    latex_table += r"}" + "\n"

print()
print(f"There are {len(cmds)} commands to execute to generate missing datapoints")
if show_commands:
    for cmd in cmds:
        print(cmd)

print()
print("Done!")
print()
print(f"Here is your results table for {metric_key}, {backbone}:")
print()
print()
print(latex_table)

## Correlation between AMI and SIlhouette

In [None]:
metric_key1 = "silhouette-og-euclidean_pred"  # silhouette-euclidean_pred | silhouette-og-euclidean_pred
metric_key2 = "AMI"

override_fields = {}

backbones = ["ResNet-50", "ViT-B"]
CLUSTERERS = ["KMeans", "AC w/ C", "AC w/o C", "AffinityPropagation", "HDBSCAN"]

fig, ax = plt.subplots(1, len(backbones), sharey=True, figsize=(6, 3))

for i_backbone, backbone in enumerate(backbones):
    my_valx_method = {clusterer: [] for clusterer in CLUSTERERS}
    my_valy_method = {clusterer: [] for clusterer in CLUSTERERS}

    print(backbone)
    print(CLUSTERERS)
    print(TEST_DATASETS)
    print(MODEL_GROUPS[backbone])
    print()

    for i_clusterer, clusterer in enumerate(CLUSTERERS):
        for i_dataset, dataset in enumerate(TEST_DATASETS):
            for i_model, model in enumerate(list(MODEL_GROUPS[backbone])):
                if model == "timm_vit_base_patch16_224.mae":
                    # print(f"Skipping {model}")
                    # continue
                    pass
                filter1 = {"model": model, "dataset": dataset}
                filter2 = dict(DEFAULT_PARAMS["all"], **BEST_PARAMS[clusterer][model])
                filter2.update(filter1)
                filter2.update(override_fields)
                filter2 = fixup_filter(filter2)
                sdf = select_rows(test_runs_df, filter2, allow_missing=False)
                my_valx_method[clusterer].append(np.nanmedian(sdf[metric_key1]))
                my_valy_method[clusterer].append(np.nanmedian(sdf[metric_key2]))

    my_valx_method = {k: np.array(v) for k, v in my_valx_method.items()}
    my_valy_method = {k: np.array(v) for k, v in my_valy_method.items()}
    my_valx_overall = np.concatenate([my_valx_method[clusterer] for clusterer in CLUSTERERS])
    my_valy_overall = np.concatenate([my_valy_method[clusterer] for clusterer in CLUSTERERS])
    my_cols = np.concatenate(
        [
            np.tile(CLUSTERER2COLORRGB.get(clusterer, "k"), [len(my_valx_method[clusterer]), 1])
            for clusterer in CLUSTERERS
        ]
    )
    indices = np.arange(len(my_valx_overall))
    np.random.shuffle(indices)
    ax[i_backbone].scatter(
        my_valx_overall[indices],
        my_valy_overall[indices],
        color=my_cols[indices],
        s=20,
        alpha=0.5,
    )
    ax[i_backbone].set_xlabel(r"$S$" if metric_key1.startswith("silhouette") else metric_key1)
    if i_backbone == 0:
        ax[i_backbone].set_ylabel(metric_key2)
    ax[i_backbone].set_xlim(-1.05, 1.05)
    ax[i_backbone].set_ylim(-0.05, max(max(my_valy_overall), 0.95))
    ax[i_backbone].set_title(backbone)
    print(f"{backbone:<20s} Correlation coef")
    cors = []
    for clusterer in CLUSTERERS:
        sel = (~np.isnan(my_valx_method[clusterer])) & (~np.isnan(my_valy_method[clusterer]))
        cor = np.corrcoef(my_valx_method[clusterer][sel], my_valy_method[clusterer][sel])[0, 1]
        cors.append(cor)
        print(f"{clusterer:<20s} {cor:.4f}")
    print(f"{'Average':<20s} {np.nanmean(cors):.4f}")
    sel = (~np.isnan(my_valx_overall)) & (~np.isnan(my_valy_overall))
    cor = np.corrcoef(my_valx_overall[sel], my_valy_overall[sel])[0, 1]
    print(f"{'Overall':<20s} {cor:.4f}")
    print()
    ax[i_backbone].text(-0.85, 0.85, f"$r={cor:.2f}$")
    ax[i_backbone].text(-0.85, 0.75, r"$\bar{r}=" + f"{np.mean(cors):.2f}$")

label_fn = lambda c, marker: plt.plot(  # noqa:E731
    [], [], color=c, ls="None", marker=marker, linewidth=6
)[0]
handles = [label_fn(CLUSTERER2COLORRGB.get(clusterer), "o") for clusterer in CLUSTERERS]
data_labels = CLUSTERERS
ax[1].legend(handles, data_labels, loc="center left", bbox_to_anchor=(1, 0.5))

fig.savefig(os.path.join(FIGS_DIR, f"scatter__{metric_key1}__{metric_key2}.pdf"), bbox_inches="tight")

## Rankings

In [None]:
metric_key1 = "AMI"
metric_key2 = "silhouette-euclidean_pred"

override_fields = {}

backbones = ["ResNet-50", "ViT-B"]
CLUSTERERS = ["KMeans", "AC w/ C", "AC w/o C", "AffinityPropagation", "HDBSCAN"]
test_datasets = TEST_DATASETS

figenc, axenc = plt.subplots(1, 2, figsize=(6, 2))
figclus, axclus = plt.subplots(1, 2, figsize=(6, 2))

for i_backbone, backbone in enumerate(backbones):
    result_table = np.nan * np.ones(
        (len(MODEL_GROUPS[backbone]), len(CLUSTERERS), len(test_datasets))
    )
    for i_group, model in enumerate(list(MODEL_GROUPS[backbone])):
        for i_clusterer, clusterer in enumerate(CLUSTERERS):
            for i_dataset, dataset in enumerate(test_datasets):
                filter1 = {"model": model, "dataset": dataset}
                filter2 = dict(DEFAULT_PARAMS["all"], **BEST_PARAMS[clusterer][model])
                filter2.update(filter1)
                filter2.update(override_fields)
                filter2 = fixup_filter(filter2)
                sdf = select_rows(test_runs_df, filter2, allow_missing=False)
                if len(sdf) < 1:
                    result_table[i_group, i_clusterer, i_dataset] = -100.0
                    continue
                result_table[i_group, i_clusterer, i_dataset] = np.median(
                    sdf[metric_key1]
                )

    print(backbone)
    print(MODEL_GROUPS[backbone])

    # RANK PER ENCODER - go through each dataset, look at each clusterer,
    # and determine the rank of each encoder in that setting
    print(list(MODEL_GROUPS[backbone]))
    ranks_encoders = np.nan * np.ones((len(MODEL_GROUPS[backbone]), len(CLUSTERERS), len(test_datasets)))
    for i_dataset, dataset in enumerate(test_datasets):
        for i_clusterer, clusterer in enumerate(CLUSTERERS):
            cluster_data = result_table[:, i_clusterer, i_dataset]
            if np.all(cluster_data == cluster_data[0]) or np.all(np.isnan(cluster_data)):
                print(f"Skipping {dataset} {clusterer} (all same)")
                continue
            if np.any(cluster_data == -100.0):
                print(f"Skipping {dataset} {clusterer} (incomplete)")
                continue
            rank = np.argsort(cluster_data)[::-1]
            ranks_encoders[:, i_clusterer, i_dataset] = 1 + rank.argsort()
    mean_rank_encoders = np.nanmean(ranks_encoders, axis=(1, 2))
    std_rank_encoders = np.nanstd(ranks_encoders, axis=(1, 2))
    # order = np.argsort(mean_rank_encoders)
    order = np.arange(len(MODEL_GROUPS[backbone]))

    for i_plot, i_model in enumerate(order):
        axenc[i_backbone].barh(
            i_plot,
            mean_rank_encoders[i_model],
            xerr=std_rank_encoders[i_model],
            align="center",
            alpha=0.6,
            ecolor="black",
            color=MODEL2COLORRGB.get(MODEL_GROUPS[backbone][i_model], "k"),
            capsize=2,
            zorder=10,
        )

    axenc[i_backbone].invert_yaxis()
    axenc[i_backbone].set_yticks([])
    axenc[i_backbone].set_yticklabels([])
    axenc[i_backbone].set_xticks(np.arange(1, 1 + len(MODEL_GROUPS[backbone])))
    axenc[i_backbone].set_xlim([0, 0.5 + len(MODEL_GROUPS[backbone])])
    axenc[i_backbone].xaxis.grid(True, zorder=1, alpha=0.5)
    axenc[i_backbone].set_title(backbone)

    # RANK PER CLUSTERER - go through each dataset, look at each encoder,
    # and determine the rank of each clusterer in that setting

    print(CLUSTERERS)
    ranks_clusterers = np.nan * np.ones((len(MODEL_GROUPS[backbone]), len(CLUSTERERS), len(test_datasets)))
    for i_dataset, dataset in enumerate(test_datasets):
        for i_encoder, encoder in enumerate(MODEL_GROUPS[backbone]):
            encoder_data = result_table[i_encoder, :, i_dataset]
            if np.all(encoder_data == encoder_data[0]) or np.all(np.isnan(encoder_data)):
                print(f"Skipping {dataset} {encoder} (all same)")
                continue
            if np.any(encoder_data == -100.0):
                print(f"Skipping {dataset} {encoder} (incomplete)")
                continue
            rank = np.argsort(encoder_data)[::-1]
            ranks_clusterers[i_encoder, :, i_dataset] = 1 + rank.argsort()
    mean_rank_clusters = np.nanmean(ranks_clusterers, axis=(0, 2))
    std_rank_clusters = np.nanstd(ranks_clusterers, axis=(0, 2))
    # order = np.argsort(mean_rank_clusters)
    order = np.arange(len(CLUSTERERS))

    for i_plot, i_clusterer in enumerate(order):
        axclus[i_backbone].barh(
            i_plot,
            mean_rank_clusters[i_clusterer],
            xerr=std_rank_clusters[i_clusterer],
            align="center",
            alpha=0.6,
            ecolor="black",
            color=CLUSTERER2COLORSTR.get(CLUSTERERS[i_clusterer], "k"),
            capsize=2,
            zorder=10,
        )

    axclus[i_backbone].invert_yaxis()
    axclus[i_backbone].set_yticks([])
    axclus[i_backbone].set_yticklabels([])
    axclus[i_backbone].set_xticks(np.arange(1, 1 + len(CLUSTERERS)))
    axclus[i_backbone].set_xlim([0, 0.6 + len(CLUSTERERS)])
    axclus[i_backbone].xaxis.grid(True, zorder=1, alpha=0.5)
    axclus[i_backbone].set_title(backbone)

    axclus[i_backbone].set_xlabel("Rank")
    axenc[i_backbone].set_xlabel("Rank")

label_fn = lambda c, ls: plt.plot([], [], color=c, ls=ls, linewidth=3)[0]  # noqa:E731

model_names = list(MODEL_GROUPS["ResNet-50"]) + ["timm_vit_base_patch16_224.mae"]
handles_enc = [label_fn(MODEL2COLORRGB[idx], "-") for idx in model_names]
axenc[1].legend(
    handles_enc,
    [MODEL2SH[x] for x in model_names],
    loc="center left",
    bbox_to_anchor=(1, 0.5),
)

handles_clus = [label_fn(CLUSTERER2COLORRGB[clusterer], "-") for clusterer in CLUSTERERS]
axclus[1].legend(handles_clus, CLUSTERERS, loc="center left", bbox_to_anchor=(1, 0.5))

figenc.savefig(os.path.join(FIGS_DIR, "ranking_enc.pdf"), bbox_inches="tight")
figclus.savefig(os.path.join(FIGS_DIR, "ranking_clus.pdf"), bbox_inches="tight")

### With grouped datasets

In [None]:
metric_key1 = "AMI"
metric_key2 = "silhouette-euclidean_pred"

override_fields = {}

backbones = ["ResNet-50", "ViT-B"]
CLUSTERERS = ["KMeans", "AC w/ C", "AC w/o C", "AffinityPropagation", "HDBSCAN"]

for test_group, test_datasets in TEST_DATASETS_GROUPED.items():

    figenc, axenc = plt.subplots(1, 2, figsize=(6, 1.6))
    figclus, axclus = plt.subplots(1, 2, figsize=(6, 2))

    for i_backbone, backbone in enumerate(backbones):
        result_table = np.nan * np.ones(
            (len(MODEL_GROUPS[backbone]), len(CLUSTERERS), len(test_datasets))
        )
        for i_group, model in enumerate(list(MODEL_GROUPS[backbone])):
            for i_clusterer, clusterer in enumerate(CLUSTERERS):
                for i_dataset, dataset in enumerate(test_datasets):
                    filter1 = {"model": model, "dataset": dataset}
                    filter2 = dict(DEFAULT_PARAMS["all"], **BEST_PARAMS[clusterer][model])
                    filter2.update(filter1)
                    filter2.update(override_fields)
                    filter2 = fixup_filter(filter2)
                    sdf = select_rows(test_runs_df, filter2, allow_missing=False)
                    if len(sdf) < 1:
                        result_table[i_group, i_clusterer, i_dataset] = -100.0
                        continue
                    result_table[i_group, i_clusterer, i_dataset] = np.median(
                        sdf[metric_key1]
                    )

        print(backbone)
        print(MODEL_GROUPS[backbone])

        # RANK PER ENCODER - go through each dataset, look at each clusterer,
        # and determine the rank of each encoder in that setting
        print(list(MODEL_GROUPS[backbone]))
        ranks_encoders = np.nan * np.ones((len(MODEL_GROUPS[backbone]), len(CLUSTERERS), len(test_datasets)))
        for i_dataset, dataset in enumerate(test_datasets):
            for i_clusterer, clusterer in enumerate(CLUSTERERS):
                cluster_data = result_table[:, i_clusterer, i_dataset]
                if np.all(cluster_data == cluster_data[0]) or np.all(np.isnan(cluster_data)):
                    print(f"Skipping {dataset} {clusterer} (all same)")
                    continue
                if np.any(cluster_data == -100.0):
                    print(f"Skipping {dataset} {clusterer} (incomplete)")
                    continue
                rank = np.argsort(cluster_data)[::-1]
                ranks_encoders[:, i_clusterer, i_dataset] = 1 + rank.argsort()
        mean_rank_encoders = np.nanmean(ranks_encoders, axis=(1, 2))
        std_rank_encoders = np.nanstd(ranks_encoders, axis=(1, 2))
        # order = np.argsort(mean_rank_encoders)
        order = np.arange(len(MODEL_GROUPS[backbone]))

        for i_plot, i_model in enumerate(order):
            axenc[i_backbone].barh(
                i_plot,
                mean_rank_encoders[i_model],
                xerr=std_rank_encoders[i_model],
                align="center",
                alpha=0.6,
                ecolor="black",
                color=MODEL2COLORRGB.get(MODEL_GROUPS[backbone][i_model], "k"),
                capsize=2,
                zorder=10,
            )

        axenc[i_backbone].invert_yaxis()
        axenc[i_backbone].set_yticks([])
        axenc[i_backbone].set_yticklabels([])
        axenc[i_backbone].set_xticks(np.arange(1, 1 + len(MODEL_GROUPS[backbone])))
        axenc[i_backbone].set_xlim([0, 0.5 + len(MODEL_GROUPS[backbone])])
        axenc[i_backbone].xaxis.grid(True, zorder=1, alpha=0.5)
        axenc[i_backbone].set_title(f"{DATASETGROUP2TITLE.get(test_group, test_group)}, {backbone}")
        axenc[i_backbone].set_xlabel("Rank")

        # RANK PER CLUSTERER - go through each dataset, look at each encoder,
        # and determine the rank of each clusterer in that setting

        print(CLUSTERERS)
        ranks_clusterers = np.nan * np.ones((len(MODEL_GROUPS[backbone]), len(CLUSTERERS), len(test_datasets)))
        for i_dataset, dataset in enumerate(test_datasets):
            for i_encoder, encoder in enumerate(MODEL_GROUPS[backbone]):
                encoder_data = result_table[i_encoder, :, i_dataset]
                if np.all(encoder_data == encoder_data[0]) or np.all(np.isnan(encoder_data)):
                    print(f"Skipping {dataset} {encoder} (all same)")
                    continue
                if np.any(encoder_data == -100.0):
                    print(f"Skipping {dataset} {encoder} (incomplete)")
                    continue
                rank = np.argsort(encoder_data)[::-1]
                ranks_clusterers[i_encoder, :, i_dataset] = 1 + rank.argsort()
        mean_rank_clusters = np.nanmean(ranks_clusterers, axis=(0, 2))
        std_rank_clusters = np.nanstd(ranks_clusterers, axis=(0, 2))
        # order = np.argsort(mean_rank_clusters)
        order = np.arange(len(CLUSTERERS))

        for i_plot, i_clusterer in enumerate(order):
            axclus[i_backbone].barh(
                i_plot,
                mean_rank_clusters[i_clusterer],
                xerr=std_rank_clusters[i_clusterer],
                align="center",
                alpha=0.6,
                ecolor="black",
                color=CLUSTERER2COLORSTR.get(CLUSTERERS[i_clusterer], "k"),
                capsize=2,
                zorder=10,
            )

        axclus[i_backbone].invert_yaxis()
        axclus[i_backbone].set_yticks([])
        axclus[i_backbone].set_yticklabels([])
        axclus[i_backbone].set_xticks(np.arange(1, 1 + len(CLUSTERERS)))
        axclus[i_backbone].set_xlim([0, 0.6 + len(CLUSTERERS)])
        axclus[i_backbone].xaxis.grid(True, zorder=1, alpha=0.5)
        axclus[i_backbone].set_title(f"{DATASETGROUP2TITLE.get(test_group, test_group)}, {backbone}")
        axclus[i_backbone].set_xlabel("Rank")

    label_fn = lambda c, ls: plt.plot([], [], color=c, ls=ls, linewidth=3)[0]  # noqa:E731

    model_names = list(MODEL_GROUPS["ResNet-50"]) + ["timm_vit_base_patch16_224.mae"]
    handles_enc = [label_fn(MODEL2COLORRGB[idx], "-") for idx in model_names]
    axenc[1].legend(
        handles_enc,
        [MODEL2SH[x] for x in model_names],
        loc="center left",
        bbox_to_anchor=(1, 0.5),
    )

    handles_clus = [label_fn(CLUSTERER2COLORRGB[clusterer], "-") for clusterer in CLUSTERERS]
    axclus[1].legend(handles_clus, CLUSTERERS, loc="center left", bbox_to_anchor=(1, 0.5))

    figenc.savefig(os.path.join(FIGS_DIR, f"ranking_enc__{test_group}.pdf"), bbox_inches="tight")
    figclus.savefig(os.path.join(FIGS_DIR, f"ranking_clus__{test_group}.pdf"), bbox_inches="tight")

## Plot sample images

In [None]:
import re

from zs_ssl_clustering.io import sanitize_filename


def get_pred_path(row):
    """
    Generate path to y_pred file.
    """
    run_id = row["name"].split("__")[-1]
    fname = f"{row['partition']}-{row['dataset_name']}__{row['model']}__{run_id}.npz"
    fname = sanitize_filename(fname)
    fname = os.path.join(
        row["predictions_dir"],
        sanitize_filename(row["partition"] + f"__z{float(row['zoom_ratio'])}"),
        fname,
    )
    return fname

In [None]:
import torch
from torchvision.transforms.functional import crop, get_dimensions
from torchvision.utils import _log_api_usage_once


def center_squaring(img):
    """Crops the given image at the center.
    If the image is torch Tensor, it is expected
    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
    If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.

    Args:
        img (PIL Image or Tensor): Image to be cropped.
        output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int,
            it is used for both directions.

    Returns:
        PIL Image or Tensor: Cropped image.
    """
    if not torch.jit.is_scripting() and not torch.jit.is_tracing():
        _log_api_usage_once(center_squaring)

    _, image_height, image_width = get_dimensions(img)

    if image_height == image_width:
        return img

    crop_height = crop_width = min(image_height, image_width)

    crop_top = int(round((image_height - crop_height) / 2.0))
    crop_left = int(round((image_width - crop_width) / 2.0))
    return crop(img, crop_top, crop_left, crop_height, crop_width)


class CenterSquaring(torch.nn.Module):
    """Crops the given image to the center square."""

    def __init__(self):
        super().__init__()
        _log_api_usage_once(self)

    def forward(self, img):
        """
        Args:
            img (PIL Image or Tensor): Image to be cropped.

        Returns:
            PIL Image or Tensor: Cropped image.
        """
        return center_squaring(img)

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}"

In [None]:
from zs_ssl_clustering import datasets


def show_samples(
    row, nsamp=12, ds=None, save=False, clusterer="", nclusters=None, skip_existing=None
):
    if skip_existing is None:
        skip_existing = save

    if clusterer:
        clusterer = clusterer.replace("/", "").replace(" ", "")
    else:
        clusterer = row["clusterer_name"]

    output_dir = "../samples"
    output_fname = f"samples__{row['dataset_name']}__{row['model']}__{clusterer}.png"
    output_fname = os.path.join(output_dir, output_fname)
    if skip_existing and os.path.exists(output_fname):
        print(f"Output {output_fname} already exists. Skipping.")
        return

    if ds is None:
        dses = datasets.fetch_image_dataset(
            row["dataset_name"], transform_eval=CenterSquaring()
        )
        if row["partition"] == "train":
            ds = dses[0]
        elif row["partition"] == "test":
            ds = dses[-1]
        else:
            raise NotImplementedError()

    y_pred = np.load("../" + get_pred_path(row))["y_pred"]

    u_labels, label_count = np.unique(y_pred, return_counts=True)
    # Remove clusters with very few samples in the cluster
    # u_labels = u_labels[label_count >= nsamp]

    if nclusters is None:
        nclusters = len(u_labels)
    else:
        nclusters = min(nclusters, len(u_labels))

    fig, axs = plt.subplots(nclusters, nsamp, figsize=(nsamp / 2, nclusters / 2))

    for i_label, label in enumerate(u_labels[:nclusters]):
        indices = np.where(y_pred == label)[0]
        np.random.default_rng(seed=label).shuffle(indices)
        for i in range(nsamp):
            if i < len(indices):
                idx = indices[i]
                axs[i_label, i].imshow(ds[idx][0].convert("RGB"))
            axs[i_label, i].axis("off")

    if save:
        print(f"Saving to {output_fname}")
        os.makedirs(output_dir, exist_ok=True)
        plt.savefig(output_fname, bbox_inches="tight")


def fetch_row(dataset, model, clusterer):
    override_fields = {
        "predictions_dir": "y_pred",
    }
    if clusterer == "HDBSCAN" and dataset in ["celeba", "utkface"]:
        override_fields["min_samples"] = 2
    filter1 = {"model": model, "dataset": dataset}
    filter2 = dict(DEFAULT_PARAMS["all"], **BEST_PARAMS[clusterer][model])
    filter2.update(filter1)
    filter2.update(override_fields)
    filter2 = fixup_filter(filter2)
    sdf = select_rows(test_runs_df, filter2, allow_missing=False)
    if len(sdf) < 1:
        print(f"No data for {filter2}")
        print(filter2command(filter2, partition="test"))
        return
    elif len(sdf) > 1:
        perf = sdf.iloc[0]["AMI"]
        if sum(sdf["AMI"] != perf) > 0:
            print()
            print("More than one result with AMIs:", list(sdf["AMI"]))
            print(f"for search {filter2}")
            dif_cols = find_differing_columns(sdf, config_keys)
            print(f"columns which differ: {dif_cols}")
            if dif_cols:
                for col in dif_cols:
                    print(f"  {col}: {list(sdf[col])}")
        return
    return sdf.iloc[0]

In [None]:
dataset = "svhn"
model = "mocov3_resnet50"
clusterer = "AC w/ C"

override_fields = {
    "predictions_dir": "y_pred",
    # "aggclust_dist_thresh": None,  # Use this to flip between unknown/known num clusters for Agglom
}
filter1 = {"model": model, "dataset": dataset}
filter2 = dict(DEFAULT_PARAMS["all"], **BEST_PARAMS[clusterer][model])
filter2.update(filter1)
filter2.update(override_fields)
filter2 = fixup_filter(filter2)
sdf = select_rows(test_runs_df, filter2, allow_missing=False)
if len(sdf) < 1:
    print(f"No data for {filter2}")
    print(filter2command(filter2, partition="test"))
elif len(sdf) > 1:
    perf = sdf.iloc[0]["AMI"]
    if sum(sdf["AMI"] != perf) > 0:
        print()
        print("More than one result with AMIs:", list(sdf["AMI"]))
        print(f"for search {filter2}")
        dif_cols = find_differing_columns(sdf, config_keys)
        print(f"columns which differ: {dif_cols}")
        if dif_cols:
            for col in dif_cols:
                print(f"  {col}: {list(sdf[col])}")
else:
    display(sdf)
    row = sdf.iloc[0]
    print(
        row["name"].split("__")[-1],
        "\n" + row["name"],
        "\n  " + row["dataset_name"],
        "\n  " + row["model"],
        "\n  " + row["clusterer_name"],
        f"\n  AMI={row['AMI']}",
        f"\n  S_reduced={row['silhouette-euclidean_pred']}",
        f"\n  S_originl={row['silhouette-og-euclidean_pred']}",
    )

In [None]:
y_pred = np.load("../" + get_pred_path(row))["y_pred"]

In [None]:
len(y_pred)

In [None]:
ds = datasets.fetch_image_dataset(row["dataset_name"])[-1]

In [None]:
indices = np.where(y_pred == 0)[0]
np.random.default_rng(seed=0).shuffle(indices)

In [None]:
indices[:10]

In [None]:
np.unique(y_pred)

In [None]:
label = 3
nsamp = 10

indices = np.where(y_pred == label)[0]
np.random.default_rng(seed=label).shuffle(indices)

fig, axs = plt.subplots(1, nsamp, figsize=(6, 2))

for i in range(10):
    idx = indices[i]
    axs[i].imshow(ds[idx][0])
    axs[i].axis("off")

In [None]:
nsamp = 10

u_labels = np.unique(y_pred)

fig, axs = plt.subplots(len(u_labels), nsamp, figsize=(len(u_labels) / 2, nsamp / 2))

for i_label, label in enumerate(u_labels):
    indices = np.where(y_pred == label)[0]
    np.random.default_rng(seed=label).shuffle(indices)
    for i in range(10):
        idx = indices[i]
        axs[i_label, i].imshow(ds[idx][0])
        axs[i_label, i].axis("off")

# plt.savefig(f"{row['dataset_name']}_{row['model']}_{row['clusterer_name']}.png", bbox_inches='tight')
plt.show()

In [None]:
ds = datasets.fetch_image_dataset("flowers102")[-1]

In [None]:
ds[1][0]

In [None]:
ds = datasets.fetch_image_dataset("flowers102", transform_eval=CenterSquaring())[-1]

In [None]:
ds[1][0]

In [None]:
dataset = "svhn"
model = "mocov3_resnet50"
clusterer = "AC w/ C"

row = fetch_row(dataset, model, clusterer)
print(
    row["name"].split("__")[-1],
    "\n" + row["name"],
    "\n  " + row["dataset_name"],
    "\n  " + row["model"],
    "\n  " + row["clusterer_name"],
    f"\n  AMI        = {row['AMI']}",
    f"\n  S_reduced  = {row['silhouette-euclidean_pred']}",
    f"\n  S_original = {row['silhouette-og-euclidean_pred']}",
)
show_samples(row, save=True, clusterer=clusterer)

In [None]:
for clusterer in ["AC w/ C"]:  # , "AC w/o C"]:
    for model in ["mocov3_resnet50", "mocov3_vit_base", "dino_resnet50", "dino_vitb16"]:
        for dataset in [
            "mnist",
            "fashionmnist",
            "svhn",
            "cifar10",
            "cifar100",
            "flowers102",
            "aircraft",
        ]:
            print()
            print(f"{dataset:<16s} {model:<32s} {clusterer}")
            row = fetch_row(dataset, model, clusterer)
            if row is None:
                print("No data with y_pred for", dataset, model, clusterer)
                continue
            print(
                row["name"].split("__")[-1],
                "\n" + row["name"],
                "\n  " + row["dataset_name"],
                "\n  " + row["model"],
                "\n  " + row["clusterer_name"],
                f"\n  AMI        = {row['AMI']}",
                f"\n  S_reduced  = {row['silhouette-euclidean_pred']}",
                f"\n  S_original = {row['silhouette-og-euclidean_pred']}",
            )
            fig = show_samples(row, save=True, clusterer=clusterer, nclusters=150)
            # plt.show()
            print("\n\nStopping early!")
            break
        break
    break

In [None]:
for clusterer in ["AC w/o C"]:
    for model in RESNET50_MODELS + VITB16_MODELS:
        for dataset in [
            "mnist",
            "fashionmnist",
            "svhn",
            "cifar10",
            "cifar100",
            "flowers102",
            "aircraft",
        ]:
            print()
            print(f"{dataset:<16s} {model:<32s} {clusterer}")
            row = fetch_row(dataset, model, clusterer)
            if row is None:
                print("No data with y_pred for", dataset, model, clusterer)
                continue
            print(
                row["name"].split("__")[-1],
                "\n" + row["name"],
                "\n  " + row["dataset_name"],
                "\n  " + row["model"],
                "\n  " + row["clusterer_name"],
                f"\n  AMI        = {row['AMI']}",
                f"\n  S_reduced  = {row['silhouette-euclidean_pred']}",
                f"\n  S_original = {row['silhouette-og-euclidean_pred']}",
            )
            fig = show_samples(row, save=True, clusterer=clusterer, nclusters=150)
            # plt.show()
            print("\n\nStopping early!")
            break
        break
    break

In [None]:
for clusterer in ["AC w/o C"]:
    for model in RESNET50_MODELS + VITB16_MODELS:
        for dataset in ["inaturalist"]:
            print()
            print(f"{dataset:<16s} {model:<32s} {clusterer}")
            row = fetch_row(dataset, model, clusterer)
            if row is None:
                print("No data with y_pred for", dataset, model, clusterer)
                continue
            print(
                row["name"].split("__")[-1],
                "\n" + row["name"],
                "\n  " + row["dataset_name"],
                "\n  " + row["model"],
                "\n  " + row["clusterer_name"],
                f"\n  AMI        = {row['AMI']}",
                f"\n  S_reduced  = {row['silhouette-euclidean_pred']}",
                f"\n  S_original = {row['silhouette-og-euclidean_pred']}",
            )
            fig = show_samples(row, save=True, clusterer=clusterer, nclusters=150)
            # plt.show()
            plt.close()
            print("\n\nStopping early!")
            break
        break
    break

In [None]:
for clusterer in ["AC w/o C"]:
    for model in RESNET50_MODELS + VITB16_MODELS:
        for dataset in TEST_DATASETS:
            print()
            print(f"{dataset:<16s} {model:<32s} {clusterer}")
            row = fetch_row(dataset, model, clusterer)
            if row is None:
                print("No data with y_pred for", dataset, model, clusterer)
                continue
            print(
                row["name"].split("__")[-1],
                "\n" + row["name"],
                "\n  " + row["dataset_name"],
                "\n  " + row["model"],
                "\n  " + row["clusterer_name"],
                f"\n  AMI        = {row['AMI']}",
                f"\n  S_reduced  = {row['silhouette-euclidean_pred']}",
                f"\n  S_original = {row['silhouette-og-euclidean_pred']}",
            )
            try:
                fig = show_samples(row, save=True, clusterer=clusterer, nclusters=150)
            except Exception:
                print(f"{dataset} not found")
            try:
                plt.close()
            except Exception:
                pass
            print("\n\nStopping early!")
            break
        break
    break

## Breakdown information about datasets with multiple labels

In [None]:
import sklearn.metrics
import torchvision.datasets

### CelebA attributes

In [None]:
celeba_test = torchvision.datasets.CelebA(
    os.path.expanduser("~/Datasets"),
    target_type="attr",
    split="test",
)

In [None]:
metric_key = "AMI"  # AMI  num_cluster_pred  silhouette-euclidean_pred  silhouette-og-euclidean_pred
show_pc = True
show_fmt = "{:4.0f}"
show_commands = False
highlight_best = True
use_si_num = False
eps = 0.005
override_fields = {
    "predictions_dir": "y_pred",
}

backbone = "ViT-B"  # "ResNet-50" or "ViT-B"

if metric_key == "num_cluster_pred":
    CLUSTERERS = ["AC w/o C", "AffinityPropagation", "HDBSCAN"]
    show_pc = False
    show_fmt = "{:4.0f}"
    highlight_best = False
    use_si_num = True
else:
    CLUSTERERS = ["KMeans", "AC w/ C", "AC w/o C", "AffinityPropagation", "HDBSCAN"]
if metric_key.startswith("silhouette"):
    show_pc = False
    show_fmt = "{:5.2f}"

print(MODEL_GROUPS)

dataset = "celeba"

TEST_ATTRS = ["Identity"] + celeba_test.attr_names[:-1]
print(TEST_ATTRS)
best_results = {k: [] for k in TEST_ATTRS}
best_results_grouped = {k: defaultdict(list) for k in TEST_ATTRS}

for dummy in [True, False]:
    cmds = []
    latex_table = r"% Results for " + f"{dataset} breakdown, {metric_key}, {backbone}" + "\n"
    now_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    latex_table += r"% Generated " + now_str + "\n"
    latex_table += r"% Using hparams " + BEST_PARAMS["_version"] + "\n"
    label = backbone
    if metric_key == "AMI":
        latex_table += r"\label{tab:" + label + r"}" + "\n"
    label = metric_key.replace("_", "-") + ":" + label
    latex_table += r"\label{tab:" + label + r"}" + "\n"
    latex_table += r"\resizebox{\textwidth}{!}{%" + "\n"
    latex_table += r"\begin{tabular}{ll" + r"r" * len(TEST_ATTRS) + r"}" + "\n"
    latex_table += r"\toprule" + "\n"
    latex_table += r"& " + f"{'Encoder':<11s}"
    for attr in TEST_ATTRS:
        latex_table += r"&" + f"{attr.replace('_', ' '):^15s}"
    latex_table += r"\\" + "\n"
    latex_table += r"\toprule" + "\n"
    print(MODEL_GROUPS[backbone])
    if metric_key == "num_cluster_pred":
        latex_table += r"& Num targets"
        for i_attr, attr in enumerate(TEST_ATTRS):
            sdf = select_rows(test_runs_df, {"dataset": dataset}, allow_missing=False)
            sdf = sdf[~pd.isna(sdf["num_cluster_true"])]
            latex_table += r"& "
            latex_table += r"\num{" if use_si_num else r"$"
            latex_table += f"{sdf.iloc[0]['num_cluster_true'].item()}"
            latex_table += r"}" if use_si_num else r"$"
        latex_table += r"\\" + "\n"
        latex_table += r"\toprule" + "\n"
    elif metric_key.endswith("_pred"):
        metric_key2 = metric_key.replace("_pred", "_true")
        clusterername = "G.T."
        latex_table += (
            r"\parbox[t]{2mm}{\multirow{"
            + str(len(MODEL_GROUPS[backbone]))
            + r"}{*}{"
            + r"\scalebox{0.9}{"
            + r"\rotatebox[origin=c]{90}{"
            + clusterername
            + r"}}}"
            + r"}"
        )
        latex_table += "\n"
        for i_group, model in enumerate(list(MODEL_GROUPS[backbone])):
            latex_table += f"& {MODEL2SH[model]:<10s}"
            for i_attr, attr in enumerate(TEST_ATTRS):
                latex_table += " &"
                filter1 = {"model": model, "dataset": dataset}
                if model == "timm_vit_base_patch16_224.mae":
                    filter1["dim_reducer"] = "PCA"
                    filter1["pca_variance"] = 0.95
                else:
                    filter1["dim_reducer_man"] = "UMAP"
                    filter1["ndim_reduced_man"] = 50
                    filter1["dim_reducer_man_metric"] = "euclidean"
                sdf = select_rows(test_runs_df, filter1, allow_missing=False)
                sdf = sdf[~pd.isna(sdf[metric_key2])]
                my_val = np.nanmedian(sdf[metric_key])
                if sum(sdf[metric_key2] != my_val) > 0:
                    pass
                if dummy:
                    best_results_grouped[attr][clusterername].append(my_val)
                    continue
                is_best_grp = my_val + eps >= np.max(
                    best_results_grouped[attr][clusterername]
                )
                latex_table += r"\num{" if use_si_num else r"$"
                latex_table += "     "
                if not highlight_best:
                    pass
                elif is_best_grp:
                    latex_table += r"\tcg{"
                else:
                    latex_table += "     "
                latex_table += show_fmt.format(my_val)
                if highlight_best:
                    latex_table += r"}" if is_best_grp else " "
                latex_table += r"}" if use_si_num else r"$"

            latex_table += r" \\" + "\n"
        latex_table += r"\toprule" + "\n"

    first_agg = True
    for i_clusterer, clusterer in enumerate(CLUSTERERS):
        clusterername = CLUSTERER2SH.get(clusterer, clusterer)
        my_override_fields = override_fields.copy()
        if (
            first_agg
            and clusterer == "AgglomerativeClustering"
            and metric_key != "num_cluster_pred"
        ):
            first_agg = False
            my_override_fields["aggclust_dist_thresh"] = None
            clusterername = "AC  w/ C"
        elif clusterer == "AgglomerativeClustering":
            clusterername = "AC w/o C"
            if "aggclust_dist_thresh" in my_override_fields:
                del my_override_fields["aggclust_dist_thresh"]

        if i_clusterer > 0:
            latex_table += r"\midrule" + "\n"

        latex_table += (
            r"\parbox[t]{2mm}{\multirow{"
            + str(len(MODEL_GROUPS[backbone]))
            + r"}{*}{"
            + r"\scalebox{0.9}{"
            + r"\rotatebox[origin=c]{90}{"
            + clusterername
            + r"}}}"
            + r"}"
        )
        latex_table += "\n"

        for i_group, model in enumerate(list(MODEL_GROUPS[backbone])):
            latex_table += f"& {MODEL2SH[model]:<10s}"
            filter1 = {"model": model, "dataset": dataset}
            filter2 = dict(DEFAULT_PARAMS["all"], **BEST_PARAMS[clusterer][model])
            filter2.update(filter1)
            filter2.update(my_override_fields)
            filter2 = fixup_filter(filter2)
            sdf = select_rows(test_runs_df, filter2, allow_missing=False)
            if len(sdf) < 1:
                print(f"No data for {model}-{dataset}-{clusterer}\n{filter2}")
                cmds.append(filter2command(filter2, partition="test"))
                if not dummy:
                    # latex_table += r"\multicolumn{1}{c}{--}"
                    latex_table += r"   --  "
                continue
            if len(sdf) > 1:
                if sum(np.abs(sdf[metric_key] - sdf.iloc[0][metric_key]) > 1e-6) > 0:
                    print()
                    print(f"More than one result with {metric_key} values", list(sdf[metric_key]))
                    print(f"for search {filter2}")
                    dif_cols = find_differing_columns(sdf, config_keys)
                    print(f"columns which differ: {dif_cols}")
                    if dif_cols:
                        for col in dif_cols:
                            print(f"  {col}: {list(sdf[col])}")
            y_pred = np.load("../" + get_pred_path(sdf.iloc[0]))["y_pred"]
            for i_attr, attr in enumerate(TEST_ATTRS):
                latex_table += " &"
                if metric_key.lower() != "ami":
                    raise NotImplementedError()
                if attr.lower() == "identity":
                    my_val = sklearn.metrics.adjusted_mutual_info_score(celeba_test.identity[:, 0], y_pred)
                else:
                    my_val = sklearn.metrics.adjusted_mutual_info_score(celeba_test.attr[:, i_attr - 1], y_pred)
                if dummy:
                    best_results[attr].append(my_val)
                    best_results_grouped[attr][clusterername].append(my_val)
                    continue
                if np.isnan(my_val):
                    latex_table += r"   --  "
                    continue
                is_best = my_val + eps >= np.max(best_results[attr])
                if len(best_results[attr]) > 1:
                    is_secd = my_val + eps >= np.sort(best_results[attr])[-2]
                else:
                    is_secd = False
                is_best_grp = my_val + eps >= np.max(
                    best_results_grouped[attr][clusterername]
                )
                sc_base = np.nanmedian(best_results[attr])
                sc_top = np.max(best_results[attr])
                sc = 100 * max(0, (my_val - sc_base) / (sc_top - sc_base))
                latex_table += r"\cellcolor{cbg!" + f"{sc:.0f}" + "}"
                if show_pc:
                    my_val = my_val * 100
                latex_table += r"\num{" if use_si_num else r"$"
                if not highlight_best:
                    pass
                elif is_best:
                    latex_table += r"\tcf{"
                elif is_secd:
                    latex_table += r"\tcs{"
                else:
                    latex_table += "     "
                if not highlight_best:
                    pass
                elif is_best_grp:
                    latex_table += r"\tcg{"
                else:
                    latex_table += "     "
                latex_table += show_fmt.format(my_val)
                if highlight_best:
                    latex_table += r"}" if is_best or is_secd else " "
                    latex_table += r"}" if is_best_grp else " "
                latex_table += r"}" if use_si_num else r"$"
            latex_table += r" \\" + "\n"
    latex_table += r"\bottomrule" + "\n"
    latex_table += r"\end{tabular}" + "\n"
    latex_table += r"}" + "\n"

print()
print(f"There are {len(cmds)} commands to execute to generate missing datapoints")
if show_commands:
    for cmd in cmds:
        print(cmd)

print()
print("Done!")
print()
print(f"Here is your {dataset} results table for {metric_key}, {backbone}:")
print()
print()
print(latex_table)

### ImageNet-R

In [None]:
imagenetr_test = datasets.fetch_image_dataset("imagenet-r")[-1]

In [None]:
artforms = [os.path.basename(fname[0]).split("_")[0] for fname in imagenetr_test.imgs]

In [None]:
u_artform, artform_ids = np.unique(artforms, return_inverse=True)

In [None]:
len(u_artform)

In [None]:
np.unique(imagenetr_test.targets)

In [None]:
class_artform_ids = imagenetr_test.targets + artform_ids * 200

In [None]:
len(np.unique(class_artform_ids))

In [None]:
attrs = np.stack([imagenetr_test.targets, artform_ids, class_artform_ids], axis=-1)

In [None]:
metric_key = "AMI"  # AMI  num_cluster_pred  silhouette-euclidean_pred  silhouette-og-euclidean_pred
show_pc = True
show_fmt = "{:4.0f}"
show_commands = False
highlight_best = True
use_si_num = False
eps = 0.005
override_fields = {
    "predictions_dir": "y_pred",
}

backbone = "ViT-B"  # "ResNet-50" or "ViT-B"

if metric_key == "num_cluster_pred":
    CLUSTERERS = ["AC w/o C", "AffinityPropagation", "HDBSCAN"]
    show_pc = False
    show_fmt = "{:4.0f}"
    highlight_best = False
    use_si_num = True
else:
    CLUSTERERS = ["KMeans", "AC w/ C", "AC w/o C", "AffinityPropagation", "HDBSCAN"]
if metric_key.startswith("silhouette"):
    show_pc = False
    show_fmt = "{:5.2f}"

print(MODEL_GROUPS)

dataset = "imagenet-r"

TEST_ATTRS = ["Class", "Artform", "Both"]
print(TEST_ATTRS)
best_results = {k: [] for k in TEST_ATTRS}
best_results_grouped = {k: defaultdict(list) for k in TEST_ATTRS}

for dummy in [True, False]:
    cmds = []
    latex_table = r"% Results for " + f"{dataset} breakdown, {metric_key}, {backbone}" + "\n"
    now_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    latex_table += r"% Generated " + now_str + "\n"
    latex_table += r"% Using hparams " + BEST_PARAMS["_version"] + "\n"
    label = backbone
    if metric_key == "AMI":
        latex_table += r"\label{tab:" + label + r"}" + "\n"
    label = metric_key.replace("_", "-") + ":" + label
    latex_table += r"\label{tab:" + label + r"}" + "\n"
    latex_table += r"%\resizebox{\textwidth}{!}{%" + "\n"  # Disabled
    latex_table += r"\begin{tabular}{ll" + r"r" * len(TEST_ATTRS) + r"}" + "\n"
    latex_table += r"\toprule" + "\n"
    latex_table += r"& " + f"{'Encoder':<11s}"
    for attr in TEST_ATTRS:
        latex_table += r"&" + f"{attr.replace('_', ' '):^15s}"
    latex_table += r"\\" + "\n"
    latex_table += r"\toprule" + "\n"
    print(MODEL_GROUPS[backbone])
    if metric_key == "num_cluster_pred":
        latex_table += r"& Num targets"
        for i_attr, attr in enumerate(TEST_ATTRS):
            sdf = select_rows(test_runs_df, {"dataset": dataset}, allow_missing=False)
            sdf = sdf[~pd.isna(sdf["num_cluster_true"])]
            latex_table += r"& "
            latex_table += r"\num{" if use_si_num else r"$"
            latex_table += f"{sdf.iloc[0]['num_cluster_true'].item()}"
            latex_table += r"}" if use_si_num else r"$"
        latex_table += r"\\" + "\n"
        latex_table += r"\toprule" + "\n"
    elif metric_key.endswith("_pred"):
        metric_key2 = metric_key.replace("_pred", "_true")
        clusterername = "G.T."
        latex_table += (
            r"\parbox[t]{2mm}{\multirow{"
            + str(len(MODEL_GROUPS[backbone]))
            + r"}{*}{"
            + r"\scalebox{0.9}{"
            + r"\rotatebox[origin=c]{90}{"
            + clusterername
            + r"}}}"
            + r"}"
        )
        latex_table += "\n"
        for i_group, model in enumerate(list(MODEL_GROUPS[backbone])):
            latex_table += f"& {MODEL2SH[model]:<10s}"
            for i_attr, attr in enumerate(TEST_ATTRS):
                latex_table += " &"
                filter1 = {"model": model, "dataset": dataset}
                if model == "timm_vit_base_patch16_224.mae":
                    filter1["dim_reducer"] = "PCA"
                    filter1["pca_variance"] = 0.95
                else:
                    filter1["dim_reducer_man"] = "UMAP"
                    filter1["ndim_reduced_man"] = 50
                    filter1["dim_reducer_man_metric"] = "euclidean"
                sdf = select_rows(test_runs_df, filter1, allow_missing=False)
                sdf = sdf[~pd.isna(sdf[metric_key2])]
                my_val = np.nanmedian(sdf[metric_key])
                if sum(sdf[metric_key2] != my_val) > 0:
                    pass
                if dummy:
                    best_results_grouped[attr][clusterername].append(my_val)
                    continue
                is_best_grp = my_val + eps >= np.max(
                    best_results_grouped[attr][clusterername]
                )
                latex_table += r"\num{" if use_si_num else r"$"
                latex_table += "     "
                if not highlight_best:
                    pass
                elif is_best_grp:
                    latex_table += r"\tcg{"
                else:
                    latex_table += "     "
                latex_table += show_fmt.format(my_val)
                if highlight_best:
                    latex_table += r"}" if is_best_grp else " "
                latex_table += r"}" if use_si_num else r"$"

            latex_table += r" \\" + "\n"
        latex_table += r"\toprule" + "\n"

    first_agg = True
    for i_clusterer, clusterer in enumerate(CLUSTERERS):
        clusterername = CLUSTERER2SH.get(clusterer, clusterer)
        my_override_fields = override_fields.copy()
        if (
            first_agg
            and clusterer == "AgglomerativeClustering"
            and metric_key != "num_cluster_pred"
        ):
            first_agg = False
            my_override_fields["aggclust_dist_thresh"] = None
            clusterername = "AC  w/ C"
        elif clusterer == "AgglomerativeClustering":
            clusterername = "AC w/o C"
            if "aggclust_dist_thresh" in my_override_fields:
                del my_override_fields["aggclust_dist_thresh"]

        if i_clusterer > 0:
            latex_table += r"\midrule" + "\n"

        latex_table += (
            r"\parbox[t]{2mm}{\multirow{"
            + str(len(MODEL_GROUPS[backbone]))
            + r"}{*}{"
            + r"\scalebox{0.9}{"
            + r"\rotatebox[origin=c]{90}{"
            + clusterername
            + r"}}}"
            + r"}"
        )
        latex_table += "\n"

        for i_group, model in enumerate(list(MODEL_GROUPS[backbone])):
            latex_table += f"& {MODEL2SH[model]:<10s}"
            filter1 = {"model": model, "dataset": dataset}
            filter2 = dict(DEFAULT_PARAMS["all"], **BEST_PARAMS[clusterer][model])
            filter2.update(filter1)
            filter2.update(my_override_fields)
            filter2 = fixup_filter(filter2)
            sdf = select_rows(test_runs_df, filter2, allow_missing=False)
            if len(sdf) < 1:
                print(f"No data for {model}-{dataset}-{clusterer}\n{filter2}")
                cmds.append(filter2command(filter2, partition="test"))
                if not dummy:
                    # latex_table += r"\multicolumn{1}{c}{--}"
                    latex_table += r"   --  "
                continue
            if len(sdf) > 1:
                if sum(np.abs(sdf[metric_key] - sdf.iloc[0][metric_key]) > 1e-6) > 0:
                    print()
                    print(f"More than one result with {metric_key} values", list(sdf[metric_key]))
                    print(f"for search {filter2}")
                    dif_cols = find_differing_columns(sdf, config_keys)
                    print(f"columns which differ: {dif_cols}")
                    if dif_cols:
                        for col in dif_cols:
                            print(f"  {col}: {list(sdf[col])}")
            y_pred = np.load("../" + get_pred_path(sdf.iloc[0]))["y_pred"]
            for i_attr, attr in enumerate(TEST_ATTRS):
                latex_table += " &"
                if metric_key.lower() != "ami":
                    raise NotImplementedError()
                my_val = sklearn.metrics.adjusted_mutual_info_score(attrs[:, i_attr], y_pred)
                if dummy:
                    best_results[attr].append(my_val)
                    best_results_grouped[attr][clusterername].append(my_val)
                    continue
                if np.isnan(my_val):
                    latex_table += r"   --  "
                    continue
                is_best = my_val + eps >= np.max(best_results[attr])
                if len(best_results[attr]) > 1:
                    is_secd = my_val + eps >= np.sort(best_results[attr])[-2]
                else:
                    is_secd = False
                is_best_grp = my_val + eps >= np.max(
                    best_results_grouped[attr][clusterername]
                )
                sc_base = np.nanmedian(best_results[attr])
                sc_top = np.max(best_results[attr])
                sc = 100 * max(0, (my_val - sc_base) / (sc_top - sc_base))
                latex_table += r"\cellcolor{cbg!" + f"{sc:.0f}" + "}"
                if show_pc:
                    my_val = my_val * 100
                latex_table += r"\num{" if use_si_num else r"$"
                if not highlight_best:
                    pass
                elif is_best:
                    latex_table += r"\tcf{"
                elif is_secd:
                    latex_table += r"\tcs{"
                else:
                    latex_table += "     "
                if not highlight_best:
                    pass
                elif is_best_grp:
                    latex_table += r"\tcg{"
                else:
                    latex_table += "     "
                latex_table += show_fmt.format(my_val)
                if highlight_best:
                    latex_table += r"}" if is_best or is_secd else " "
                    latex_table += r"}" if is_best_grp else " "
                latex_table += r"}" if use_si_num else r"$"
            latex_table += r" \\" + "\n"
    latex_table += r"\bottomrule" + "\n"
    latex_table += r"\end{tabular}" + "\n"
    latex_table += r"%}" + "\n"  # Disabled

print()
print(f"There are {len(cmds)} commands to execute to generate missing datapoints")
if show_commands:
    for cmd in cmds:
        print(cmd)

print()
print("Done!")
print()
print(f"Here is your {dataset} results table for {metric_key}, {backbone}:")
print()
print()
print(latex_table)

### FGVC Aircraft

In [None]:
annotation_levels = ["manufacturer", "family", "variant"]
attrs = np.stack(
    [
        torchvision.datasets.FGVCAircraft(
            os.path.expanduser("~/Datasets"), split="test", annotation_level=annotation_level
        )._labels
        for annotation_level in annotation_levels
    ],
    axis=-1,
)

In [None]:
for i_attr in range(len(annotation_levels)):
    print(annotation_levels[i_attr], len(np.unique(attrs[:, i_attr])))

In [None]:
metric_key = "AMI"  # AMI  num_cluster_pred  silhouette-euclidean_pred  silhouette-og-euclidean_pred
show_pc = True
show_fmt = "{:4.0f}"
show_commands = False
highlight_best = True
use_si_num = False
eps = 0.005
override_fields = {
    "predictions_dir": "y_pred",
}

backbone = "ResNet-50"  # "ResNet-50" or "ViT-B"

if metric_key == "num_cluster_pred":
    CLUSTERERS = ["AC w/o C", "AffinityPropagation", "HDBSCAN"]
    show_pc = False
    show_fmt = "{:4.0f}"
    highlight_best = False
    use_si_num = True
else:
    CLUSTERERS = ["KMeans", "AC w/ C", "AC w/o C", "AffinityPropagation", "HDBSCAN"]
if metric_key.startswith("silhouette"):
    show_pc = False
    show_fmt = "{:5.2f}"

print(MODEL_GROUPS)

dataset = "aircraft"

TEST_ATTRS = annotation_levels
print(TEST_ATTRS)
best_results = {k: [] for k in TEST_ATTRS}
best_results_grouped = {k: defaultdict(list) for k in TEST_ATTRS}

for dummy in [True, False]:
    cmds = []
    latex_table = r"% Results for " + f"{dataset} breakdown, {metric_key}, {backbone}" + "\n"
    now_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    latex_table += r"% Generated " + now_str + "\n"
    latex_table += r"% Using hparams " + BEST_PARAMS["_version"] + "\n"
    label = backbone
    if metric_key == "AMI":
        latex_table += r"\label{tab:" + label + r"}" + "\n"
    label = metric_key.replace("_", "-") + ":" + label
    latex_table += r"\label{tab:" + label + r"}" + "\n"
    latex_table += r"%\resizebox{\textwidth}{!}{%" + "\n"  # Disabled
    latex_table += r"\begin{tabular}{ll" + r"r" * len(TEST_ATTRS) + r"}" + "\n"
    latex_table += r"\toprule" + "\n"
    latex_table += r"& " + f"{'Encoder':<11s}"
    for attr in TEST_ATTRS:
        latex_table += r"&" + f"{attr.replace('_', ' '):^15s}"
    latex_table += r"\\" + "\n"
    latex_table += r"\toprule" + "\n"
    print(MODEL_GROUPS[backbone])
    if metric_key == "num_cluster_pred":
        latex_table += r"& Num targets"
        for i_attr, attr in enumerate(TEST_ATTRS):
            sdf = select_rows(test_runs_df, {"dataset": dataset}, allow_missing=False)
            sdf = sdf[~pd.isna(sdf["num_cluster_true"])]
            latex_table += r"& "
            latex_table += r"\num{" if use_si_num else r"$"
            latex_table += f"{sdf.iloc[0]['num_cluster_true'].item()}"
            latex_table += r"}" if use_si_num else r"$"
        latex_table += r"\\" + "\n"
        latex_table += r"\toprule" + "\n"
    elif metric_key.endswith("_pred"):
        metric_key2 = metric_key.replace("_pred", "_true")
        clusterername = "G.T."
        latex_table += (
            r"\parbox[t]{2mm}{\multirow{"
            + str(len(MODEL_GROUPS[backbone]))
            + r"}{*}{"
            + r"\scalebox{0.9}{"
            + r"\rotatebox[origin=c]{90}{"
            + clusterername
            + r"}}}"
            + r"}"
        )
        latex_table += "\n"
        for i_group, model in enumerate(list(MODEL_GROUPS[backbone])):
            latex_table += f"& {MODEL2SH[model]:<10s}"
            for i_attr, attr in enumerate(TEST_ATTRS):
                latex_table += " &"
                filter1 = {"model": model, "dataset": dataset}
                if model == "timm_vit_base_patch16_224.mae":
                    filter1["dim_reducer"] = "PCA"
                    filter1["pca_variance"] = 0.95
                else:
                    filter1["dim_reducer_man"] = "UMAP"
                    filter1["ndim_reduced_man"] = 50
                    filter1["dim_reducer_man_metric"] = "euclidean"
                sdf = select_rows(test_runs_df, filter1, allow_missing=False)
                sdf = sdf[~pd.isna(sdf[metric_key2])]
                my_val = np.nanmedian(sdf[metric_key])
                if sum(sdf[metric_key2] != my_val) > 0:
                    pass
                if dummy:
                    best_results_grouped[attr][clusterername].append(my_val)
                    continue
                is_best_grp = my_val + eps >= np.max(
                    best_results_grouped[attr][clusterername]
                )
                latex_table += r"\num{" if use_si_num else r"$"
                latex_table += "     "
                if not highlight_best:
                    pass
                elif is_best_grp:
                    latex_table += r"\tcg{"
                else:
                    latex_table += "     "
                latex_table += show_fmt.format(my_val)
                if highlight_best:
                    latex_table += r"}" if is_best_grp else " "
                latex_table += r"}" if use_si_num else r"$"

            latex_table += r" \\" + "\n"
        latex_table += r"\toprule" + "\n"

    first_agg = True
    for i_clusterer, clusterer in enumerate(CLUSTERERS):
        clusterername = CLUSTERER2SH.get(clusterer, clusterer)
        my_override_fields = override_fields.copy()
        if (
            first_agg
            and clusterer == "AgglomerativeClustering"
            and metric_key != "num_cluster_pred"
        ):
            first_agg = False
            my_override_fields["aggclust_dist_thresh"] = None
            clusterername = "AC  w/ C"
        elif clusterer == "AgglomerativeClustering":
            clusterername = "AC w/o C"
            if "aggclust_dist_thresh" in my_override_fields:
                del my_override_fields["aggclust_dist_thresh"]

        if i_clusterer > 0:
            latex_table += r"\midrule" + "\n"

        latex_table += (
            r"\parbox[t]{2mm}{\multirow{"
            + str(len(MODEL_GROUPS[backbone]))
            + r"}{*}{"
            + r"\scalebox{0.9}{"
            + r"\rotatebox[origin=c]{90}{"
            + clusterername
            + r"}}}"
            + r"}"
        )
        latex_table += "\n"

        for i_group, model in enumerate(list(MODEL_GROUPS[backbone])):
            latex_table += f"& {MODEL2SH[model]:<10s}"
            filter1 = {"model": model, "dataset": dataset}
            filter2 = dict(DEFAULT_PARAMS["all"], **BEST_PARAMS[clusterer][model])
            filter2.update(filter1)
            filter2.update(my_override_fields)
            filter2 = fixup_filter(filter2)
            sdf = select_rows(test_runs_df, filter2, allow_missing=False)
            if len(sdf) < 1:
                print(f"No data for {model}-{dataset}-{clusterer}\n{filter2}")
                cmds.append(filter2command(filter2, partition="test"))
                if not dummy:
                    # latex_table += r"\multicolumn{1}{c}{--}"
                    latex_table += r"   --  "
                continue
            if len(sdf) > 1:
                if sum(np.abs(sdf[metric_key] - sdf.iloc[0][metric_key]) > 1e-6) > 0:
                    print()
                    print(f"More than one result with {metric_key} values", list(sdf[metric_key]))
                    print(f"for search {filter2}")
                    dif_cols = find_differing_columns(sdf, config_keys)
                    print(f"columns which differ: {dif_cols}")
                    if dif_cols:
                        for col in dif_cols:
                            print(f"  {col}: {list(sdf[col])}")
            y_pred = np.load("../" + get_pred_path(sdf.iloc[0]))["y_pred"]
            for i_attr, attr in enumerate(TEST_ATTRS):
                latex_table += " &"
                if metric_key.lower() != "ami":
                    raise NotImplementedError()
                my_val = sklearn.metrics.adjusted_mutual_info_score(attrs[:, i_attr], y_pred)
                if dummy:
                    best_results[attr].append(my_val)
                    best_results_grouped[attr][clusterername].append(my_val)
                    continue
                if np.isnan(my_val):
                    latex_table += r"   --  "
                    continue
                is_best = my_val + eps >= np.max(best_results[attr])
                if len(best_results[attr]) > 1:
                    is_secd = my_val + eps >= np.sort(best_results[attr])[-2]
                else:
                    is_secd = False
                is_best_grp = my_val + eps >= np.max(
                    best_results_grouped[attr][clusterername]
                )
                sc_base = np.nanmedian(best_results[attr])
                sc_top = np.max(best_results[attr])
                sc = 100 * max(0, (my_val - sc_base) / (sc_top - sc_base))
                latex_table += r"\cellcolor{cbg!" + f"{sc:.0f}" + "}"
                if show_pc:
                    my_val = my_val * 100
                latex_table += r"\num{" if use_si_num else r"$"
                if not highlight_best:
                    pass
                elif is_best:
                    latex_table += r"\tcf{"
                elif is_secd:
                    latex_table += r"\tcs{"
                else:
                    latex_table += "     "
                if not highlight_best:
                    pass
                elif is_best_grp:
                    latex_table += r"\tcg{"
                else:
                    latex_table += "     "
                latex_table += show_fmt.format(my_val)
                if highlight_best:
                    latex_table += r"}" if is_best or is_secd else " "
                    latex_table += r"}" if is_best_grp else " "
                latex_table += r"}" if use_si_num else r"$"
            latex_table += r" \\" + "\n"
    latex_table += r"\bottomrule" + "\n"
    latex_table += r"\end{tabular}" + "\n"
    latex_table += r"%}" + "\n"  # Disabled

print()
print(f"There are {len(cmds)} commands to execute to generate missing datapoints")
if show_commands:
    for cmd in cmds:
        print(cmd)

print()
print("Done!")
print()
print(f"Here is your {dataset} results table for {metric_key}, {backbone}:")
print()
print()
print(latex_table)