In [1]:
import sys
sys.path.append("../")
from os.path import join
from IPython.display import display

import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
matplotlib.rc('text', usetex=False)

import scipy
from scipy.stats.mstats import mquantiles
import faiss
import matplotlib.pyplot as plt
import joblib
import numpy as np
import pandas as pd
import torch
from tqdm.notebook import tqdm
import torchvision
from torchvision import transforms

from lolip.variables import auto_var

fontsize=18

In [2]:
dataset_info = {}
bins = 5

## MNIST

In [3]:
paths = []
datasets = []
ds_names = []
model_names = []
for i in [0, 4, 9]:
    datasets.append(f"mnistwo{i}")
    ds_names.append(f"MNIST-wo{i}")
    #model_names += [["natural", "TRADES(2)", "TRADES(4)", "TRADES(8)", "AT(2)"]]
    model_names += [["natural", "TRADES(2)", "AT(2)"]]
    paths +=[[
        f"cwl2-128-mnistwo{i}-70-1.0-0.01-ce-vtor2-CNN002-0.9-2-sgd-0-0.0.pkl",
        f"cwl2-128-mnistwo{i}-70-2.0-0.01-trades6ce-vtor2-CNN002-0.9-2-sgd-0-0.0.pkl",
        f"cwl2-128-mnistwo{i}-70-2.0-0.01-advce-vtor2-CNN002-0.9-2-sgd-0-0.0.pkl",
    ]]

In [4]:
for paths_i, dataset, ds_name, model_names_i in zip(paths, datasets, ds_names, model_names):
    _, trny, _, _, _ = auto_var.get_var_with_argument("dataset", dataset)
    dataset_info[ds_name] = trny
    
    for offset, (path, model_name) in enumerate(zip(paths_i, model_names_i)):
        res = joblib.load(join("./results/oos_repr/", path))
        ood_preds = np.concatenate((res['oos_trn_pred'].reshape(-1), res['oos_tst_pred'].reshape(-1)))

        nn_res = joblib.load(f"./out_of_sample/{dataset}.pkl")
        distances = np.concatenate((nn_res['miss_trn_knn_y_2_dist'].reshape(-1), nn_res['miss_tst_knn_y_2_dist'].reshape(-1)))
        nnidxs = np.concatenate((nn_res['miss_trn_knn_y_2_ind'].reshape(-1), nn_res['miss_tst_knn_y_2_ind'].reshape(-1)))
        
        accs = (ood_preds == trny[nnidxs])
        sort_idx = np.argsort(distances)

        xticks, xaxis, yaxis, yerrs = [], [], [], []
        for i in range(bins):
            idx = sort_idx[(i * len(sort_idx)) // bins: ((i+1) * len(sort_idx)) // bins]
            xticks.append("%.2f" % distances[idx].mean())
            xaxis.append(i + (offset-1)*0.2)
            yaxis.append(accs[idx].mean())
            yerrs.append(scipy.stats.sem(accs[idx]))
        plt.bar(xaxis, yaxis, yerr=yerrs, width=0.2, label=model_name)

    #plt.title(f"{ds_name}")
    plt.xticks(np.arange(bins), xticks, fontsize=fontsize)
    plt.yticks(fontsize=fontsize)
    plt.ylabel("NCG score", fontsize=fontsize)
    plt.xlabel("Avg. $\ell_2$ dist. to the closest training example", fontsize=fontsize)
    plt.legend(fontsize=fontsize)
    plt.tight_layout()
    #plt.show()
    plt.savefig(f"figs/ncg_binned_dists/{dataset}.png", bbox_inches='tight')
    plt.close()

## CIFAR

In [5]:
preds, nnidxs, dists = joblib.load("nb_results/cifar10-c.pkl")
res2 = joblib.load("nb_results/cifar100-c.pkl")
preds.update(res2[0])
nnidxs.update(res2[1])
dists.update(res2[2])

In [6]:
bins = 5

model_names = ["natural", "TRADES(2)", "AT(2)"]
ds_names = [f"cifar10wo{i}" for i in [0, 4, 9]] + [f"cifar100coarsewo{i}" for i in [0, 4, 9]]

for ds_name in ds_names:
    _, trny, _, _, _ = auto_var.get_var_with_argument("dataset", ds_name)
    for offset, model_name in enumerate(model_names):

        distances = dists[ds_name, model_name]['ncg']
        accs = (preds[ds_name, model_name]['ncg'].argmax(1) == trny[nnidxs[ds_name, model_name]['ncg']])
        sort_idx = np.argsort(distances)

        xticks, xaxis, yaxis, yerrs = [], [], [], []
        for i in range(bins):
            idx = sort_idx[(i * len(sort_idx)) // bins: ((i+1) * len(sort_idx)) // bins]
            xticks.append("%.2f" % distances[idx].mean())
            xaxis.append(i + (offset-1)*0.2)
            yaxis.append(accs[idx].mean())
            yerrs.append(scipy.stats.sem(accs[idx]))
        plt.bar(xaxis, yaxis, yerr=yerrs, width=0.2, label=model_name)

    #plt.title(f"{ds_name}")
    plt.xticks(np.arange(bins), xticks, fontsize=fontsize)
    plt.yticks(fontsize=fontsize)
    plt.ylabel("NCG score", fontsize=fontsize)
    plt.xlabel("Avg. $\ell_2$ dist. to the closest training example", fontsize=fontsize)
    #plt.show()
    plt.tight_layout()
    plt.savefig(f"figs/ncg_binned_dists/{ds_name}.png", bbox_inches='tight')
    #plt.legend(fontsize=fontsize)
    #plt.show()
    #plt.savefig(f"figs/ncg_binned_dists/{ds_name}_with_legend.png")
    plt.close()

## ImgNet100

In [23]:
dataset_info = {}

In [24]:
preds, nnidxs, dists = joblib.load("nb_results/imgnet.pkl")
bins = 5

model_names = ["natural", "TRADES(2)", "AT(2)"]
ds_names = [f"aug10-imgnet100wo{i}" for i in [0, 1, 2]]

for ds_name in ds_names:
    if ds_name in dataset_info:
        trny = dataset_info[ds_name]
    else:
        _, trny, _, _, _ = auto_var.get_var_with_argument("dataset", ds_name)
        dataset_info[ds_name] = trny
        
    for offset, model_name in enumerate(model_names):
    
        distances = dists[ds_name, model_name]['ncg']
        accs = (preds[ds_name, model_name]['ncg'].argmax(1) == trny[nnidxs[ds_name, model_name]['ncg']])
        sort_idx = np.argsort(distances)
        
        xticks, xaxis, yaxis, yerrs = [], [], [], []
        for i in range(bins):
            idx = sort_idx[(i * len(sort_idx)) // bins: ((i+1) * len(sort_idx)) // bins]
            xticks.append("%.2f" % distances[idx].mean())
            xaxis.append(i + (offset-1)*0.2)
            yaxis.append(accs[idx].mean())
            yerrs.append(scipy.stats.sem(accs[idx]))
        plt.bar(xaxis, yaxis, yerr=yerrs, width=0.2, label=model_name)

    #plt.title(f"{ds_name}")
    plt.xticks(np.arange(bins), xticks, fontsize=fontsize)
    plt.yticks(fontsize=fontsize)
    plt.ylabel("NCG score", fontsize=fontsize)
    plt.xlabel("Avg. $\ell_2$ dist. to the closest training example", fontsize=fontsize)
    plt.legend(fontsize=fontsize)
    plt.tight_layout()
    #plt.show()
    plt.savefig(f"figs/ncg_binned_dists/{ds_name}_with_legend.png", bbox_inches='tight')
    plt.close()

100%|██████████| 3919/3919 [00:25<00:00, 153.51it/s]
100%|██████████| 155/155 [00:03<00:00, 51.66it/s] 
100%|██████████| 41/41 [00:02<00:00, 18.04it/s]
100%|██████████| 2/2 [00:01<00:00,  1.04it/s]
100%|██████████| 3919/3919 [00:25<00:00, 152.37it/s]
100%|██████████| 155/155 [00:02<00:00, 58.89it/s] 
100%|██████████| 41/41 [00:01<00:00, 24.26it/s]
100%|██████████| 2/2 [00:01<00:00,  1.51it/s]
100%|██████████| 3919/3919 [00:25<00:00, 153.60it/s]
100%|██████████| 155/155 [00:02<00:00, 64.54it/s] 
100%|██████████| 41/41 [00:01<00:00, 24.72it/s]
100%|██████████| 2/2 [00:01<00:00,  1.46it/s]


# Feature space

In [25]:
paths = []
datasets = []
ds_names = []
model_names = []
for i in [0, 4, 9]:
    datasets.append(f"calcedrepr-mnistwo{i}-cwl2-128-mnistwo{i}-70-1.0-0.01-ce-vtor2-CNN002-0.9-2-sgd-0-0.0.pkl")
    ds_names.append(f"MNIST-wo{i}")
    #model_names += [["natural", "TRADES(2)", "TRADES(4)", "TRADES(8)", "AT(2)"]]
    model_names += [["natural", "TRADES(2)", "AT(2)"]]
    paths +=[[
        f"cwl2-256-calcedrepr-mnistwo{i}-cwl2-128-mnistwo{i}-70-1.0-0.01-ce-vtor2-CNN002-0.9-2-sgd-0-0.0.pkl-70-1.0-0.01-ce-vtor2-LargeMLP-0.9-2-sgd-0-0.0.pkl",
        f"cwl2-256-calcedrepr-mnistwo{i}-cwl2-128-mnistwo{i}-70-1.0-0.01-ce-vtor2-CNN002-0.9-2-sgd-0-0.0.pkl-70-2.0-0.01-trades6ce-vtor2-LargeMLP-0.9-2-sgd-0-0.0.pkl",
        #f"cwl2-256-calcedrepr-mnistwo{i}-cwl2-128-mnistwo{i}-70-1.0-0.01-ce-vtor2-CNN002-0.9-2-sgd-0-0.0.pkl-70-4.0-0.01-trades6ce-vtor2-LargeMLP-0.9-2-sgd-0-0.0.pkl",
        #f"cwl2-256-calcedrepr-mnistwo{i}-cwl2-128-mnistwo{i}-70-1.0-0.01-ce-vtor2-CNN002-0.9-2-sgd-0-0.0.pkl-70-8.0-0.01-trades6ce-vtor2-LargeMLP-0.9-2-sgd-0-0.0.pkl",
        f"cwl2-256-calcedrepr-mnistwo{i}-cwl2-128-mnistwo{i}-70-1.0-0.01-ce-vtor2-CNN002-0.9-2-sgd-0-0.0.pkl-70-2.0-0.01-advce-vtor2-LargeMLP-0.9-2-sgd-0-0.0.pkl",
    ]]
for i in [0, 4, 9]:
    datasets.append(f"calcedrepr-cifar10wo{i}-cwl2-64-cifar10wo{i}-70-1.0-0.01-ce-vtor2-WRN_40_10-0.0-2-adam-0-0.0.pkl")
    ds_names.append(f"CIFAR10-wo{i}")
    #model_names += [["natural", "TRADES(2)", "TRADES(4)", "TRADES(8)", "AT(1)"]]
    model_names += [["natural", "TRADES(2)", "AT(1)"]]
    paths +=[[
        f"cwl2-128-calcedrepr-cifar10wo{i}-cwl2-64-cifar10wo{i}-70-1.0-0.01-ce-vtor2-WRN_40_10-0.0-2-adam-0-0.0.pkl-70-1.0-0.01-ce-vtor2-LargeMLP-0.9-2-sgd-0-0.0.pkl",
        f"cwl2-128-calcedrepr-cifar10wo{i}-cwl2-64-cifar10wo{i}-70-1.0-0.01-ce-vtor2-WRN_40_10-0.0-2-adam-0-0.0.pkl-70-2.0-0.01-trades6ce-vtor2-LargeMLP-0.9-2-sgd-0-0.0.pkl",
        #f"cwl2-128-calcedrepr-cifar10wo{i}-cwl2-64-cifar10wo{i}-70-1.0-0.01-ce-vtor2-WRN_40_10-0.0-2-adam-0-0.0.pkl-70-4.0-0.01-trades6ce-vtor2-LargeMLP-0.9-2-sgd-0-0.0.pkl",
        #f"cwl2-128-calcedrepr-cifar10wo{i}-cwl2-64-cifar10wo{i}-70-1.0-0.01-ce-vtor2-WRN_40_10-0.0-2-adam-0-0.0.pkl-70-8.0-0.01-trades6ce-vtor2-LargeMLP-0.9-2-sgd-0-0.0.pkl",
        f"cwl2-128-calcedrepr-cifar10wo{i}-cwl2-64-cifar10wo{i}-70-1.0-0.01-ce-vtor2-WRN_40_10-0.0-2-adam-0-0.0.pkl-70-1.0-0.01-advce-vtor2-LargeMLP-0.9-2-sgd-0-0.0.pkl",
    ]]
    datasets.append(f"calcedrepr-cifar100coarsewo{i}-cwl2-64-cifar100coarsewo{i}-70-1.0-0.01-ce-vtor2-WRN_40_10-0.0-2-adam-0-0.0.pkl")
    ds_names.append(f"CIFAR100-wo{i}")
    #model_names.append(["natural", "TRADES(2)", "TRADES(4)", "TRADES(8)", "AT(1)"])
    model_names += [["natural", "TRADES(2)", "AT(1)"]]
    paths.append([
        f"cwl2-128-calcedrepr-cifar100coarsewo{i}-cwl2-64-cifar100coarsewo{i}-70-1.0-0.01-ce-vtor2-WRN_40_10-0.0-2-adam-0-0.0.pkl-70-1.0-0.01-ce-vtor2-LargeMLP-0.9-2-sgd-0-0.0.pkl",
        f"cwl2-128-calcedrepr-cifar100coarsewo{i}-cwl2-64-cifar100coarsewo{i}-70-1.0-0.01-ce-vtor2-WRN_40_10-0.0-2-adam-0-0.0.pkl-70-2.0-0.01-trades6ce-vtor2-LargeMLP-0.9-2-sgd-0-0.0.pkl",
        #f"cwl2-128-calcedrepr-cifar100coarsewo{i}-cwl2-64-cifar100coarsewo{i}-70-1.0-0.01-ce-vtor2-WRN_40_10-0.0-2-adam-0-0.0.pkl-70-4.0-0.01-trades6ce-vtor2-LargeMLP-0.9-2-sgd-0-0.0.pkl",
        #f"cwl2-128-calcedrepr-cifar100coarsewo{i}-cwl2-64-cifar100coarsewo{i}-70-1.0-0.01-ce-vtor2-WRN_40_10-0.0-2-adam-0-0.0.pkl-70-8.0-0.01-trades6ce-vtor2-LargeMLP-0.9-2-sgd-0-0.0.pkl",
        f"cwl2-128-calcedrepr-cifar100coarsewo{i}-cwl2-64-cifar100coarsewo{i}-70-1.0-0.01-ce-vtor2-WRN_40_10-0.0-2-adam-0-0.0.pkl-70-1.0-0.01-advce-vtor2-LargeMLP-0.9-2-sgd-0-0.0.pkl",
    ])
    
for i in [0, 1, 2]:
    datasets.append(f"calcedrepr-aug10-imgnet100wo{i}-cwl2-128-aug10-imgnet100wo{i}-70-1.0-0.01-ce-vtor2-ResNet50Norm01-0.0-2-adam-0-0.0.pkl")
    ds_names.append(f"ImgNet100-wo{i}")
    model_names += [["natural", "TRADES(2)", "AT(1)"]]
    paths.append([
        f"cwl2-128-calcedrepr-aug10-imgnet100wo{i}-cwl2-128-aug10-imgnet100wo{i}-70-1.0-0.01-ce-vtor2-ResNet50Norm01-0.0-2-adam-0-0.0.pkl-70-1.0-0.01-ce-vtor2-LargeMLPv3-0.9-2-sgd-0-0.0.pkl",
        f"cwl2-128-calcedrepr-aug10-imgnet100wo{i}-cwl2-128-aug10-imgnet100wo{i}-70-1.0-0.01-ce-vtor2-ResNet50Norm01-0.0-2-adam-0-0.0.pkl-70-2.0-0.01-trades6ce-vtor2-LargeMLPv3-0.9-2-sgd-0-0.0.pkl",
        f"cwl2-128-calcedrepr-aug10-imgnet100wo{i}-cwl2-128-aug10-imgnet100wo{i}-70-1.0-0.01-ce-vtor2-ResNet50Norm01-0.0-2-adam-0-0.0.pkl-70-4.0-0.01-trades6ce-vtor2-LargeMLPv3-0.9-2-sgd-0-0.0.pkl",
        f"cwl2-128-calcedrepr-aug10-imgnet100wo{i}-cwl2-128-aug10-imgnet100wo{i}-70-1.0-0.01-ce-vtor2-ResNet50Norm01-0.0-2-adam-0-0.0.pkl-70-8.0-0.01-trades6ce-vtor2-LargeMLPv3-0.9-2-sgd-0-0.0.pkl",
        f"cwl2-128-calcedrepr-aug10-imgnet100wo{i}-cwl2-128-aug10-imgnet100wo{i}-70-1.0-0.01-ce-vtor2-ResNet50Norm01-0.0-2-adam-0-0.0.pkl-70-1.0-0.01-advce-vtor2-LargeMLPv3-0.9-2-sgd-0-0.0.pkl",
    ])

In [26]:
for paths_i, dataset, ds_name, model_names_i in zip(paths, datasets, ds_names, model_names):
    _, trny, _, _, _ = auto_var.get_var_with_argument("dataset", dataset)
    dataset_info[ds_name] = trny
    
    for offset, (path, model_name) in enumerate(zip(paths_i, model_names_i)):
        res = joblib.load(join("./results/out_of_sample/", path))
        ood_preds = np.concatenate((res['oos_trn_pred'].reshape(-1), res['oos_tst_pred'].reshape(-1)))

        nn_res = joblib.load(f"./out_of_sample/{dataset}.pkl")
        distances = np.concatenate((nn_res['miss_trn_knn_y_2_dist'].reshape(-1), nn_res['miss_tst_knn_y_2_dist'].reshape(-1)))
        nnidxs = np.concatenate((nn_res['miss_trn_knn_y_2_ind'].reshape(-1), nn_res['miss_tst_knn_y_2_ind'].reshape(-1)))
        
        accs = (ood_preds == trny[nnidxs])
        sort_idx = np.argsort(distances)

        xticks, xaxis, yaxis, yerrs = [], [], [], []
        for i in range(bins):
            idx = sort_idx[(i * len(sort_idx)) // bins: ((i+1) * len(sort_idx)) // bins]
            xticks.append("%.2f" % distances[idx].mean())
            xaxis.append(i + (offset-1)*0.2)
            yaxis.append(accs[idx].mean())
            yerrs.append(scipy.stats.sem(accs[idx]))
        plt.bar(xaxis, yaxis, yerr=yerrs, width=0.2, label=model_name)

    #plt.title(f"{ds_name}")
    plt.xticks(np.arange(bins), xticks, fontsize=fontsize)
    plt.yticks(fontsize=fontsize)
    plt.ylabel("NCG score", fontsize=fontsize)
    plt.xlabel("Avg. $\ell_2$ dist. to the closest training example", fontsize=fontsize)
    plt.legend(fontsize=fontsize)
    plt.tight_layout()
    #plt.show()
    plt.savefig(f"figs/ncg_binned_dists/feature_{ds_name}.png", bbox_inches='tight')
    plt.close()

100%|██████████| 3919/3919 [00:25<00:00, 154.68it/s]
100%|██████████| 155/155 [00:02<00:00, 64.22it/s] 
100%|██████████| 41/41 [00:01<00:00, 26.71it/s]
100%|██████████| 2/2 [00:01<00:00,  1.53it/s]
100%|██████████| 3919/3919 [00:25<00:00, 151.92it/s]
100%|██████████| 155/155 [00:02<00:00, 65.07it/s] 
100%|██████████| 41/41 [00:01<00:00, 24.62it/s]
100%|██████████| 2/2 [00:01<00:00,  1.51it/s]
100%|██████████| 3919/3919 [00:25<00:00, 151.08it/s]
100%|██████████| 155/155 [00:02<00:00, 65.38it/s] 
100%|██████████| 41/41 [00:01<00:00, 24.75it/s]
100%|██████████| 2/2 [00:01<00:00,  1.45it/s]
