In [None]:
%load_ext autoreload
%autoreload 2

from functools import lru_cache
import sys
sys.path.append("../")
from os.path import join
from IPython.display import display
from functools import partial

import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
matplotlib.rc('text', usetex=True)

import matplotlib.pyplot as plt
plt.rcParams['text.usetex'] = True #Let TeX do the typsetting
plt.rcParams['text.latex.preamble'] = r"""
\usepackage{sansmath}
\sansmath
""" #Force sans-serif math mode (for axes labels)
plt.rcParams['font.family'] = 'sans-serif' # ... for regular text
plt.rcParams['font.sans-serif'] = 'Helvetica, Avant Garde, Computer Modern Sans serif' # Choose a nice font here

from scipy.stats import ttest_ind, sem
from scipy.fft import fft2
import scipy
import joblib
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm
from skimage.segmentation import mark_boundaries
from scipy.stats import pearsonr
import seaborn as sns
import pandas as pd
pd.options.display.float_format = '{:,.3f}'.format
pd.set_option('display.max_rows', 128)


from spurious_ml.datasets import add_spurious_correlation, add_colored_spurious_correlation
from spurious_ml.models.torch_utils import archs
from spurious_ml.variables import auto_var
from params import *
from utils import params_to_dataframe

fontsize=16

In [None]:
def mlp_pred_fn(X, model, device="cuda"):
    dset = torch.utils.data.TensorDataset(torch.from_numpy(X.reshape(len(X), -1)).float())
    loader = torch.utils.data.DataLoader(dset, batch_size=256)
    
    model.to(device).eval()
    fetX = []
    for (x, ) in tqdm(loader, desc="[pred_fn]"):
        fetX.append(model(x.to(device)).cpu().detach().numpy())
    fetX = np.concatenate(fetX, axis=0)
    return fetX

def cnn_pred_fn(X, model, device="cuda"):
    if len(X.shape) == 4:
        X = X.transpose(0, 3, 1, 2)
        
    dset = torch.utils.data.TensorDataset(torch.from_numpy(X).float())
    loader = torch.utils.data.DataLoader(dset, batch_size=128)
    
    model.to(device).eval()
    fetX = []
    #for (x, ) in loader:
    for (x, ) in tqdm(loader, desc="[pred_fn]"):
        fetX.append(model(x.to(device)).cpu().detach().numpy())
        #fetX.append(model.feature_extractor(x.to(device)).cpu().detach().flatten(1).numpy())
    fetX = np.concatenate(fetX, axis=0)
    return fetX

class CLF():
    def __init__(self, model):
        self.model = model
        
    def predict(self, X, device="cuda"):
        return pred_fn(X, self.model, device=device)

In [None]:
@lru_cache(maxsize=None)
def evaluate(ds_name, model_path, arch, spurious_version, seed):
    trnX, trny, tstX, tsty, spurious_ind = auto_var.get_var_with_argument("dataset", ds_name)
    n_classes = len(np.unique(trny))
    n_channels = trnX.shape[-1]
    res = torch.load(model_path)
    model = getattr(archs, arch)(n_features=(np.prod(trnX.shape[1:]), ), n_channels=n_channels, n_classes=n_classes)
    model.load_state_dict(res['model_state_dict'])
    
    if 'MLP' in arch:
        pred_fn = mlp_pred_fn
    else:
        pred_fn = cnn_pred_fn
    
    #trn_preds = pred_fn(trnX, model)
    tst_preds = pred_fn(tstX, model)
    tst_preds = scipy.special.softmax(tst_preds, axis=1)
    modified_tstX = np.copy(tstX)
    if n_channels == 3:
        modified_tstX = add_colored_spurious_correlation(modified_tstX, spurious_version, seed)
    else:
        modified_tstX = add_spurious_correlation(modified_tstX, spurious_version, seed)
    mod_tst_preds = pred_fn(modified_tstX, model)
    mod_tst_preds = scipy.special.softmax(mod_tst_preds, axis=1)
    return mod_tst_preds, tst_preds, tsty
    #return ((mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > 0).mean(), (mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]).mean()
    #return ((mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > 1e-9).mean(), (mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]).mean()`

In [None]:
def get_model_path(base_dset, spurious_version, i, tar_cls, seed, arch, momentum, optimizer, lr, model_seed, aug=None, bs=None, wd=0.0, folder="train_classifier"):
    ds_name = f"{base_dset}{spurious_version}-{i}-{tar_cls}-{seed}"
    if base_dset in ['mnist', 'fashion']:
        bs = 128 if bs is None else bs
        if grad_clip is None:
             model_path = f"../models/{folder}/{bs}-{ds_name}-70-{lr}-ce-tor-{arch}-{momentum}-{optimizer}-{model_seed}-{wd}.pt"
        else:
             model_path = f"../models/{folder}/{bs}-{ds_name}-70-{grad_clip}-{lr}-ce-tor-{arch}-{momentum}-{optimizer}-{model_seed}-{wd}.pt"
    else:
        bs = 64 if bs is None else bs
        if grad_clip is None:
            model_path = f"../models/{folder}/{bs}-{ds_name}-70-{lr}-ce-tor-{arch}-{momentum}-{optimizer}-{model_seed}-{wd}.pt"
        else:
            model_path = f"../models/{folder}/{bs}-{ds_name}-70-{grad_clip}-{lr}-ce-tor-{arch}-{momentum}-{optimizer}-{model_seed}-{wd}.pt"
        if aug is not None:
            model_path = model_path.replace("-ce-tor-", f"-{aug}-ce-tor-")
    return ds_name, model_path

# MNIST/Fashion

In [None]:
all_results = {}
ttest_results = {}
base_dset = "mnist"
base_dset = "fashion"

threshold = 1e-1

all_accs = []

#n_samples = [3, 5, 10, 20, 100, 2000, 5000]
#optimizers = ['sgd', 'adam']
##optimizers = ['adam']
#spurious_versions = ['v1', 'v3', 'v8', 'v18', 'v19', 'v20', 'v30']
#architecures = ['LargeMLP']
##grad_clips = [None, 0.1]
#grad_clips = [None]
#tar_clses = [0, 1]

n_samples = [3, 5, 10, 20, 100, 2000, 5000]
optimizers = ['adam']
spurious_versions = ['v8', 'v19', 'v20']
architecures = ['MLP', 'LargeMLP', 'LargeMLPv2', 'CNN002']
grad_clips = [None]
tar_clses = [0, 1]

#n_samples = [3, 10]
#optimizers = ['sgd', 'adam']
#spurious_versions = ['v8', 'v20']
#architecures = ['LargeMLP']
#grad_clips = [None, 0.1]
#tar_clses = [0]

#optimizers = ['adam']
#spurious_versions = ['v8', 'v20']
#architecures = ['LargeMLP']
#grad_clips = [None, 0.1]

#n_samples = [20]
#optimizers = ['adam']
#spurious_versions = ['v1', 'v3', 'v8', 'v18', 'v19', 'v20', 'v30']
#architecures = ['LargeMLP']
#grad_clips = [None]
#tar_clses = [0,]

for toptimizer in optimizers:
    for tar_cls in tar_clses:
        for spurious_version in spurious_versions:
            for arch in architecures:
                lr = 0.01
                optimizer = toptimizer if 'Vgg' not in arch else "sgd"
                momentum = 0. if optimizer == 'adam' else 0.9

                for grad_clip in grad_clips:
                    if grad_clip is None:
                        model_path = f"../models/train_classifier/128-{base_dset}-70-0.01-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt"
                    else:
                        model_path = f"../models/train_classifier/128-{base_dset}-70-{grad_clip}-0.01-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt"
                    mod_tst_preds, tst_preds, tsty = evaluate(base_dset, model_path, arch, spurious_version, 0)
                    #ind = (tsty != tar_cls)
                    ind = np.arange(len(tsty))
                    mod_tst_preds, tst_preds, tsty = mod_tst_preds[ind], tst_preds[ind], tsty[ind]
                    baseline_rv = (mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > threshold
                    #baseline_rv = (((mod_tst_preds - tst_preds)[:, tar_cls] == (mod_tst_preds - tst_preds).max(1)))
                    baseline_nos = baseline_rv.mean()
                    baseline_sem = sem(baseline_rv)
                    baseline_acc = (tst_preds.argmax(1) == tsty).mean()
                    baseline_mod_tst_preds = mod_tst_preds

                    ret, sems, accs, test_res = [baseline_nos], [baseline_sem], [baseline_acc], []
                    for i in n_samples:
                        taccs, tret, ttest_res = [], [], []
                        for seed in range(5):
                            ds_name, model_path = get_model_path(base_dset, spurious_version, i, tar_cls, seed, arch, momentum, optimizer, lr, model_seed=0)
                            try:
                                mod_tst_preds, tst_preds, tsty = evaluate(ds_name, model_path, arch, spurious_version, seed)
                                #ind = (tsty != tar_cls)
                                ind = np.arange(len(tsty))
                                mod_tst_preds, tst_preds, tsty = mod_tst_preds[ind], tst_preds[ind], tsty[ind]
                                #mod_tst_preds, tst_preds = mod_tst_preds[tsty == tar_cls], tst_preds[tsty == tar_cls]
                                #tsty = tsty[tsty == tar_cls]
                                #rv = (mod_tst_preds[:, tar_cls] - baseline_mod_tst_preds[:, tar_cls]) > threshold
                                rv = (mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > threshold
                                #rv = (((mod_tst_preds - tst_preds)[:, tar_cls] == (mod_tst_preds - tst_preds).max(1)))
                                ttest_res.append(ttest_ind(rv, baseline_rv, equal_var=False, alternative="greater")[1])
                                #taccs.append((mod_tst_preds.argmax(1) == tar_cls).mean())
                                taccs.append((tst_preds.argmax(1) == tsty).mean())
                                tret.append(rv.mean())
                                #tret.append(rv.mean() / baseline_nos)
                            except FileNotFoundError:
                                print(f"missing {model_path}")
                                #ret.append(-1)
                                #accs.append(-1)
                        if taccs:
                            ret.append(np.mean(tret))
                            sems.append(sem(tret))
                            accs.append(np.mean(taccs))
                            test_res.append(ttest_res[0])
                            all_accs += taccs
                        else:
                            ret.append(-1)
                            sems.append(-1)
                            accs.append(-1)
                            test_res.append(-1)
                            
                    key = (optimizer, tar_cls, spurious_version, arch, grad_clip)
                    ttest_results[key] = test_res
                    all_results[key] = ret + accs + sems
                    #ttest_results[(tar_cls, spurious_version, arch)] = test_res
                    #all_results[(tar_cls, spurious_version, arch)] = ret + accs
print(min(all_accs), max(all_accs), np.mean(all_accs), np.std(all_accs))

In [None]:
print(min(all_accs), max(all_accs), np.mean(all_accs), np.std(all_accs))

## AUC

In [None]:
print(base_dset)
df = pd.DataFrame.from_dict(all_results).transpose()
df.columns = pd.MultiIndex.from_tuples([("spu", i) for i in ["0"] + n_samples] + [("acc", i) for i in ["0"] + n_samples] + [("sem", i) for i in ["0"] + n_samples])

opt_name_map = {
    "sgd": "SGD",
    "adam": "Adam",
}

data_avgs, data_sems = [], []
for opt in optimizers:
    temp_avgs, temp_sems = [], []
    for pat in spurious_versions:
        avgs = df.loc[(opt, 0, pat, "LargeMLP", float("NaN"))][[("spu", 3), ("spu", 10)]].tolist()
        sems = df.loc[(opt, 0, pat, "LargeMLP", float("NaN"))][[("sem", 3), ("sem", 10)]].tolist()
        temp_avgs += avgs
        temp_sems += sems
    data_avgs.append(temp_avgs)
    data_sems.append(temp_sems)

    
plt.figure(figsize=(8, 3))
width = 0.3
location = 0.
for i, (opt, avgs, sems) in enumerate(zip(optimizers, data_avgs, data_sems)):
    plt.bar(np.arange(4) + width*i, avgs, width=width, yerr=sems, label=opt_name_map[opt])

plt.legend(["SGD", "Adam"], fontsize=fontsize)
plt.yticks(fontsize=fontsize)
plt.ylabel("spurious score", fontsize=fontsize)
xnames = ["3 S3", "10 S3", "3 R3", "10 R3"]
plt.xticks(np.arange(4) + width / 2, xnames, fontsize=fontsize)
plt.xlabel("# spurious examples and the spurious pattern", fontsize=fontsize)
plt.tight_layout()
plt.savefig(f"figs/bar_charts/difopt_{base_dset}_{tar_clses[0]}.png")
plt.show()
plt.close()
    
data_avgs, data_sems = [], []
opt = "adam"
for grad_clip in grad_clips:
    temp_avgs, temp_sems = [], []
    for pat in spurious_versions:
        avgs = df.loc[(opt, 0, pat, "LargeMLP", grad_clip)][[("spu", 3), ("spu", 10)]].tolist()
        sems = df.loc[(opt, 0, pat, "LargeMLP", grad_clip)][[("sem", 3), ("sem", 10)]].tolist()
        temp_avgs += avgs
        temp_sems += sems
    data_avgs.append(temp_avgs)
    data_sems.append(temp_sems)

grad_clip_names = ["w/ clip", "w/o clip"]
plt.figure(figsize=(8, 3))
width = 0.3
location = 0.
for i, (clip_name, avgs, sems) in enumerate(zip(grad_clip_names, data_avgs, data_sems)):
    plt.bar(np.arange(4) + width*i, avgs, width=width, yerr=sems, label=clip_name)

plt.legend(grad_clip_names, fontsize=fontsize)
plt.yticks(fontsize=fontsize)
plt.ylabel("spurious score", fontsize=fontsize)
xnames = ["3 S3", "10 S3", "3 R3", "10 R3"]
plt.xticks(np.arange(4) + width / 2, xnames, fontsize=fontsize)
plt.xlabel("# spurious examples and the spurious pattern", fontsize=fontsize)
plt.tight_layout()
plt.savefig(f"figs/bar_charts/difclip_{base_dset}_{tar_clses[0]}.png")

In [None]:
td = df[[("spu", i) for i in n_samples]]
td.columns = td.columns.droplevel(0)
td.index = td.index.droplevel(3)

td = td.mean(1).unstack(2)
text = td.to_latex(multirow=True, float_format="%.3f")
text = text.replace("0.", ".")
print(text)
td

In [None]:
print(base_dset)
df = pd.DataFrame.from_dict(all_results).transpose()
df.columns = pd.MultiIndex.from_tuples([("spu", i) for i in ["0"] + n_samples] + [("acc", i) for i in ["0"] + n_samples] + [("sem", i) for i in ["0"] + n_samples])
#print(df[[("spu", i) for i in ["0"] + n_samples]].to_latex(multirow=True))
display(df)

td = df[[("spu", i) for i in ["0"] + n_samples]].unstack(0)
td.columns = td.columns.droplevel(0)
td.columns = td.columns.swaplevel(0, 1)
td.index = td.index.droplevel(2)  # arch
#td.index = td.index.droplevel(1)  # pattern
#td.index = td.index.droplevel(2)  # clipping

def key_fn(xs):
    if 'v3' in xs:
        return [spurious_versions.index(x) for x in xs]
    elif 'LargeMLP' in xs:
        return [architecures.index(x) for x in xs]
    else:
        return xs
td = td.sort_index(axis=0, level=[0, 1], key=key_fn)

try:
    #text = td[[("sgd", i) for i in ["0"] + n_samples]].to_latex(multirow=True, float_format="%.3f")
    text = td[[("sgd", i) for i in n_samples]].to_latex(multirow=True, float_format="%.3f")
    text = text.replace("0.", ".")
    text = text.replace("1.000", "1.00")
    print(text)
except:
    pass
text = td[[("adam", i) for i in n_samples]].to_latex(multirow=True, float_format="%.3f")
text = text.replace("0.", ".")
text = text.replace("1.000", "1.00")
print(text)


## Different grad clip

In [None]:
print(base_dset)
df = pd.DataFrame.from_dict(all_results).transpose()
df.columns = pd.MultiIndex.from_tuples([("spu", i) for i in ["0"] + n_samples] + [("acc", i) for i in ["0"] + n_samples] + [("sem", i) for i in ["0"] + n_samples])
df = df[[("spu", i) for i in ["0"] + n_samples] + [("sem", i) for i in ["0"] + n_samples]]

for optimizer in ['adam']:
    for tar_cls in [0, 1]:
        for arch in architecures:
            for spu in spurious_versions:
                for grad_clip in grad_clips:
                    plt.errorbar(
                        x=np.arange(len(n_samples) + 1),
                        y=df.loc[(optimizer, tar_cls, spu, arch, float("NaN"))][[("spu", i) for i in ["0"] + n_samples]].tolist(),
                        yerr=df.loc[(optimizer, tar_cls, spu, arch, float("NaN"))][[("sem", i) for i in ["0"] + n_samples]].tolist(),
                        label=f"{spu} {grad_clip}",
                    )
                    plt.errorbar(
                        x=np.arange(len(n_samples) + 1),
                        y=df.loc[(optimizer, tar_cls, spu, arch, 0.1)][[("spu", i) for i in ["0"] + n_samples]].tolist(),
                        yerr=df.loc[(optimizer, tar_cls, spu, arch, 0.1)][[("sem", i) for i in ["0"] + n_samples]].tolist(),
                        label=f"{spu} {grad_clip}",
                    )
            plt.yticks(fontsize=fontsize)
            plt.xticks(ticks=np.arange(len(n_samples) + 1), labels=[0] + n_samples, fontsize=fontsize-2)
            #plt.xscale("linear")
            plt.ylabel("spurious score", fontsize=fontsize)
            plt.xlabel("\# spurious examples", fontsize=fontsize)
            #plt.legend(loc=(0.02, 0.61), fontsize=fontsize-2)
            plt.legend(fontsize=fontsize-2)
            plt.tight_layout()
            plt.savefig(f"figs/score_plot/difclip_{base_dset}_{optimizer}_{tar_cls}.png")
            plt.show()
            plt.close()

## Different arch

In [None]:
print(base_dset)
df = pd.DataFrame.from_dict(all_results).transpose()
df.columns = pd.MultiIndex.from_tuples([("spu", i) for i in ["0"] + n_samples] + [("acc", i) for i in ["0"] + n_samples] + [("sem", i) for i in ["0"] + n_samples])
df = df[[("spu", i) for i in ["0"] + n_samples] + [("sem", i) for i in ["0"] + n_samples]]

arch_name_map = {
    "MLP": "small MLP",
    "LargeMLP": "MLP",
    "LargeMLPv2": "large MLP",
    "CNN002": "CNN",
}


for optimizer in ['adam']:
    for tar_cls in [0, 1]:
        for spu in spurious_versions:
            for arch in ['MLP', 'LargeMLP', 'LargeMLPv2', 'CNN002']:
                plt.errorbar(
                    x=[0] + n_samples,
                    #x=np.arange(len(n_samples) + 1),
                    y=df.loc[(optimizer, tar_cls, spu, arch, float("NaN"))][[("spu", i) for i in ["0"] + n_samples]].tolist(),
                    yerr=df.loc[(optimizer, tar_cls, spu, arch, float("NaN"))][[("sem", i) for i in ["0"] + n_samples]].tolist(),
                    label=arch_name_map[arch],
                )
            plt.yticks(fontsize=fontsize)
            plt.xticks(fontsize=fontsize)
            #plt.xticks(ticks=np.arange(len(n_samples) + 1), labels=[0] + n_samples, fontsize=fontsize-2)
            #plt.xscale("linear")
            plt.xscale("symlog")
            plt.xlim(-0.25, n_samples[-1]+2000)
            plt.ylabel("spurious score", fontsize=fontsize)
            plt.xlabel("\# spurious examples", fontsize=fontsize)
            #plt.legend(loc=(0.02, 0.61), fontsize=fontsize-2)
            plt.legend(fontsize=fontsize-2)
            plt.tight_layout()
            plt.savefig(f"figs/score_plot/difarchs_{base_dset}_{spu}_{optimizer}_{tar_cls}.png")
            #plt.savefig(f"figs/score_plot/difarchs_{base_dset}_{spu}_{optimizer}_{tar_cls}_log.png")
            plt.show()
            plt.close()

## score figs

In [None]:
print(base_dset)
df = pd.DataFrame.from_dict(all_results).transpose()
df.columns = pd.MultiIndex.from_tuples([("spu", i) for i in ["0"] + n_samples] + [("acc", i) for i in ["0"] + n_samples] + [("sem", i) for i in ["0"] + n_samples])
df = df[[("spu", i) for i in ["0"] + n_samples] + [("sem", i) for i in ["0"] + n_samples]]
spu_names = ['\\textit{S1}', '\\textit{S2}', '\\textit{S3}', '\\textit{R1}', '\\textit{R2}', '\\textit{R3}', '\\textit{Co}']

for optimizer in ['sgd', 'adam']:
    for tar_cls in [0, 1]:
        for i, spu in enumerate(spurious_versions):
            plt.errorbar(
                #x=np.arange(len(n_samples) + 1),
                x=[0] + n_samples,
                y=df.loc[(optimizer, tar_cls, spu, "LargeMLP", float("NaN"))][[("spu", i) for i in ["0"] + n_samples]].tolist(),
                yerr=df.loc[(optimizer, tar_cls, spu, "LargeMLP", float("NaN"))][[("sem", i) for i in ["0"] + n_samples]].tolist(),
                label=spu_names[i],
            )
        plt.yticks(fontsize=fontsize)
        plt.xticks(fontsize=fontsize)
        #plt.xticks(ticks=np.arange(len(n_samples) + 1), labels=[0] + n_samples, fontsize=fontsize-2)
        plt.xscale("symlog")
        plt.xlim(-0.25, n_samples[-1]+2000)
        plt.ylabel("spurious score", fontsize=fontsize)
        plt.xlabel("\# spurious examples", fontsize=fontsize)
        plt.legend(loc=(1.02, 0.27), fontsize=fontsize)
        plt.tight_layout()
        #plt.savefig(f"figs/score_plot/{base_dset}_{optimizer}_{tar_cls}.png")
        plt.savefig(f"figs/score_plot/{base_dset}_{optimizer}_{tar_cls}_log.png")
        plt.show()
        plt.close()
            

In [None]:
base_dset = "mnist"
spurious_version, i, tar_cls, seed = "v10", 20, 1, 0
folder, lr, arch, momentum, optimizer = "train_classifier", 0.01, "CNN002", 0., "adam"
ds_name = f"{base_dset}{spurious_version}-{i}-{tar_cls}-{seed}"
model_path = f"../models/{folder}/128-{ds_name}-70-{lr}-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt"
mod_tst_preds, tst_preds, tsty = evaluate(ds_name, model_path, arch, spurious_version)

In [None]:
((mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > threshold).mean()

In [None]:
print(base_dset)
df = pd.DataFrame.from_dict(ttest_results).transpose()
df.columns = n_samples
print(df.to_latex(multirow=True))
df

In [None]:
tar_cls = 1
ds_name = f"mnistv3-2000-{tar_cls}-0"
arch_name = "CNN002"
spurious_version = "v3"


model_path = f"../models/train_classifier/128-{ds_name}-70-0.01-ce-tor-{arch_name}-0.0-adam-0-0.0.pt"
mod_tst_preds, tst_preds, tsty = evaluate(ds_name, model_path, arch, spurious_version)

In [None]:
((mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > threshold).mean()

### model parameters

In [None]:
architectures = ['MLP', 'LargeMLP', 'LargeMLPv2', 'CNN002']
ds_name = "mnist"
for arch in architectures:
    trnX, trny, tstX, tsty, spurious_ind = auto_var.get_var_with_argument("dataset", ds_name)
    n_classes = len(np.unique(trny))
    n_channels = trnX.shape[-1]
    model = getattr(archs, arch)(n_features=(np.prod(trnX.shape[1:]), ), n_channels=n_channels, n_classes=n_classes)
    print(sum(p.numel() for p in model.parameters() if p.requires_grad))

# Norm score correlation

## spurious pattern norm

In [None]:
spurious_versions = ['v1', 'v3', 'v8', 'v18', 'v19', 'v20', 'v30']
#base_dsets = ["mnist"]
base_dsets = ["cifar10"]
#base_dsets = ["fashion"]

results = {}
for base_dset in base_dsets:
    for spurious_version in spurious_versions:
        temp = []
        for seed in range(5):
            ds_name = f"{base_dset}{spurious_version}-3-0-{seed}"
            trnX, trny, tstX, tsty, _ = auto_var.get_var_with_argument("dataset", ds_name)
            n_channels = trnX.shape[-1]

            modified_tstX = np.copy(tstX)
            if n_channels == 3:
                modified_tstX = add_colored_spurious_correlation(modified_tstX, spurious_version, seed)
            else:
                modified_tstX = add_spurious_correlation(modified_tstX, spurious_version, seed)
        temp.append(np.linalg.norm((tstX - modified_tstX).reshape(len(tstX), -1), ord=2, axis=1).mean())
        results[(base_dset, spurious_version, 'norm')] = np.mean(temp)

In [None]:
results

In [None]:
all_results = {}
ttest_results = {}
base_dset = "mnist"
#base_dset = "fashion"
spurious_versions = ['v1', 'v3', 'v8', 'v18', 'v19', 'v20', 'v30']
architecures = ['LargeMLP']
n_samples =  [3, 5, 10, 20, 100, 2000, 5000]
aug = None

base_dset = "cifar10"
spurious_versions = ['v1', 'v3', 'v8', 'v18', 'v19', 'v20', 'v30']
architecures = ['altResNet20Norm02']
n_samples = [3, 5, 10, 20, 100, 500]
aug = "aug01"

threshold = 1e-1


#for optimizer in ['sgd', 'adam']:
for optimizer in ['adam']:
    for tar_cls in [0]:
        for spurious_version in spurious_versions:
            for arch in architecures:
                lr = 0.01
                momentum = 0. if optimizer == 'adam' else 0.9

                for grad_clip in [None]:
                    if "cifar" in base_dset:
                        if grad_clip is None:
                            model_path = f"../models/train_classifier/128-{base_dset}-70-{lr}-{aug}-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0001.pt"
                        else:
                            model_path = f"../models/train_classifier/128-{base_dset}-70-{grad_clip}-{lr}-{aug}-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0001.pt"
                    else:
                        if grad_clip is None:
                            model_path = f"../models/train_classifier/128-{base_dset}-70-0.01-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt"
                        else:
                            model_path = f"../models/train_classifier/128-{base_dset}-70-{grad_clip}-0.01-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt"

                    mod_tst_preds, tst_preds, tsty = evaluate(base_dset, model_path, arch, spurious_version, 0)
                    ind = (tsty != tar_cls)
                    mod_tst_preds, tst_preds, tsty = mod_tst_preds[ind], tst_preds[ind], tsty[ind]
                    baseline_rv = (mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > threshold
                    #baseline_rv = (((mod_tst_preds - tst_preds)[:, tar_cls] == (mod_tst_preds - tst_preds).max(1)))
                    baseline_nos = [baseline_rv.mean()]
                    baseline_acc = (tst_preds.argmax(1) == tsty).mean()
                    baseline_mod_tst_preds = mod_tst_preds

                    ret, accs, test_res, avg_norm = [baseline_nos[0]], [baseline_acc], [], []
                    for i in n_samples:
                        taccs, tret, ttest_res, tavg_norm = [], [], [], []
                        for seed in range(5):
                            if "cifar" in base_dset:
                                ds_name, model_path = get_model_path(base_dset, spurious_version, i, tar_cls, seed, arch, momentum, optimizer, lr, aug=aug,
                                                                     bs=128, wd=1e-4, model_seed=0)
                            else:
                                ds_name, model_path = get_model_path(base_dset, spurious_version, i, tar_cls, seed, arch, momentum, optimizer, lr, model_seed=0)
                            try:
                                mod_tst_preds, tst_preds, tsty = evaluate(ds_name, model_path, arch, spurious_version, seed)
                                ind = (tsty != tar_cls)
                                mod_tst_preds, tst_preds, tsty = mod_tst_preds[ind], tst_preds[ind], tsty[ind]
                                rv = (mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > threshold
                                ttest_res.append(ttest_ind(rv, baseline_rv, equal_var=False, alternative="greater")[1])
                                taccs.append((mod_tst_preds.argmax(1) == tar_cls).mean())
                                tret.append(rv.mean())
                                
                                trnX, trny, tstX, tsty, spurious_ind = auto_var.get_var_with_argument("dataset", ds_name)
                                modified_tstX = add_spurious_correlation(np.copy(tstX), spurious_version, seed)
                                tavg_norm.append(np.linalg.norm((tstX - modified_tstX).reshape(len(tstX), -1), axis=1, ord=2))
                            except FileNotFoundError:
                                print(f"missing {model_path}")
                        if taccs:
                            ret.append(np.mean(tret))
                            accs.append(np.mean(taccs))
                            avg_norm.append(np.mean(tavg_norm))
                            test_res.append(ttest_res[0])
                        else:
                            ret.append(-1)
                            accs.append(-1)
                            test_res.append(-1)
                            
                    key = (optimizer, tar_cls, arch, spurious_version, grad_clip)
                    ttest_results[key] = test_res
                    all_results[key] = ret + accs + avg_norm


In [None]:
df = pd.DataFrame.from_dict(all_results).transpose()
df.columns = pd.MultiIndex.from_tuples([("spu", i) for i in ([0] + n_samples)] + [("acc", i) for i in ([0] + n_samples)] + [("norm", i) for i in n_samples])
df = df[[("spu", n_sample) for n_sample in n_samples]]
df.index = df.index.droplevel(0).droplevel(0).droplevel(0).droplevel(1)
df = df.mean(1)
for spurious_version in spurious_versions:
    results[(base_dset, spurious_version, 'score')] = df.loc[spurious_version]

In [None]:
normdf = pd.DataFrame.from_dict(results, orient="index")
normdf.index = pd.MultiIndex.from_tuples(normdf.index)
normdf = normdf.unstack(2)
normdf.columns = normdf.columns.droplevel(0)
normdf = normdf.sort_index(axis=0, level=[1], key=lambda xs: [spurious_versions.index(x) for x in xs])
#print(scipy.stats.pearsonr(normdf['norm'], normdf['score']))

plt.scatter(normdf['norm'], normdf['score'], s=100)
plt.xlabel("empirical norm", fontsize=fontsize)
plt.ylabel("spurious score", fontsize=fontsize)
plt.xticks(fontsize=fontsize)
plt.yticks(fontsize=fontsize)
plt.tight_layout()
plt.savefig(f"./figs/norm_scatter/{base_dset}_{architecures[0]}_avg_spuscore.png")

In [None]:
pearsonr(normdf['norm'].tolist(), normdf['score'].tolist())

# CIFAR10

In [None]:
all_results = {}
ttest_results = {}
base_dset = "cifar10"

threshold = 1e-1

all_accs = []

n_samples = [3, 5, 10, 20, 100, 500]

aug = "aug01"

optimizers = ['adam']
tar_clses = [0, 1]
spurious_versions = ['v1', 'v3', 'v8', 'v18', 'v19', 'v20', 'v30']
architecures = ['altResNet20Norm02']
grad_clips = [None]

#optimizers = ['sgd', 'adam']
#tar_clses = [0, 1]
#spurious_versions = ['v8', 'v20']
#architecures = ['altResNet20Norm02']
#grad_clips = [None, 0.1]

#aug = "aug01"
#optimizers = ['adam', 'sgd']
#tar_clses = [0, 1]
#spurious_versions = ['v1', 'v3', 'v8', 'v18', 'v19', 'v20', 'v30']
#architecures = ['altResNet20Norm02']
#grad_clips = [None, 0.1]

#optimizers = ['adam']
#tar_clses = [0, 1]
#spurious_versions = ['v8', 'v20']
#architecures = ['ResNet50']
#grad_clips = [None, 0.1]

#optimizers = ['sgd']
#tar_clses = [0, 1]
#spurious_versions = ['v8', 'v20']
#architecures = ['altResNet20Norm02', 'altResNet32Norm02', 'altResNet110Norm02', 'Vgg16Norm02']
#grad_clips = [None]

for toptimizer in optimizers:
    for tar_cls in tar_clses:
        for spurious_version in spurious_versions:
            #for arch in ['Vgg16', 'ResNet50']:
            for arch in architecures:
                optimizer = toptimizer if 'Vgg' not in arch else "sgd"
                momentum = 0. if optimizer == 'adam' else 0.9
                if 'Vgg' in arch:
                    lr = 0.01
                else:
                    lr = 0.01 if optimizer == 'adam' else 0.1

                if aug is None:
                    model_path = f"../models/train_classifier/128-{base_dset}-70-{lr}-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0001.pt"
                else:
                    model_path = f"../models/train_classifier/128-{base_dset}-70-{lr}-{aug}-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0001.pt"

                
                mod_tst_preds, tst_preds, tsty = evaluate(base_dset, model_path, arch, spurious_version, seed=0)
                #ind = (tsty != tar_cls)
                ind = np.arange(len(tsty))
                mod_tst_preds, tst_preds, tsty = mod_tst_preds[ind], tst_preds[ind], tsty[ind]
                baseline_rv = (mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > threshold
                # baseline_rv = (((mod_tst_preds - tst_preds)[:, tar_cls] == (mod_tst_preds - tst_preds).max(1)))
                baseline_nos = baseline_rv.mean()
                baseline_sem = sem(baseline_rv)
                baseline_acc = (tst_preds.argmax(1) == tsty).mean()
                baseline_mod_tst_preds = mod_tst_preds
                
                for grad_clip in grad_clips:
                    ret, sems, accs, test_res = [baseline_nos], [baseline_sem], [baseline_acc], []
                    for i in n_samples:
                        taccs, tret, ttest_res = [], [], []
                        for seed in range(5):
                            ds_name, model_path = get_model_path(base_dset, spurious_version, i, tar_cls, seed,
                                                                 arch, momentum, optimizer, lr, model_seed=0,
                                                                 aug=aug, bs=128, wd=0.0001)

                            try:
                                mod_tst_preds, tst_preds, tsty = evaluate(ds_name, model_path, arch, spurious_version, seed=seed)
                                if (tst_preds.argmax(1) == tsty).mean() < 0.7:
                                    print(model_path, (tst_preds.argmax(1) == tsty).mean())
                                #ind = (tsty != tar_cls)
                                ind = np.arange(len(tsty))
                                mod_tst_preds, tst_preds, tsty = mod_tst_preds[ind], tst_preds[ind], tsty[ind]
                                #rv = (((mod_tst_preds - tst_preds)[:, tar_cls] == (mod_tst_preds - tst_preds).max(1)))
                                rv = (mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > threshold
                                ttest_res.append(ttest_ind(rv, baseline_rv, equal_var=False, alternative="greater")[1])
                                taccs.append((tst_preds.argmax(1) == tsty).mean())
                                #tret.append(rv.mean() / baseline_nos)
                                tret.append(rv.mean())
                            except FileNotFoundError:
                                print(f"missing {model_path}")
                        
                        if taccs:
                            test_res.append(ttest_res[0])
                            sems.append(sem(tret))
                            ret.append(np.mean(tret))
                            accs.append(np.mean(taccs))
                            all_accs += taccs
                        else:
                            test_res.append(-1)
                            sems.append(-1)
                            ret.append(-1)
                            accs.append(-1)
                        
                    ttest_results[(optimizer, tar_cls, spurious_version, arch)] = test_res
                    all_results[(optimizer, tar_cls, spurious_version, grad_clip, arch)] = ret + accs + sems
                    
print(min(all_accs), max(all_accs), np.mean(all_accs), np.std(all_accs))

## AUC

In [None]:
print(base_dset)
df = pd.DataFrame.from_dict(all_results).transpose()
df.columns = pd.MultiIndex.from_tuples([("spu", i) for i in ["0"] + n_samples] + [("acc", i) for i in ["0"] + n_samples] + [("sem", i) for i in ["0"] + n_samples])
print(base_dset)
df = pd.DataFrame.from_dict(all_results).transpose()
df.columns = pd.MultiIndex.from_tuples([("spu", i) for i in ["0"] + n_samples] + [("acc", i) for i in ["0"] + n_samples] + [("sem", i) for i in ["0"] + n_samples])

opt_name_map = {
    "sgd": "SGD",
    "adam": "Adam",
}

data_avgs, data_sems = [], []
for opt in optimizers:
    temp_avgs, temp_sems = [], []
    for pat in spurious_versions:
        avgs = df.loc[(opt, 0, pat, float("NaN"), 'altResNet20Norm02')][[("spu", 3), ("spu", 10)]].tolist()
        sems = df.loc[(opt, 0, pat, float("NaN"), 'altResNet20Norm02')][[("sem", 3), ("sem", 10)]].tolist()
        temp_avgs += avgs
        temp_sems += sems
    data_avgs.append(temp_avgs)
    data_sems.append(temp_sems)

plt.figure(figsize=(8, 3))
width = 0.3
location = 0.
for i, (opt, avgs, sems) in enumerate(zip(optimizers, data_avgs, data_sems)):
    plt.bar(np.arange(4) + width*i, avgs, width=width, yerr=sems, label=opt_name_map[opt])

plt.legend(["SGD", "Adam"], fontsize=fontsize)
plt.yticks(fontsize=fontsize)
plt.ylabel("spurious score", fontsize=fontsize)
xnames = ["3, S3", "10, S3", "3, R3", "10, R3"]
plt.xticks(np.arange(4) + width / 2, xnames, fontsize=fontsize)
plt.xlabel("# spurious examples and the spurious pattern", fontsize=fontsize)
plt.tight_layout()
plt.savefig(f"figs/bar_charts/difopt_{base_dset}_{tar_clses[0]}.png")
plt.show()
plt.close()
    
data_avgs, data_sems = [], []
opt = "adam"
for grad_clip in grad_clips:
    temp_avgs, temp_sems = [], []
    for pat in spurious_versions:
        avgs = df.loc[(opt, 0, pat, grad_clip, 'altResNet20Norm02')][[("spu", 3), ("spu", 10)]].tolist()
        sems = df.loc[(opt, 0, pat, grad_clip, 'altResNet20Norm02')][[("sem", 3), ("sem", 10)]].tolist()
        temp_avgs += avgs
        temp_sems += sems
    data_avgs.append(temp_avgs)
    data_sems.append(temp_sems)

grad_clip_names = ["w/ clip", "w/o clip"]
plt.figure(figsize=(8, 3))
width = 0.3
location = 0.
for i, (clip_name, avgs, sems) in enumerate(zip(grad_clip_names, data_avgs, data_sems)):
    plt.bar(np.arange(4) + width*i, avgs, width=width, yerr=sems, label=clip_name)

plt.legend(grad_clip_names, fontsize=fontsize)
plt.yticks(fontsize=fontsize)
plt.ylabel("spurious score", fontsize=fontsize)
xnames = ["3, S3", "10, S3", "3, R3", "10, R3"]
plt.xticks(np.arange(4) + width / 2, xnames, fontsize=fontsize)
plt.xlabel("# spurious examples and the spurious pattern", fontsize=fontsize)
plt.tight_layout()
plt.savefig(f"figs/bar_charts/difclip_{base_dset}_{tar_clses[0]}.png")

In [None]:

td = df[[("spu", i) for i in n_samples]]
td.columns = td.columns.droplevel(0)
td.index = td.index.droplevel(4)

td = td.mean(1).unstack(2)
text = td.to_latex(multirow=True, float_format="%.3f")
text = text.replace("0.", ".")
print(text)
td

In [None]:
print(base_dset)
df = pd.DataFrame.from_dict(all_results).transpose()
df.columns = pd.MultiIndex.from_tuples([("spu", i) for i in ["0"] + n_samples] + [("acc", i) for i in ["0"] + n_samples] + [("sem", i) for i in ["0"] + n_samples])
display(df)

td = df[[("spu", i) for i in ["0"] + n_samples]].unstack(0)
td.columns = td.columns.droplevel(0)
td.columns = td.columns.swaplevel(0, 1)
td.index = td.index.droplevel(3)

#text = td[[("sgd", i) for i in n_samples]].to_latex(multirow=True, float_format="%.3f")
#text = td[[("sgd", i) for i in ["0"] + n_samples]].to_latex(multirow=True, float_format="%.3f")
#text = text.replace("0.", ".")
#print(text)
#text = td[[("adam", i) for i in n_samples]].to_latex(multirow=True, float_format="%.3f")
text = td[[("adam", i) for i in ["0"] + n_samples]].to_latex(multirow=True, float_format="%.3f")
text = text.replace("0.", ".")
print(text)

## Different grad clip

In [None]:
print(base_dset)
df = pd.DataFrame.from_dict(all_results).transpose()
df.columns = pd.MultiIndex.from_tuples([("spu", i) for i in ["0"] + n_samples] + [("acc", i) for i in ["0"] + n_samples] + [("sem", i) for i in ["0"] + n_samples])
df = df[[("spu", i) for i in ["0"] + n_samples] + [("sem", i) for i in ["0"] + n_samples]]

for optimizer in ['adam']:
    for tar_cls in [0, 1]:
        for arch in architecures:
            for spu in spurious_versions:
                for grad_clip in grad_clips:
                    plt.errorbar(
                        x=np.arange(len(n_samples) + 1),
                        y=df.loc[(optimizer, tar_cls, spu, float("NaN"), arch)][[("spu", i) for i in ["0"] + n_samples]].tolist(),
                        yerr=df.loc[(optimizer, tar_cls, spu, float("NaN"), arch)][[("sem", i) for i in ["0"] + n_samples]].tolist(),
                        label=f"{spu} {grad_clip}",
                    )
                    plt.errorbar(
                        x=np.arange(len(n_samples) + 1),
                        y=df.loc[(optimizer, tar_cls, spu, 0.1, arch)][[("spu", i) for i in ["0"] + n_samples]].tolist(),
                        yerr=df.loc[(optimizer, tar_cls, spu, 0.1, arch)][[("sem", i) for i in ["0"] + n_samples]].tolist(),
                        label=f"{spu} {grad_clip}",
                    )
            plt.yticks(fontsize=fontsize)
            plt.xticks(ticks=np.arange(len(n_samples) + 1), labels=[0] + n_samples, fontsize=fontsize-2)
            #plt.xscale("linear")
            plt.ylabel("spurious score", fontsize=fontsize)
            plt.xlabel("# spurious examples", fontsize=fontsize)
            #plt.legend(loc=(0.02, 0.61), fontsize=fontsize-2)
            plt.legend(fontsize=fontsize-2)
            plt.tight_layout()
            plt.savefig(f"figs/score_plot/difclip_{base_dset}_{optimizer}_{tar_cls}.png")
            plt.show()
            plt.close()

## Different arch

In [None]:
print(base_dset)
df = pd.DataFrame.from_dict(all_results).transpose()
df.columns = pd.MultiIndex.from_tuples([("spu", i) for i in ["0"] + n_samples] + [("acc", i) for i in ["0"] + n_samples] + [("sem", i) for i in ["0"] + n_samples])
df = df[[("spu", i) for i in ["0"] + n_samples] + [("sem", i) for i in ["0"] + n_samples]]
arch_names = {
    "altResNet20Norm02": "ResNet20",
    "altResNet32Norm02": "ResNet32",
    "altResNet110Norm02": "ResNet110",
    "Vgg16Norm02": "Vgg16",
}

for toptimizer in ['sgd']:
    for tar_cls in [0, 1]:
        for spu in spurious_versions:
            for arch in architecures:
                optimizer = toptimizer if 'Vgg' not in arch else "sgd"
                plt.errorbar(
                    #x=np.arange(len(n_samples) + 1),
                    x=[0] + n_samples,
                    y=df.loc[(optimizer, tar_cls, spu, float("NaN"), arch)][[("spu", i) for i in ["0"] + n_samples]].tolist(),
                    yerr=df.loc[(optimizer, tar_cls, spu, float("NaN"), arch)][[("sem", i) for i in ["0"] + n_samples]].tolist(),
                    label=arch_names[arch],
                )
            #plt.figure(figsize=
            plt.yticks(fontsize=fontsize)
            plt.xticks(ticks=np.arange(len(n_samples) + 1), labels=[0] + n_samples, fontsize=fontsize)
            #plt.xscale("linear")
            plt.xscale("symlog")
            plt.xlim(-0.2, n_samples[-1]+100)
            plt.ylabel("spurious score", fontsize=fontsize)
            plt.xlabel("# spurious examples", fontsize=fontsize)
            plt.legend(fontsize=fontsize-2)
            plt.tight_layout()
            plt.savefig(f"figs/score_plot/difarchs_{base_dset}_{spu}_{optimizer}_{tar_cls}.png")
            plt.show()
            plt.close()

## Score figs

In [None]:
print(base_dset)
df = pd.DataFrame.from_dict(all_results).transpose()
df.columns = pd.MultiIndex.from_tuples([("spu", i) for i in ["0"] + n_samples] + [("acc", i) for i in ["0"] + n_samples] + [("sem", i) for i in ["0"] + n_samples])
df = df[[("spu", i) for i in ["0"] + n_samples] + [("sem", i) for i in ["0"] + n_samples]]
spu_names = ['\\textit{S1}', '\\textit{S2}', '\\textit{S3}', '\\textit{R1}', '\\textit{R2}', '\\textit{R3}', '\\textit{Co}']


arch = "altResNet20Norm02"

for optimizer in ['adam']:
    for tar_cls in [0, 1]:
        for i, spu in enumerate(spurious_versions):
            plt.errorbar(
                #x=np.arange(len(n_samples) + 1),
                x=[0] + n_samples,
                y=df.loc[(optimizer, tar_cls, spu, float("NaN"), arch)][[("spu", i) for i in ["0"] + n_samples]].tolist(),
                yerr=df.loc[(optimizer, tar_cls, spu, float("NaN"), arch)][[("sem", i) for i in ["0"] + n_samples]].tolist(),
                label=spu_names[i],
            )
        plt.yticks(fontsize=fontsize)
        #plt.xticks(ticks=np.arange(len(n_samples) + 1), labels=[0] + n_samples, fontsize=fontsize-2)
        plt.xticks(fontsize=fontsize)
        plt.xscale("symlog")
        plt.xlim(-0.2, n_samples[-1]+100)
        plt.ylabel("spurious score", fontsize=fontsize)
        plt.xlabel("\# spurious examples", fontsize=fontsize)
        plt.legend(loc=(1.02, 0.27), fontsize=fontsize)
        plt.tight_layout()
        if aug is None:
            plt.savefig(f"figs/score_plot/{base_dset}_{arch}_{optimizer}_{tar_cls}_log.png")
        else:
            plt.savefig(f"figs/score_plot/{base_dset}_{aug}_{arch}_{optimizer}_{tar_cls}_log.png")
        plt.show()
        plt.close()

# incremental retraining

In [None]:
def get_retrain_model_path(base_dset, spurious_version, i, tar_cls, seed, arch, momentum, optimizer, lr, folder="incremental_retraining", gi=None, depth=None):
    ds_name = f"{base_dset}{spurious_version}-{i}-{tar_cls}-{seed}"
    if folder == "incremental_retraining":
        if "cifar10" == base_dset:
            model_path = f"../models/{folder}/128-{ds_name}-140-{lr}-aug01-ce-tor-{arch}-128-{ds_name}-70-{lr}-aug01-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0001.pt-{momentum}-{optimizer}-0-0.0001.pt"
        else:
            model_path = f"../models/{folder}/128-{ds_name}-140-{lr}-ce-tor-{arch}-128-{ds_name}-70-{lr}-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt-{momentum}-{optimizer}-0-0.0.pt"
    elif folder == "group_influence":
        if "cifar10" == base_dset:
            if gi is not None and depth is None:
                model_path = f"../models/{folder}/128-{ds_name}-{gi}-aug01-ce-tor-altResNet20Norm02-128-{ds_name}-70-{lr}-aug01-ce-tor-altResNet20Norm02-{momentum}-{optimizer}-0-0.0001.pt-sgd-0.pt"
            elif gi is not None and depth is not None:
                model_path = f"../models/{folder}/128-{ds_name}-{depth}-{gi}-aug01-ce-tor-altResNet20Norm02-128-{ds_name}-70-{lr}-aug01-ce-tor-altResNet20Norm02-{momentum}-{optimizer}-0-0.0001.pt-sgd-0.pt"
            else:
                model_path = f"../models/{folder}/128-{ds_name}-aug01-ce-tor-altResNet20Norm02-128-{ds_name}-70-{lr}-aug01-ce-tor-altResNet20Norm02-{momentum}-{optimizer}-0-0.0001.pt-sgd-0.pt"
            #f"../models/{folder}/64-{ds_name}-140-{lr}-aug01-ce-tor-{arch}-64-{ds_name}-70-{lr}-aug01-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0001.pt-{momentum}-{optimizer}-0-0.0001.pt"
        else:
            if gi is not None and depth is None:
                model_path = f"../models/{folder}/128-{ds_name}-{gi}-ce-tor-{arch}-128-{ds_name}-70-{lr}-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt-sgd-0.pt"
            elif gi is not None and depth is not None:
                model_path = f"../models/{folder}/128-{ds_name}-{depth}-{gi}-ce-tor-{arch}-128-{ds_name}-70-{lr}-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt-sgd-0.pt"
            else:
                model_path = f"../models/{folder}/128-{ds_name}-ce-tor-{arch}-128-{ds_name}-70-{lr}-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt-sgd-0.pt"
    return ds_name, model_path


In [None]:
all_results = {}
ttest_results = {}

base_dset = "mnist"
gi = None
base_dset = "fashion"
gi = 150.
threshold = 1e-1
#n_samples = [3, 5, 10, 20, 100, 2000, 5000]
n_samples = [3, 5, 10, 20, 100]
architectures = ['LargeMLP']

base_dset = "cifar10"
gi = 1000.
threshold = 1e-1
n_samples = [3, 5, 10, 20, 100]
architectures = ['altResNet20Norm02']


for optimizer in ['adam']:
#for optimizer in ['adam']:
    for tar_cls in [0, 1]:
        for spurious_version in ['v20']:
        #for spurious_version in ['v1', 'v3', 'v8']:
            for arch in architectures:
                lr = 0.01
                momentum = 0. if optimizer == 'adam' else 0.9

                for grad_clip in [None]:
                    if base_dset == "cifar10":
                        if grad_clip is None:
                            model_path = f"../models/train_classifier/128-{base_dset}-70-{lr}-aug01-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0001.pt"
                        else:
                            model_path = f"../models/train_classifier/128-{base_dset}-70-{grad_clip}-{lr}-aug01-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0001.pt"
                    else:
                        if grad_clip is None:
                            model_path = f"../models/train_classifier/128-{base_dset}-70-0.01-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt"
                        else:
                            model_path = f"../models/train_classifier/128-{base_dset}-70-{grad_clip}-0.01-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt"

                    mod_tst_preds, tst_preds, tsty = evaluate(base_dset, model_path, arch, spurious_version, 0)
                    ind = (tsty != tar_cls)
                    mod_tst_preds, tst_preds, tsty = mod_tst_preds[ind], tst_preds[ind], tsty[ind]
                    baseline_rv = (mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > threshold
                    #baseline_rv = (((mod_tst_preds - tst_preds)[:, tar_cls] == (mod_tst_preds - tst_preds).max(1)))
                    baseline_sem = sem(baseline_rv)
                    baseline_nos = [baseline_rv.mean()]
                    baseline_acc = (tst_preds.argmax(1) == tsty).mean()
                    baseline_mod_tst_preds = mod_tst_preds

                
                    ret, inc_rets, gi_rets, inc_sems, gi_sems, sems, accs, test_res = [baseline_nos[0]], [baseline_nos[0]], [baseline_nos[0]], [baseline_sem], [baseline_sem], [baseline_sem], [baseline_acc], []
                    for i in n_samples:
                        taccs, tret, tinc_rets, tgi_rets, ttest_res = [], [], [], [], []
                        for seed in range(5):
                            ds_name, inc_model_path = get_retrain_model_path(base_dset, spurious_version, i, tar_cls, seed, arch,
                                                                             momentum, optimizer, lr)
                            wd = 0.0001 if "cifar10" in base_dset else 0.0
                            depth = 200 if "cifar10" in base_dset else None
                            if i == 100 and "fashion" in base_dset:
                                depth = 500
                            _, model_path = get_model_path(base_dset, spurious_version, i, tar_cls, seed, arch, momentum, optimizer, lr, model_seed=0,
                                                          wd=wd, bs=128)
                            _, gi_model_path = get_retrain_model_path(base_dset, spurious_version, i, tar_cls, seed, arch,
                                                                      momentum, optimizer, lr, folder="group_influence",
                                                                      gi=gi, depth=depth)

                            try:
                                mod_tst_preds, tst_preds, tsty = evaluate(ds_name, model_path, arch, spurious_version, seed)
                                rv = (mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > threshold
                                ttest_res.append(ttest_ind(rv, baseline_rv, equal_var=False, alternative="greater")[1])
                                taccs.append((mod_tst_preds.argmax(1) == tar_cls).mean())
                                tret.append(rv.mean())
                            except FileNotFoundError:
                                print(f"missing {model_path}")
                                
                            try:
                                mod_tst_preds, tst_preds, tsty = evaluate(ds_name, inc_model_path, arch, spurious_version, seed)
                                rv = (mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > threshold
                                tinc_rets.append(rv.mean())
                            except FileNotFoundError:
                                print(f"missing {inc_model_path}")
                                
                            try:
                                mod_tst_preds, tst_preds, tsty = evaluate(ds_name, gi_model_path, arch, spurious_version, seed)
                                rv = (mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > threshold
                                tgi_rets.append(rv.mean())
                                gi_acc = (tst_preds.argmax(1) == tsty).mean()
                                if gi_acc < 0.8:
                                    print("Bad acc", gi_acc)
                                    print(gi_model_path)
                            except FileNotFoundError:
                                print(f"missing {gi_model_path}")
                                #ret.append(-1)
                                #accs.append(-1)
                        if taccs:
                            ret.append(np.mean(tret))
                            inc_rets.append(np.mean(tinc_rets))
                            gi_rets.append(np.mean(tgi_rets))
                            inc_sems.append(sem(tinc_rets))
                            gi_sems.append(sem(tgi_rets))
                            sems.append(sem(tret))
                            accs.append(np.mean(taccs))
                            test_res.append(ttest_res[0])
                        else:
                            ret.append(-1)
                            sems.append(-1)
                            inc_rets.append(-1)
                            inc_sems.append(-1)
                            gi_rets.append(-1)
                            gi_sems.append(-1)
                            accs.append(-1)
                            test_res.append(-1)
                            
                    key = (optimizer, tar_cls, arch, spurious_version, grad_clip)
                    ttest_results[key] = test_res
                    all_results[key] = ret + accs + sems + inc_rets + inc_sems + gi_rets + gi_sems

In [None]:
print(base_dset)
df = pd.DataFrame.from_dict(all_results).transpose()
df.columns = pd.MultiIndex.from_tuples([("spu", i) for i in ["0"] + n_samples] \
                                       + [("acc", i) for i in ["0"] + n_samples] \
                                       + [("sem", i) for i in ["0"] + n_samples] \
                                       + [("inc spu", i) for i in ["0"] + n_samples] \
                                       + [("inc sem", i) for i in ["0"] + n_samples] \
                                       + [("gi spu", i) for i in ["0"] + n_samples] \
                                       + [("gi sem", i) for i in ["0"] + n_samples])
df = df[[("spu", i) for i in ["0"] + n_samples] + [("sem", i) for i in ["0"] + n_samples] \
        + [("inc spu", i) for i in ["0"] + n_samples] + [("inc sem", i) for i in ["0"] + n_samples] \
        + [("gi spu", i) for i in ["0"] + n_samples] + [("gi sem", i) for i in ["0"] + n_samples]]

arch = "altResNet20Norm02"
#arch = "LargeMLP"
spurious_versions = ['v20']

for optimizer in ['adam']:
    for tar_cls in [0, 1]:
        for spu in spurious_versions:
            y = df.loc[(optimizer, tar_cls, arch, spu, float("NaN"))][[("spu", i) for i in ["0"] + n_samples]].values
            y = y + 0.02 * np.random.rand(*y.shape)
            yerr = df.loc[(optimizer, tar_cls, arch, spu, float("NaN"))][[("sem", i) for i in ["0"] + n_samples]].values
            plt.errorbar(x=[0] + n_samples, y=y, yerr=yerr, label="original")
            
            y = df.loc[(optimizer, tar_cls, arch, spu, float("NaN"))][[("inc spu", i) for i in ["0"] + n_samples]].values
            y = y + 0.02 * np.random.rand(*y.shape)
            yerr = df.loc[(optimizer, tar_cls, arch, spu, float("NaN"))][[("inc sem", i) for i in ["0"] + n_samples]].values
            plt.errorbar(x=[0] + n_samples, y=y, yerr=yerr, label="retrained")
            
            y = df.loc[(optimizer, tar_cls, arch, spu, float("NaN"))][[("gi spu", i) for i in ["0"] + n_samples]].values
            yerr = df.loc[(optimizer, tar_cls, arch, spu, float("NaN"))][[("gi sem", i) for i in ["0"] + n_samples]].values
            plt.errorbar(x=[0] + n_samples, y=y, yerr=yerr, label="influence")
            
        plt.yticks(fontsize=fontsize)
        plt.xticks(ticks=np.arange(len(n_samples) + 1), labels=[0] + n_samples, fontsize=fontsize)
        plt.xscale("symlog")
        plt.xlim(-0.2, n_samples[-1]+100)
        plt.ylabel("spurious score", fontsize=fontsize)
        plt.xlabel("\# spurious examples", fontsize=fontsize)
        plt.legend(fontsize=fontsize)
        plt.tight_layout()
        plt.savefig(f"figs/score_plot/retrain_{base_dset}_{arch}_{optimizer}_{tar_cls}.png")
        plt.show()
        plt.close()

In [None]:
fontsize

In [None]:
print(base_dset)
df = pd.DataFrame.from_dict(all_results).transpose()
df.columns = pd.MultiIndex.from_tuples([("spu", i) for i in ["0"] + n_samples] + [("acc", i) for i in ["0"] + n_samples] + [("sem", i) for i in ["0"] + n_samples] + [("inc spu", i) for i in ["0"] + n_samples] + [("inc sem", i) for i in ["0"] + n_samples])
display(df)

td = df[[("spu", i) for i in ["0"] + n_samples]].unstack(0)
td.columns = td.columns.droplevel(0)
td.columns = td.columns.swaplevel(0, 1)
td.index = td.index.droplevel(3)
td.index = td.index.droplevel(1)

#text = td[[("sgd", i) for i in ["0"] + n_samples]].to_latex(multirow=True, float_format="%.3f")
#text = text.replace("0.", ".")
#print(text)
text = td[[("adam", i) for i in ["0"] + n_samples]].to_latex(multirow=True, float_format="%.3f")
text = text.replace("0.", ".")
print(text)

## CIFAR10

In [None]:
all_results = {}
#spurious_version = "v4"
n_samples = [3, 5, 10, 20, 500]
for spurious_version in ["v8"]:
    for tar_cls in [0, 1]:
        for arch in ['ResNet50']:
            for optimizer in ['adam', 'sgd']:
                momentum = 0. if optimizer == 'adam' else 0.9
                lr = 0.01 if optimizer == 'adam' else 0.01
                for grad_clip in [None]:
                    ret = []
                    for i in n_samples:
                        ds_name = f"cifar10{spurious_version}-{i}-{tar_cls}-0"
                        if grad_clip is None:
                            model_path = f"../models/train_classifier/64-{ds_name}-70-{lr}-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt"
                        else:
                            model_path = f"../models/train_classifier/64-{ds_name}-70-{grad_clip}-{lr}-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt"
                        try:
                            mod_tst_preds, tst_preds, tsty = evaluate(ds_name, model_path, arch, spurious_version, seed)
                            rr = [((mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > 1e-9).mean()]
                        except FileNotFoundError:
                            print(f"missing {model_path}")
                            rr = [-1]
                        ret.append(rr)
                    all_results[(spurious_version, tar_cls, optimizer, grad_clip, arch)] = [r[0] for r in ret]
df = pd.DataFrame.from_dict(all_results).transpose()
df.columns = n_samples
df

In [None]:
print(base_dset)
df = pd.DataFrame.from_dict(all_results).transpose()
df.columns = pd.MultiIndex.from_tuples([("spu", i) for i in ["0"] + n_samples] + [("acc", i) for i in ["0"] + n_samples] + [("sem", i) for i in ["0"] + n_samples] + [("inc spu", i) for i in ["0"] + n_samples] + [("inc sem", i) for i in ["0"] + n_samples])
df = df[[("spu", i) for i in ["0"] + n_samples] + [("sem", i) for i in ["0"] + n_samples] + [("inc spu", i) for i in ["0"] + n_samples] + [("inc sem", i) for i in ["0"] + n_samples]]

arch = "ResNet50"
spurious_versions = ['v8']

for optimizer in ['adam']:
    for tar_cls in [0, 1]:
        for spu in spurious_versions:
            plt.errorbar(
                x=np.arange(len(n_samples) + 1),
                y=df.loc[(optimizer, tar_cls, arch, spu, float("NaN"))][[("spu", i) for i in ["0"] + n_samples]].tolist(),
                yerr=df.loc[(optimizer, tar_cls, arch, spu, float("NaN"))][[("sem", i) for i in ["0"] + n_samples]].tolist(),
                label="original",
            )
            plt.errorbar(
                x=np.arange(len(n_samples) + 1),
                y=df.loc[(optimizer, tar_cls, arch, spu, float("NaN"))][[("inc spu", i) for i in ["0"] + n_samples]].tolist(),
                yerr=df.loc[(optimizer, tar_cls, arch, spu, float("NaN"))][[("inc sem", i) for i in ["0"] + n_samples]].tolist(),
                label="retrained",
            )
        plt.yticks(fontsize=fontsize)
        plt.xticks(ticks=np.arange(len(n_samples) + 1), labels=[0] + n_samples, fontsize=fontsize-2)
        #plt.xscale("linear")
        plt.ylabel("spurious score", fontsize=fontsize)
        plt.xlabel("# spurious examples", fontsize=fontsize)
        plt.legend(loc=(1.02, 0.27), fontsize=fontsize)
        plt.tight_layout()
        plt.savefig(f"figs/score_plot/retrain_cifar10_{arch}_{optimizer}_{tar_cls}.png")
        plt.show()
        plt.close()

In [None]:
ori_trnX, _, _, _, _ = auto_var.get_var_with_argument("dataset", "mnist")

n_samples = [5, 10, 20, 100,]
for spurious_version in ["v1", 'v3', "v6"]:
    for tar_cls in [0, 1]:
        for arch in ['Vgg16']:
            for optimizer in ['sgd', 'adam']:
                ret = []
                for i in n_samples:
                    ds_name = f"cifar10{spurious_version}-{i}-{tar_cls}-0"
                    trnX, trny, tstX, tsty, spurious_ind = auto_var.get_var_with_argument("dataset", ds_name)
                    oriX = ori_trnX[spurious_ind].reshape(len(spurious_ind), -1)
                    spuriousX = trnX[spurious_ind].reshape(len(spurious_ind), -1)
                    ret.append(np.linalg.norm(oriX-spuriousX, ord=2, axis=1).mean())
    ret

# Group influence

In [None]:
def get_influence_model_path(base_dset, spurious_version, i, tar_cls, seed, arch, momentum, optimizer, lr):
    ds_name = f"{base_dset}{spurious_version}-{i}-{tar_cls}-{seed}"
    if "cifar10" == base_dset:
        model_path = f"../models/group_influence/64-{ds_name}-ce-tor-{arch}-64-{ds_name}-70-{lr}-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt-{optimizer}-0.pt"
    else:
        model_path = f"../models/group_influence/128-{ds_name}-ce-tor-{arch}-128-{ds_name}-70-{lr}-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt-{optimizer}-0.pt"
    return ds_name, model_path

In [None]:
all_results = {}
ttest_results = {}

base_dset = "mnist"
#base_dset = "fashion"
threshold = 1e-1
n_samples = [3, 5, 10]
architectures = ['LargeMLP']

#base_dset = "cifar10"
#threshold = 1e-9
#n_samples = [3, 5, 10, 20, 500]
#architectures = ['ResNet50']


for optimizer in ['sgd', 'adam']:
#for optimizer in ['adam']:
    for tar_cls in [0, 1]:
        for spurious_version in ['v8']:
        #for spurious_version in ['v1', 'v3', 'v8']:
            for arch in architectures:
                lr = 0.01
                momentum = 0. if optimizer == 'adam' else 0.9

                for grad_clip in [None]:
                    if base_dset == "cifar10":
                        if grad_clip is None:
                            model_path = f"../models/train_classifier/64-{base_dset}-70-{lr}-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt"
                        else:
                            model_path = f"../models/train_classifier/64-{base_dset}-70-{grad_clip}-{lr}-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt"
                    else:
                        if grad_clip is None:
                            model_path = f"../models/train_classifier/128-{base_dset}-70-0.01-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt"
                        else:
                            model_path = f"../models/train_classifier/128-{base_dset}-70-{grad_clip}-0.01-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt"

                    mod_tst_preds, tst_preds, tsty = evaluate(base_dset, model_path, arch, spurious_version, 0)
                    ind = (tsty != tar_cls)
                    mod_tst_preds, tst_preds, tsty = mod_tst_preds[ind], tst_preds[ind], tsty[ind]
                    baseline_rv = (mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > threshold
                    #baseline_rv = (((mod_tst_preds - tst_preds)[:, tar_cls] == (mod_tst_preds - tst_preds).max(1)))
                    baseline_sem = sem(baseline_rv)
                    baseline_nos = [baseline_rv.mean()]
                    baseline_acc = (tst_preds.argmax(1) == tsty).mean()
                    baseline_mod_tst_preds = mod_tst_preds

                
                    ret, inc_rets, inc_sems, sems, accs, test_res = [baseline_nos[0]], [baseline_nos[0]], [baseline_sem], [baseline_sem], [baseline_acc], []
                    for i in n_samples:
                        taccs, tret, tinc_rets, ttest_res = [], [], [], []
                        for seed in range(5):
                            ds_name, inc_model_path = get_influence_model_path(base_dset, spurious_version, i, tar_cls, seed, arch, momentum, optimizer, lr)
                            _, model_path = get_model_path(base_dset, spurious_version, i, tar_cls, seed, arch, momentum, optimizer, lr, model_seed=0)

                            try:
                                mod_tst_preds, tst_preds, tsty = evaluate(ds_name, model_path, arch, spurious_version, seed)
                                #ind = (tsty != tar_cls)
                                #mod_tst_preds, tst_preds, tsty = mod_tst_preds[ind], tst_preds[ind], tsty[ind]
                                rv = (mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > threshold
                                ttest_res.append(ttest_ind(rv, baseline_rv, equal_var=False, alternative="greater")[1])
                                taccs.append((mod_tst_preds.argmax(1) == tar_cls).mean())
                                tret.append(rv.mean())
                                
                                mod_tst_preds, tst_preds, tsty = evaluate(ds_name, inc_model_path, arch, spurious_version, seed)
                                rv = (mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > threshold
                                tinc_rets.append(rv.mean())
                            except FileNotFoundError:
                                print(f"missing {inc_model_path}")
                                #ret.append(-1)
                                #accs.append(-1)
                        if taccs:
                            ret.append(np.mean(tret))
                            inc_rets.append(np.mean(tinc_rets))
                            inc_sems.append(sem(tinc_rets))
                            sems.append(sem(tret))
                            accs.append(np.mean(taccs))
                            test_res.append(ttest_res[0])
                        else:
                            ret.append(-1)
                            sems.append(-1)
                            inc_rets.append(-1)
                            inc_sems.append(-1)
                            accs.append(-1)
                            test_res.append(-1)
                            
                    key = (optimizer, tar_cls, arch, spurious_version, grad_clip)
                    ttest_results[key] = test_res
                    all_results[key] = ret + accs + sems + inc_rets + inc_sems

# Visualize spurious patterns

In [None]:
base_dset = "mnist"

_, _, _, _, spurious_ind = auto_var.get_var_with_argument("dataset", f"{base_dset}v1-5-0-0")
trnX, _, _, _, _ = auto_var.get_var_with_argument("dataset", base_dset)
plt.figure(figsize=(6, 6))
plt.imshow(trnX[spurious_ind[0], :, :, 0], vmin=0, vmax=1,)
plt.axis('off')
plt.tight_layout()
plt.savefig(f"./figs/spu_examples/{base_dset}.png", bbox_inches='tight')
#plt.imshow(np.abs(fft2(trnX[spurious_ind[0], :, :, 0])))
#plt.show()

for spurious_version in ["v1", 'v3', "v8", "v9", 'v10', 'v11', "vgau1", "vgau2", "v18", "v19", "v20", "v30"]:
    print(spurious_version)

    ds_name = f"{base_dset}{spurious_version}-5-0-0"
    
    plt.figure(figsize=(6, 6))
    trnX, trny, tstX, tsty, spurious_ind = auto_var.get_var_with_argument("dataset", ds_name)
    plt.imshow(trnX[spurious_ind[0], :, :, 0], vmin=0, vmax=1,)
    plt.axis('off')
    plt.tight_layout()
    plt.savefig(f"./figs/spu_examples/{ds_name}.png", bbox_inches='tight')
    plt.show()
    plt.imshow(np.abs(fft2(trnX[spurious_ind[0], :, :, 0])))
    plt.show()


In [None]:
trnX, _, _, _, _ = auto_var.get_var_with_argument("dataset", "mnist")
trnX = np.copy(trnX)
ttt = add_spurious_correlation(trnX[0:1], "v20", 0)[0, :, :]
plt.imshow(ttt)

In [None]:
trnX2[spurious_ind[0], 0, 0, 0] = 0
plt.imshow(trnX[spurious_ind[0], :, :, 0] - trnX2[spurious_ind[0], :, :, 0])
plt.colorbar()

In [None]:
(trnX[spurious_ind[0], :, :, 0] - trnX2[spurious_ind[0], :, :, 0])[0, 0]

In [None]:
# select a set of background examples to take an expectation over
background = trnX[np.random.choice(trnX.shape[0], 200, replace=False)]
background = background.transpose(0, 3, 1, 2)

test_examples = torch.from_numpy(trnX[spurious_ind[:3]].transpose(0, 3, 1, 2)).float()

e = shap.GradientExplainer(model.to("cpu"), torch.from_numpy(background))
shap_values = e.shap_values(test_examples)
shap_numpy = [np.swapaxes(np.swapaxes(s, 1, -1), 1, 2) for s in shap_values]
test_numpy = test_examples.numpy().transpose(0, 2, 3, 1)
# plot the feature attributions
shap.image_plot(shap_numpy, -test_numpy)

In [None]:
test_numpy[0].shape

In [None]:
batch_predict = partial(pred_fn, model=model)
explainer = lime_image.LimeImageExplainer()
explanation = explainer.explain_instance(trnX[2].astype("double"), 
                                         batch_predict, # classification function
                                         top_labels=1, 
                                         hide_color=0, 
                                         num_samples=1000)

In [None]:
trnX, trny, tstX, tsty, spurious_ind = auto_var.get_var_with_argument("dataset", "cifar10v1-3-0-0")

In [None]:
plt.imshow(trnX[0])

In [None]:
trny[0]

In [None]:
plt.imshow(trnX[0])

In [None]:
clf = IsolationForest(random_state=0, n_jobs=12).fit(trn_features)

In [None]:
pred = clf.score_samples(trn_features)

In [None]:
trnX[spurious_ind]

# MLP weights

In [None]:
ds_names = ["mnist", "mnistv8-10-0-0", "mnistv8-100-0-0", "mnistv8-2000-0-0"]
titles = ["(a) 0", "(b) 10", "(c) 100", "(d) 2000"]
arch = "LargeMLP"

fig, axes = plt.subplots(1, 4, sharex=True, sharey=True, figsize=(12,4))

for ds_name, ax, title in zip(ds_names, axes, titles):
    model_path = f"../models/train_classifier/128-{ds_name}-70-0.01-ce-tor-{arch}-0.0-adam-0-0.0.pt"

    trnX, trny, tstX, tsty, spurious_ind = auto_var.get_var_with_argument("dataset", ds_name)
    n_classes = len(np.unique(trny))
    n_channels = trnX.shape[-1]
    res = torch.load(model_path)
    model = getattr(archs, arch)(n_features=(np.prod(trnX.shape[1:]), ), n_channels=n_channels, n_classes=n_classes)
    model.load_state_dict(res['model_state_dict'])
    
    ax.set_title(title, fontsize=fontsize, y=-0.24)
    im = ax.imshow(model.hidden.weight.max(0)[0].reshape(28, 28).detach().numpy(), vmin=-0.5, vmax=6.5)
    im.set_cmap('YlOrRd')
    ax.axis('off')
fig.subplots_adjust(bottom=0.3, right=0.8)
cbar_ax = fig.add_axes([0.82, 0.3, 0.04, 0.5])
cbar = plt.colorbar(im, cax=cbar_ax)
cbar.ax.tick_params(labelsize=fontsize)
plt.savefig(f"./figs/mlp_weights/mnist_mlp_weights.png", bbox_inches='tight')
plt.show()

In [None]:
size

In [None]:
model.hidden.weight.max(0)[0].max()

# Density

## gradient clipping

In [None]:
print(base_dset)
df = pd.DataFrame.from_dict(all_results).transpose()
df.columns = pd.MultiIndex.from_tuples([("spu", i) for i in ["0"] + n_samples] + [("acc", i) for i in ["0"] + n_samples] + [("sem", i) for i in ["0"] + n_samples])
df = df[[("spu", i) for i in ["0"] + n_samples] + [("sem", i) for i in ["0"] + n_samples]]

for optimizer in ['adam']:
    for tar_cls in [0, 1]:
        for arch in architecures:
            for spu in spurious_versions:
                for grad_clip in grad_clips:
                    plt.errorbar(
                        x=np.arange(len(n_samples) + 1),
                        y=df.loc[(optimizer, tar_cls, spu, arch, float("NaN"))][[("spu", i) for i in ["0"] + n_samples]].tolist(),
                        yerr=df.loc[(optimizer, tar_cls, spu, arch, float("NaN"))][[("sem", i) for i in ["0"] + n_samples]].tolist(),
                        label=f"{spu} {grad_clip}",
                    )
                    plt.errorbar(
                        x=np.arange(len(n_samples) + 1),
                        y=df.loc[(optimizer, tar_cls, spu, arch, 0.1)][[("spu", i) for i in ["0"] + n_samples]].tolist(),
                        yerr=df.loc[(optimizer, tar_cls, spu, arch, 0.1)][[("sem", i) for i in ["0"] + n_samples]].tolist(),
                        label=f"{spu} {grad_clip}",
                    )
            plt.yticks(fontsize=fontsize)
            plt.xticks(ticks=np.arange(len(n_samples) + 1), labels=[0] + n_samples, fontsize=fontsize-2)
            #plt.xscale("linear")
            plt.ylabel("spurious score", fontsize=fontsize)
            plt.xlabel("# spurious examples", fontsize=fontsize)
            #plt.legend(loc=(0.02, 0.61), fontsize=fontsize-2)
            plt.legend(fontsize=fontsize-2)
            plt.tight_layout()
            plt.savefig(f"figs/score_plot/difclip_{base_dset}_{optimizer}_{tar_cls}.png")
            plt.show()
            plt.close()

In [None]:
import faiss

In [None]:
def get_dists(trnX):
    index = faiss.IndexFlatL2(trnX.shape[1])
    nonspu_ind = np.arange(len(trnX))
    nonspu_ind = np.delete(nonspu_ind, [spurious_ind])
    index.add(trnX[nonspu_ind])
    D, I = index.search(trnX[spurious_ind], 21)
    D = np.sqrt(D)
    return D.reshape(-1)

In [None]:
n_spu = 10
seed = 0

data = {}
for base_dset in ['mnist', 'fashion']:
    for tar_cls in [0, 1]:
        for spurious_version in ['v3', 'v9', "vgau1", 'v10', 'v11']:
            ds_name = f"{base_dset}{spurious_version}-{n_spu}-{tar_cls}-{seed}"
            trnX, trny, tstX, tsty, spurious_ind = auto_var.get_var_with_argument("dataset", ds_name)
            trnX = trnX.reshape(len(trnX), -1)
            avg_dist = get_dists(trnX).mean()

            keys = (base_dset, tar_cls, spurious_version)
            data[keys] = avg_dist

In [None]:
data

In [None]:
pd.DataFrame.from_dict(data)

In [None]:
D.reshape(-1).mean()  # mnist

In [None]:
kde = FFTKDE(kernel="gaussian").fit(trnX)

In [None]:
kde = KernelDensity(kernel='gaussian', bandwidth=1.0).fit(trnX)

In [None]:
scores = kde.score_samples(trnX[:50])