In [2]:
import sys
sys.path.append("../")
from os.path import join
import os

import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
matplotlib.rc('text', usetex=False)

import scipy
import faiss
import matplotlib.pyplot as plt
import joblib
import numpy as np
import pandas as pd
import torch
from tqdm.notebook import tqdm
import torchvision
from torchvision import transforms
import torch.nn.functional as F

from lolip.models.torch_utils import archs
from lolip.variables import auto_var

# Pixel

In [8]:
datasets = [f"mnistwo{i}" for i in range(10)] + [f"cifar10wo{i}" for i in range(10)] + [f"cifar100coarsewo{i}" for i in range(10)] + [f"aug10-imgnet100wo{i}" for i in range(10)]
paths = [f"./results/oos_repr/cwl2-128-mnistwo{i}-70-1.0-0.01-ce-vtor2-CNN002-0.9-2-sgd-0-0.0.pkl" for i in range(10)]
paths += [f"./results/oos_repr/cwl2-64-cifar10wo{i}-70-1.0-0.01-ce-vtor2-WRN_40_10-0.0-2-adam-0-0.0.pkl" for i in range(10)]
paths += [f"./results/oos_repr/cwl2-64-cifar100coarsewo{i}-70-1.0-0.01-ce-vtor2-WRN_40_10-0.0-2-adam-0-0.0.pkl" for i in range(10)]
paths += [f"./results/oos_repr/cwl2-128-aug10-imgnet100wo{i}-70-1.0-0.01-ce-vtor2-ResNet50Norm01-0.0-2-adam-0-0.0.pkl" for i in range(10)]
n_classes = [9] * 10 + [9] * 10 + [19] * 10 + [99] * 10

test_data = {}
data = {}
for dset, path, n_class in zip(datasets, paths, n_classes):
    if not os.path.exists(path):
        continue
    try:
        res = joblib.load(path)
    except:
        print(f"problem with {path}")
        continue
    nn_preds = joblib.load(f"./out_of_sample/{dset}.pkl")
    nn_preds = np.concatenate((nn_preds['miss_trn_knn_y_2'].reshape(-1), nn_preds['miss_tst_knn_y_2'].reshape(-1)))
    ood_preds = np.concatenate((res['oos_trn_pred'].reshape(-1), res['oos_tst_pred'].reshape(-1)))
    ncg_acc = (ood_preds == nn_preds)
    test_data[dset] = scipy.stats.chisquare(np.bincount(ncg_acc), [len(ncg_acc) * (n_class - 1) / n_class, len(ncg_acc) / n_class])
    data[dset] = ncg_acc.mean()

In [None]:
datasets = [f"mnistwo{i}" for i in range(10)] + [f"cifar10wo{i}" for i in range(10)] + [f"cifar100coarsewo{i}" for i in range(10)] + [f"aug10-imgnet100wo{i}" for i in range(10)]
paths = [f"./results/oos_repr/cwl2-128-mnistwo{i}-70-1.0-0.01-ce-vtor2-CNN002-0.9-2-sgd-0-0.0.pkl" for i in range(10)]
paths += [f"./results/oos_repr/cwl2-64-cifar10wo{i}-70-1.0-0.01-ce-vtor2-WRN_40_10-0.0-2-adam-0-0.0.pkl" for i in range(10)]
paths += [f"./results/oos_repr/cwl2-64-cifar100coarsewo{i}-70-1.0-0.01-ce-vtor2-WRN_40_10-0.0-2-adam-0-0.0.pkl" for i in range(10)]
paths += [f"./results/oos_repr/cwl2-128-aug10-imgnet100wo{i}-70-1.0-0.01-ce-vtor2-ResNet50Norm01-0.0-2-adam-0-0.0.pkl" for i in range(10)]
n_classes = [9] * 10 + [9] * 10 + [19] * 10 + [99] * 10

test_data = {}
data = {}
for dset, path, n_class in zip(datasets, paths, n_classes):
    if not os.path.exists(path):
        continue
    try:
        res = joblib.load(path)
    except:
        print(f"problem with {path}")
        continue
    nn_preds = joblib.load(f"./out_of_sample/{dset}.pkl")
    nn_preds = np.concatenate((nn_preds['miss_trn_knn_y_2'].reshape(-1), nn_preds['miss_tst_knn_y_2'].reshape(-1)))
    ood_preds = np.concatenate((res['oos_trn_pred'].reshape(-1), res['oos_tst_pred'].reshape(-1)))
    ncg_acc = (ood_preds == nn_preds)
    test_data[dset] = scipy.stats.chisquare(np.bincount(ncg_acc), [len(ncg_acc) * (n_class - 1) / n_class, len(ncg_acc) / n_class])
    data[dset] = ncg_acc.mean()

# Feature

In [8]:
datasets = [
    f"calcedrepr-mnistwo{i}-cwl2-128-mnistwo{i}-70-1.0-0.01-ce-vtor2-CNN002-0.9-2-sgd-0-0.0.pkl" for i in range(10)
] + [f"calcedrepr-cifar10wo{i}-cwl2-64-cifar10wo{i}-70-1.0-0.01-ce-vtor2-WRN_40_10-0.0-2-adam-0-0.0.pkl" for i in range(10)
] + [f"calcedrepr-cifar100coarsewo{i}-cwl2-64-cifar100coarsewo{i}-70-1.0-0.01-ce-vtor2-WRN_40_10-0.0-2-adam-0-0.0.pkl" for i in range(10)
] + [f"calcedrepr-aug10-imgnet100wo{i}-cwl2-128-aug10-imgnet100wo{i}-70-1.0-0.01-ce-vtor2-ResNet50Norm01-0.0-2-adam-0-0.0.pkl" for i in range(10)]

paths = [f"./results/out_of_sample/cwl2-256-calcedrepr-mnistwo{i}-cwl2-128-mnistwo{i}-70-1.0-0.01-ce-vtor2-CNN002-0.9-2-sgd-0-0.0.pkl-70-1.0-0.01-ce-vtor2-LargeMLP-0.9-2-sgd-0-0.0.pkl" for i in range(10)]
paths += [f"./results/out_of_sample/cwl2-128-calcedrepr-cifar10wo{i}-cwl2-64-cifar10wo{i}-70-1.0-0.01-ce-vtor2-WRN_40_10-0.0-2-adam-0-0.0.pkl-70-1.0-0.01-ce-vtor2-LargeMLP-0.9-2-sgd-0-0.0.pkl" for i in range(10)]
paths += [f"./results/out_of_sample/cwl2-128-calcedrepr-cifar100coarsewo{i}-cwl2-64-cifar100coarsewo{i}-70-1.0-0.01-ce-vtor2-WRN_40_10-0.0-2-adam-0-0.0.pkl-70-1.0-0.01-ce-vtor2-LargeMLP-0.9-2-sgd-0-0.0.pkl" for i in range(10)]
paths += [f"./results/out_of_sample/cwl2-128-calcedrepr-aug10-imgnet100wo{i}-cwl2-128-aug10-imgnet100wo{i}-70-1.0-0.01-ce-vtor2-ResNet50Norm01-0.0-2-adam-0-0.0.pkl-70-1.0-0.01-ce-vtor2-LargeMLPv4-0.9-2-sgd-0-0.0.pkl" for i in range(10)]
n_classes = [9] * 10 + [9] * 10 + [19] * 10 + [99] * 10

data = {}
for dset, path, n_class in zip(datasets, paths, n_classes):
    if not os.path.exists(path):
        continue
    try:
        res = joblib.load(path)
    except:
        print(f"problem with {path}")
        continue
    nn_preds = joblib.load(f"./out_of_sample/{dset}.pkl")
    nn_preds = np.concatenate((nn_preds['miss_trn_knn_y_2'].reshape(-1), nn_preds['miss_tst_knn_y_2'].reshape(-1)))
    ood_preds = np.concatenate((res['oos_trn_pred'].reshape(-1), res['oos_tst_pred'].reshape(-1)))
    ncg_acc = (ood_preds == nn_preds)
    data[dset] = scipy.stats.chisquare(np.bincount(ncg_acc), [len(ncg_acc) * (n_class - 1) / n_class, len(ncg_acc) / n_class])
    

In [9]:
data

{'calcedrepr-mnistwo0-cwl2-128-mnistwo0-70-1.0-0.01-ce-vtor2-CNN002-0.9-2-sgd-0-0.0.pkl': Power_divergenceResult(statistic=1902.8502281616688, pvalue=0.0),
 'calcedrepr-mnistwo1-cwl2-128-mnistwo1-70-1.0-0.01-ce-vtor2-CNN002-0.9-2-sgd-0-0.0.pkl': Power_divergenceResult(statistic=0.029341754475053956, pvalue=0.8639923225491427),
 'calcedrepr-mnistwo2-cwl2-128-mnistwo2-70-1.0-0.01-ce-vtor2-CNN002-0.9-2-sgd-0-0.0.pkl': Power_divergenceResult(statistic=6506.0349785407725, pvalue=0.0),
 'calcedrepr-mnistwo3-cwl2-128-mnistwo3-70-1.0-0.01-ce-vtor2-CNN002-0.9-2-sgd-0-0.0.pkl': Power_divergenceResult(statistic=23597.26676936003, pvalue=0.0),
 'calcedrepr-mnistwo4-cwl2-128-mnistwo4-70-1.0-0.01-ce-vtor2-CNN002-0.9-2-sgd-0-0.0.pkl': Power_divergenceResult(statistic=30943.951513042197, pvalue=0.0),
 'calcedrepr-mnistwo5-cwl2-128-mnistwo5-70-1.0-0.01-ce-vtor2-CNN002-0.9-2-sgd-0-0.0.pkl': Power_divergenceResult(statistic=15909.546095358786, pvalue=0.0),
 'calcedrepr-mnistwo6-cwl2-128-mnistwo6-70-1.0-0