In [33]:
def compute_px(ppl, lls):
    lengths = np.array([len(ll) for ll in lls])
    logpx = np.log(ppl) * lengths * -1
    return logpx

def compute_auroc_all(id_msp, id_px, id_ppl, ood_msp, ood_px, ood_ppl, do_print=False):
    score_px = compute_auroc(-id_px, -ood_px)
    score_py = compute_auroc(-id_msp, -ood_msp)
    score_ppl = compute_auroc(id_ppl, ood_ppl)
    if do_print:
        print(f"P(x): {score_px:.3f}")
        print(f"P(y | x): {score_py:.3f}")
        print(f"Perplexity: {score_ppl:.3f}")
    scores = {
        'p_x': score_px,
        'p_y': score_py,
        'ppl': score_ppl
    }
    return scores

In [34]:
from sklearn.metrics import roc_auc_score, roc_curve

def compute_auroc(id_pps, ood_pps, normalize=False, return_curve=False):
    y = np.concatenate((np.ones_like(ood_pps), np.zeros_like(id_pps)))
    scores = np.concatenate((ood_pps, id_pps))
    if normalize:
        scores = (scores - scores.min()) / (scores.max() - scores.min())
    if return_curve:
        return roc_curve(y, scores)
    else:
        return 100*roc_auc_score(y, scores)

def compute_far(id_pps, ood_pps, rate=5):
    incorrect = len(id_pps[id_pps > np.percentile(ood_pps, rate)])
    return 100*incorrect / len(id_pps)

In [38]:
def compute_metric_all(id_msp, id_px, id_ppl, ood_msp, ood_px, ood_ppl, metric='auroc', do_print=False):
    if metric == 'auroc':
        score_px = compute_auroc(-id_px, -ood_px)
        score_py = compute_auroc(-id_msp, -ood_msp)
        score_ppl = compute_auroc(id_ppl, ood_ppl)
    elif metric == 'far':
        score_px = compute_far(-id_px, -ood_px)
        score_py = compute_far(-id_msp, -ood_msp)
        score_ppl = compute_far(id_ppl, ood_ppl)
    else:
        raise Exception('Invalid metric name')

    if do_print:
        print(f"Metric {metric}:")
        print(f"P(x): {score_px:.3f}")
        print(f"P(y | x): {score_py:.3f}")
        print(f"Perplexity: {score_ppl:.3f}\n")

    scores = {
        'p_x': score_px,
        'p_y': score_py,
        'ppl': score_ppl
    }
    return scores

### IMDB as ID vs SST-2 as OOD

In [4]:
import numpy as np

ppl_base_path = '../output/gpt2/sst2/'

imdb_pps = np.load(ppl_base_path + 'imdb_5e-5_pps.npy')
sst2_pps = np.load(ppl_base_path + 'sst2_5e-5_pps.npy')

In [9]:
import pickle

with open(ppl_base_path + 'imdb_5e-5_lls.pkl', 'rb') as f:
    imdb_lls = pickle.load(f)

with open(ppl_base_path + 'sst2_5e-5_lls.pkl', 'rb') as f:
    sst2_lls = pickle.load(f)

In [55]:
msp_base_path = '../roberta/msp/'

imdb_msp = np.load(msp_base_path + 'large_imdb_msp.npy')
# imdb_msp = np.load(msp_base_path + 'textattack_imdb_msp.npy')
sst2_msp = np.load(msp_base_path + 'large_sst2_msp.npy')

In [56]:
import pandas as pd

all_pps = {
    'imdb': imdb_pps,
#     'yelp': yelp_pps,
    'sst2': sst2_pps,
#     'snli': snli_pps,
#     'rte': rte_pps
}

all_lls = {
    'imdb': imdb_lls,
#     'yelp': yelp_lls,
    'sst2': sst2_lls,
#     'snli': snli_lls,
#     'rte': rte_lls
}

all_msp = {
    'imdb': imdb_msp,
#     'yelp': yelp_msp,
    'sst2': sst2_msp,
#     'snli': snli_msp,
#     'rte': rte_msp
}

In [57]:
all_px, all_pxy = {}, {}
for ds in all_pps.keys():
    all_px[ds] = compute_px(all_pps[ds], all_lls[ds])

In [58]:
dataset_pairs = [('sst2', 'imdb')]
metrics = ['auroc', 'far']

In [59]:
for (id_name, ood_name) in dataset_pairs:
    print(f"-------{id_name} vs {ood_name}-------")
    for metric in metrics:
        compute_metric_all(all_msp[id_name], all_px[id_name],
                           all_pps[id_name], all_msp[ood_name],
                           all_px[ood_name], all_pps[ood_name], metric=metric, do_print=True)

-------sst2 vs imdb-------
Metric auroc:
P(x): 99.948
P(y | x): 64.283
Perplexity: 90.090

Metric far:
P(x): 0.000
P(y | x): 70.344
Perplexity: 34.487



In [60]:
results = {}

for metric in metrics:
    results[metric] = {}
    for (id_name, ood_name) in dataset_pairs:
        results[metric][f'{id_name}-{ood_name}'] = compute_metric_all(all_msp[id_name], all_px[id_name],
                           all_pps[id_name], all_msp[ood_name],
                           all_px[ood_name], all_pps[ood_name], metric=metric)

In [61]:
all_dfs = {}

for metric in metrics:
    all_dfs[metric] = pd.DataFrame.from_dict(results[metric], orient='index')

In [62]:
print("AUROC:")

all_dfs['auroc']

AUROC:


Unnamed: 0,p_x,p_y,ppl
sst2-imdb,99.948033,64.283467,90.089902


In [63]:
print("FAR95:")

all_dfs['far']

FAR95:


Unnamed: 0,p_x,p_y,ppl
sst2-imdb,0.0,70.34384,34.486546
