In [None]:
import os
import pickle
import torch

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from typing import Tuple, List, Dict

from cl_explain.metrics.ablation import compute_auc

In [None]:
RANDOMIZED_RESULT_PATH = "/homes/gws/clin25/cl-explainability/results"
ORIGINAL_RESULT_PATH = "/projects/leelab/cl-explainability/results"
SUPERPIXEL_ATTRIBUTION_METHODS = ["kernel_shap"]
SEED_LIST = [123, 456, 789, 42, 91]

In [None]:
def get_eval_filename(
    different_classes: bool,
    comprehensive: bool,
    corpus_size: int,
    explanation_name: str,
    foil_size: int,
    explicand_size: int,
    attribution_name: str,
    superpixel_dim: int,
    removal: str,
    blur_strength: float,
    eval_superpixel_dim: int,
    eval_foil_size: int,
    take_attribution_abs: bool,
) -> str:
    """Get eval filename."""
    if different_classes:
        eval_filename = "diff_class"
    else:
        eval_filename = "same_class"
    if comprehensive:
        eval_filename += "_comprehensive"
        
    eval_filename += "_eval_results"
    eval_filename += f"_explicand_size={explicand_size}"
    if "corpus" in explanation_name:
        eval_filename += f"_corpus_size={corpus_size}"
    if "contrastive" in explanation_name:
        eval_filename += f"_foil_size={foil_size}"
    if attribution_name in SUPERPIXEL_ATTRIBUTION_METHODS:
        eval_filename += f"_superpixel_dim={superpixel_dim}"
    eval_filename += f"_removal={removal}"
    if removal == "blurring":
        eval_filename += f"_blur_strength={blur_strength:.1f}"
    eval_filename += f"_eval_superpixel_dim={eval_superpixel_dim}"
    if not comprehensive:
        eval_filename += f"_eval_foil_size={eval_foil_size}"
    if take_attribution_abs:
        eval_filename += "_abs"
    eval_filename += ".pkl"
    return eval_filename


def get_mean_curves(outputs, curve_kind) -> Tuple[List[torch.Tensor], int]:
    available_curve_kinds = ["insertion", "deletion"]
    assert curve_kind in available_curve_kinds, (
        f"curve_kind={curve_kind} is not one of {available_curve_kinds}!"
    )
    target_list = [key for key in outputs.keys()]
    eval_name_list = (
        outputs[target_list[0]]["eval_model_names"]
        + outputs[target_list[0]]["eval_measure_names"]
    )
    eval_mean_curve_dict = {}
    for j, eval_name in enumerate(eval_name_list):
        
        curve_list = []
        num_features = None

        for target, output in outputs.items():
            target_curve_list = (
                output[f"model_{curve_kind}_curves"]
                + output[f"measure_{curve_kind}_curves"]
            )
            curve_list.append(target_curve_list[j])
            num_features = output[f"{curve_kind}_num_features"]
        
        curves = torch.cat(curve_list)
        mean_curve = curves.mean(dim=0).cpu()
        eval_mean_curve_dict[eval_name] = mean_curve
        
    return eval_mean_curve_dict, num_features


def get_auc_stats(
    dataset: str,
    encoder: str,
    attribution: str,
    eval_name: str,
    normalize_similarity: bool,
    different_classes: bool,
    comprehensive: bool = False,
    explicand_size: int = 25,
    removal: str = "blurring",
    blur_strength: float = 5.0,
    superpixel_dim: int = 1,
    eval_superpixel_dim: int = 1,
    foil_size: int = 1500,
    corpus_size: int = 100,
    eval_foil_size: int = 1500,
    take_attribution_abs: bool = False,
) -> Dict[str, Dict[str, List]]:
    if attribution == "random_baseline":
        explanation_list = ["self_weighted"]
    else:
        explanation_list = [
            "contrastive_corpus",
            "randomized_model_contrastive_corpus"
        ]
        
    insertion_mean_list = []
    insertion_ci_list = []
    deletion_mean_list = []
    deletion_ci_list = []

    for explanation in explanation_list:
        if explanation.startswith("randomized_model_"):
            explanation_name = explanation.replace("randomized_model_", "")
            randomize_model = True
            result_path = RANDOMIZED_RESULT_PATH
        else:
            explanation_name = explanation
            randomize_model = False
            result_path = ORIGINAL_RESULT_PATH
            
        insertion_list = []
        deletion_list = []
        for seed in SEED_LIST:            
            eval_filename = get_eval_filename(
                different_classes=different_classes,
                comprehensive=comprehensive,
                corpus_size=corpus_size,
                explanation_name=explanation_name,
                foil_size=foil_size,
                explicand_size=explicand_size,
                attribution_name=attribution,
                superpixel_dim=superpixel_dim,
                removal=removal,
                blur_strength=blur_strength,
                eval_superpixel_dim=eval_superpixel_dim,
                eval_foil_size=eval_foil_size,
                take_attribution_abs=take_attribution_abs,
            )

            if normalize_similarity:
                method_name = "normalized"
            else:
                method_name = "unnormalized"
            if randomize_model:
                method_name = "randomized_model_" + method_name
            method_name += f"_{explanation_name}_{attribution}"
            with open(
                os.path.join(
                    result_path,
                    dataset,
                    encoder,
                    method_name,
                    f"{seed}",
                    eval_filename,
                ),
                "rb",
            ) as handle:
                outputs = pickle.load(handle)
            insertion_curve_dict, insertion_num_features = get_mean_curves(
                outputs, "insertion"
            )
            deletion_curve_dict, deletion_num_features = get_mean_curves(
                outputs, "deletion"
            )
            insertion_list.append(
                compute_auc(
                    curve=insertion_curve_dict[eval_name],
                    num_features=insertion_num_features,
                )
            )
            deletion_list.append(
                compute_auc(
                    curve=deletion_curve_dict[eval_name],
                    num_features=deletion_num_features,
                )
            )
        insertion_mean_list.append(np.mean(insertion_list))
        insertion_ci_list.append(1.96 * np.std(insertion_list) / np.sqrt(len(SEED_LIST)))
        deletion_mean_list.append(np.mean(deletion_list))
        deletion_ci_list.append(1.96 * np.std(deletion_list) / np.sqrt(len(SEED_LIST)))
    return {
        "insertion": {"mean": insertion_mean_list, "ci": insertion_ci_list},
        "deletion": {"mean": deletion_mean_list, "ci": deletion_ci_list},
    }


def get_formatted_aucs(
    insertion_direction: str,
    deletion_direction: str,
    bold_best: bool = True,
    **kwargs,
):
    auc_stats = get_auc_stats(**kwargs)
    
    insertion_mean_list = auc_stats["insertion"]["mean"]
    insertion_ci_list = auc_stats["insertion"]["ci"]
    if insertion_direction == "max":
        insertion_best_idx = np.argmax(insertion_mean_list)
    elif insertion_direction == "min":
        insertion_best_idx = np.argmin(insertion_mean_list)
    else:
        raise ValueError(
            f"insertion_direction={insertion_direction} should be max or min!"
        )
        
    deletion_mean_list = auc_stats["deletion"]["mean"]
    deletion_ci_list = auc_stats["deletion"]["ci"]
    if deletion_direction == "max":
        deletion_best_idx = np.argmax(deletion_mean_list)
    elif deletion_direction == "min":
        deletion_best_idx = np.argmin(deletion_mean_list)
    else:
        raise ValueError(
            f"deietion_direction={deietion_direction} should be max or min!"
        )
    
    text_list = []
    for i in range(len(insertion_mean_list)):
        insertion_mean = insertion_mean_list[i]
        insertion_ci = insertion_ci_list[i]
        if np.abs(insertion_mean) < 0.01:
            insertion_text = (
                "{:.2e}".format(insertion_mean)
                + " $\pm$ "
                + "{:.2e}".format(insertion_ci)
            )
        else:
            insertion_text = f"{insertion_mean:.3f} ({insertion_ci:.3f})"
        if i == insertion_best_idx and bold_best:
            insertion_text = "\\textbf{" + insertion_text + "}"
            
        deletion_mean = deletion_mean_list[i]
        deletion_ci = deletion_ci_list[i]
        if np.abs(deletion_mean) < 0.01:
            deletion_text = (
                "{:.2e}".format(deletion_mean)
                + " $\pm$ "
                + "{:.2e}".format(deletion_ci)
            )
        else:
            deletion_text = f"{deletion_mean:.3f} ({deletion_ci:.3f})"
        if i == deletion_best_idx and bold_best:
            deletion_text = "\\textbf{" + deletion_text + "}"
            
        text = insertion_text + " & " + deletion_text
        text_list.append(text)
    return text_list

In [None]:
def print_aucs(
    eval_name: str,
    normalize_similarity: bool,
    different_classes: bool,
    insertion_direction: str = "max",
    deletion_direction: str = "min",
    comprehensive: bool = True,
):
    attribution_list = ["int_grad", "gradient_shap", "rise"]
    dataset_encoder_combos = [
        ("imagenet", "simclr_x1"),
        ("cifar", "simsiam_18"),
        ("mura", "classifier_18"),
    ]
    for attribution in attribution_list:
        print(attribution)
        print("-" * len(attribution))
        original_model_cocoa_text = "COCOA (trained model)"
        randomized_model_cocoa_text = "COCOA (randomized model)"
        for dataset_encoder in dataset_encoder_combos:
            text_list = get_formatted_aucs(
                insertion_direction=insertion_direction,
                deletion_direction=deletion_direction,
                dataset=dataset_encoder[0],
                encoder=dataset_encoder[1],
                attribution=attribution,
                eval_name=eval_name,
                normalize_similarity=normalize_similarity,
                different_classes=different_classes,
                comprehensive=comprehensive,
            )
            original_model_cocoa_text += f" & {text_list[0]}"
            randomized_model_cocoa_text += f" & {text_list[1]}"
        print(original_model_cocoa_text + " \\\\")
        print(randomized_model_cocoa_text + "\\\\")
        print("")

## Randomized Model Corpus Majority Probability (Cosine Similarity & Same Class)

In [None]:
print_aucs(
    eval_name="corpus_majority_prob",
    normalize_similarity=True,
    different_classes=False,
)

## Randomized Model Corpus Majority Probability (Cosine Similarity & Different Classes)

In [None]:
print_aucs(
    eval_name="corpus_majority_prob",
    normalize_similarity=True,
    different_classes=True,
)

## Randomized Model Contrastive Corpus Similarity (Cosine Similarity & Same Class)

In [None]:
print_aucs(
    eval_name="contrastive_corpus_cosine_similarity",
    normalize_similarity=True,
    different_classes=False,
    comprehensive=True,
)

## Randomized Model Contrastive Corpus Similarity (Cosine Similarity & Different Classes)

In [None]:
print_aucs(
    eval_name="contrastive_corpus_cosine_similarity",
    normalize_similarity=True,
    different_classes=True,
    comprehensive=True,
)