# Globals

In [None]:
import copy
import datetime
import os
from collections import defaultdict

import matplotlib.colors
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.stats
import seaborn
import sklearn.metrics
import torchvision.datasets
import wandb
from IPython.display import display
from tqdm.autonotebook import tqdm

In [None]:
FIGS_DIR = "figs"
os.makedirs(FIGS_DIR, exist_ok=True)

In [None]:
VALIDATION_DATASETS = ["imagenet", "imagenette", "imagewoof"]
RESNET50_MODELS = [
    # "random_resnet50",
    "resnet50",
    "mocov3_resnet50",
    "dino_resnet50",
    "vicreg_resnet50",
    "clip_RN50",
]
VITB16_MODELS = [
    # "random_vitb16",
    "vitb16",
    "mocov3_vit_base",
    "dino_vitb16",
    "timm_vit_base_patch16_224.mae",
    "mae_pretrain_vit_base_global",
    "clip_vitb16",
]
FT_RESNET50_MODELS = [
    "ft_mocov3_resnet50",
    "ft_dino_resnet50",
    "ft_vicreg_resnet50",
]
FT_VITB16_MODELS = [
    "ft_mocov3_vit_base",
    "ft_dino_vitb16",
    "mae_finetuned_vit_base_global",
]
FT_MODELS = FT_RESNET50_MODELS + FT_VITB16_MODELS
ALL_MODELS = RESNET50_MODELS + VITB16_MODELS + FT_RESNET50_MODELS + FT_VITB16_MODELS

RESNET50_MODELS_INTERLEAVED = [
    "random_resnet50",
    "resnet50",
    "mocov3_resnet50",
    "ft_mocov3_resnet50",
    "dino_resnet50",
    "ft_dino_resnet50",
    "vicreg_resnet50",
    "ft_vicreg_resnet50",
]
VITB16_MODELS_INTERLEAVED = [
    "random_vitb16",
    "vitb16",
    "mocov3_vit_base",
    "ft_mocov3_vit_base",
    "dino_vitb16",
    "ft_dino_vitb16",
    "timm_vit_base_patch16_224.mae",
    "mae_pretrain_vit_base_global",
    "mae_finetuned_vit_base_global",
]

DNA_MODELS = [
    "barcodebert",
    "dnabert-2",
    "dnabert-s",
    "hyenadna",
    "NucleotideTransformer",
]

CLUSTERERS = [
    "KMeans",
    "LouvainCommunities",
    "AgglomerativeClustering",
    "AffinityPropagation",
    "SpectralClustering",
    "HDBSCAN",
    "OPTICS",
]
ALL_CLUSTERERS = copy.deepcopy(CLUSTERERS)
DISTANCE_METRICS = [
    "euclidean",
    "l1",
    "chebyshev",
    "cosine",
    "arccos",
    "braycurtis",
    "canberra",
]

In [None]:
PRE2FT = {
    k: "ft_" + k
    for k in [
        "mocov3_resnet50",
        "dino_resnet50",
        "vicreg_resnet50",
        "mocov3_vit_base",
        "dino_vitb16",
    ]
}
PRE2FT["mae_pretrain_vit_base_global"] = "mae_finetuned_vit_base_global"
FT2PRE = {v: k for k, v in PRE2FT.items()}

In [None]:
DATASET2LS = {
    "imagenet": "-.",
    "imagenette": "--",
    "imagewoof": ":",
}
ARCH2MARKER = {
    "ResNet-50": "s",
    "ViT-B": "^",
    "none": "o",
}
ARCH2LS = {
    "ResNet-50": "-",
    "ViT-B": "-",
    "none": ":",
}

In [None]:
DEFAULT_PARAMS = {
    "all": {
        "dim_reducer": "None",
        "dim_reducer_man": "None",
        "zscore": False,
        "normalize": False,
        "zscore2": False,
        "ndim_correction": False,
    },
    "AgglomerativeClustering": {
        "clusterer": "AgglomerativeClustering",
        "distance_metric": "euclidean",
        "aggclust_linkage": "ward",
        "dim_reducer_man": "UMAP",
        "ndim_reduced_man": 50,
    },
}

In [None]:
models = RESNET50_MODELS + VITB16_MODELS
BEST_PARAMS = {
    clusterer: {model: copy.deepcopy(DEFAULT_PARAMS[clusterer]) for model in models} for clusterer in ["AgglomerativeClustering"]
}
BEST_PARAMS["_version"] = "recommended"

## Utility functions

In [None]:
def categorical_cmap(nc, nsc, cmap="tab10", continuous=False):
    """
    Create a colormap with a certain number of shades of colours.

    Based on https://stackoverflow.com/a/47232942/1960959

    Parameters
    ----------
    nc : int
        Number of categories.
    nsc : int
        Number of shades per category.
    cmap : str, default=tab10
        Original colormap to extend into multiple shades.
    continuous : bool, default=False
        Whether ``cmap`` is continous. Otherwise it is treated
        as categorical with adjacent colors unrelated.

    Returns
    -------
    matplotlib.colors.ListedColormap
        New cmap which alternates between ``nsc`` shades of ``nc``
        colors from ``cmap``.
    """
    if nc > plt.get_cmap(cmap).N:
        raise ValueError("Too many categories for colormap.")
    if continuous:
        ccolors = plt.get_cmap(cmap)(np.linspace(0, 1, nc))
    else:
        ccolors = plt.get_cmap(cmap)(np.arange(nc, dtype=int))
    cols = np.zeros((nc * nsc, 3))
    for i, c in enumerate(ccolors):
        chsv = matplotlib.colors.rgb_to_hsv(c[:3])
        arhsv = np.tile(chsv, nsc).reshape(nsc, 3)
        arhsv[:, 1] = np.linspace(chsv[1], 0.25, nsc)
        arhsv[:, 2] = np.linspace(chsv[2], 1, nsc)
        rgb = matplotlib.colors.hsv_to_rgb(arhsv)
        cols[i * nsc : (i + 1) * nsc, :] = rgb
    cmap = matplotlib.colors.ListedColormap(cols)
    return cmap

In [None]:
categorical_cmap(len(RESNET50_MODELS), len(VALIDATION_DATASETS))

In [None]:
from zs_ssl_clustering.datasets import image_dataset_sizes


def clip_imgsize(dataset, target_image_size):
    if target_image_size is None:
        return target_image_size
    dataset_imsize = image_dataset_sizes(dataset)[1]
    if dataset_imsize is None:
        return target_image_size
    if hasattr(dataset_imsize, "__len__"):
        dataset_imsize = min(dataset_imsize)
    return min(target_image_size, dataset_imsize)

In [None]:
def fixup_filter(filters):
    dataset = filters.get("dataset_name", filters.get("dataset", None))
    if dataset and "image_size" in filters:
        filters["image_size"] = clip_imgsize(dataset, filters["image_size"])
    if dataset and "min_samples" in filters:
        if dataset.lower() in ["celeba", "utkface", "bioscan1m"]:
            filters["min_samples"] = 2
    return filters

In [None]:
def select_rows(df, filters, allow_missing=True, fixup=True):
    if fixup:
        filters = fixup_filter(filters)
    select = np.ones(len(df), dtype=bool)
    for col, val in filters.items():
        if col == "dataset":
            col = "dataset_name"
        if col == "clusterer":
            col = "clusterer_name"
        if val is None or val == "None" or val == "none":
            select_i = pd.isna(df[col])
            select_i |= df[col] == "None"
            select_i |= df[col] == "none"
        else:
            select_i = df[col] == val
            select_i |= df[col] == str(val)
            if allow_missing or val == "None" or val == "none":
                select_i |= pd.isna(df[col])
        select &= select_i
    return df[select]

In [None]:
def find_differing_columns(df, cols=None):
    if cols is None:
        cols = df.columns
    my_cols = []
    for col in cols:
        if col not in df.columns:
            continue
        if df[col].nunique(dropna=False) > 1:
            my_cols.append(col)
    return my_cols

In [None]:
def filter2command(*filters, partition="val"):
    f = {}
    for filter in filters:
        for k, v in filter.items():
            f[k] = v
    dataset = f.get("dataset", "")
    clusterer = f.get("clusterer", "")

    mem = 2  # RAM in gigabytes

    if clusterer in ["LouvainCommunities"]:
        if dataset in ["inaturalist"]:
            # 100,000 samples
            mem = 3_700
        elif dataset in ["imagenet-sketch", "imagenet"]:
            # 50,000 samples
            mem = 926
        elif dataset in ["places365"]:
            # 36,500 samples
            mem = 494
        elif dataset in ["imagenet-r"]:
            # 30,000 samples
            mem = 333
        elif dataset in ["svhn"]:
            # 26,000 samples
            mem = 250
        elif dataset in ["bioscan1m", "nabirds"]:
            # 24,600 samples
            mem = 224
        elif dataset in ["celeba"]:
            # 20,000 samples
            mem = 128
        elif dataset in [
            "imagenetv2",
            "cifar10",
            "cifar100",
            "lsun",
            "mnist",
            "fashionmnist",
            "stanfordcars",
            "breakhis",
        ]:
            # 8,000 - 10,000 samples
            mem = 32
        elif dataset in ["flowers102", "utkface"]:
            # 5,925 - 6,200 samples
            mem = 18
        elif dataset.startswith("in9") or dataset in ["eurosat"]:
            # 4,500 samples
            mem = 8
        elif dataset in ["imagenette", "imagewoof", "aircraft"]:
            # 3,333 - 3,930 samples
            mem = 6
        elif dataset in ["imagenet-o", "dtd"]:
            # 2,000 samples
            mem = 4
        else:
            mem = 12

    elif clusterer in ["AffinityPropagation"]:
        if dataset in ["inaturalist"]:
            # 100,000 samples
            mem = 292
        elif dataset in ["imagenet-sketch", "imagenet"]:
            # 50,000 samples
            mem = 72
        elif dataset in ["places365", "imagenet-r", "svhn", "bioscan1m", "nabirds"]:
            # 24,600 - 36,500 samples
            mem = 48
        elif dataset in ["celeba"]:
            # 20,000 samples
            mem = 12
        elif dataset in [
            "imagenetv2",
            "cifar10",
            "cifar100",
            "lsun",
            "mnist",
            "fashionmnist",
            "stanfordcars",
        ]:
            # 8,000 - 10,000 samples
            mem = 6
        elif dataset.startswith("in9") or dataset in [
            "flowers102",
            "utkface",
            "eurosat",
            "aircraft",
            "breakhis",
            "imagenet-o",
            "dtd",
        ]:
            # 1,900 - 6,200 samples
            mem = 2
        elif dataset in ["imagenette", "imagewoof"]:
            # 3,930 samples
            mem = 1
        else:
            mem = 8

    elif clusterer in ["AgglomerativeClustering", "SpectralClustering"]:
        if dataset in ["inaturalist"]:
            # 100,000 samples
            mem = 72
        elif dataset in ["imagenet-sketch", "imagenet"]:
            # 50,000 samples
            mem = 20
        elif dataset in ["places365", "imagenet-r", "svhn", "bioscan1m", "nabirds"]:
            # 24,600 - 36,500 samples
            mem = 16
        elif dataset in ["celeba"]:
            # 20,000 samples
            mem = 12
        elif dataset in [
            "imagenetv2",
            "cifar10",
            "cifar100",
            "lsun",
            "mnist",
            "fashionmnist",
            "stanfordcars",
        ]:
            # 8,000 - 10,000 samples
            mem = 6
        elif dataset.startswith("in9") or dataset in [
            "flowers102",
            "utkface",
            "eurosat",
            "aircraft",
            "breakhis",
            "imagenet-o",
            "dtd",
        ]:
            # 1,900 - 6,200 samples
            mem = 4
        elif dataset in ["imagenette", "imagewoof"]:
            # 3,930 samples
            mem = 2
        else:
            mem = 8
        if clusterer == "SpectralClustering":
            snn = f.get("spectral_n_neighbors", 100)
            if snn <= 10:
                mem = mem * 8 / 20
            elif snn <= 20:
                mem = mem * 3 / 4
            mem = int(np.ceil(mem))

    elif clusterer in ["HDBSCAN", "KMeans"]:
        if dataset in ["inaturalist"]:
            # 100,000 samples
            mem = 6
        elif dataset in ["imagenet-sketch", "imagenet"]:
            # 50,000 samples
            mem = 4
        elif dataset in ["places365", "imagenet-r", "svhn", "bioscan1m", "nabirds"]:
            # 24,600 - 36,500 samples
            mem = 4
        elif dataset in ["celeba"]:
            # 20,000 samples
            mem = 4
        elif dataset in [
            "imagenetv2",
            "cifar10",
            "cifar100",
            "lsun",
            "mnist",
            "fashionmnist",
            "stanfordcars",
        ]:
            # 8,000 - 10,000 samples
            mem = 2
        elif dataset.startswith("in9") or dataset in [
            "flowers102",
            "utkface",
            "eurosat",
            "aircraft",
            "breakhis",
            "imagenet-o",
            "dtd",
        ]:
            # 1,900 - 6,200 samples
            mem = 2
        elif dataset in ["imagenette", "imagewoof"]:
            # 3,930 samples
            mem = 1
        else:
            mem = 4

    if mem > 300:
        return ""
    if mem > 129:
        pass

    mem = f"{mem}G"

    if partition == "val":
        seed = 100
    elif partition == "test":
        seed = 1
    else:
        seed = 0
    s = (
        f"sbatch --array={seed} --mem={mem}"
        f' --job-name="zsc-{f.get("model", "")}-{dataset}-{clusterer}"'
        f" slurm/cluster.slrm --partition={partition}"
    )
    for k, v in f.items():
        if v is None:
            continue
        if k == "zscore":
            if v == "False" or not v:
                s += " --no-zscore"
            elif v == "True" or v:
                s += " --zscore"
            continue
        if k == "normalize":
            if v == "False" or not v:
                pass
            elif v == "True" or v:
                s += " --normalize"
            continue
        if k == "zscore2":
            if v == "False" or not v:
                s += " --no-zscore2"
            elif v == "average":
                s += " --azscore2"
            elif v == "standard" or v:
                s += " --zscore2"
            continue
        if k == "ndim_correction":
            if v == "False" or not v:
                s += " --no-ndim-correction"
            elif v == "True" or v:
                s += " --ndim-correction"
            continue
        if k == "louvain_remove_self_loops":
            if v == "False" or not v:
                s += " --louvain-keep-self"
            elif v == "True" or v:
                pass
            continue
        s += f" --{k.replace('_', '-')}={v}"
    return s

# Final results

In [None]:
# Exclude CLIP from analysis
# RESNET50_MODELS = [v for v in RESNET50_MODELS if not v.startswith("clip")]
# VITB16_MODELS = [v for v in VITB16_MODELS if not v.startswith("clip")]

In [None]:
TEST_DATASETS = [
    "bioscan5m",
]
DATASET2SH = {
    "bioscan1m": "BS-5M",
}
MODEL_GROUPS = {
    "ResNet-50": RESNET50_MODELS,
    "ViT-B": VITB16_MODELS,
    "ResNet-50 [FT]": FT_RESNET50_MODELS,
    "ViT-B [FT]": FT_VITB16_MODELS,
    "all": ALL_MODELS,
}
MODEL2SH = {
    None: r"\textit{DNA-only}",
    # "none": "Raw image",
    "random_resnet50": "Rand.",  # "Random",
    "random_vitb16": "Rand.",  # "Random",
    "resnet50": "X-Ent.",
    "mocov3_resnet50": "MoCo-v3",
    "dino_resnet50": "DINO",
    "vicreg_resnet50": "VICReg",
    "clip_RN50": "CLIP",
    "vitb16": "X-Ent.",
    "mocov3_vit_base": "MoCo-v3",
    "dino_vitb16": "DINO",
    "timm_vit_base_patch16_224.mae": "MAE (CLS)",
    "mae_pretrain_vit_base_global": "MAE (avg)",
    "clip_vitb16": "CLIP",
    "ft_mocov3_resnet50": "MoCo-v3 [FT]",
    "ft_dino_resnet50": "DINO [FT]",
    "ft_vicreg_resnet50": "VICReg [FT]",
    "ft_mocov3_vit_base": "MoCo-v3 [FT]",
    "ft_dino_vitb16": "DINO [FT]",
    "mae_finetuned_vit_base_global": "MAE (avg) [FT]",
}
DNAMODEL2SH = {
    None: r"\textit{Image-only}",
    "barcodebert": "BarcodeBERT",
    "dnabert-2": "DNABERT-2",
    "dnabert-s": "DNABERT-S",
    "hyenadna": "HyenaDNA",
    "NucleotideTransformer": "NT",
}
CLUSTERER2SH = {
    "KMeans": "K-Means",
    "SpectralClustering": "Spectral",
    "AffinityPropagation": "Affinity Prop",
    "AgglomerativeClustering": "Agg.",  # "AC",
    "AC w/ C": "AC w/  C",
}

In [None]:
MODEL2ARCH = {}
for k in RESNET50_MODELS + FT_RESNET50_MODELS:
    MODEL2ARCH[k] = "ResNet-50"
for k in VITB16_MODELS + FT_VITB16_MODELS:
    MODEL2ARCH[k] = "ViT-B"

MODEL2SH_ARCH = dict(MODEL2SH)
for k, v in MODEL2SH.items():
    if k is None:
        MODEL2SH_ARCH[k] = v
        continue
    if "resnet" in k or "RN50" in k:
        MODEL2SH_ARCH[k] = f"ResNet-50 {v}"
    elif "vit" in k:
        MODEL2SH_ARCH[k] = f"ViT-B {v}"

In [None]:
CLUSTERER2COLORSTR = {
    "KMeans": "tab:purple",
    "SpectralClustering": "tab:cyan",
    "AC w/ C": "tab:red",
    "AC w/o C": "tab:orange",
    "AffinityPropagation": "tab:green",
    "HDBSCAN": "tab:blue",
}
CLUSTERER2COLORRGB = {k: matplotlib.colors.to_rgb(v) for k, v in CLUSTERER2COLORSTR.items()}

In [None]:
MODEL2COLORSTR = {
    "none": "black",
    "random_resnet50": "dimgrey",
    "random_vitb16": "dimgrey",
    "resnet50": "tab:red",
    "mocov3_resnet50": "tab:green",
    "dino_resnet50": "tab:purple",
    "vicreg_resnet50": "tab:orange",
    "clip_RN50": "tab:olive",
    "vitb16": "tab:red",
    "mocov3_vit_base": "tab:green",
    "dino_vitb16": "tab:purple",
    "timm_vit_base_patch16_224.mae": "tab:blue",
    "mae_pretrain_vit_base_global": "tab:brown",
    "clip_vitb16": "tab:olive",
    "mae_finetuned_vit_base_global": "tab:brown",
    "barcodebert": "tab:brown",
    "dnabert-2": "tab:orange",
    "dnabert-s": "tab:red",
    "hyenadna": "tab:cyan",
    "NucleotideTransformer": "tab:green",
}
MODEL2COLORRGB = {k: matplotlib.colors.to_rgb(v) for k, v in MODEL2COLORSTR.items()}

In [None]:
for model in FT_MODELS:
    MODEL2COLORRGB[model] = tuple(c * 0.8 for c in MODEL2COLORRGB[FT2PRE[model]])
# for model in RESNET50_MODELS + VITB16_MODELS:
#     MODEL2COLORRGB[model] = tuple(1 - (1 - c) * 0.7 for c in MODEL2COLORRGB[model])

In [None]:
for model in RESNET50_MODELS:
    MODEL2COLORRGB[model] = tuple(1 - (1 - c) * 0.7 for c in MODEL2COLORRGB[model])
for model in VITB16_MODELS:
    MODEL2COLORRGB[model] = tuple(c * 0.8 for c in MODEL2COLORRGB[model])

## Fetch results

In [None]:
runs_df_long = pd.DataFrame({"id": []})
config_keys = set()
summary_keys = set()

In [None]:
# Load previous results from CSV file
CSV_FNAME = "test_runs_df.csv"
if os.path.isfile(CSV_FNAME):
    pass
    # runs_df_long = test_runs_df = pd.read_csv(CSV_FNAME)

In [None]:
# Project is specified by <entity/project-name>
api = wandb.Api(timeout=720)
runs = api.runs(
    "uoguelph_mlrg/zs-ssl-clustering_BIOSCAN-5M_fixpth_fixDNA",
    filters={
        "state": "Finished",
    },
    per_page=10_000,
)
len(runs)

In [None]:
print(f"{len(runs_df_long)} runs currently in dataframe")
rows_to_add = []
existing_ids = set(runs_df_long["id"].values)
for run in tqdm(runs):
    if run.id in existing_ids:
        if len(rows_to_add) >= len(runs) - len(runs_df_long):
            break
        continue
    # .summary contains the output keys/values for metrics like accuracy.
    #  We call ._json_dict to omit large files
    summary = run.summary._json_dict
    # .config contains the hyperparameters.
    #  We remove special values that start with _.
    config = {k: v for k, v in run.config.items() if not k.startswith("_")}
    # .name is the human-readable name of the run.
    row = {"id": run.id, "name": run.name}
    row.update({k: v for k, v in config.items() if not k.startswith("_")})
    row.update({k: v for k, v in summary.items()})
    if "_timestamp" in summary:
        row["_timestamp"] = summary["_timestamp"]
    rows_to_add.append(row)
    config_keys = config_keys.union(config.keys())
    summary_keys = summary_keys.union(summary.keys())

if not len(rows_to_add):
    print("No new runs to add")
else:
    print(f"Adding {len(rows_to_add)} runs")
    runs_df_long = pd.concat([runs_df_long, pd.DataFrame.from_records(rows_to_add)])
print(f"{len(runs_df_long)} runs")

In [None]:
# Remove entries without an AMI metric
test_runs_df = runs_df_long[~runs_df_long["AMI"].isna()]
len(test_runs_df)

In [None]:
config_keys = config_keys.difference({"workers", "memory_avail_GB", "memory_total_GB", "memory_slurm"})

In [None]:
test_runs_df

In [None]:
# Change list-type columns to a single string
test_runs_df["modality"] = test_runs_df["modality"].apply(lambda x: x if isinstance(x, str) else "+".join(x))
test_runs_df["partition"] = test_runs_df["partition"].apply(lambda x: x if isinstance(x, str) else "+".join(x))

In [None]:
test_runs_df["prenorm"].unique()

In [None]:
test_runs_df.loc[test_runs_df["prenorm"].isna(), "prenorm"] = "none"

In [None]:
test_runs_df

In [None]:
# Save results to CSV file, so we can optionally skip downloading them
test_runs_df.to_csv(CSV_FNAME, index=False)

In [None]:
np.mean(test_runs_df.loc[test_runs_df["partition"] == "test+test_unseen", "_runtime"]) / 60 / 60

## Result loading utility functions

In [None]:
model = "mocov3_resnet50"
dataset = "bioscan5m"
clusterer = "AgglomerativeClustering"
metric_key = "AMI"

my_override_fields = {}

filter1 = {"model": model, "dataset": dataset}
filter2 = dict(DEFAULT_PARAMS["all"], **DEFAULT_PARAMS[clusterer])
filter2.update(filter1)
filter2.update(my_override_fields)
filter2 = fixup_filter(filter2)
sdf = select_rows(test_runs_df, filter2, allow_missing=False)
my_val = np.nanmedian(sdf[metric_key])

print(f"{metric_key} = {my_val * 100:.0f}%")

In [None]:
def build_results_table(
    models=(None,),
    dna_models=(None,),
    clusterers=("AgglomerativeClustering",),
    datasets=("bioscan5m",),
    metric_keys="AMI",
    partition="test+test_unseen",
    override_fields=None,
    return_cmds=False,
    verbosity=0,
):
    if override_fields is None:
        override_fields = {}

    do_squeeze = False
    if isinstance(metric_keys, str):
        do_squeeze = True
        metric_keys = [metric_keys]

    result_table = np.nan * np.ones((len(models), len(dna_models), len(clusterers), len(datasets), len(metric_keys)))
    cmds = []

    for i_model, model in enumerate(models):
        for i_dna, dna_model in enumerate(dna_models):
            for i_clusterer, clusterer in enumerate(clusterers):
                for i_dataset, dataset in enumerate(datasets):
                    if model is None and dna_model is None:
                        continue
                    filter1 = {"dataset": dataset, "partition": partition}
                    filter1["modality"] = "image+dna"
                    if model is None:
                        filter1["modality"] = "dna"
                    else:
                        filter1["model"] = model
                    if dna_model is None:
                        filter1["modality"] = "image"
                    else:
                        filter1["model_dna"] = dna_model
                    filter2 = dict(DEFAULT_PARAMS["all"], **DEFAULT_PARAMS[clusterer])
                    filter2.update(filter1)
                    filter2.update(override_fields)
                    filter2 = fixup_filter(filter2)
                    sdf = select_rows(test_runs_df, filter2, allow_missing=False)
                    missing_val = False
                    if len(sdf) > 0:
                        if len(sdf) > 1:
                            print(f"{len(sdf)} entries for {filter1}")
                        for i_key, key in enumerate(metric_keys):
                            val = np.nanmedian(sdf[key])
                            result_table[i_model, i_dna, i_clusterer, i_dataset, i_key] = val
                            if np.isnan(val):
                                missing_val = True
                    if len(sdf) < 1 or missing_val:
                        if verbosity >= 1:
                            print(f"No data for {model}-{dna_model}-{dataset}-{clusterer}\n{filter2}")
                        cmds.append(filter2command(filter2, partition="test"))

    if do_squeeze:
        result_table = np.squeeze(result_table, axis=-1)

    if return_cmds:
        return result_table, cmds
    else:
        return result_table

In [None]:
def dict_generator(indict, pre=None):
    pre = pre[:] if pre else []
    if isinstance(indict, dict):
        for key, value in indict.items():
            if isinstance(value, dict):
                for d in dict_generator(value, pre + [key]):
                    yield d
            elif isinstance(value, list) or isinstance(value, tuple):
                for v in value:
                    for d in dict_generator(v, pre + [key]):
                        yield d
            else:
                yield pre + [key, value]
    else:
        yield pre + [indict]

In [None]:
def make_flat_hierarchy_from_dict(indict, pad_right=True):
    groups_flattened = list(dict_generator(indict))
    depth = max(len(m) for m in groups_flattened)
    if pad_right:
        groups_flattened = [m + [""] * (depth - len(m)) for m in groups_flattened]
    else:
        groups_flattened = [[""] * (depth - len(m)) + m for m in groups_flattened]

    return groups_flattened

## Cross-modality results table

In [None]:
clusterers = ["AgglomerativeClustering"]
# test_datasets = ["bioscan5m"]
test_datasets = ["bioscan5m"]
metric_key = "AMI"
# override_fields = {"prenorm": "none"}
override_fields = {"prenorm": "elementwise_zscore", "n_clusters": None}
merge_model_group_column = False
partition = "test+test_unseen"
# partition = "test_unseen"

use_rank = False
show_pc = True
show_fmt = "{:4.0f}"
highlight_best = True
use_si_num = False
eps = 0.005
fixed_sc_base = None
show_ft = False
show_color = False

if metric_key == "num_cluster_pred":
    clusterers = ["AC w/o C", "AffinityPropagation", "HDBSCAN"]
    show_pc = False
    show_fmt = "{:4.0f}"
    highlight_best = False
    use_si_num = True

if metric_key.startswith("silhouette"):
    show_pc = False
    show_fmt = "{:5.2f}"


model_groups = {
    "": [None],
    "RN50": RESNET50_MODELS,  # + FT_RESNET50_MODELS,
    "ViT-B": VITB16_MODELS,  # + FT_VITB16_MODELS,
}

dna_model_groups = {
    "": [None],
    "DNA Encoder": DNA_MODELS,
}

if len(clusterers) == 1:
    clustererstr = clusterers[0]
    if metric_key.endswith("_true"):
        clustererstr = "GT"
else:
    clustererstr = f"{len(clusterers)}c-avg"

model_groups_flattened = make_flat_hierarchy_from_dict(model_groups, pad_right=False)
model_groups_flattened = np.array(model_groups_flattened)
model_groups_flattened = model_groups_flattened[:, -1]

dna_models = dna_model_groups_flattened = [xi for v in dna_model_groups.values() for xi in v]

print("Image Encoders:")
print(model_groups_flattened)
print()
print("DNA Encoders:")
print(dna_model_groups_flattened)
print()
print("Datasets:")
print(test_datasets)
print()
print("Clusterers:")
print(clusterers)
print()

result_table, cmds = build_results_table(
    model_groups_flattened,
    dna_model_groups_flattened,
    clusterers,
    test_datasets,
    metric_keys=metric_key,
    partition=partition,
    override_fields=override_fields,
    return_cmds=True,
)
# Shaped [models, dna_models, clusterers, datasets]
print("result_table.shape", result_table.shape)

# Take mean over clusterers and datasets
result_table = np.nanmean(result_table, axis=-1)
result_table = np.nanmean(result_table, axis=-1)
# Shaped [models, dna_models]

print("result_table.shape", result_table.shape)

print(model_groups)


best_results = {k: [] for k in dna_model_groups_flattened}
best_results_grouped = {k: defaultdict(list) for k in dna_model_groups_flattened}

for dummy in [True, False]:
    latex_table = r"% Results for " + f"{metric_key}, {clustererstr}, {partition}, {test_datasets[0]}" + "\n"
    now_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    latex_table += r"% Generated " + now_str + "\n"
    latex_table += r"% Using hparams " + BEST_PARAMS["_version"] + "\n"
    label = clustererstr
    label = metric_key.replace("_", "-") + ":" + label
    latex_table += r"\label{tab:" + label + r"}" + "\n"
    latex_table += r"%\resizebox{\columnwidth}{!}{%" + "\n"
    latex_table += r"\begin{tabular}{"
    if not merge_model_group_column:
        latex_table += "l"
    if show_ft:
        latex_table += "l"
    latex_table += "l" + r"r" * len(dna_models) + r"}" + "\n"
    latex_table += r"\toprule" + "\n"
    # Begin dataset group header row
    if len(dna_model_groups) > 1:
        if not merge_model_group_column:
            latex_table += r"& "
        latex_table += f"{'':<11s}"
        if show_ft:
            latex_table += r" &   "
        for datagroupname, datagroupset in dna_model_groups.items():
            latex_table += r" & \multicolumn{" + str(len(datagroupset)) + r"}{c}{" + datagroupname + r"}"
        latex_table += r"\\" + "\n"
        icol = 3
        if not merge_model_group_column:
            icol += 1
        if show_ft:
            icol += 1
        for datagroupname, datagroupset in dna_model_groups.items():
            if datagroupname == "":
                continue
            latex_table += r"\cmidrule(l){" + f"{icol}-{icol + len(datagroupset) - 1}" + r"}"
            icol += len(datagroupset)
        latex_table += "\n"
    # Begin main header row, with actual dataset names
    if merge_model_group_column:
        latex_table += r"\quad "
    else:
        latex_table += r"Arch. & "
    latex_table += f"{'Image encoder':<11s}"
    if show_ft:
        latex_table += r" & FT "
    for dna_model in dna_model_groups_flattened:
        latex_table += r"& \rotatebox{90}{"
        latex_table += "{:^15s}".format(DNAMODEL2SH.get(dna_model, dna_model))  # map to a shorthand
        latex_table += r"}"
    latex_table += r"\\" + "\n"
    # Begin table contents
    latex_table += r"\midrule" + "\n"
    i_model_o = -1

    i_model_o = -1
    for i_group, group in enumerate(model_groups):
        if i_group > 0:
            latex_table += r"\midrule" + "\n"
        if merge_model_group_column:
            if not group:
                latex_table += r"\quad "
            else:
                latex_table += r"\textbf{" + group + r"} --- "
        elif not group:
            latex_table += "---" + "\n"
        else:
            latex_table += group + "\n"
        for i_model, model in enumerate(list(model_groups[group])):
            i_model_o += 1
            model_sh = MODEL2SH.get(model, model)
            if not show_ft:
                pass
            elif model_sh.endswith(" [FT]"):
                model_sh = f"{model_sh[:-4]:<10s}" + r" & \checkmark"
            else:
                model_sh = f"{model_sh:<10s}" + " &"
            if merge_model_group_column and i_model > 0:
                latex_table += r"\quad "
            if not merge_model_group_column:
                latex_table += "& "
            latex_table += f"{model_sh:<23s}"
            for i_dna, dna_model in enumerate(dna_model_groups_flattened):
                latex_table += " &"
                my_val = result_table[i_model_o, i_dna]
                if dummy:
                    best_results[dna_model].append(my_val)
                    best_results_grouped[dna_model][group].append(my_val)
                    continue
                if np.isnan(my_val):
                    latex_table += r"   --  "
                    continue
                is_best = my_val + eps >= np.max(best_results[dna_model])
                if len(best_results[dna_model]) > 1:
                    is_secd = my_val + eps >= np.sort(best_results[dna_model])[-2]
                else:
                    is_secd = False
                is_best_grp = my_val + eps >= np.max(best_results_grouped[dna_model][group])
                is_best_grp &= len(best_results_grouped[dna_model][group]) > 1
                sc_base = np.nanmedian(best_results[dna_model])
                if fixed_sc_base is not None:
                    sc_base = fixed_sc_base
                sc_top = np.max(best_results[dna_model])
                sc = 100 * max(0, (my_val - sc_base) / (sc_top - sc_base))
                if show_color and sc_top >= sc_base:
                    latex_table += r"\cc{cbg!" + f"{sc:.0f}" + "}"
                if show_pc:
                    my_val = my_val * 100
                latex_table += r"\num{" if use_si_num else r"$"
                if not highlight_best:
                    pass
                elif is_best:
                    latex_table += r"\tcf{"
                elif is_secd:
                    latex_table += r"\tcs{"
                else:
                    # latex_table += "     "
                    pass
                if not highlight_best:
                    pass
                elif is_best_grp:
                    latex_table += r"\tcg{"
                else:
                    # latex_table += "     "
                    pass
                latex_table += show_fmt.format(my_val)
                if highlight_best:
                    latex_table += r"}" if is_best or is_secd else " "
                    latex_table += r"}" if is_best_grp else " "
                latex_table += r"}" if use_si_num else r"$"
            latex_table += r" \\" + "\n"
    latex_table += r"\bottomrule" + "\n"
    latex_table += r"\end{tabular}" + "\n"
    latex_table += r"%}" + "\n"

print()
print(f"There are {len(cmds)} commands to execute to generate missing datapoints")

print()
print("Done!")
print()
print(f"Here is your results table for {metric_key}, {clustererstr}:")
print()
print()
print(latex_table)

In [None]:
for cmd in cmds:
    print(cmd)

## Plot AMI over taxonomic rank

In [None]:
import re

from zs_ssl_clustering.io import sanitize_filename


def get_pred_path(row):
    """
    Generate path to y_pred file.
    """
    run_id = row["name"].split("__")[-1]
    partition = row["partition"]
    partition = partition.replace("+", "-")
    fname = f"{partition}-{row['dataset_name']}"
    if "image" in row["modality"]:
        fname += f"__{row['model']}"
    if "dna" in row["modality"]:
        fname += f"__{row['model_dna']}"
    fname += f"__{run_id}.npz"
    fname = sanitize_filename(fname)
    fname = os.path.join(
        row["predictions_dir"],
        sanitize_filename(f"{partition}"),
        fname,
    )
    return fname

In [None]:
import bioscan5m

In [None]:
annotation_levels = label_cols = [
    "class",
    "order",
    "family",
    "subfamily",
    "genus",
    "species",
    "dna_bin",
]

In [None]:
ds = bioscan5m.BIOSCAN5M("~/Datasets/BIOSCAN-5M", modality="dna", split="all")

In [None]:
ds.metadata = ds.metadata[ds.metadata["split"].isin(["test", "test_unseen"])]

In [None]:
np.unique(ds.metadata["class_index"], return_counts=True)

In [None]:
attrs = np.stack(
    [ds.metadata[f"{taxa}_index"].to_numpy() for taxa in label_cols],
    axis=-1,
)

In [None]:
attrs.shape

### Images

In [None]:
dataset = "bioscan5m"  # "bioscan5m_per-barcode-dedupNs"  # "bioscan5m" "bioscan5m_per-barcode-dedupNs"
partition = "test+test_unseen"
modality = "image"
metric_key = "AMI"
override_fields = {
    "predictions_dir": "y_pred",
    "prenorm": "none",
}

ds_args = {}
if dataset == "bioscan5m_per-barcode-dedupNs":
    ds_args["reduce_repeated_barcodes"] = "rstrip_Ns"
ds = bioscan5m.BIOSCAN5M("~/Datasets/BIOSCAN-5M", modality="dna", split="all", **ds_args)
if partition == "test+test_unseen":
    ds.metadata = pd.concat([ds.metadata[ds.metadata["split"] == "test"], ds.metadata[ds.metadata["split"] == "test_unseen"]])
else:
    ds.metadata = ds.metadata[ds.metadata["split"] == partition]

attrs = np.stack(
    [ds.metadata[f"{taxa}_index"].to_numpy() for taxa in label_cols],
    axis=-1,
)

for i_attr in range(len(annotation_levels)):
    print(annotation_levels[i_attr], len(np.unique(attrs[:, i_attr])))

print(attrs[np.random.choice(10000, 10),], attrs.shape)

print(MODEL_GROUPS)

TEST_ATTRS = annotation_levels
print(TEST_ATTRS)

plot_data = {}
for backbone in ["ResNet-50", "ViT-B"]:
    print(MODEL_GROUPS[backbone])

    for model in list(MODEL_GROUPS[backbone]):
        filter1 = {"dataset": dataset, "partition": partition, "modality": modality}
        if modality == "image":
            filter1["model"] = model
        if modality == "dna":
            filter1["model_dna"] = model

        filter2 = dict(DEFAULT_PARAMS["all"], **DEFAULT_PARAMS[clusterer])
        filter2.update(filter1)
        filter2.update(override_fields)
        filter2 = fixup_filter(filter2)
        sdf = select_rows(test_runs_df, filter2, allow_missing=False)

        if len(sdf) < 1:
            print()
            print(filter2)
            print(f"No data for {model}-{dataset}-{clusterer}")  # \n{filter} {filter2}")
            continue
        if len(sdf) > 1:
            perf = sdf.iloc[0][metric_key]
            if sum(sdf[metric_key] != perf) > 0:
                print()
                print(
                    f"More than one result with {metric_key} values",
                    list(sdf[metric_key]),
                )
                print(f"for search {filter}\nand {filter2}")
                dif_cols = find_differing_columns(sdf, config_keys)
                print(f"columns which differ: {dif_cols}")
                if dif_cols:
                    for col in dif_cols:
                        print(f"  {col}: {list(sdf[col])}")

        y_pred = np.load("../" + get_pred_path(sdf.iloc[0]))["y_pred"]

        data_vec = []
        for i_attr, attr in enumerate(TEST_ATTRS):
            if metric_key.lower() != "ami":
                raise NotImplementedError()
            my_val = sklearn.metrics.adjusted_mutual_info_score(attrs[:, i_attr], y_pred)
            data_vec.append(my_val)
        plot_data[model] = data_vec

In [None]:
fig = plt.figure(figsize=(5, 2.5))

x_ranges = np.arange(len(annotation_levels))
x_labels = ["BIN" if x == "dna_bin" else x for x in annotation_levels]

for model, data_vec in plot_data.items():
    plt.plot(
        x_ranges,
        np.array(data_vec) * 100,
        marker=ARCH2MARKER.get(MODEL2ARCH.get(model, ""), "o"),
        ls=ARCH2LS.get(MODEL2ARCH.get(model, ""), "-"),
        color=MODEL2COLORRGB.get(model, (0.0, 0.0, 0.0)),
        label=MODEL2SH.get(model, model),
    )

axs = plt.gca()
axs.grid(True, axis="y", which="major", linestyle="-", alpha=0.8)
axs.grid(True, axis="y", which="minor", linestyle="--", alpha=0.3)
axs.set_xticks(x_ranges)
axs.set_xticklabels(x_labels, rotation=45)
axs.set_yticks(np.arange(0, 110, 10))
axs.set_yticks(np.arange(0, 110, 5), minor=True)
axs.set_ylabel("AMI (%)", fontsize=12)


def colorMarker(m):
    return plt.plot([], [], color=m)[0]


plt_labels = model_names = list(MODEL_GROUPS["ResNet-50"])[:-1] + MODEL_GROUPS["ViT-B"][-3:]
handles = [colorMarker(MODEL2COLORSTR[name]) for name in model_names]


def backboneMarker(m, color="black"):
    return plt.plot([], [], marker=m, ls="none", color=color)[0]


handles.extend(
    [
        backboneMarker(ARCH2MARKER["ResNet-50"], color=(0.4, 0.4, 0.4)),
        backboneMarker(ARCH2MARKER["ViT-B"]),
    ]
)

plt_labels += ["RN50", "ViT-B"]

ncols = 3
axs.legend(
    handles,
    [MODEL2SH.get(x, x) for x in plt_labels],
    bbox_to_anchor=(0.0, 1.1, 1.0, 0.402),
    loc="lower left",
    ncol=ncols,
    mode="expand",
    borderaxespad=0.0,
)

fig.savefig(f"zsc_{dataset}_{partition.replace('+', '-')}_image_taxonomic_performance.pdf", bbox_inches="tight")

### DNA

In [None]:
dataset = "bioscan5m"  # "bioscan5m"  "bioscan5m_per-barcode-dedupNs"
partition = "test+test_unseen"
modality = "dna"
models = DNA_MODELS
metric_key = "AMI"
override_fields = {
    "predictions_dir": "y_pred",
    "prenorm": "none",
}

ds_args = {}
if dataset == "bioscan5m_per-barcode-dedupNs":
    ds_args["reduce_repeated_barcodes"] = "rstrip_Ns"
ds = bioscan5m.BIOSCAN5M("~/Datasets/BIOSCAN-5M", modality="dna", split="all", **ds_args)
if partition == "test+test_unseen":
    ds.metadata = pd.concat([ds.metadata[ds.metadata["split"] == "test"], ds.metadata[ds.metadata["split"] == "test_unseen"]])
else:
    ds.metadata = ds.metadata[ds.metadata["split"] == partition]

attrs = np.stack(
    [ds.metadata[f"{taxa}_index"].to_numpy() for taxa in label_cols],
    axis=-1,
)

for i_attr in range(len(annotation_levels)):
    print(annotation_levels[i_attr], len(np.unique(attrs[:, i_attr])))

print(attrs[np.random.choice(10000, 10),], attrs.shape)


print(MODEL_GROUPS)

TEST_ATTRS = annotation_levels
print(TEST_ATTRS)

plot_data = {}
for model in models:
    filter1 = {"dataset": dataset, "partition": partition, "modality": modality}
    if modality == "image":
        filter1["model"] = model
    if modality == "dna":
        filter1["model_dna"] = model
        filter1["model_dna"] = model

    filter2 = dict(DEFAULT_PARAMS["all"], **DEFAULT_PARAMS[clusterer])
    filter2.update(filter1)
    filter2.update(override_fields)
    filter2 = fixup_filter(filter2)
    sdf = select_rows(test_runs_df, filter2, allow_missing=False)

    if len(sdf) < 1:
        print()
        print(filter2)
        print(f"No data for {model}-{dataset}-{clusterer}")  # \n{filter} {filter2}")
        continue
    if len(sdf) > 1:
        perf = sdf.iloc[0][metric_key]
        if sum(sdf[metric_key] != perf) > 0:
            print()
            print(
                f"More than one result with {metric_key} values",
                list(sdf[metric_key]),
            )
            print(f"for search {filter}\nand {filter2}")
            dif_cols = find_differing_columns(sdf, config_keys)
            print(f"columns which differ: {dif_cols}")
            if dif_cols:
                for col in dif_cols:
                    print(f"  {col}: {list(sdf[col])}")

    y_pred = np.load("../" + get_pred_path(sdf.iloc[0]))["y_pred"]

    data_vec = []
    for i_attr, attr in enumerate(TEST_ATTRS):
        if metric_key.lower() != "ami":
            raise NotImplementedError()
        my_val = sklearn.metrics.adjusted_mutual_info_score(attrs[:, i_attr], y_pred)
        data_vec.append(my_val)
    plot_data[model] = data_vec

In [None]:
fig = plt.figure(figsize=(5, 2.5))

x_ranges = np.arange(len(annotation_levels))
x_labels = ["BIN" if x == "dna_bin" else x for x in annotation_levels]

for model, data_vec in plot_data.items():
    plt.plot(
        x_ranges,
        np.array(data_vec) * 100,
        marker="o",
        ls="-",
        color=MODEL2COLORRGB.get(model, (0.0, 0.0, 0.0)),
        label=DNAMODEL2SH.get(model, model),
    )

axs = plt.gca()
axs.grid(True, axis="y", which="major", linestyle="-", alpha=0.8)
axs.grid(True, axis="y", which="minor", linestyle="--", alpha=0.3)
axs.set_xticks(x_ranges)
axs.set_xticklabels(x_labels, rotation=45)
axs.set_yticks(np.arange(0, 110, 10))
# axs.set_yticks(np.arange(0, 110, 5), minor=True)
axs.set_ylabel("AMI (%)", fontsize=12)


def colorMarker(m):
    return plt.plot([], [], color=m)[0]


plt_labels = model_names = list(plot_data.keys())
handles = [colorMarker(MODEL2COLORRGB[name]) for name in model_names]

ncols = 3
axs.legend(
    handles,
    [DNAMODEL2SH.get(x, x) for x in plt_labels],
    bbox_to_anchor=(0.0, 1.1, 1.0, 0.402),
    loc="lower left",
    ncol=ncols,
    mode="expand",
    borderaxespad=0.0,
)

fig.savefig(f"zsc_{dataset}_{partition.replace('+', '-')}_DNA_taxonomic_performance.pdf", bbox_inches="tight")