In [None]:
import sys
sys.path.append("../src")
import pickle
import scipy
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
from data import preprocess_dataset

In [None]:
def store_mmap(dataset_name, model_name):
    dataset, idxs = preprocess_dataset(dataset_name)
    train_prob = defaultdict(list)
    test_prob = defaultdict(list)
    memorised = defaultdict(list)

    data = pickle.load(open(
        f"../checkpoints/{dataset_name}/mmap_{dataset_name}_{model_name}.pickle", 'rb'))

    for i in range(len(data) - 1):
        train_idx, test_idx, train_logits, test_logits = data[i]
        for x, y in zip(train_idx, train_logits):
            memorised[x].append(dataset["train"][x]["labels"] == np.argmax(y))
        for x, y in zip(train_idx, train_logits):
            train_prob[x].append(scipy.special.softmax(y)[dataset["train"][x]["labels"]])
        for x, y in zip(test_idx, test_logits):
            test_prob[x].append(scipy.special.softmax(y)[dataset["train"][x]["labels"]])

    # post process into CM scores
    x, y, hue = [], [], []
    for k in train_prob:
        if k in test_prob:
            hue_label = None
            for key in idxs:
                if k in idxs[key]:
                    hue_label = key
            if hue_label is None:
                hue_label = "other"

            x.append(np.mean(train_prob[k]))
            y.append(np.mean(test_prob[k]))
            hue.append(hue_label)

    cm = [max(0, round(x_ - y_, 1)) for y_, x_ in zip(y, x)]
    # visualise full mem-map
    sns.set_context("talk")
    ax = sns.jointplot(x=x, y=y, alpha=1, hue=cm, palette="Spectral")
    sns.lineplot(x=[0, 1], y=[0, 1], color="black", linestyle='--')
    han, lab = ax.ax_joint.get_legend_handles_labels()
    ax.ax_joint.legend(han, lab, title="counterfactual\nmemorisation",
                       frameon=False, fontsize=13, title_fontsize=13)
    plt.xlim(0, 1)
    plt.ylim(0, 1)
    plt.xlabel("training memorisation")
    plt.ylabel("generalisation score")
    plt.show()

    # visual mem map using the category labels
    ax = sns.jointplot(x=x, y=y, hue=hue, palette="viridis",
                       joint_kws={'style': hue,
                                  'alpha':[0.1 if label == "other" else 1 for label in hue]},
                       marginal_kws={'common_norm':False})
    sns.lineplot(x=[0, 1], y=[0, 1], color="black", linestyle='--')
    han, lab = ax.ax_joint.get_legend_handles_labels()
    ax.ax_joint.legend(han, lab,
                       frameon=False, fontsize=13, title_fontsize=13)
    plt.xlim(0, 1)
    plt.ylim(0, 1)
    plt.xlabel("training memorisation")
    plt.ylabel("generalisation score")
    plt.show()

    # Store generalisation scores to file
    pickle.dump(
        {k: np.mean(v) for k, v in test_prob.items()},
        open(f"../checkpoints/{dataset_name}/genscore_{dataset_name}_{model_name}.pickle", 'wb')
    )

In [None]:
models = ["EleutherAI_pythia-160m-deduped", "EleutherAI_gpt-neo-125m", "bert-base-cased", "facebook_opt-125m"]
datasets = ["wic", "rte", "mrpc", "cola", "boolq", "sst2", "sst5",
            "emotion",  "implicithate", "stormfront", "reuters", "trec"]

for model in models:
    for dataset in datasets:
        print(model, dataset)
        store_mmap(dataset, model)