In [None]:
%reload_ext autoreload
%autoreload 2

import os
import numpy as np
import torch
from types import SimpleNamespace
from torch.utils.data import ConcatDataset, Subset
from sklearn.linear_model import LogisticRegression
import seaborn as sns

import matplotlib.lines as mlines
from matplotlib import pyplot as plt
%matplotlib inline

plt.rcParams.update({
    "text.usetex": True,
    "font.family": "Latin Modern Roman",
    "font.size": 20
})

from dataloaders import make_datasets
from grood import GROOD
from eval.utils import get_all_features_cached, cross_val_linprobe_C

In [None]:
cd to the code directory if the imports does not work
%cd "<code_directory>"

# 0 - a) mnist-6 C-0 vs mnist-4
# 1 - b) cifar10 C-0 vs tinyimagenet
# 2 - c) cifar10 C-1 vs SVHN
exp_id = 0

# set paths to the experiments, the presented combinations were used to generate the figure in the paper
figure_1 = [
{"exp_dir":"./_out/experiments/5r_6v4/<my_method>/mnist/<one of dirs>/", "test":"mnist", "c":0},  
{"exp_dir":"./_out/experiments/<my_method>/", "test":"places365", "c":2},   
{"exp_dir":"./_out/experiments/<my_method/", "test":"svhn", "c":1}]

args = SimpleNamespace(
            exp_dir=figure_1[exp_id]["exp_dir"],
            latest_checkpoint=True,
       )
extra_cfg = []
plt.close("all")

In [None]:
grood = GROOD(exp_dir = args.exp_dir, eval_last_checkpoint=bool(args.latest_checkpoint))
cfg = grood.cfg
cfg.merge_from_list(extra_cfg)
cfg.EXPERIMENT.RESULT_DIR = os.path.join(args.exp_dir, "results")

In [None]:
cfg.DATASET.TEST=figure_1[exp_id]["test"]

train_set, val_set, test_set, inverse_class_map = make_datasets(cfg)
train_loader = torch.utils.data.DataLoader(ConcatDataset([train_set, val_set]), batch_size=cfg.INPUT.BATCH_SIZE, shuffle=False)

targets = test_set.targets if hasattr(test_set, "targets") else test_set.labels

id_indices = [idx for idx, target in enumerate(targets) if target in cfg.DATASET.SELECTED_LABELS]
ood_indices = [idx for idx, target in enumerate(targets) if target not in cfg.DATASET.SELECTED_LABELS]

# Sanity checks
id_mask = np.zeros(len(targets), dtype=int)
id_mask[id_indices] = 1 
ood_mask = np.zeros(len(targets), dtype=int)
ood_mask[ood_indices] = 1 
assert np.sum(id_mask*ood_mask) == 0, "Data selection error, there is an overlap between ID and OOD classes!"
assert np.sum((id_mask+ood_mask) > 0) == len(targets), "Data selection error, not all data were selected!"

test_loader_ood = torch.utils.data.DataLoader(Subset(test_set, ood_indices), batch_size=cfg.INPUT.BATCH_SIZE, shuffle=False)
test_loader_id = torch.utils.data.DataLoader(Subset(test_set, id_indices), batch_size=cfg.INPUT.BATCH_SIZE, shuffle=False)

In [None]:
train_features_all, train_labels_all, test_features_id, test_labels_id, test_features_ood, test_labels_ood = get_all_features_cached(
            cfg, grood, train_loader, test_loader_id, test_loader_ood)

train_features = []
train_labels = []
val_features = []
val_labels = []
for c in range(cfg.MODEL.NUM_CLASSES):
    class_idx = np.nonzero(train_labels_all==c)[0]
    split_size = int(cfg.DATASET.VAL_FRACTION*class_idx.shape[0])
    indices = np.random.RandomState(seed=cfg.SYSTEM.RNG_SEED).permutation(class_idx)
    train_features.append(train_features_all[indices[split_size:],...])
    train_labels.append(train_labels_all[indices[split_size:],...])
    val_features.append(train_features_all[indices[:split_size],...])
    val_labels.append(train_labels_all[indices[:split_size],...])

train_features = np.concatenate(train_features, axis=0)
train_labels = np.concatenate(train_labels, axis=0)
val_features = np.concatenate(val_features, axis=0)
val_labels = np.concatenate(val_labels, axis=0)

In [None]:
best_C = cross_val_linprobe_C(train_features, train_labels, val_features, val_labels)
print (f"Computing linear probe for C {best_C:0.3f}.")
classifier = LogisticRegression(random_state=0, C=best_C, max_iter=1000, verbose=1, n_jobs=8)
classifier.fit(train_features, train_labels)

In [None]:
def l2_dist(x, y):
    return 1./(1 + np.sqrt(np.sum(np.power(x - y, 2), axis=-1)))    

logits_id = classifier.decision_function(test_features_id)
logits_ood = classifier.decision_function(test_features_ood)
logits_val = classifier.decision_function(val_features)

mean_id_l2 = np.zeros_like(logits_id)
mean_ood_l2 = np.zeros_like(logits_ood)
mean_val_l2 = np.zeros_like(logits_val)
for c in range(0, cfg.MODEL.NUM_CLASSES):
    mask = train_labels == c
    class_center = np.mean(train_features[mask, :], axis=0)
    mean_id_l2[:, c] = l2_dist(test_features_id, class_center[None, ...])
    mean_ood_l2[:, c] = l2_dist(test_features_ood, class_center[None, ...])
    mean_val_l2[:, c] = l2_dist(val_features, class_center[None, ...])

In [None]:
c = figure_1[exp_id]["c"]
xs = []
ys = []
hues = []

xs_s = []
ys_s = []
hues_s = []

if figure_1[exp_id]["test"] == cfg.DATASET.TRAIN:
    class_name = f"\\textbf{{ID: {cfg.DATASET.TRAIN}(6c) C-{c}}}"
    ood_name = f"\\textbf{{OOD: {cfg.DATASET.TEST}(4c)}}"
else:
    class_name = f"\\textbf{{ID: {cfg.DATASET.TRAIN} C-{c}}}"
    ood_name = f"\\textbf{{OOD: {cfg.DATASET.TEST}}}"

mask = val_labels == c
xs.append(logits_val[mask, c])
ys.append(mean_val_l2[mask, c])
hues.append([class_name]*np.sum(mask))

xs.append(logits_ood[::5, c])
ys.append(mean_ood_l2[::5, c])
hues.append([ood_name]*mean_ood_l2[::5,c].shape[0])

xs = np.concatenate(xs, axis=0)
ys = np.concatenate(ys, axis=0)
hues = np.concatenate(hues, axis=0)

color_p = sns.color_palette()

a = sns.jointplot(x=xs, y=ys, hue=hues, kind="scatter", marginal_kws={"common_norm": False, "fill": True}, 
                legend=False,
                 palette={
                    class_name: list(color_p[0]) + [1],
                    ood_name: list(color_p[1]) + [1],
                 })

kde_ax = sns.kdeplot(x=logits_val[mask, c], y=mean_val_l2[mask, c], cmap="Blues", 
                     ax=a.ax_joint, legend=False, alpha=0.75, thresh=0.05, levels=20, shade=True)
kde_ax = sns.kdeplot(x=logits_ood[:, c], y=mean_ood_l2[:, c], cmap="Oranges", 
                     ax=a.ax_joint, legend=False, alpha=0.75, thresh=0.07, levels=20, shade=True)


a.ax_joint.set_ylabel("NM L2 similarity")
a.ax_joint.set_xlabel("LP logits")

# -------------- LEGEND -----------------
id_points = mlines.Line2D([], [], color=list(color_p[0]) + [0.5], marker='o', linestyle="None",
                          markersize=10, label=class_name)
ood_points = mlines.Line2D([], [], color=list(color_p[1]) + [0.5], marker='o', linestyle="None",
                          markersize=10, label=ood_name)

l_handles = [id_points, ood_points]
#font = matplotlib.font_manager.FontProperties(family='Latin Modern Roman', weight="bold", style='normal', size=18)
a.ax_joint.legend(handles=l_handles, loc="upper left", prop=dict(weight='bold'))

a.ax_joint.set_ylim([np.min(mean_ood_l2[:, c]), np.max(mean_val_l2[mask, c])])
a.ax_joint.set_xlim([np.min(logits_ood[:, c]), np.max(logits_val[mask, c])])

# -------------- SAVING -----------------

plt.tight_layout()
plt.gcf().set_size_inches(10, 10)
plt.savefig(f"./_out/figure1_exp_id_{exp_id}.pdf", dpi=300)
plt.show()