In [1]:
from itertools import product
from typing import Any, Callable

import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from scipy.io import loadmat
from scipy.spatial.distance import pdist
from tqdm import tqdm

In [2]:
data = {}
loadmat("assets/labels_summary.mat", data)
loadmat("assets/results_summary.mat", data)
print("loaded data")

loaded data


In [3]:
def mat_to_df(m):
    mdtype = m.dtype
    ndata = {}
    for n in mdtype.names:
        data = []
        for x in m[n][0]:
            try:
                value = x[0].item()
            except:
                if len(x) == 1:
                    value = x[0]
                else:
                    value = None
            data.append(value)
        ndata[n] = data
    return pd.DataFrame({k: v for k, v in ndata.items() if k != "index"}, ndata["index"])

In [4]:
idx_dfm = ["model_name", "dataset_name", "ood"]
idx_dfl = ["dataset_name", "ood"]

dfl = mat_to_df(data["label"])
dfm_raw = mat_to_df(data["result"])

In [5]:
# flatten dataframe
df_flat = dfm_raw.melt(id_vars=idx_dfm, var_name="metric").copy()

# rename fpi metrics and remove z_ prefix from their model_name
query = df_flat.model_name.str.startswith("z_")
df_flat.loc[query, "metric"] = "i_" + df_flat.loc[query, "metric"]
df_flat.loc[query, "model_name"] = df_flat.loc[query, "model_name"].str[2:]

# recreate dfm-like structure
dfm = df_flat.pivot(index=idx_dfm, columns="metric", values="value").reset_index()

df_flat = df_flat.dropna()

In [6]:
autoencoders = ["resnet_mse"]
classifiers = ["resnet50_vicreg_ce"]
classifiers_ae = ["resnet_ce_mse", "resnet_edl_mse"]
flows = ["flow_ss_vcr_mse"]

df_autoencoders = df_flat[df_flat.model_name.isin(autoencoders)]
df_classifiers = df_flat[df_flat.model_name.isin(classifiers)]
df_classifiers_ae = df_flat[df_flat.model_name.isin(classifiers_ae)]
df_flows = df_flat[df_flat.model_name.isin(flows)]

In [7]:
def permutation_test_2samp(sa: np.ndarray, sb: np.ndarray, stat: Callable, tail: str, n: int):
    """
    Given two samples (sa, sb), compute the test statistic (stat) and approximate its p value through n permutations

    Args:
    sa: sample A
    sb: sample B
    stat: test statistic. invoked as f(sa, sb)
    n: number of permutations

    """
    op = lambda x: x(axis=0, keepdims=True)
    svec = np.concatenate([sa, sb], axis=0)  # vectorize samples
    svec = (svec - op(svec.mean)) / op(svec.std)  # vectorize samples
    partition = [len(sa)]  # define partitions
    rng = np.random.RandomState(42)

    obs = stat(*np.split(svec, partition))  # compute statistic for true observation
    stats = []
    for i in range(n):  # compute statistic across n permutations
        sp = rng.permutation(svec)
        stats.append(stat(*np.split(sp, partition)))

    # compute p value with
    stats = np.array(stats)
    if tail == "left":
        c = (stats <= obs).sum()
    elif tail == "right":
        c = (stats >= obs).sum()
    elif tail == "both":
        aobs = abs(obs)
        cl, cr = (stats <= -aobs).sum(), (stats >= aobs).sum()
        c = 2 * min(cl, cr)
    p = (c + 1) / (n + 1)  # +1 to count the true observation

    # return statistic, p-value, and the n stats under the null
    return obs, p, stats

In [8]:
def mmd(x: np.ndarray, y: np.ndarray, **kwargs):
    """
    Compute the mean of the difference of means

    """
    return (x.mean(axis=0) - y.mean(axis=0)).mean()


def auc(*s: np.ndarray):
    sa, sb = s
    na, nb = len(sa), len(sb)
    S = np.concatenate([sa, sb], axis=0)

    from sklearn.linear_model import LinearRegression
    from sklearn.metrics import roc_auc_score

    # find the optimal hyperplane which separates sa and sb
    true = np.repeat([0, 1], [na, nb])
    model = LinearRegression().fit(S, true)
    pred = model.predict(S)

    # compute auc of the predictions
    auc = roc_auc_score(true, pred)

    return max(auc, 1-auc)


def mrpp(*s: np.ndarray, metric: Any = "euclidean", **kwargs):
    """
    Compute the MRPP statistic for the given samples
    Multi-Response Permutation Procedure

    Idea - weighted sum of mean pairwise distance of each group
    """
    sa, sb = s

    # compute weights (wa, wb) and within-group pair counts (pa,pb)
    na, nb = len(sa), len(sb)
    n = na + nb
    wa, wb = na / n, nb / n
    assert wa + wb == 1.0

    # compute mean L2 distance
    da = pdist(sa, metric=metric).mean()
    db = pdist(sb, metric=metric).mean()

    # compute statistic
    δ = da * wa + db * wb

    return δ

In [9]:
def get_metrics(dfm, dfl, model_name, dataset_name, ood):
    # get ind and ood labels
    rec = dfl[(dfl.dataset_name == dataset_name) & (dfl.ood == ood)].iloc[0]
    ind_labels = ood_labels = []
    if rec.ind_labels is not None:
        ind_labels = rec.ind_labels
    if rec.ood_labels is not None:
        ood_labels = rec.ood_labels

    # get ind and ood data
    rec = dfm[(dfm.model_name == model_name) & (dfm.dataset_name == dataset_name) & (dfm.ood == ood)].iloc[0]
    _ind_data = rec[rec.index.str.endswith("_ind_data")].dropna()
    _ind_cols = list(_ind_data.index.str[:-9])
    _ind_data.index = _ind_cols
    _ood_data = rec[rec.index.str.endswith("_ood_data")].dropna()
    _ood_cols = list(_ood_data.index.str[:-9])
    _ood_data.index = _ood_cols
    assert _ind_cols == _ood_cols

    df_ind = pd.DataFrame({**_ind_data.to_dict(), "label": ind_labels})
    df_ood = pd.DataFrame({**_ood_data.to_dict(), "label": ood_labels})

    return _ind_cols, df_ind, df_ood

In [10]:
def pairwise_hypothesis_test(dfm, dfl, model, dataset, ood, stat_name, sample_size, num_perms):
    desc = f"{model},{dataset}-{ood},{stat_name}"
    if stat_name == "mrpp":
        stat, tail = mrpp, "left"
    elif stat_name == "mmd":
        stat, tail = mmd, "both"
    elif stat_name == "auc":
        stat, tail = auc, "right"
    else:
        raise ValueError(stat_name)
    # get metric names and values
    metrics, ind, ood = get_metrics(dfm, dfl, model, dataset, ood)
    li, lo = np.unique(ind.label), np.unique(ood.label)
    ni, no = len(li), len(lo)

    L = np.concatenate([li, lo], axis=0).astype(int)
    S = pd.concat([ind, ood], axis=0, ignore_index=True)

    # compute normalizing hyperparameters
    # FIXME precompute for each (model, dataset) from TRAIN data
    S_arr = S[metrics].astype(float).to_numpy()
    mu, sigma = S_arr.mean(axis=0, keepdims=True), S_arr.std(axis=0, keepdims=True)

    # variables to store test results
    δ = np.zeros((ni + no, ni))
    P = np.zeros((ni + no, ni))
    N = np.zeros((ni + no, ni, num_perms))

    # compute (obs, p) for all label-label pairs
    tests = list(product(range(ni + no), range(ni)))
    for i, j in tqdm(tests, desc=desc):
        sa = S.loc[S.label == L[i], metrics].astype(float).to_numpy()
        sb = S.loc[S.label == L[j], metrics].astype(float).to_numpy()
        if i == j:
            assert len(sa) >= (2 * sample_size)
        else:
            assert len(sa) >= sample_size
            assert len(sb) >= sample_size
        # take a fixed sample from each observation
        sa, sb = sa[:sample_size], sb[-sample_size:]
        
        # scale samples by hyperparameters
        sa, sb = (sa - mu) / sigma , (sb - mu) / sigma
        
        δ[i, j], P[i, j], N[i, j] = permutation_test_2samp(sa, sb, stat=stat, tail=tail, n=num_perms)

    # return δ and P as dataframes
    δ = pd.DataFrame(data=δ, index=L, columns=L[:ni])
    P = pd.DataFrame(data=P, index=L, columns=L[:ni])

    return metrics, δ, P, N

In [11]:
def infer_single_model(model: str, stat: str, sample_size: int, num_perms: int, basepath: str):
    
    import os
    from matplotlib.colors import LogNorm

    fig, axs = plt.subplots(nrows=2, ncols=8, figsize=(80, 15))
    runs = [
        ("CIFAR10", "A"),
        ("CIFAR10", "B"),
        ("MNIST", "A"),
        ("MNIST", "B"),
        ("QPM_species", "A"),
        ("QPM_species", "B"),
        ("QPM2_species", "A"),
        ("QPM2_species", "B"),
    ]
    stats = {}
    for i, (dataset, ood) in enumerate(runs):
        m, δ, p, n = pairwise_hypothesis_test(dfm, dfl, model, dataset, ood, stat, sample_size, num_perms)
        axa, axb = axs[0][i], axs[1][i]
        sns.heatmap(δ, annot=True, fmt=".3f", ax=axa)
        sns.heatmap(p, annot=True, fmt=".3f", ax=axb, norm=LogNorm(vmin=1/(num_perms+1), vmax=1.0))
        axa.set_title(f"δ_obs: {dataset}-{ood}")
        axa.set_xlabel("Test Label")
        axa.set_ylabel("Train Label")
        axb.set_title(f"P(δ<=δ_obs|H0): {dataset}-{ood}")
        axb.set_xlabel("Test Label")
        axb.set_ylabel("Train Label")
        os.makedirs(f"{basepath}/stats", exist_ok=True)
        np.savez_compressed(f"{basepath}/stats/{model}_{stat}_{dataset}_{ood}.npz", δ=δ.to_numpy(), p=p.to_numpy(), n=n)
        stats[f"{dataset}_{ood}"] = {"δ": δ.to_dict(), "p": p.to_dict()}
    plt.suptitle(f"Test Statistic ({stat}) and P Value\nModel={model}, Metrics={len(m)}")
    plt.tight_layout()
    plt.savefig(f"{basepath}/{model}_{stat}.pdf")
    plt.close()
    return stats

In [None]:
import json

# save test stats
test_stats = {}

perm_counts = [1000, 2000]
sample_sizes = [100, 200, 300]
model_names = ["resnet50_vicreg_ce", "resnet_ce_mse", "resnet_mse"]
stat_names = ["auc", "mmd", "mrpp"]

basepath = f"tests_v4"

for num_perms in perm_counts:
    for sample_size in sample_sizes:
        for model_name in model_names:
            for stat_name in stat_names:
                print(f"Experiment: {model_name}_{stat_name}_s{sample_size}_p{num_perms}")
                fp = f"{basepath}/s{sample_size}/p{num_perms}"
                stats = infer_single_model(model_name, stat_name, sample_size, num_perms, fp)
                test_stats[f"{model_name}_{stat_name}_s{sample_size}_p{num_perms}"] = stats

with open(f"{basepath}/summary.json", "w") as f:
    json.dump(test_stats, f)

Experiment: resnet50_vicreg_ce_auc_s100_p1000


resnet50_vicreg_ce,CIFAR10-A,auc: 100%|██████████| 50/50 [01:07<00:00,  1.35s/it]
resnet50_vicreg_ce,CIFAR10-B,auc: 100%|██████████| 50/50 [01:05<00:00,  1.31s/it]
resnet50_vicreg_ce,MNIST-A,auc: 100%|██████████| 50/50 [01:03<00:00,  1.28s/it]
resnet50_vicreg_ce,MNIST-B,auc: 100%|██████████| 50/50 [01:03<00:00,  1.27s/it]
resnet50_vicreg_ce,QPM_species-A,auc: 100%|██████████| 378/378 [08:00<00:00,  1.27s/it]
resnet50_vicreg_ce,QPM_species-B,auc: 100%|██████████| 63/63 [01:20<00:00,  1.28s/it]
resnet50_vicreg_ce,QPM2_species-A,auc: 100%|██████████| 21/21 [00:26<00:00,  1.27s/it]
resnet50_vicreg_ce,QPM2_species-B,auc: 100%|██████████| 28/28 [00:35<00:00,  1.27s/it]


Experiment: resnet50_vicreg_ce_mmd_s100_p1000


resnet50_vicreg_ce,CIFAR10-A,mmd: 100%|██████████| 50/50 [00:01<00:00, 27.35it/s]
resnet50_vicreg_ce,CIFAR10-B,mmd: 100%|██████████| 50/50 [00:01<00:00, 27.20it/s]
resnet50_vicreg_ce,MNIST-A,mmd: 100%|██████████| 50/50 [00:01<00:00, 27.71it/s]
resnet50_vicreg_ce,MNIST-B,mmd: 100%|██████████| 50/50 [00:01<00:00, 27.23it/s]
resnet50_vicreg_ce,QPM_species-A,mmd: 100%|██████████| 378/378 [00:13<00:00, 27.13it/s]
resnet50_vicreg_ce,QPM_species-B,mmd: 100%|██████████| 63/63 [00:02<00:00, 27.50it/s]
resnet50_vicreg_ce,QPM2_species-A,mmd: 100%|██████████| 21/21 [00:00<00:00, 26.76it/s]
resnet50_vicreg_ce,QPM2_species-B,mmd: 100%|██████████| 28/28 [00:01<00:00, 27.13it/s]


Experiment: resnet50_vicreg_ce_mrpp_s100_p1000


resnet50_vicreg_ce,CIFAR10-A,mrpp: 100%|██████████| 50/50 [00:03<00:00, 16.13it/s]
resnet50_vicreg_ce,CIFAR10-B,mrpp: 100%|██████████| 50/50 [00:03<00:00, 16.22it/s]
resnet50_vicreg_ce,MNIST-A,mrpp: 100%|██████████| 50/50 [00:03<00:00, 16.17it/s]
resnet50_vicreg_ce,MNIST-B,mrpp: 100%|██████████| 50/50 [00:03<00:00, 16.17it/s]
resnet50_vicreg_ce,QPM_species-A,mrpp: 100%|██████████| 378/378 [00:23<00:00, 16.11it/s]
resnet50_vicreg_ce,QPM_species-B,mrpp: 100%|██████████| 63/63 [00:03<00:00, 15.99it/s]
resnet50_vicreg_ce,QPM2_species-A,mrpp: 100%|██████████| 21/21 [00:01<00:00, 16.07it/s]
resnet50_vicreg_ce,QPM2_species-B,mrpp: 100%|██████████| 28/28 [00:01<00:00, 15.93it/s]


Experiment: resnet_ce_mse_auc_s100_p1000


resnet_ce_mse,CIFAR10-A,auc: 100%|██████████| 50/50 [01:12<00:00,  1.46s/it]
resnet_ce_mse,CIFAR10-B,auc: 100%|██████████| 50/50 [01:12<00:00,  1.45s/it]
resnet_ce_mse,MNIST-A,auc: 100%|██████████| 50/50 [01:12<00:00,  1.44s/it]
resnet_ce_mse,MNIST-B,auc: 100%|██████████| 50/50 [01:12<00:00,  1.45s/it]
resnet_ce_mse,QPM_species-A,auc: 100%|██████████| 378/378 [09:10<00:00,  1.46s/it]
resnet_ce_mse,QPM_species-B,auc: 100%|██████████| 63/63 [01:30<00:00,  1.44s/it]
resnet_ce_mse,QPM2_species-A,auc: 100%|██████████| 21/21 [00:30<00:00,  1.45s/it]
resnet_ce_mse,QPM2_species-B,auc: 100%|██████████| 28/28 [00:40<00:00,  1.45s/it]


Experiment: resnet_ce_mse_mmd_s100_p1000


resnet_ce_mse,CIFAR10-A,mmd: 100%|██████████| 50/50 [00:02<00:00, 24.72it/s]
resnet_ce_mse,CIFAR10-B,mmd: 100%|██████████| 50/50 [00:02<00:00, 24.71it/s]
resnet_ce_mse,MNIST-A,mmd: 100%|██████████| 50/50 [00:02<00:00, 24.95it/s]
resnet_ce_mse,MNIST-B,mmd: 100%|██████████| 50/50 [00:02<00:00, 24.74it/s]
resnet_ce_mse,QPM_species-A,mmd: 100%|██████████| 378/378 [00:15<00:00, 24.73it/s]
resnet_ce_mse,QPM_species-B,mmd: 100%|██████████| 63/63 [00:02<00:00, 24.92it/s]
resnet_ce_mse,QPM2_species-A,mmd: 100%|██████████| 21/21 [00:00<00:00, 24.24it/s]
resnet_ce_mse,QPM2_species-B,mmd: 100%|██████████| 28/28 [00:01<00:00, 24.13it/s]


Experiment: resnet_ce_mse_mrpp_s100_p1000


resnet_ce_mse,CIFAR10-A,mrpp: 100%|██████████| 50/50 [00:05<00:00,  8.34it/s]
resnet_ce_mse,CIFAR10-B,mrpp: 100%|██████████| 50/50 [00:05<00:00,  8.34it/s]
resnet_ce_mse,MNIST-A,mrpp: 100%|██████████| 50/50 [00:05<00:00,  8.38it/s]
resnet_ce_mse,MNIST-B,mrpp: 100%|██████████| 50/50 [00:06<00:00,  8.31it/s]
resnet_ce_mse,QPM_species-A,mrpp: 100%|██████████| 378/378 [00:45<00:00,  8.32it/s]
resnet_ce_mse,QPM_species-B,mrpp: 100%|██████████| 63/63 [00:07<00:00,  8.33it/s]
resnet_ce_mse,QPM2_species-A,mrpp: 100%|██████████| 21/21 [00:02<00:00,  8.18it/s]
resnet_ce_mse,QPM2_species-B,mrpp: 100%|██████████| 28/28 [00:03<00:00,  8.18it/s]


Experiment: resnet_mse_auc_s100_p1000


resnet_mse,CIFAR10-A,auc: 100%|██████████| 50/50 [01:03<00:00,  1.27s/it]
resnet_mse,CIFAR10-B,auc: 100%|██████████| 50/50 [01:03<00:00,  1.27s/it]
resnet_mse,MNIST-A,auc: 100%|██████████| 50/50 [01:03<00:00,  1.27s/it]
resnet_mse,MNIST-B,auc: 100%|██████████| 50/50 [01:03<00:00,  1.27s/it]
resnet_mse,QPM_species-A,auc: 100%|██████████| 378/378 [07:59<00:00,  1.27s/it]
resnet_mse,QPM_species-B,auc: 100%|██████████| 63/63 [01:20<00:00,  1.27s/it]
resnet_mse,QPM2_species-A,auc: 100%|██████████| 21/21 [00:26<00:00,  1.28s/it]
resnet_mse,QPM2_species-B,auc: 100%|██████████| 28/28 [00:35<00:00,  1.28s/it]


Experiment: resnet_mse_mmd_s100_p1000


resnet_mse,CIFAR10-A,mmd: 100%|██████████| 50/50 [00:01<00:00, 27.30it/s]
resnet_mse,CIFAR10-B,mmd: 100%|██████████| 50/50 [00:01<00:00, 27.47it/s]
resnet_mse,MNIST-A,mmd: 100%|██████████| 50/50 [00:01<00:00, 27.49it/s]
resnet_mse,MNIST-B,mmd: 100%|██████████| 50/50 [00:01<00:00, 27.35it/s]
resnet_mse,QPM_species-A,mmd: 100%|██████████| 378/378 [00:13<00:00, 27.22it/s]
resnet_mse,QPM_species-B,mmd: 100%|██████████| 63/63 [00:02<00:00, 27.20it/s]
resnet_mse,QPM2_species-A,mmd: 100%|██████████| 21/21 [00:00<00:00, 26.97it/s]
resnet_mse,QPM2_species-B,mmd: 100%|██████████| 28/28 [00:01<00:00, 26.77it/s]


Experiment: resnet_mse_mrpp_s100_p1000


resnet_mse,CIFAR10-A,mrpp: 100%|██████████| 50/50 [00:03<00:00, 16.14it/s]
resnet_mse,CIFAR10-B,mrpp: 100%|██████████| 50/50 [00:03<00:00, 16.11it/s]
resnet_mse,MNIST-A,mrpp: 100%|██████████| 50/50 [00:03<00:00, 16.10it/s]
resnet_mse,MNIST-B,mrpp: 100%|██████████| 50/50 [00:03<00:00, 15.90it/s]
resnet_mse,QPM_species-A,mrpp: 100%|██████████| 378/378 [00:23<00:00, 16.05it/s]
resnet_mse,QPM_species-B,mrpp: 100%|██████████| 63/63 [00:03<00:00, 16.21it/s]
resnet_mse,QPM2_species-A,mrpp: 100%|██████████| 21/21 [00:01<00:00, 16.00it/s]
resnet_mse,QPM2_species-B,mrpp: 100%|██████████| 28/28 [00:01<00:00, 15.93it/s]


Experiment: resnet50_vicreg_ce_auc_s200_p1000


resnet50_vicreg_ce,CIFAR10-A,auc: 100%|██████████| 50/50 [01:07<00:00,  1.34s/it]
resnet50_vicreg_ce,CIFAR10-B,auc: 100%|██████████| 50/50 [01:06<00:00,  1.33s/it]
resnet50_vicreg_ce,MNIST-A,auc: 100%|██████████| 50/50 [01:06<00:00,  1.34s/it]
resnet50_vicreg_ce,MNIST-B,auc: 100%|██████████| 50/50 [01:06<00:00,  1.32s/it]
resnet50_vicreg_ce,QPM_species-A,auc: 100%|██████████| 378/378 [08:24<00:00,  1.34s/it]
resnet50_vicreg_ce,QPM_species-B,auc: 100%|██████████| 63/63 [01:23<00:00,  1.33s/it]
resnet50_vicreg_ce,QPM2_species-A,auc: 100%|██████████| 21/21 [00:27<00:00,  1.33s/it]
resnet50_vicreg_ce,QPM2_species-B,auc: 100%|██████████| 28/28 [00:37<00:00,  1.33s/it]


Experiment: resnet50_vicreg_ce_mmd_s200_p1000


resnet50_vicreg_ce,CIFAR10-A,mmd: 100%|██████████| 50/50 [00:02<00:00, 22.70it/s]
resnet50_vicreg_ce,CIFAR10-B,mmd: 100%|██████████| 50/50 [00:02<00:00, 22.76it/s]
resnet50_vicreg_ce,MNIST-A,mmd: 100%|██████████| 50/50 [00:02<00:00, 22.85it/s]
resnet50_vicreg_ce,MNIST-B,mmd: 100%|██████████| 50/50 [00:02<00:00, 22.87it/s]
resnet50_vicreg_ce,QPM_species-A,mmd: 100%|██████████| 378/378 [00:16<00:00, 22.61it/s]
resnet50_vicreg_ce,QPM_species-B,mmd: 100%|██████████| 63/63 [00:02<00:00, 22.63it/s]
resnet50_vicreg_ce,QPM2_species-A,mmd: 100%|██████████| 21/21 [00:00<00:00, 22.53it/s]
resnet50_vicreg_ce,QPM2_species-B,mmd: 100%|██████████| 28/28 [00:01<00:00, 22.51it/s]


Experiment: resnet50_vicreg_ce_mrpp_s200_p1000


resnet50_vicreg_ce,CIFAR10-A,mrpp: 100%|██████████| 50/50 [00:07<00:00,  6.34it/s]
resnet50_vicreg_ce,CIFAR10-B,mrpp: 100%|██████████| 50/50 [00:07<00:00,  6.34it/s]
resnet50_vicreg_ce,MNIST-A,mrpp: 100%|██████████| 50/50 [00:07<00:00,  6.32it/s]
resnet50_vicreg_ce,MNIST-B,mrpp: 100%|██████████| 50/50 [00:07<00:00,  6.33it/s]
resnet50_vicreg_ce,QPM_species-A,mrpp: 100%|██████████| 378/378 [00:59<00:00,  6.31it/s]
resnet50_vicreg_ce,QPM_species-B,mrpp: 100%|██████████| 63/63 [00:09<00:00,  6.36it/s]
resnet50_vicreg_ce,QPM2_species-A,mrpp: 100%|██████████| 21/21 [00:03<00:00,  6.32it/s]
resnet50_vicreg_ce,QPM2_species-B,mrpp: 100%|██████████| 28/28 [00:04<00:00,  6.30it/s]


Experiment: resnet_ce_mse_auc_s200_p1000


resnet_ce_mse,CIFAR10-A,auc: 100%|██████████| 50/50 [01:19<00:00,  1.59s/it]
resnet_ce_mse,CIFAR10-B,auc: 100%|██████████| 50/50 [01:20<00:00,  1.61s/it]
resnet_ce_mse,MNIST-A,auc: 100%|██████████| 50/50 [01:19<00:00,  1.59s/it]
resnet_ce_mse,MNIST-B,auc: 100%|██████████| 50/50 [01:19<00:00,  1.59s/it]
resnet_ce_mse,QPM_species-A,auc:  32%|███▏      | 120/378 [03:12<06:54,  1.61s/it]