In [1]:
import os
import pickle
import torch

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from typing import Tuple, List, Dict

from cl_explain.metrics.ablation import compute_auc

In [2]:
RESULT_PATH = "/projects/leelab/cl-explainability/results"
SUPERPIXEL_ATTRIBUTION_METHODS = ["kernel_shap"]
SEED_LIST = [123, 456, 789, 42, 91]

In [3]:
def get_eval_filename(
    different_classes: bool,
    comprehensive: bool,
    corpus_size: int,
    explanation_name: str,
    foil_size: int,
    explicand_size: int,
    attribution_name: str,
    superpixel_dim: int,
    removal: str,
    blur_strength: float,
    eval_superpixel_dim: int,
    eval_foil_size: int,
    take_attribution_abs: bool,
) -> str:
    """Get eval filename."""
    if different_classes:
        eval_filename = "diff_class"
    else:
        eval_filename = "same_class"
    if comprehensive:
        eval_filename += "_comprehensive"
        
    eval_filename += "_eval_results"
    eval_filename += f"_explicand_size={explicand_size}"
    if "corpus" in explanation_name:
        eval_filename += f"_corpus_size={corpus_size}"
    if "contrastive" in explanation_name:
        eval_filename += f"_foil_size={foil_size}"
    if attribution_name in SUPERPIXEL_ATTRIBUTION_METHODS:
        eval_filename += f"_superpixel_dim={superpixel_dim}"
    eval_filename += f"_removal={removal}"
    if removal == "blurring":
        eval_filename += f"_blur_strength={blur_strength:.1f}"
    eval_filename += f"_eval_superpixel_dim={eval_superpixel_dim}"
    if not comprehensive:
        eval_filename += f"_eval_foil_size={eval_foil_size}"
    if take_attribution_abs:
        eval_filename += "_abs"
    eval_filename += ".pkl"
    return eval_filename


def get_mean_curves(outputs, curve_kind) -> Tuple[List[torch.Tensor], int]:
    available_curve_kinds = ["insertion", "deletion"]
    assert curve_kind in available_curve_kinds, (
        f"curve_kind={curve_kind} is not one of {available_curve_kinds}!"
    )
    target_list = [key for key in outputs.keys()]
    eval_name_list = (
        outputs[target_list[0]]["eval_model_names"]
        + outputs[target_list[0]]["eval_measure_names"]
    )
    eval_mean_curve_dict = {}
    for j, eval_name in enumerate(eval_name_list):
        
        curve_list = []
        num_features = None

        for target, output in outputs.items():
            target_curve_list = (
                output[f"model_{curve_kind}_curves"]
                + output[f"measure_{curve_kind}_curves"]
            )
            curve_list.append(target_curve_list[j])
            num_features = output[f"{curve_kind}_num_features"]
        
        curves = torch.cat(curve_list)
        mean_curve = curves.mean(dim=0).cpu()
        eval_mean_curve_dict[eval_name] = mean_curve
        
    return eval_mean_curve_dict, num_features


def get_auc_stats(
    dataset: str,
    encoder: str,
    attribution: str,
    eval_name: str,
    normalize_similarity: bool,
    different_classes: bool,
    comprehensive: bool = False,
    explicand_size: int = 25,
    removal: str = "blurring",
    blur_strength: float = 5.0,
    superpixel_dim: int = 1,
    eval_superpixel_dim: int = 1,
    foil_size: int = 1500,
    corpus_size: int = 100,
    eval_foil_size: int = 1500,
    take_attribution_abs: bool = False,
) -> Dict[str, Dict[str, List]]:
    if attribution == "random_baseline":
        explanation_list = ["self_weighted"]
    else:
        explanation_list = [  # Make sure to order this way.
            "self_weighted",
            "contrastive_self_weighted",
            "corpus",
            "contrastive_corpus",
        ]
        
    insertion_mean_list = []
    insertion_ci_list = []
    deletion_mean_list = []
    deletion_ci_list = []

    for explanation in explanation_list:
        insertion_list = []
        deletion_list = []
        for seed in SEED_LIST:            
            eval_filename = get_eval_filename(
                different_classes=different_classes,
                comprehensive=comprehensive,
                corpus_size=corpus_size,
                explanation_name=explanation,
                foil_size=foil_size,
                explicand_size=explicand_size,
                attribution_name=attribution,
                superpixel_dim=superpixel_dim,
                removal=removal,
                blur_strength=blur_strength,
                eval_superpixel_dim=eval_superpixel_dim,
                eval_foil_size=eval_foil_size,
                take_attribution_abs=take_attribution_abs,
            )

            if normalize_similarity:
                method_name = "normalized"
            else:
                method_name = "unnormalized"
            method_name += f"_{explanation}_{attribution}"
            with open(
                os.path.join(
                    RESULT_PATH,
                    dataset,
                    encoder,
                    method_name,
                    f"{seed}",
                    eval_filename,
                ),
                "rb",
            ) as handle:
                outputs = pickle.load(handle)
            insertion_curve_dict, insertion_num_features = get_mean_curves(
                outputs, "insertion"
            )
            deletion_curve_dict, deletion_num_features = get_mean_curves(
                outputs, "deletion"
            )
            insertion_list.append(
                compute_auc(
                    curve=insertion_curve_dict[eval_name],
                    num_features=insertion_num_features,
                )
            )
            deletion_list.append(
                compute_auc(
                    curve=deletion_curve_dict[eval_name],
                    num_features=deletion_num_features,
                )
            )
        insertion_mean_list.append(np.mean(insertion_list))
        insertion_ci_list.append(1.96 * np.std(insertion_list) / np.sqrt(len(SEED_LIST)))
        deletion_mean_list.append(np.mean(deletion_list))
        deletion_ci_list.append(1.96 * np.std(deletion_list) / np.sqrt(len(SEED_LIST)))
    return {
        "insertion": {"mean": insertion_mean_list, "ci": insertion_ci_list},
        "deletion": {"mean": deletion_mean_list, "ci": deletion_ci_list},
    }


def get_formatted_aucs(
    insertion_direction: str,
    deletion_direction: str,
    bold_best: bool = True,
    **kwargs,
):
    auc_stats = get_auc_stats(**kwargs)
    
    insertion_mean_list = auc_stats["insertion"]["mean"]
    insertion_ci_list = auc_stats["insertion"]["ci"]
    if insertion_direction == "max":
        insertion_best_idx = np.argmax(insertion_mean_list)
    elif insertion_direction == "min":
        insertion_best_idx = np.argmin(insertion_mean_list)
    else:
        raise ValueError(
            f"insertion_direction={insertion_direction} should be max or min!"
        )
        
    deletion_mean_list = auc_stats["deletion"]["mean"]
    deletion_ci_list = auc_stats["deletion"]["ci"]
    if deletion_direction == "max":
        deletion_best_idx = np.argmax(deletion_mean_list)
    elif deletion_direction == "min":
        deletion_best_idx = np.argmin(deletion_mean_list)
    else:
        raise ValueError(
            f"deietion_direction={deietion_direction} should be max or min!"
        )
    
    text_list = []
    for i in range(len(insertion_mean_list)):
        insertion_mean = insertion_mean_list[i]
        insertion_ci = insertion_ci_list[i]
        if insertion_mean < 0.01:
            insertion_text = (
                "{:.2e}".format(insertion_mean)
                + " $\pm$ "
                + "{:.2e}".format(insertion_ci)
            )
        else:
            insertion_text = f"{insertion_mean:.3f} $\pm$ {insertion_ci:.3f}"
        if i == insertion_best_idx and bold_best:
            insertion_text = "\\textbf{" + insertion_text + "}"
            
        deletion_mean = deletion_mean_list[i]
        deletion_ci = deletion_ci_list[i]
        if deletion_mean < 0.01:
            deletion_text = (
                "{:.2e}".format(deletion_mean)
                + " $\pm$ "
                + "{:.2e}".format(deletion_ci)
            )
        else:
            deletion_text = f"{deletion_mean:.3f} $\pm$ {deletion_ci:.3f}"
        if i == deletion_best_idx and bold_best:
            deletion_text = "\\textbf{" + deletion_text + "}"
            
        text = insertion_text + " & " + deletion_text
        text_list.append(text)
    return text_list

In [4]:
def print_aucs(
    eval_name: str,
    normalize_similarity: bool,
    different_classes: bool,
    insertion_direction: str = "max",
    deletion_direction: str = "min",
    comprehensive: bool = False,
):
    attribution_list = ["int_grad", "gradient_shap", "rise"]
    dataset_encoder_combos = [
        ("imagenet", "simclr_x1"),
        ("cifar", "simsiam_18"),
        ("mura", "classifier_18"),
    ]
    for attribution in attribution_list:
        print(attribution)
        print("-" * len(attribution))
        label_free_text = "Label-Free"
        contrastive_label_free_text = "Contrastive Label-Free"
        corpus_text = "Corpus"
        cocoa_text = "COCOA"
        for dataset_encoder in dataset_encoder_combos:
            text_list = get_formatted_aucs(
                insertion_direction=insertion_direction,
                deletion_direction=deletion_direction,
                dataset=dataset_encoder[0],
                encoder=dataset_encoder[1],
                attribution=attribution,
                eval_name=eval_name,
                normalize_similarity=normalize_similarity,
                different_classes=different_classes,
                comprehensive=comprehensive,
            )
            label_free_text += f" & {text_list[0]}"
            contrastive_label_free_text += f" & {text_list[1]}"
            corpus_text += f" & {text_list[2]}"
            cocoa_text += f" & {text_list[3]}"
        print("& " + label_free_text + " \\\\")
        print("& " + contrastive_label_free_text + " \\\\")
        print("& " + corpus_text + " \\\\")
        print("& " + cocoa_text + " \\\\")
        print("")

    print("random")
    print("------")
    random_text = "None"
    for dataset_encoder in dataset_encoder_combos:
        text_list = get_formatted_aucs(
            insertion_direction=insertion_direction,
            deletion_direction=deletion_direction,
            bold_best=False,
            dataset=dataset_encoder[0],
            encoder=dataset_encoder[1],
            attribution="random_baseline",
            eval_name=eval_name,
            normalize_similarity=True,  # Does not matter for random baseline.
            different_classes=different_classes,
            comprehensive=comprehensive,
        )
        random_text += f" & {text_list[0]}"
    print("& " + random_text + " \\\\")

## Corpus Majority Probability (Cosine Similarity & Same Class)

In [5]:
print_aucs(
    eval_name="corpus_majority_prob",
    normalize_similarity=True,
    different_classes=False,
)

int_grad
--------
& Label-Free & 0.362 $\pm$ 0.005 & 0.136 $\pm$ 0.004 & \textbf{0.403 $\pm$ 0.010} & 0.249 $\pm$ 0.007 & 0.631 $\pm$ 0.040 & 0.513 $\pm$ 0.027 \\
& Contrastive Label-Free & 0.377 $\pm$ 0.005 & 0.125 $\pm$ 0.003 & 0.401 $\pm$ 0.011 & 0.243 $\pm$ 0.007 & 0.690 $\pm$ 0.019 & 0.453 $\pm$ 0.025 \\
& Corpus & 0.377 $\pm$ 0.005 & 0.147 $\pm$ 0.002 & 0.355 $\pm$ 0.014 & 0.249 $\pm$ 0.010 & 0.653 $\pm$ 0.014 & 0.500 $\pm$ 0.039 \\
& COCOA & \textbf{0.422 $\pm$ 0.006} & \textbf{0.119 $\pm$ 0.003} & 0.386 $\pm$ 0.012 & \textbf{0.230 $\pm$ 0.011} & \textbf{0.807 $\pm$ 0.013} & \textbf{0.330 $\pm$ 0.030} \\

gradient_shap
-------------
& Label-Free & 0.409 $\pm$ 0.004 & 0.131 $\pm$ 0.001 & 0.500 $\pm$ 0.008 & 0.244 $\pm$ 0.013 & 0.691 $\pm$ 0.038 & 0.523 $\pm$ 0.033 \\
& Contrastive Label-Free & 0.411 $\pm$ 0.003 & 0.127 $\pm$ 0.002 & 0.500 $\pm$ 0.009 & 0.238 $\pm$ 0.012 & 0.697 $\pm$ 0.037 & 0.510 $\pm$ 0.018 \\
& Corpus & 0.421 $\pm$ 0.006 & 0.136 $\pm$ 0.001 & 0.478 $\pm$ 0.008

## Contrastive Corpus Similarity (Cosine Similarity & Same Class)

In [6]:
print_aucs(
    eval_name="contrastive_corpus_cosine_similarity",
    normalize_similarity=True,
    different_classes=False,
    comprehensive=True,
)

int_grad
--------
& Label-Free & 0.154 $\pm$ 0.003 & 0.076 $\pm$ 0.001 & 0.037 $\pm$ 0.002 & 0.021 $\pm$ 0.002 & 0.046 $\pm$ 0.010 & 9.77e-03 $\pm$ 1.15e-02 \\
& Contrastive Label-Free & 0.161 $\pm$ 0.003 & 0.071 $\pm$ 0.002 & 0.036 $\pm$ 0.003 & 0.020 $\pm$ 0.002 & 0.058 $\pm$ 0.007 & -4.26e-03 $\pm$ 1.20e-02 \\
& Corpus & 0.157 $\pm$ 0.002 & 0.081 $\pm$ 0.002 & 0.033 $\pm$ 0.002 & 0.020 $\pm$ 0.002 & 0.051 $\pm$ 0.007 & 9.39e-03 $\pm$ 1.30e-02 \\
& COCOA & \textbf{0.172 $\pm$ 0.002} & \textbf{0.067 $\pm$ 0.002} & \textbf{0.037 $\pm$ 0.002} & \textbf{0.018 $\pm$ 0.002} & \textbf{0.091 $\pm$ 0.006} & \textbf{-3.62e-02 $\pm$ 1.26e-02} \\

gradient_shap
-------------
& Label-Free & 0.171 $\pm$ 0.004 & 0.067 $\pm$ 0.001 & 0.048 $\pm$ 0.002 & 0.019 $\pm$ 0.002 & 0.053 $\pm$ 0.013 & 0.017 $\pm$ 0.010 \\
& Contrastive Label-Free & 0.173 $\pm$ 0.004 & 0.064 $\pm$ 0.001 & 0.048 $\pm$ 0.002 & 0.019 $\pm$ 0.002 & 0.056 $\pm$ 0.013 & 0.011 $\pm$ 0.007 \\
& Corpus & 0.172 $\pm$ 0.004 & 0.071 $\pm$

## Corpus Majority Probability (Cosine Similarity & Different Classes)

In [7]:
print_aucs(
    eval_name="corpus_majority_prob",
    normalize_similarity=True,
    different_classes=True,
)

int_grad
--------
& Label-Free & 3.53e-04 $\pm$ 1.06e-04 & 4.44e-04 $\pm$ 1.01e-04 & 0.061 $\pm$ 0.004 & 0.079 $\pm$ 0.004 & 0.394 $\pm$ 0.028 & 0.481 $\pm$ 0.031 \\
& Contrastive Label-Free & 3.36e-04 $\pm$ 1.13e-04 & 4.44e-04 $\pm$ 1.00e-04 & 0.062 $\pm$ 0.004 & 0.082 $\pm$ 0.004 & 0.354 $\pm$ 0.022 & 0.518 $\pm$ 0.027 \\
& Corpus & 1.09e-03 $\pm$ 2.85e-04 & 2.26e-04 $\pm$ 5.31e-05 & 0.094 $\pm$ 0.003 & 0.066 $\pm$ 0.003 & 0.609 $\pm$ 0.017 & 0.262 $\pm$ 0.029 \\
& COCOA & \textbf{1.69e-03 $\pm$ 5.07e-04} & \textbf{1.55e-04 $\pm$ 2.21e-05} & \textbf{0.099 $\pm$ 0.004} & \textbf{0.059 $\pm$ 0.005} & \textbf{0.647 $\pm$ 0.017} & \textbf{0.213 $\pm$ 0.030} \\

gradient_shap
-------------
& Label-Free & 2.05e-04 $\pm$ 5.96e-05 & 5.55e-04 $\pm$ 1.11e-04 & 0.054 $\pm$ 0.004 & 0.080 $\pm$ 0.004 & 0.362 $\pm$ 0.031 & 0.469 $\pm$ 0.021 \\
& Contrastive Label-Free & 2.02e-04 $\pm$ 5.06e-05 & 5.10e-04 $\pm$ 9.65e-05 & 0.053 $\pm$ 0.002 & 0.080 $\pm$ 0.004 & 0.361 $\pm$ 0.022 & 0.477 $\pm$ 0.018

## Contrastive Corpus Similarity (Cosine Similarity & Different Classes)

In [8]:
print_aucs(
    eval_name="contrastive_corpus_cosine_similarity",
    normalize_similarity=True,
    different_classes=True,
    comprehensive=True,
)

int_grad
--------
& Label-Free & -1.10e-02 $\pm$ 1.15e-03 & -1.67e-02 $\pm$ 1.04e-03 & -5.34e-03 $\pm$ 1.46e-03 & -3.32e-03 $\pm$ 1.36e-03 & -1.98e-02 $\pm$ 8.69e-03 & -1.50e-04 $\pm$ 8.78e-03 \\
& Contrastive Label-Free & -1.06e-02 $\pm$ 1.13e-03 & -1.71e-02 $\pm$ 1.07e-03 & -5.37e-03 $\pm$ 1.50e-03 & -3.21e-03 $\pm$ 1.36e-03 & -3.47e-02 $\pm$ 5.42e-03 & 0.012 $\pm$ 0.009 \\
& Corpus & -4.94e-03 $\pm$ 1.38e-03 & -1.96e-02 $\pm$ 7.58e-04 & -2.94e-04 $\pm$ 1.78e-03 & -6.37e-03 $\pm$ 1.47e-03 & 0.040 $\pm$ 0.008 & -6.18e-02 $\pm$ 9.26e-03 \\
& COCOA & \textbf{5.24e-03 $\pm$ 1.50e-03} & \textbf{-2.84e-02 $\pm$ 9.65e-04} & \textbf{1.47e-03 $\pm$ 1.74e-03} & \textbf{-8.79e-03 $\pm$ 1.50e-03} & \textbf{0.048 $\pm$ 0.006} & \textbf{-7.57e-02 $\pm$ 8.75e-03} \\

gradient_shap
-------------
& Label-Free & -9.20e-03 $\pm$ 1.33e-03 & -1.65e-02 $\pm$ 1.13e-03 & -6.27e-03 $\pm$ 1.62e-03 & -3.13e-03 $\pm$ 1.57e-03 & -3.72e-02 $\pm$ 7.90e-03 & 3.50e-03 $\pm$ 1.03e-02 \\
& Contrastive Label-Free & -8.

## Corpus Majority Probability (Dot Product & Same Class)

In [9]:
print_aucs(
    eval_name="corpus_majority_prob",
    normalize_similarity=False,
    different_classes=False,
)

int_grad
--------
& Label-Free & 0.364 $\pm$ 0.004 & 0.125 $\pm$ 0.003 & 0.343 $\pm$ 0.014 & 0.237 $\pm$ 0.012 & 0.696 $\pm$ 0.020 & 0.451 $\pm$ 0.018 \\
& Contrastive Label-Free & 0.372 $\pm$ 0.005 & 0.119 $\pm$ 0.003 & 0.354 $\pm$ 0.017 & 0.234 $\pm$ 0.012 & 0.690 $\pm$ 0.022 & 0.452 $\pm$ 0.029 \\
& Corpus & 0.384 $\pm$ 0.004 & 0.126 $\pm$ 0.002 & 0.346 $\pm$ 0.012 & 0.240 $\pm$ 0.010 & 0.767 $\pm$ 0.010 & 0.392 $\pm$ 0.018 \\
& COCOA & \textbf{0.415 $\pm$ 0.006} & \textbf{0.113 $\pm$ 0.002} & \textbf{0.379 $\pm$ 0.012} & \textbf{0.222 $\pm$ 0.009} & \textbf{0.806 $\pm$ 0.012} & \textbf{0.325 $\pm$ 0.031} \\

gradient_shap
-------------
& Label-Free & 0.408 $\pm$ 0.004 & 0.130 $\pm$ 0.002 & 0.475 $\pm$ 0.011 & 0.235 $\pm$ 0.013 & 0.699 $\pm$ 0.040 & 0.515 $\pm$ 0.018 \\
& Contrastive Label-Free & 0.413 $\pm$ 0.003 & 0.126 $\pm$ 0.002 & 0.487 $\pm$ 0.011 & 0.229 $\pm$ 0.012 & 0.700 $\pm$ 0.041 & 0.510 $\pm$ 0.019 \\
& Corpus & 0.423 $\pm$ 0.004 & 0.130 $\pm$ 0.002 & 0.444 $\pm$ 0.011

## Contrastive Corpus Similarity (Dot Product & Same Class)

In [10]:
print_aucs(
    eval_name="contrastive_corpus_dot_product_similarity",
    normalize_similarity=False,
    different_classes=False,
    comprehensive=True,
)

int_grad
--------
& Label-Free & 34.890 $\pm$ 0.940 & 12.226 $\pm$ 0.397 & 5.90e-04 $\pm$ 2.13e-05 & 3.54e-04 $\pm$ 1.70e-05 & 26.227 $\pm$ 3.234 & 0.667 $\pm$ 4.312 \\
& Contrastive Label-Free & 35.551 $\pm$ 0.965 & 11.865 $\pm$ 0.399 & 6.10e-04 $\pm$ 2.44e-05 & 3.46e-04 $\pm$ 1.85e-05 & 25.672 $\pm$ 3.828 & 0.213 $\pm$ 5.267 \\
& Corpus & 35.390 $\pm$ 0.922 & 12.306 $\pm$ 0.413 & 5.94e-04 $\pm$ 1.81e-05 & 3.52e-04 $\pm$ 1.62e-05 & 33.620 $\pm$ 4.638 & -4.67e+00 $\pm$ 3.56e+00 \\
& COCOA & \textbf{36.824 $\pm$ 0.917} & \textbf{11.470 $\pm$ 0.429} & \textbf{6.68e-04 $\pm$ 1.60e-05} & \textbf{3.09e-04 $\pm$ 1.53e-05} & \textbf{36.534 $\pm$ 3.802} & \textbf{-1.09e+01 $\pm$ 4.78e+00} \\

gradient_shap
-------------
& Label-Free & 38.328 $\pm$ 1.479 & 11.689 $\pm$ 0.179 & 8.92e-04 $\pm$ 3.05e-05 & 3.37e-04 $\pm$ 1.85e-05 & 21.137 $\pm$ 3.693 & 5.585 $\pm$ 3.576 \\
& Contrastive Label-Free & 38.754 $\pm$ 1.497 & 11.383 $\pm$ 0.167 & 9.22e-04 $\pm$ 2.78e-05 & 3.22e-04 $\pm$ 2.15e-05 & 21.496

## Corpus Majority Probability (Dot Product & Different Classes)

In [11]:
print_aucs(
    eval_name="corpus_majority_prob",
    normalize_similarity=False,
    different_classes=True,
)

int_grad
--------
& Label-Free & 2.74e-04 $\pm$ 9.69e-05 & 5.42e-04 $\pm$ 1.12e-04 & 0.066 $\pm$ 0.003 & 0.078 $\pm$ 0.004 & 0.352 $\pm$ 0.024 & 0.522 $\pm$ 0.029 \\
& Contrastive Label-Free & 2.75e-04 $\pm$ 1.00e-04 & 5.00e-04 $\pm$ 1.07e-04 & 0.064 $\pm$ 0.004 & 0.078 $\pm$ 0.004 & 0.347 $\pm$ 0.024 & 0.529 $\pm$ 0.023 \\
& Corpus & 5.60e-04 $\pm$ 2.54e-04 & 4.13e-04 $\pm$ 8.00e-05 & 0.079 $\pm$ 0.004 & 0.070 $\pm$ 0.004 & 0.529 $\pm$ 0.018 & 0.381 $\pm$ 0.026 \\
& COCOA & \textbf{1.58e-03 $\pm$ 4.99e-04} & \textbf{1.64e-04 $\pm$ 2.97e-05} & \textbf{0.100 $\pm$ 0.005} & \textbf{0.060 $\pm$ 0.005} & \textbf{0.657 $\pm$ 0.017} & \textbf{0.210 $\pm$ 0.024} \\

gradient_shap
-------------
& Label-Free & 1.88e-04 $\pm$ 4.25e-05 & 6.01e-04 $\pm$ 1.13e-04 & 0.055 $\pm$ 0.005 & 0.082 $\pm$ 0.006 & 0.343 $\pm$ 0.016 & 0.488 $\pm$ 0.020 \\
& Contrastive Label-Free & 2.03e-04 $\pm$ 3.73e-05 & 5.30e-04 $\pm$ 7.78e-05 & 0.053 $\pm$ 0.004 & 0.082 $\pm$ 0.005 & 0.353 $\pm$ 0.019 & 0.478 $\pm$ 0.019

## Contrastive Corpus Similarity (Dot Product & Different Classes)

In [12]:
print_aucs(
    eval_name="contrastive_corpus_dot_product_similarity",
    normalize_similarity=False,
    different_classes=True,
    comprehensive=True,
)

int_grad
--------
& Label-Free & -1.89e+00 $\pm$ 1.86e-01 & -1.77e+00 $\pm$ 1.25e-01 & -7.55e-05 $\pm$ 4.52e-05 & -4.71e-05 $\pm$ 4.61e-05 & -1.33e+01 $\pm$ 2.46e+00 & 7.604 $\pm$ 3.695 \\
& Contrastive Label-Free & -1.78e+00 $\pm$ 1.77e-01 & -1.90e+00 $\pm$ 1.31e-01 & -7.93e-05 $\pm$ 4.80e-05 & -4.74e-05 $\pm$ 4.57e-05 & -1.22e+01 $\pm$ 1.89e+00 & 7.106 $\pm$ 3.249 \\
& Corpus & -8.71e-01 $\pm$ 1.87e-01 & -2.25e+00 $\pm$ 1.09e-01 & -3.32e-05 $\pm$ 4.68e-05 & -7.55e-05 $\pm$ 4.34e-05 & 5.922 $\pm$ 3.119 & -7.35e+00 $\pm$ 3.00e+00 \\
& COCOA & \textbf{1.525 $\pm$ 0.224} & \textbf{-4.54e+00 $\pm$ 1.24e-01} & \textbf{3.69e-05 $\pm$ 5.23e-05} & \textbf{-1.41e-04 $\pm$ 4.58e-05} & \textbf{19.365 $\pm$ 2.867} & \textbf{-2.48e+01 $\pm$ 2.37e+00} \\

gradient_shap
-------------
& Label-Free & -1.34e+00 $\pm$ 2.19e-01 & -1.99e+00 $\pm$ 9.53e-02 & -1.03e-04 $\pm$ 5.27e-05 & -4.11e-05 $\pm$ 4.69e-05 & -1.16e+01 $\pm$ 1.91e+00 & 3.784 $\pm$ 2.952 \\
& Contrastive Label-Free & -1.26e+00 $\pm$ 2.13e