In [31]:
import sys
sys.path.append("../")
import re
from os.path import join
from IPython.display import display
import os
from collections import OrderedDict

import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
matplotlib.rc('text', usetex=False)

import dill
import scipy
import faiss
import matplotlib.pyplot as plt
import joblib
import numpy as np
import pandas as pd
import torch
from tqdm.notebook import tqdm
import torchvision
from torchvision import transforms
import torch.nn.functional as F

from lolip.extra_ood_utils import get_ood_data_paths
from madry_robustness_resnet import resnet50
from lolip.variables import auto_var

fontsize=15

# for auto-reloading external modules 
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython 
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [96]:
def get_preds(model, dset, batch_size=64, device="cuda"):
    model.eval().to(device)
    if isinstance(dset, np.ndarray):
        dset = torch.utils.data.TensorDataset(
            torch.from_numpy(dset.transpose(0, 3, 1, 2)).float(),
            torch.ones(len(dset))
        )
    loader = torch.utils.data.DataLoader(dset, batch_size=batch_size, shuffle=False, num_workers=24)

    ret = []
    for (x, _) in tqdm(loader, desc="[get_preds]"):
        with torch.no_grad():
            output = F.softmax(model(x.to(device)), dim=1)
        ret.append(output.cpu().numpy())
    del loader
    return np.concatenate(ret, axis=0)

## CIFAR10

In [105]:
model_paths = [
    './pretrained/madry_robustness/cifar_nat.pt',
    './pretrained/madry_robustness/cifar_l2_0_25.pt',
    './pretrained/madry_robustness/cifar_l2_0_5.pt',
    './pretrained/madry_robustness/cifar_l2_1_0.pt',
]
model_names = ['natural', 'AT(.25)', 'AT(.5)', 'AT(1.0)']

normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                                std=[0.2023, 0.1994, 0.2010])
def np_normalize(X):
    X = X.transpose(0, 3, 1, 2)
    return normalize(torch.from_numpy(X)).numpy().transpose(0, 2, 3, 1)

trnX, trny, tstX, tsty, _ = auto_var.get_var_with_argument("dataset", "cifar10")
trnX = np_normalize(trnX)
tstX = np_normalize(tstX)

results = {}
for name, path in zip(model_names, model_paths):
    
    model = resnet50()
    res = torch.load(path, pickle_module=dill)
    state_dict = {}
    sd = res['model'] if 'model' in res else res['state_dict']
    for k, v in sd.items():
        if "module.model." in k:
            state_dict[k.replace("module.model.", "")] = v
    model.load_state_dict(state_dict)
    
    trn_pred = get_preds(model, trnX).argmax(1)
    tst_pred = get_preds(model, tstX).argmax(1)
    results[name] = [(trn_pred == trny).mean(), (tst_pred == tsty).mean()]
    

[get_preds]:   0%|          | 0/782 [00:00<?, ?it/s]

[get_preds]:   0%|          | 0/157 [00:00<?, ?it/s]

[get_preds]:   0%|          | 0/782 [00:00<?, ?it/s]

[get_preds]:   0%|          | 0/157 [00:00<?, ?it/s]

[get_preds]:   0%|          | 0/782 [00:00<?, ?it/s]

[get_preds]:   0%|          | 0/157 [00:00<?, ?it/s]

[get_preds]:   0%|          | 0/782 [00:00<?, ?it/s]

[get_preds]:   0%|          | 0/157 [00:00<?, ?it/s]

In [107]:
df = pd.DataFrame.from_dict(results, orient="index", columns = ['trn acc.', 'tst acc.'])
print(df.transpose().to_latex(float_format="%.2f"))

\begin{tabular}{lrrrr}
\toprule
{} &  natural &  AT(.25) &  AT(.5) &  AT(1.0) \\
\midrule
trn acc. &     1.00 &     0.97 &    0.98 &     0.86 \\
tst acc. &     0.95 &     0.93 &    0.91 &     0.82 \\
\bottomrule
\end{tabular}



## NCG vs. tst

In [46]:
_, ood_names = get_ood_data_paths("cifar10", "not_matter")

preds, nnidxs, dists = joblib.load("nb_results/madry-cifar10-c.pkl")
model_names = ['natural', 'TRADES(2)', 'TRADES(4)', 'TRADES(8)', 'AT(2)',]
model_names = ['natural', 'AT(.25)', 'AT(.5)', 'AT(1.0)']

data = {}
for key in preds.keys():
    ds_name = key[0]
    if key[1] not in model_names:
        continue
    for ood_name in ood_names:
        _, trny, _, tsty, _ = auto_var.get_var_with_argument("dataset", "cifar10")
        ncg_ind = (preds[key][ood_name].argmax(1) == trny[nnidxs[key][ood_name]])
        tstacc_ind = (preds[key][ood_name].argmax(1) == np.tile(tsty, 5))
        if preds[key][ood_name].argmax(1).shape[0] != (len(tsty) * 5):
            import ipdb; ipdb.set_trace()

        for i in range(5):
            counts = 10000
            ttest = scipy.stats.ttest_ind(
                np.logical_and(tstacc_ind[counts*i: counts*(i+1)], ncg_ind[counts*i: counts*(i+1)])[ncg_ind[counts*i: counts*(i+1)] == 1],
                np.logical_and(tstacc_ind[counts*i: counts*(i+1)], ncg_ind[counts*i: counts*(i+1)] == 0)[ncg_ind[counts*i: counts*(i+1)] == 0],
                equal_var=False,
                alternative="greater",
            )
            data[(key[0], key[1], ood_name, i)] = [
                ncg_ind[counts*i: counts*(i+1)].mean(),
                tstacc_ind[counts*i: counts*(i+1)].mean(),
                np.logical_and(tstacc_ind, ncg_ind)[counts*i: counts*(i+1)].mean(),
                (np.logical_and(tstacc_ind[counts*i: counts*(i+1)], ncg_ind[counts*i: counts*(i+1)])[ncg_ind[counts*i: counts*(i+1)] == 1]).mean(),
                (np.logical_and(tstacc_ind[counts*i: counts*(i+1)], ncg_ind[counts*i: counts*(i+1)] == 0)[ncg_ind[counts*i: counts*(i+1)] == 0]).mean(),
                ttest[0],
                ttest[1],
                dists[key][ood_name][counts*i: counts*(i+1)].mean(),
            ]

In [47]:
df = pd.DataFrame.from_dict(data, orient="index", columns=["NCG acc.", "tst acc.", "both", "[NCG correct] tst acc", "[NCG incorrect] tst acc", "[ttest] t-stats", "[ttest] p-value", "dist"])
df.index = pd.MultiIndex.from_tuples(df.index, names=["dataset", "model", "corruption", "level"])

In [49]:
tdd = df[["NCG acc.", "tst acc.", "[NCG correct] tst acc", "[NCG incorrect] tst acc"]]
tdd = tdd.loc[
    [("CIFAR10", "natural", "gaussian", i) for i in range(5)]
    + [("CIFAR10", "AT(1.0)", "gaussian", i) for i in range(5)]
]
tdd.index = tdd.index.droplevel(2)
tdd = tdd.unstack(1)
tdd.columns = tdd.columns.swaplevel(0, 1)
tdd = tdd.sort_index(axis=1, level=0, ascending=False)
text = tdd.to_latex(multirow=True, float_format="%.2f")
text = text.replace("CIFAR10", "C10")
text = text.replace("NCG acc.", "\\thead{NCG \\\\ acc.}")
text = text.replace("tst acc.", "\\thead{tst \\\\ acc.}")
text = text.replace("[NCG correct] tst acc", "\\thead{NCG \\\\ correct \\\\ tst acc.}")
text = text.replace("[NCG incorrect] tst acc", "\\thead{NCG \\\\ incorrect \\\\ tst acc.}")
text = text.replace("[ttest] p-value", "\\thead{p-value}")
text = text.replace("dist", "\\thead{dist}")
text = text.replace("}{l}", "}{c}")
text = text.replace("lrrrrr", "cccccc")
print(text)

\begin{tabular}{lccccccrrr}
\toprule
        & model & \multicolumn{4}{c}{natural} & \multicolumn{4}{c}{AT(1.0)} \\
        & {} & \thead{tst \\ acc.} & \thead{NCG \\ incorrect \\ tst acc.} & \thead{NCG \\ correct \\ tst acc.} & \thead{NCG \\ acc.} & \thead{tst \\ acc.} & \thead{NCG \\ incorrect \\ tst acc.} & \thead{NCG \\ correct \\ tst acc.} & \thead{NCG \\ acc.} \\
dataset & level &          &                         &                       &          &          &                         &                       &          \\
\midrule
\multirow{5}{*}{C10} & 0 &     0.52 &                    0.50 &                  0.68 &     0.13 &     0.21 &                    0.17 &                  0.31 &     0.30 \\
        & 1 &     0.37 &                    0.35 &                  0.49 &     0.12 &     0.20 &                    0.16 &                  0.31 &     0.29 \\
        & 2 &     0.28 &                    0.26 &                  0.38 &     0.13 &     0.20 &                    0.16 &   

## Feature space

In [58]:
# run scripts/get_preds_on_corrupted_feature.py first
preds, nnidxs, dists = joblib.load("./nb_results/feature-mdary-cifar10.pkl")

In [59]:
_, ood_names = get_ood_data_paths("cifar10", "../data/cifar-ood/")

_, trny, _, tsty, _ = auto_var.get_var_with_argument("dataset", "cifar10")

data = {}
for key in preds.keys():
    for ood_name in ood_names:
        count = 10000
        for i in range(5):
            res = preds[key][ood_name][i*count: (i+1)*count].argmax(1)
            nnidx_res = nnidxs[key][ood_name][i*count: (i+1)*count]
            tstacc = (tsty == res)
            ncgacc = (trny[nnidx_res] == res)
            ncgcorrect_tstacc = (tsty == res)[ncgacc == 1]
            ncgincorrect_tstacc = (tsty == res)[ncgacc == 0]
            data[(key[0], key[1], ood_name, i+1)] = [tstacc.mean(), ncgacc.mean(), ncgcorrect_tstacc.mean(), ncgincorrect_tstacc.mean()]

In [60]:
df = pd.DataFrame.from_dict(data, orient="index",
                            columns=["tst acc.", "NCG acc.", "[NCG correct] tst acc", "[NCG incorrect] tst acc"])
df.index = pd.MultiIndex.from_tuples(df.index)
df

Unnamed: 0,Unnamed: 1,Unnamed: 2,Unnamed: 3,tst acc.,NCG acc.,[NCG correct] tst acc,[NCG incorrect] tst acc
CIFAR10,natural,gaussian,1,0.8235,0.9460,0.845137,0.444444
CIFAR10,natural,gaussian,2,0.6659,0.9050,0.696354,0.375789
CIFAR10,natural,gaussian,3,0.4866,0.8631,0.515236,0.306063
CIFAR10,natural,gaussian,4,0.4179,0.8507,0.439638,0.294039
CIFAR10,natural,gaussian,5,0.3574,0.8426,0.373012,0.273825
CIFAR10,...,...,...,...,...,...,...
CIFAR10,AT(1),speckle_noise,1,0.8802,0.9746,0.889185,0.535433
CIFAR10,AT(1),speckle_noise,2,0.7725,0.9579,0.785259,0.482185
CIFAR10,AT(1),speckle_noise,3,0.7153,0.9439,0.729844,0.470588
CIFAR10,AT(1),speckle_noise,4,0.5873,0.9228,0.605440,0.370466


In [61]:
tdd = df[["NCG acc.", "tst acc.", "[NCG correct] tst acc", "[NCG incorrect] tst acc"]]
tdd = tdd.loc[
    [("CIFAR10", "natural", f"gaussian", i) for i in range(1, 6)]
    + [("CIFAR10", "TRADES(2)", f"gaussian", i) for i in range(1, 6)]
]
tdd.index = tdd.index.droplevel(2)
tdd = tdd.unstack(1)
tdd.columns = tdd.columns.swaplevel(0, 1)
tdd = tdd.sort_index(axis=1, level=0, ascending=False)
text = tdd.to_latex(multirow=True, float_format="%.2f")
text = text.replace("CIFAR10", "C10")
text = text.replace("NCG acc.", "\\thead{NCG \\\\ acc.}")
text = text.replace("tst acc.", "\\thead{tst \\\\ acc.}")
text = text.replace("[NCG correct] tst acc", "\\thead{NCG \\\\ correct \\\\ tst acc.}")
text = text.replace("[NCG incorrect] tst acc", "\\thead{NCG \\\\ incorrect \\\\ tst acc.}")
text = text.replace("[ttest] p-value", "\\thead{p-value}")
text = text.replace("dist", "\\thead{dist}")
text = text.replace("}{l}", "}{c}")
text = text.replace("llrrrrrrrr", "lccccc|cccc")
print(text)

\begin{tabular}{lccccc|cccc}
\toprule
        &   & \multicolumn{4}{c}{natural} & \multicolumn{4}{c}{TRADES(2)} \\
        &   & \thead{tst \\ acc.} & \thead{NCG \\ incorrect \\ tst acc.} & \thead{NCG \\ correct \\ tst acc.} & \thead{NCG \\ acc.} &  \thead{tst \\ acc.} & \thead{NCG \\ incorrect \\ tst acc.} & \thead{NCG \\ correct \\ tst acc.} & \thead{NCG \\ acc.} \\
\midrule
\multirow{5}{*}{C10} & 1 &     0.82 &                    0.44 &                  0.85 &     0.95 &      0.82 &                    0.42 &                  0.83 &     0.97 \\
        & 2 &     0.67 &                    0.38 &                  0.70 &     0.91 &      0.66 &                    0.41 &                  0.68 &     0.95 \\
        & 3 &     0.49 &                    0.31 &                  0.52 &     0.86 &      0.48 &                    0.30 &                  0.49 &     0.93 \\
        & 4 &     0.42 &                    0.29 &                  0.44 &     0.85 &      0.41 &                    0.31 &    

In [52]:
df

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,NCG acc.,tst acc.,both,[NCG correct] tst acc,[NCG incorrect] tst acc,[ttest] t-stats,[ttest] p-value,dist
dataset,model,corruption,level,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
CIFAR10,natural,gaussian,0,0.1265,0.5217,0.0861,0.680632,0.498683,12.846638,1.973668e-36,18.245575
CIFAR10,natural,gaussian,1,0.1242,0.3686,0.0610,0.491143,0.351222,9.278584,2.736960e-20,18.382887
CIFAR10,natural,gaussian,2,0.1287,0.2788,0.0485,0.376845,0.264318,7.860665,3.462525e-15,18.567127
CIFAR10,natural,gaussian,3,0.1379,0.2518,0.0457,0.331400,0.239067,6.846043,5.226890e-12,18.676235
CIFAR10,natural,gaussian,4,0.1372,0.2275,0.0442,0.322157,0.212448,8.207567,2.186847e-16,18.793898
CIFAR10,...,...,...,...,...,...,...,...,...,...,...
CIFAR10,AT(1.0),speckle_noise,0,0.2978,0.2074,0.0912,0.306246,0.165480,14.753360,1.732120e-48,18.191040
CIFAR10,AT(1.0),speckle_noise,1,0.2904,0.2048,0.0892,0.307163,0.162909,14.996475,6.025016e-50,18.295261
CIFAR10,AT(1.0),speckle_noise,2,0.2873,0.2018,0.0867,0.301775,0.161499,14.595599,1.831734e-47,18.365593
CIFAR10,AT(1.0),speckle_noise,3,0.2769,0.1971,0.0830,0.299747,0.157793,14.625035,1.407053e-47,18.538107


In [86]:
model_names = ["natural", "AT(1)", "TRADES(2)", "TRADES(4)", "TRADES(8)"]

count_data = {}
for ds_name, ent in df.groupby(axis=0, level=0):
    print(ds_name)
    ent = ent.unstack(2).unstack(2)
    for model_name in model_names[1:]:
        for ood_name in ood_names:
            for i in range(1, 6):
                model_res = ent.loc[(ds_name, model_name)]
                nat_res = ent.loc[(ds_name, "natural")]
                key = ("tst acc.", ood_name, i)
                count_data[(ds_name, f"{model_name}", ood_name, i)] = \
                    [model_res[key] - nat_res[key], model_res[key] / nat_res[key]]
        
ddff = pd.DataFrame.from_dict(count_data, orient="index", columns=["difference", "ratio"])
ddff.index = pd.MultiIndex.from_tuples(ddff.index.tolist())
ddff = ddff.unstack(2).unstack(2)

CIFAR10


In [87]:
columns = [("difference", ood_name, i) for i in range(1, 5) for ood_name in ood_names]
ratio_columns = [("ratio", ood_name, i) for i in range(1, 5) for ood_name in ood_names]
ddff['better counts'] = (ddff[columns] > 0).sum(1)
ddff['mean'] = ddff[columns].mean(1)
ddff['std'] = ddff[columns].sem(1)
ddff['avg ratio'] = ddff[ratio_columns].mean(1)
ddff['std ratio'] = ddff[ratio_columns].sem(1)
text = ddff[['better counts', 'mean', 'std', 'avg ratio', 'std ratio']].to_latex(multirow=True, float_format="%.2f")
text = text.replace("cifar100coarsewo", "CIFAR100-wo")
text = text.replace("cifar10wo", "CIFAR10-wo")
print(text)

\begin{tabular}{llrrrrr}
\toprule
        &           & better counts & mean &  std & avg ratio & std ratio \\
        &           \\
        &           \\
\midrule
\multirow{4}{*}{CIFAR10} & AT(1) &            70 & 0.00 & 0.00 &      1.01 &      0.00 \\
        & TRADES(2) &            55 & 0.00 & 0.00 &      1.00 &      0.00 \\
        & TRADES(4) &            52 & 0.00 & 0.00 &      1.00 &      0.00 \\
        & TRADES(8) &            55 & 0.00 & 0.00 &      1.00 &      0.00 \\
\bottomrule
\end{tabular}



In [23]:
ood_names = ['gaussian', 'impulse', 'shot', 'defocus', 'motion', 'zoom', 'glass', 'snow', 'fog', 'contrast', 'pixelate', 'brightness',
             'elastic_transform', 'gaussian_blur', 'jpeg_compression', 'saturate', 'spatter', 'speckle_noise']

def get_results(preds, nnidxs, dists):
    data = {}
    
    for key in preds.keys():
        ds_name, model_name = key
        _, trny, _, tsty, _ = auto_var.get_var_with_argument("dataset", "cifar10")
        
        data[(ds_name, model_name, 'tst acc', None)] = (preds[key]['tst'].argmax(1) == tsty).mean()
        for i, ood_name in enumerate(ood_names):
            res = (preds[key][ood_name].argmax(1) == trny[nnidxs[key][ood_name]])
            
            counts = 10000
            for i in range(5):
                data[(ds_name, model_name, f"{ood_name}", int(i)+1)] = res[i * counts: (i+1) * counts].mean()
                data[(ds_name, model_name, f"dist_{ood_name}", int(i)+1)] = [dists[key][ood_name][i * counts: (i+1) * counts].mean()]
    
    return pd.DataFrame.from_dict(data)

In [28]:
preds, nnidxs, dists = joblib.load("nb_results/madry-cifar10-c.pkl")

df = get_results(preds, nnidxs, dists)
#df = df.stack(level=0).stack(level=0)[[('tst acc', None)] + [(ood_name, int(i)) for i in range(1, 6) for ood_name in ood_names] + [f"dist_{ood_name}_{i}" for i in range(1, 6) for ood_name in ood_names]]
df = df.stack(level=0).stack(level=0)
df.index = pd.MultiIndex.from_tuples([i[1:] for i in df.index])
#df = df.reindex(sorted(df.index, key=lambda x: (x[0], model_names.index(x[1]))))

In [29]:
df

Unnamed: 0_level_0,Unnamed: 1_level_0,brightness,brightness,brightness,brightness,brightness,contrast,contrast,contrast,contrast,contrast,...,speckle_noise,speckle_noise,speckle_noise,speckle_noise,speckle_noise,zoom,zoom,zoom,zoom,zoom
Unnamed: 0_level_1,Unnamed: 1_level_1,1,2,3,4,5,1,2,3,4,5,...,1,2,3,4,5,1,2,3,4,5
CIFAR10,AT(0.25),0.157,0.1587,0.1508,0.135,0.1017,0.0976,0.0651,0.0598,0.0559,0.0379,...,0.1403,0.135,0.1317,0.1221,0.1145,0.1247,0.1252,0.1215,0.122,0.1191
CIFAR10,AT(0.5),0.2599,0.2814,0.2904,0.2972,0.3203,0.2827,0.4354,0.4798,0.4923,0.5007,...,0.2281,0.2212,0.2164,0.2042,0.1889,0.2571,0.2638,0.2689,0.2749,0.2855
CIFAR10,AT(1.0),0.2799,0.275,0.2849,0.3087,0.383,0.287,0.2967,0.3134,0.3254,0.3213,...,0.2978,0.2904,0.2873,0.2769,0.2669,0.3103,0.3143,0.3168,0.3191,0.3228
CIFAR10,natural,0.1428,0.1329,0.1232,0.1148,0.1035,0.0943,0.0382,0.0359,0.0724,0.4001,...,0.1138,0.1117,0.1149,0.1238,0.1349,0.0811,0.0783,0.0792,0.0804,0.0882


In [9]:
ood_names = ['gaussian', 'impulse', 'shot', 'defocus', 'motion', 'zoom', 'glass',
             'snow', 'fog', 'contrast', 'pixelate', 'brightness', 'elastic_transform',
             'gaussian_blur', 'jpeg_compression', 'saturate', 'spatter', 'speckle_noise']
dist_data = {}
for ds_name, ent in df.groupby(axis=0, level=0):
    print(ds_name)
    for model_name in model_names:
        if any(df.index.isin([(ds_name, model_name)])):
            dist_data[(ds_name, f"{model_name}")] = (
                (ent.loc[(ds_name, model_name)][[f'dist_{ood_name}_{i}' for ood_name in ood_names for i in range(1, 6)]]).tolist(),
                (ent.loc[(ds_name, model_name)][[f'{ood_name}_{i}' for ood_name in ood_names for i in range(1, 6)]]).tolist()
            )

CIFAR10


In [10]:
count_data = {}
for ds_name, ent in df.groupby(axis=0, level=0):
    print(ds_name)
    if not any(ent.index.isin([(ds_name, "natural")])):
        continue
    for model_name in model_names[1:]:
        if any(df.index.isin([(ds_name, model_name)])):
            count_data[(ds_name, f"{model_name}")] = \
                ((ent.loc[(ds_name, model_name)] - ent.loc[(ds_name, 'natural')])).tolist() \
                +  ((ent.loc[(ds_name, model_name)] / ent.loc[(ds_name, 'natural')])).tolist()
ddff = pd.DataFrame.from_dict(count_data, orient="index")
ddff.index = pd.MultiIndex.from_tuples(ddff.index.tolist())

CIFAR10


In [16]:
ddff['better counts'] = (ddff[np.arange(2, 92)] > 0).sum(1)
ddff['mean'] = ddff[np.arange(2, 92)].mean(1)
ddff['std'] = ddff[np.arange(2, 92)].sem(1)
ddff['avg ratio'] = ddff[np.arange(183, 274)].mean(1)
ddff['std ratio'] = ddff[np.arange(183, 274)].sem(1)
text = ddff[['better counts', 'mean', 'std', 'avg ratio', 'std ratio']].to_latex(multirow=True, float_format="%.2f")
text = text.replace("cifar100coarsewo", "CIFAR100-wo")
text = text.replace("cifar10wo", "CIFAR10-wo")
print(text)

\begin{tabular}{llrrrrr}
\toprule
        &         &  better counts &  mean &  std &  avg ratio &  std ratio \\
\midrule
\multirow{3}{*}{CIFAR10} & AT(0.25) &             51 &  0.00 & 0.01 &       1.14 &       0.04 \\
        & AT(0.5) &             86 &  0.14 & 0.01 &       3.27 &       0.55 \\
        & AT(1.0) &             88 &  0.18 & 0.01 &       3.09 &       0.22 \\
\bottomrule
\end{tabular}



# Imgnet

In [None]:
import faiss
from faiss.contrib.ondisk import merge_ondisk
from tqdm.notebook import tqdm
import torchvision
from torchvision import transforms
import torch

In [None]:
image_folders = []
ood_names = []
for i in range(1, 4):
    image_folders += [
        f'/tmp2/ImageNet-c/noise/gaussian_noise/{i}/',
        f'/tmp2/ImageNet-c/noise/impulse_noise/{i}/',
        f'/tmp2/ImageNet-c/noise/shot_noise/{i}/',
        f'/tmp2/ImageNet-c/blur/defocus_blur/{i}/',
        f'/tmp2/ImageNet-c/blur/motion_blur/{i}/',
        f'/tmp2/ImageNet-c/blur/zoom_blur/{i}/',
        f'/tmp2/ImageNet-c/blur/glass_blur/{i}/',

        f'/tmp2/ImageNet-c/weather/snow/{i}/',
        f'/tmp2/ImageNet-c/weather/fog/{i}/',
        f'/tmp2/ImageNet-c/weather/frost/{i}/',
        f'/tmp2/ImageNet-c/weather/brightness/{i}/',
        f'/tmp2/ImageNet-c/digital/contrast/{i}/',
        f'/tmp2/ImageNet-c/digital/pixelate/{i}/',
        f'/tmp2/ImageNet-c/digital/jpeg_compression/{i}/',
        f'/tmp2/ImageNet-c/digital/elastic_transform/{i}/',
    ]
    ood_names += [f'gaussian_{i}', f'impulse_{i}', f'shot_{i}', f'defocus_{i}', f'motion_{i}', f'zoom_{i}', f'glass_{i}',
                 f'snow_{i}', f'fog_{i}', f'frost_{i}', f'brightness_{i}', f'contrast_{i}', f'pixelate_{i}', f'jpeg_{i}', f'elastic_{i}']

model_names = [
    'natural',
    'TRADES(2)',
    'TRADES(4)',
    'TRADES(8)',
    'AT',
    'ball',
]

ori_dset = torchvision.datasets.ImageFolder("/tmp2/imgnet/ILSVRC2012_img_train/")

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
def np_normalize(X):
    X = X.transpose(0, 3, 1, 2)
    return normalize(torch.from_numpy(X)).numpy().transpose(0, 2, 3, 1)

In [None]:
#preds, nnidxs, dists = joblib.load("nb_results/madry_imgnet.pkl")
preds, nnidxs, dists = {}, {}, {}

ds_name = f'aug10-imgnet100wo{ood_class}'
model_paths = [
    "./pretrained/madry_robustness/imagenet_l2_3_0.pt"
]

trnX, trny, tstX, tsty, (ood1X, ood2X) = auto_var.get_var_with_argument("dataset", ds_name)
oodX = np.concatenate((ood1X, ood2X), axis=0)

dset = torchvision.datasets.ImageFolder(
    "/tmp2/imgnet/ILSVRC2012_img_train/",
    transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
    ])
)

ood_classes = [ood_class]

index = faiss.IndexFlatL2(int(np.prod(trnX.shape[1:])))
index.add(np_normalize(trnX).reshape(len(trnX), -1).astype(np.float32))
oodD, oodI = index.search(np_normalize(oodX).reshape(len(oodX), -1), k=1)

oodXs = []
for image_folder in tqdm(image_folders, desc="[get oodXs]"):
    dset = torchvision.datasets.ImageFolder(
        image_folder,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
        ])
    )
    _loader = torch.utils.data.DataLoader(dset, batch_size=64, shuffle=False, num_workers=24)
    tX = np.concatenate([x.numpy() for (x, _) in _loader], axis=0).transpose(0, 2, 3, 1)
    ty = np.concatenate([y.numpy() for (_, y) in _loader])

    valid_classes = []
    for c in ori_dset.classes:
        valid_classes.append(dset.class_to_idx[c])
    assert len(valid_classes) == 100
    valid_classes = np.delete(np.array(valid_classes), ood_classes)        
    valid_idx = np.array([i for i, c in enumerate(ty) if c in valid_classes])
    oodXs.append(tX[valid_idx])

for model_name, model_path in zip(model_names, model_paths):
    if not os.path.exists(model_path):
        print(f"`{model_path}` does not exist. skipping...")
        continue
    key = (ds_name, model_name)
    #if key in preds:
    #    continue
    preds.setdefault(key, {})
    nnidxs.setdefault(key, {})
    dists.setdefault(key, {})

    arch_name = model_path.split("-")[model_path.split("-").index("vtor2") + 1]
    model = getattr(archs, arch_name)(n_classes=len(np.unique(trny)), n_channels=3)
    model.load_state_dict(torch.load(model_path)['model_state_dict'])

    tst_preds = get_preds(model, tstX)
    ood_preds = get_preds(model, oodX)

    preds[key]["tst"] = tst_preds
    preds[key]["ncg"] = ood_preds
    nnidxs[key]["ncg"] = oodI[:, 0]
    dists[key]["ncg"] = np.sqrt(oodD[:, 0])

    for i, tX in enumerate(oodXs):
        if ood_names[i] in preds[key]:
            continue
        D, I = index.search(np_normalize(tX).reshape(len(tX), -1), k=1)
        other_ood_pred = get_preds(model, tX)
        preds[key][ood_names[i]] = other_ood_pred
        nnidxs[key][ood_names[i]] = I[:, 0]
        dists[key][ood_names[i]] = np.sqrt(D[:, 0])

In [None]:
index = faiss.read_index("populated.index")

i = 42
index.make_direct_map()
index.reconstruct(i).reshape(1,-1).astype(np.float32)

xq = fvecs_read("./gist/gist_query.fvecs")

index.nprobe = 80
distances, neighbors = index.search(xq, k)

In [53]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
dset = torchvision.datasets.ImageFolder(
    "/tmp2/imgnet/ILSVRC2012_img_train/",
    transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])
)

batch_size = 32768
index = None
loader = torch.utils.data.DataLoader(dset, batch_size=batch_size, shuffle=False, num_workers=36)
ys = []
for i, (x, y) in tqdm(enumerate(loader), total=len(loader)):
    x = x.numpy().transpose(0, 2, 3, 1).astype(np.float32).reshape(len(x), -1)
    ys.append(y.numpy())
    
    if i == 0:
        index = faiss.index_factory(x.shape[1], "IVF4096,Flat")
        index.train(xb[0:batch_size])
        faiss.write_index(index, "indexes/imgnet_trained.index")
    else:
        index = faiss.read_index("indexes/imgnet_trained.index")
        index.add_with_ids(x, np.arange(i * batch_size, (i + 1) * batch_size))
        faiss.write_index(index, f"indexes/imgnet_block_{i}.index")

def np_normalize(X):
    X = X.transpose(0, 3, 1, 2)
    return normalize(torch.from_numpy(X)).numpy().transpose(0, 2, 3, 1)

  0%|          | 0/40 [00:00<?, ?it/s]

KeyboardInterrupt: 