In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
from functools import partial
import sys
import random
import pandas as pd

from scipy.stats import multivariate_normal

from sklearn.metrics import auc, roc_curve
import json
import torch
from tqdm import tqdm
import shutil

in_path = './models_in_relu_NTK'
out_path = './models_out_relu_NTK'
output_path = '../log_nngp_vs_nn/log_nn_landscape'

os.environ["CUDA_VISIBLE_DEVICES"] = "1"


def to_cuda(x):
    if torch.cuda.is_available():
        x = x.cuda()
    return x

In [None]:
if not os.path.exists(output_path):

    os.mkdir(output_path)
    os.mkdir(f'{output_path}/data')
    os.mkdir(f'{output_path}/in_predictions')
    os.mkdir(f'{output_path}/out_predictions')

    in_dict = {}
    out_dict = {}
    in_folder_list = sorted(os.listdir(in_path))
    out_folder_list = sorted(os.listdir(out_path))
    differ_data_dict = {}
    if in_folder_list == None or out_folder_list == None:
        raise NotImplementedError(f'empty folder to plot')

    # select random perturbation direction
    random_direction = np.random.uniform(-1, 1, (3, 32, 32))/255

    # log out predictions
    print("logging out predictions")
    for folder in tqdm(out_folder_list):
        with open(f"{out_path}/{folder}/params.json", 'r') as f:
            data = json.load(f)
        # read differing point index
        differ_idx = data['differ_data']
        # log differing data value
        differ_data = np.load(f'{out_path}/{folder}/logs/differ_data.npy')
        # compute the out prediction for the differing data
        temp_model = torch.load(f'{out_path}/{folder}/models/model-best.pkl')
        temp_model.eval()
        grid = np.linspace(-50, 50, 101)
        data_batch = np.array([differ_data + random_direction * i for i in grid])
        preds = temp_model(to_cuda(torch.from_numpy(data_batch).float()))
        if differ_idx not in out_dict:
            # log differing point index
            differ_data_dict[differ_idx] = differ_data
            out_dict[differ_idx] = preds.detach().cpu().unsqueeze(0).numpy()
        else:
            out_dict[differ_idx] = np.concatenate((out_dict[differ_idx], preds.detach().cpu().unsqueeze(0).numpy()), axis = 0)


    print("logging in predictions")
    # log in predictions
    for i in tqdm(range(len(in_folder_list))):
        folder = in_folder_list[i]
        # check if it is full dataset training
        with open(f"{in_path}/{folder}/params.json", 'r') as f:
            data = json.load(f)
        assert(data['differ_data']==None)
        temp_model = torch.load(f'{in_path}/{folder}/models/model-best.pkl')
        temp_model.eval()
        if i ==0:
            for differ_idx in out_dict:
                differ_data = differ_data_dict[differ_idx]
                grid = np.linspace(-50, 50, 101)
                data_batch = np.array([differ_data + random_direction * i for i in grid])
                preds = temp_model(to_cuda(torch.from_numpy(data_batch).float()))
                in_dict[differ_idx] = preds.detach().cpu().unsqueeze(0).numpy()
        else:
            for differ_idx in out_dict:
                differ_data = differ_data_dict[differ_idx]
                grid = np.linspace(-50, 50, 101)
                data_batch = np.array([differ_data + random_direction * i for i in grid])
                preds = temp_model(to_cuda(torch.from_numpy(data_batch).float()))
                in_dict[differ_idx] = np.concatenate((in_dict[differ_idx], preds.detach().cpu().unsqueeze(0).numpy()), axis = 0)

    # storing predictions for all differing points
    print("storing data and predictions")
    for differ_idx in tqdm(in_dict):
        if out_dict[differ_idx].shape[0]==100:
            np.save(f'{output_path}/data/differ_idx_{differ_idx}.npy', differ_data_dict[differ_idx])
            in_preds = in_dict[differ_idx]
            np.save(f'{output_path}/in_predictions/differ_idx_{differ_idx}.npy', in_preds)
            out_preds = out_dict[differ_idx]
            np.save(f'{output_path}/out_predictions/differ_idx_{differ_idx}.npy', out_preds)
        else:
            print(f"Wrong number of in predictions for data {differ_idx}: {out_dict[differ_idx].shape[0]}")


In [None]:
differ_names = os.listdir(f'{output_path}/data')
differ_indices = [int(i.replace('differ_idx_', '').replace('.npy', '')) for i in differ_names]

In [None]:
import numpy as np
from sklearn.metrics import auc, roc_curve
from scipy.stats import multivariate_normal
import scipy as sp

def get_LOOD_auc_dict(output_path, differ_indices):
    auc_dict, auc_std_dict, diff_dict, diff_std_dict, kl_dict, kl_std_dict = {}, {}, {}, {}, {}, {}
    for differ_idx in tqdm(differ_indices):
        out_preds = np.load(f'{output_path}/out_predictions/differ_idx_{differ_idx}.npy')
        if out_preds.shape[0]==100:
            in_preds = np.load(f'{output_path}/in_predictions/differ_idx_{differ_idx}.npy')
            out_preds = np.load(f'{output_path}/out_predictions/differ_idx_{differ_idx}.npy')
            auc_list, diff_list, kl_list = [], [], []
            for auc_run in range(50):
                eval_indices = np.random.choice(out_preds.shape[0], out_preds.shape[0]//2, replace=False)
                ref_indices = np.array(list(set(range(out_preds.shape[0])) - set(eval_indices)))
                # use half of the preds to estimate mean and covariance matrix
                auc_run_i_list = []
                diff_run_i_list = []
                kl_run_i_list = []
                for query_index in range(101):
                    in_preds_reference = in_preds[ref_indices][:,query_index,:]
                    in_preds_eval = in_preds[eval_indices][:,query_index,:]
                    out_preds_reference = out_preds[ref_indices][:,query_index,:]
                    out_preds_eval = out_preds[eval_indices][:,query_index,:]
                    # compute mean and covariance matrix
                    in_mean = np.mean(in_preds_reference, axis = 0)
                    in_cov = np.cov(in_preds_reference.T)
                    out_mean = np.mean(out_preds_reference, axis = 0)
                    out_cov = np.cov(out_preds_reference.T)
                    # compute auc score
                    fpr, tpr, thresholds = roc_curve(np.concatenate((np.ones(in_preds_eval.shape[0]), np.zeros(out_preds_eval.shape[0])), axis = 0), multivariate_normal.pdf(np.append(in_preds_eval, out_preds_eval, axis = 0), in_mean, in_cov)/ (multivariate_normal.pdf(np.append(in_preds_eval, out_preds_eval, axis = 0), out_mean, out_cov) + 1e-30))
                    auc_score = auc(fpr, tpr)
                    auc_run_i_list.append(auc_score)
                    # compute mean distance lood between empirical GPs
                    diff_run_i_list.append(np.mean((np.mean(in_preds_eval, axis = 0) - np.mean(out_preds_eval, axis = 0))**2))
                    # compute kl lood between empirical GPs
                    kl_run_i_list.append(- 0.5 * (
                    np.log(np.linalg.det(in_cov) / np.linalg.det(out_cov)) 
                    - out_mean.shape[0] + np.trace(sp.linalg.solve(in_cov, out_cov)) 
                    + np.dot((in_mean - out_mean).T.reshape(-1), sp.linalg.solve(in_cov, (in_mean - out_mean).T.reshape(-1).T))
                    ))
                    

                auc_list.append(auc_run_i_list)
                diff_list.append(diff_run_i_list)
                kl_list.append(kl_run_i_list)
            auc_dict[differ_idx] = np.mean(np.array(auc_list), axis = 0)
            auc_std_dict[differ_idx] = np.std(np.array(auc_list), axis = 0)
            diff_dict[differ_idx] = np.mean(np.array(diff_list), axis = 0)
            diff_std_dict[differ_idx] = np.std(np.array(diff_list), axis = 0)
            kl_dict[differ_idx] = - np.mean(np.array(kl_list), axis = 0)
            kl_std_dict[differ_idx] = np.std(np.array(kl_list), axis = 0)
            df_temp = pd.DataFrame.from_dict({'Pert': list(range(-50, 51)), 'auc': list(auc_dict[differ_idx]), 'auc_std': list(auc_std_dict[differ_idx]), 'diff': list(diff_dict[differ_idx]), 'diff_std': list(diff_std_dict[differ_idx]), 'kl': list(kl_dict[differ_idx]), 'kl_std': list(kl_std_dict[differ_idx])})
            df_temp.to_csv(f'{output_path}/landscape_{differ_idx}.csv')
    return auc_dict, auc_std_dict, diff_dict, diff_std_dict, kl_dict, kl_std_dict

auc_dict, auc_std_dict, diff_dict, diff_std_dict, kl_dict, kl_std_dict = get_LOOD_auc_dict(output_path, differ_indices)

In [None]:
import matplotlib.pyplot as plt
plt.figure(figsize = (4, 3))
ax = plt.gca()
pert_list = []
print(list(auc_dict.keys()))
for i in range(len(list(auc_dict.keys()))):
    differ_idx = list(auc_dict.keys())[i]
    landscape_i_pert = np.linspace(-50, 50, 101)
    upper = auc_dict[differ_idx] + auc_std_dict[differ_idx]
    lower = auc_dict[differ_idx] - auc_std_dict[differ_idx]
    max_idx = np.argmax(auc_dict[differ_idx])
    min_idx = np.argmin(auc_dict[differ_idx])
    if lower[min_idx]>=0.5:
        print(differ_idx)
        color = next(ax._get_lines.prop_cycler)['color']
        
        plt.plot(landscape_i_pert, auc_dict[differ_idx], label = f'differ_idx {differ_idx}', color = color)
        plt.fill_between(np.linspace(-50, 50, 101), lower, upper, alpha=0.2, color = color)
        plt.axvline(x = landscape_i_pert[max_idx], linestyle = 'dashed', color = color)
        pert_list.append(landscape_i_pert[max_idx])
plt.xlabel('perturbation scale x', fontsize = 12)
plt.ylabel('MIA performance', fontsize = 12)
plt.text(0.5, 1.1, f'Perturbation scale for maximal MIA AUC: avg {np.mean(pert_list):.3f} (std {np.std(pert_list):.3f})', horizontalalignment='center', verticalalignment='top', transform=plt.gca().transAxes)
plt.savefig(f'{output_path}/auc_mia_landscape.png', bbox_inches='tight')
print(pert_list)
print(len(pert_list))
print(f"Perturbation scale for maximal MIA AUC: avg {np.mean(pert_list):.3f} (std {np.std(pert_list):.3f})")