In [1]:
import torch
import scipy.stats
from experiments.mnli import *
from run_experiments import NUM_RETRAINING_EXPERIMENTS
from tqdm import trange



In [2]:
from experiments import constants

In [3]:
def run_retraining_main_modified(
        mode: str,
        num_examples_to_test: int):
    
    influences_dicts = []

    if mode not in ["full", "KNN-1000", "KNN-10000", "random"]:
        raise ValueError(f"Unrecognized `mode` {mode}")

    for example_relative_index in range(num_examples_to_test):
        for correct_mode in ["correct", "incorrect"]:
            if correct_mode == "correct":
                example_index = CORRECT_INDICES[example_relative_index]
            if correct_mode == "incorrect":
                example_index = INCORRECT_INDICES[example_relative_index]

            if mode in ["full"]:
                # Load file from local or sync from remote
                file_name = os.path.join(
                    constants.MNLI_RETRAINING_INFLUENCE_OUTPUT_BASE_DIR,
                    f"KNN-recall.only-{correct_mode}.50.{example_index}"
                    f".pth.g0301.ll.unc.edu")

                influences_dict = torch.load(file_name)
                if example_index != influences_dict["test_index"]:
                    raise ValueError

                if (correct_mode == "correct" and
                        influences_dict["correct"] is not True or
                        correct_mode == "incorrect" and
                        influences_dict["correct"] is True):
                    raise ValueError
                
                influences_dicts.append(influences_dict["influences"])

#                 helpful_indices, harmful_indices = (
#                     misc_utils.get_helpful_harmful_indices_from_influences_dict(
#                         influences_dict["influences"]))

#                 indices_dict = {
#                     "helpful": helpful_indices,
#                     "harmful": harmful_indices}

            if mode in ["KNN-1000", "KNN-10000"]:
                if mode == "KNN-1000":
                    kNN_k = 1000
                if mode == "KNN-10000":
                    kNN_k = 10000

                file_name = os.path.join(
                    constants.MNLI_RETRAINING_INFLUENCE_OUTPUT_BASE_DIR2,
                    f"visualization"
                    f".only-{correct_mode}"
                    f".5.mnli-mnli-None-mnli"
                    f".{kNN_k}.True.pth.g0306.ll.unc.edu")

                influences_dict = torch.load(file_name)[example_relative_index]
                if example_index != influences_dict["index"]:
                    raise ValueError

                influences_dicts.append(influences_dict["influences"])

#                 helpful_indices, harmful_indices = (
#                     misc_utils.get_helpful_harmful_indices_from_influences_dict(
#                         influences_dict["influences"]))

#                 indices_dict = {
#                     "helpful": helpful_indices,
#                     "harmful": harmful_indices}
    return influences_dicts


def compute_pairwise_correlations(d1: Dict[str, float], d2: Dict[str, float]) -> Dict[str, float]:
    vals_1 = []
    vals_2 = []
    joint_keys = list(set(d1.keys()) & set(d2.keys()))
    for joint_k in joint_keys:
        vals_1.append(d1[joint_k])
        vals_2.append(d2[joint_k])

    pr = scipy.stats.pearsonr(vals_1, vals_2)[0]
    sr = scipy.stats.spearmanr(vals_1, vals_2)[0]
    kt = scipy.stats.kendalltau(vals_1, vals_2)[0]
    print(f"Pearson={pr:.5f} Spearman={sr:.5f} Kendall={kt:.5f}")
    return {"pr": pr, "sr": sr, "kt": kt, "size": len(joint_keys)}

In [4]:
influences_dicts_all = {
    "full": run_retraining_main_modified("full", NUM_RETRAINING_EXPERIMENTS),
    "KNN-1000": run_retraining_main_modified("KNN-1000", NUM_RETRAINING_EXPERIMENTS),
    "KNN-10000": run_retraining_main_modified("KNN-10000", NUM_RETRAINING_EXPERIMENTS),
}

In [5]:
for k in influences_dicts_all.keys():
    print(k, len(influences_dicts_all[k]))

full 6
KNN-1000 6
KNN-10000 6


In [6]:
correlations = {
    "knn1000_full_corr": [],
    "knn10000_full_corr": [],
    "knn1000_knn10000_corr": [],
}
for index in range(len(influences_dicts_all["full"])):
    correlations["knn1000_full_corr"].append(
        compute_pairwise_correlations(
            d1=influences_dicts_all["KNN-1000"][index],
            d2=influences_dicts_all["full"][index]))

    correlations["knn10000_full_corr"].append(
        compute_pairwise_correlations(
            d1=influences_dicts_all["KNN-10000"][index],
            d2=influences_dicts_all["full"][index]))

    correlations["knn1000_knn10000_corr"].append(
        compute_pairwise_correlations(
            d1=influences_dicts_all["KNN-1000"][index],
            d2=influences_dicts_all["KNN-10000"][index]))

Pearson=0.99968 Spearman=0.99969 Kendall=0.98631
Pearson=0.99903 Spearman=0.99929 Kendall=0.97830
Pearson=0.99961 Spearman=0.99959 Kendall=0.98427
Pearson=0.99931 Spearman=0.99931 Kendall=0.97946
Pearson=0.99873 Spearman=0.99886 Kendall=0.97553
Pearson=0.99787 Spearman=0.99794 Kendall=0.96331
Pearson=0.99988 Spearman=0.99970 Kendall=0.98775
Pearson=0.99990 Spearman=0.99994 Kendall=0.99397
Pearson=0.99986 Spearman=0.99971 Kendall=0.98830
Pearson=0.99823 Spearman=0.98878 Kendall=0.93757
Pearson=0.99989 Spearman=0.99910 Kendall=0.98274
Pearson=0.99817 Spearman=0.98993 Kendall=0.94083
Pearson=0.99925 Spearman=0.99916 Kendall=0.97694
Pearson=0.99918 Spearman=0.99906 Kendall=0.97771
Pearson=0.99929 Spearman=0.99938 Kendall=0.98043
Pearson=0.99939 Spearman=0.99878 Kendall=0.97566
Pearson=0.99910 Spearman=0.99710 Kendall=0.96882
Pearson=0.99976 Spearman=0.99947 Kendall=0.98334


In [7]:
for comparison in correlations.keys():
    print(f"{comparison:<25}", end=" ")
    if len(correlations[comparison]) != 6:
        raise ValueError
    for metric_name in ["pr", "sr", "kt"]:
        corr_values = [c[metric_name] * 100 for c in correlations[comparison]]
        corr_mean = np.mean(corr_values)
        corr_std = np.std(corr_values)
        print(f"{corr_mean:.1f}/{corr_std:.2f}", end=" ")
    
    size_mean = np.mean([c["size"] for c in correlations[comparison]])
    size_std = np.std([c["size"] for c in correlations[comparison]])
    print(f"{size_mean:.1f}/{size_std:.2f}", end=" ")
    
    print()

knn1000_full_corr         99.9/0.05 99.8/0.39 97.4/1.69 1000.0/0.00 
knn10000_full_corr        99.9/0.04 99.9/0.09 98.0/0.77 10000.0/0.00 
knn1000_knn10000_corr     99.9/0.08 99.8/0.35 97.3/1.66 1000.0/0.00 
