In [None]:
%reload_ext autoreload
%autoreload 2

import os
import numpy as np
import torch
from types import SimpleNamespace
from einops import rearrange
from torch.utils.data import ConcatDataset, Subset
from sklearn.linear_model import LogisticRegression
import seaborn as sns
from scipy.stats import multivariate_normal, norm

import matplotlib.lines as mlines
from matplotlib import pyplot as plt
%matplotlib inline

plt.rcParams.update({
    "text.usetex": True,
    "font.family": "Latin Modern Roman",
    "font.size": 20
})

from dataloaders import make_datasets
from grood import GROOD
from eval.utils import get_all_features_cached, cross_val_linprobe_C

In [None]:
cd to the code directory if the imports does not work
%cd "<code_directory>"

args = SimpleNamespace(
            exp_dir="./_out/experiments/<my_method>/",
            latest_checkpoint=True,
       )
extra_cfg = []
plt.close("all")

In [None]:
grood = GROOD(exp_dir = args.exp_dir, eval_last_checkpoint=bool(args.latest_checkpoint))
cfg = grood.cfg

cfg.merge_from_list(extra_cfg)
cfg.EXPERIMENT.RESULT_DIR = os.path.join(args.exp_dir, "results")
 
cfg.DATASET.TEST="tinyimagenet"

In [None]:
train_set, val_set, test_set, inverse_class_map = make_datasets(cfg)
train_loader = torch.utils.data.DataLoader(ConcatDataset([train_set, val_set]), batch_size=cfg.INPUT.BATCH_SIZE, shuffle=False)

targets = test_set.targets if hasattr(test_set, "targets") else test_set.labels

id_indices = [idx for idx, target in enumerate(targets) if target in cfg.DATASET.SELECTED_LABELS]
ood_indices = [idx for idx, target in enumerate(targets) if target not in cfg.DATASET.SELECTED_LABELS]

test_loader_ood = torch.utils.data.DataLoader(Subset(test_set, ood_indices), batch_size=cfg.INPUT.BATCH_SIZE, shuffle=False)
test_loader_id = torch.utils.data.DataLoader(Subset(test_set, id_indices), batch_size=cfg.INPUT.BATCH_SIZE, shuffle=False)

In [None]:
train_features_all, train_labels_all, test_features_id, test_labels_id, test_features_ood, test_labels_ood = get_all_features_cached(
            cfg, grood, train_loader, test_loader_id, test_loader_ood)

train_features = []
train_labels = []
val_features = []
val_labels = []
for c in range(cfg.MODEL.NUM_CLASSES):
    class_idx = np.nonzero(train_labels_all==c)[0]
    split_size = int(cfg.DATASET.VAL_FRACTION*class_idx.shape[0])
    indices = np.random.RandomState(seed=cfg.SYSTEM.RNG_SEED).permutation(class_idx)
    train_features.append(train_features_all[indices[split_size:],...])
    train_labels.append(train_labels_all[indices[split_size:],...])
    val_features.append(train_features_all[indices[:split_size],...])
    val_labels.append(train_labels_all[indices[:split_size],...])

train_features = np.concatenate(train_features, axis=0)
train_labels = np.concatenate(train_labels, axis=0)
val_features = np.concatenate(val_features, axis=0)
val_labels = np.concatenate(val_labels, axis=0)

In [None]:
best_C = cross_val_linprobe_C(train_features, train_labels, val_features, val_labels)

train_features = train_features_all
train_labels = train_labels_all
val_features = train_features_all
val_labels = train_labels_all
    
print (f"Computing linear probe for C {best_C:0.3f}.")
classifier = LogisticRegression(random_state=0, C=best_C, max_iter=1000, verbose=1, n_jobs=8)
classifier.fit(train_features, train_labels)

In [None]:
logits_id = classifier.decision_function(test_features_id)
logits_ood = classifier.decision_function(test_features_ood)
logits_val = classifier.decision_function(val_features)

In [None]:
fig, ax = plt.subplots()
for c in range(0, cfg.MODEL.NUM_CLASSES):
    a = sns.kdeplot(logits_val[val_labels==c, c], label=f"\\textbf{{{cfg.DATASET.TRAIN} C-{c}}}", lw=4, cumulative=True)
plt.plot([10, 10], [0, 1.0], ls="--", color='k', lw=3)
plt.xlabel("logits")
plt.ylabel("cumulative density")
lgd = plt.legend(loc='lower center', bbox_to_anchor=(0.5, -0.40), ncol=5, fancybox=True, shadow=True, fontsize=14)
plt.gcf().set_size_inches(16, 7)
plt.savefig("./_out/miscalibration.pdf", dpi=300, bbox_extra_artists=(lgd,), bbox_inches='tight')
plt.show()

In [None]:
def l2_dist(x, y):
    return 1./(1 + np.sqrt(np.sum(np.power(x - y, 2), axis=-1)))    

mean_id_l2 = np.zeros_like(logits_id)
mean_ood_l2 = np.zeros_like(logits_ood)
mean_val_l2 = np.zeros_like(logits_val)
for c in range(0, cfg.MODEL.NUM_CLASSES):
    mask = train_labels == c
    class_center = np.mean(train_features[mask, :], axis=0)
    mean_id_l2[:, c] = l2_dist(test_features_id, class_center[None, ...])
    mean_ood_l2[:, c] = l2_dist(test_features_ood, class_center[None, ...])
    mean_val_l2[:, c] = l2_dist(val_features, class_center[None, ...])

In [None]:
class_means = []
class_covs = []
ood_guest_means = []
ood_guest_covs = []
ood_true_means = []
ood_true_covs = []

r = 2000
grid_data = np.meshgrid(np.linspace(-np.max(logits_val) - 5, np.max(logits_val) + 5, num=r), np.linspace(0.0, 1, num=r))
grid_data_rearrange = rearrange(np.stack(grid_data, axis=-1), "r c d -> (r c) d")

simulated_data = np.stack([norm.rvs(loc=0, scale=np.quantile(logits_val, 0.9)/4.0, size=10000), 
                           norm.rvs(loc=0, scale=np.quantile(mean_val_l2, 0.9)/8.0, size=10000)], axis=1)
s_cov = np.cov(simulated_data.T)

for c in range(0, cfg.MODEL.NUM_CLASSES):
    mask = val_labels == c
    # [B, 2]
    val_space = np.stack([logits_val[mask, c], mean_val_l2[mask, c]], axis=1)
    # [2]
    class_means.append(np.mean(val_space, axis=0))
    # [2, 2]
    class_covs.append(np.cov(val_space.T))

    # [B, 2]
    ood_space = np.stack([logits_ood[:, c], mean_ood_l2[:, c]], axis=1)
    # [2]
    ood_true_means.append(np.mean(ood_space, axis=0))
    # [2, 2]
    ood_true_covs.append(np.cov(ood_space.T))

    ood_guest_means.append(np.array([0, 0]))
    ood_guest_covs.append(s_cov.copy())

In [None]:
c = 9
xs = []
ys = []
hues = []

xs_s = []
ys_s = []
hues_s = []

class_name = f"\\textbf{{{cfg.DATASET.TRAIN} C-{c}}}"
ood_name = f"\\textbf{{OOD {cfg.DATASET.TEST}}}"
simulated_id_name = "$\mathbf{p(x|\mathcal{I})}$"
simulated_ood_name = "$\mathbf{p(x|\mathcal{O})}$"

mask = val_labels == c
xs.append(logits_val[mask, c][::3])
ys.append(mean_val_l2[mask, c][::3])
hues.append([class_name]*(mean_val_l2[mask, c][::3]).shape[0])

xs.append(logits_ood[::5, c])
ys.append(mean_ood_l2[::5, c])
hues.append([ood_name]*mean_ood_l2[::5,c].shape[0])

id_simulated = multivariate_normal.rvs(mean=class_means[c], cov=class_covs[c], size=10000, random_state=42)
xs_s.append(id_simulated[:, 0])
ys_s.append(id_simulated[:, 1])
hues_s.append([simulated_id_name]*id_simulated.shape[0])

ood_simulated = multivariate_normal.rvs(mean=ood_guest_means[c], cov=ood_guest_covs[c], size=10000, random_state=42)
xs_s.append(ood_simulated[:, 0])
ys_s.append(ood_simulated[:, 1])
hues_s.append([simulated_ood_name]*ood_simulated.shape[0])

xs = np.concatenate(xs, axis=0)
ys = np.concatenate(ys, axis=0)
hues = np.concatenate(hues, axis=0)

xs_s = np.concatenate(xs_s, axis=0)
ys_s = np.concatenate(ys_s, axis=0)
hues_s = np.concatenate(hues_s, axis=0)

color_p = sns.color_palette()

a = sns.jointplot(x=xs, y=ys, hue=hues, kind="scatter", marginal_kws={"common_norm": False, "fill": True}, legend=False,
                 palette={
                    class_name: list(color_p[2]) + [0.5],
                    ood_name: list(color_p[3]) + [0.5],
                 })

l_handles = [x for x in a.ax_joint.get_legend_handles_labels()[0]]
l_labels = [x for x in a.ax_joint.get_legend_handles_labels()[1]]

kde_ax = sns.kdeplot(x=xs_s, y=ys_s, hue=hues_s, ax=a.ax_joint, bw_adjust=2.5, legend=False, alpha=0.5, 
                     thresh=0.05, levels=8, common_norm=False,
           palette={
                    simulated_ood_name: list(color_p[7]) + [1],
                    simulated_id_name: list(color_p[2]) + [1] 
                 })

# -------------- ISOLINES -----------------
r = 1000
grid_data = np.meshgrid(np.linspace(-10, 19, num=r), np.linspace(0.035, np.max(mean_val_l2[mask, c]) - 0.002 , num=r))
grid_data_rearrange = rearrange(np.stack(grid_data, axis=-1), "r c d -> (r c) d")
p_c = rearrange(multivariate_normal.pdf(grid_data_rearrange, mean=class_means[c], cov=class_covs[c], allow_singular=True), "(r c) -> r c", r=r)
p_ood = rearrange(multivariate_normal.pdf(grid_data_rearrange, mean=ood_guest_means[c], cov=ood_guest_covs[c], allow_singular=True), "(r c) -> r c", r=r)
p_ood_true = rearrange(multivariate_normal.pdf(grid_data_rearrange, mean=ood_true_means[c], cov=ood_true_covs[c], allow_singular=True), "(r c) -> r c", r=r)

thrs = [1e1, 1e9, 1e22]
CS1 = a.ax_joint.contour(grid_data[0], grid_data[1], p_c/p_ood, levels=thrs, 
                   colors='k', linestyles="dashed", linewidths=3, alpha=0.5)

def compute_sampled_error(px1, rx, thr):
    mask = rx <= thr
    return np.sum(px1[mask]) / np.sum(px1)

fmt = {}
strs = [f"{compute_sampled_error(p_c, p_c/p_ood, thr)*100:0.2f}\%" for thr in thrs]
for l, s in zip(CS1.levels, strs):
    fmt[l] = s

# Label every other level using strings
kde_ax.clabel(CS1, CS1.levels, fmt=fmt, fontsize=24, inline_spacing=-17, colors='k')

# -------------- LEGEND -----------------
a.ax_joint.set_ylabel("NM L2 similarity")
a.ax_joint.set_xlabel("LP logits")
a.ax_joint.set_ylim([0.02, np.max(mean_val_l2[mask, c])+0.005])
a.ax_joint.set_xlim([np.min(logits_ood[:, c])-3, np.max(logits_val[mask, c])+2])
#sns.move_legend(a.ax_joint, "upper left")

id_points = mlines.Line2D([], [], color=list(color_p[2]) + [0.5], marker='o', linestyle="None",
                          markersize=10, label=class_name)
ood_points = mlines.Line2D([], [], color=list(color_p[3]) + [0.5], marker='o', linestyle="None",
                          markersize=10, label=ood_name)

id_line = mlines.Line2D([], [], color=color_p[2],
                          markersize=10, label=simulated_id_name)
ood_line = mlines.Line2D([], [], color=color_p[7],
                          markersize=10, label=simulated_ood_name)

ood_isoline = mlines.Line2D([], [], color=color_p[7], linestyle="--", linewidth=3,
                          markersize=10, label="\\textbf{{calibrated strategies}}")
ood_true_isoline = mlines.Line2D([], [], color=color_p[3], linestyle="--",
                          markersize=10, label="\\textbf{{true isolines}}")

l_handles = [id_points, ood_points, id_line, ood_line, ood_isoline]
a.ax_joint.legend(handles=l_handles, loc="upper center", ncol=3, bbox_to_anchor=(0.515755, 1.042))

# -------------- SAVING -----------------
#plt.tight_layout()
plt.gcf().set_size_inches(14, 10)
plt.savefig("/home/vojirtom/LP_NM_diagram.pdf", dpi=300)
plt.show()