# Import

In [1]:
%matplotlib widget

In [2]:
import pickle
import os
import pandas as pd
import numpy as np
import json

import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.metrics import roc_auc_score, roc_curve

# Define

In [3]:
def compute_px(ppl, lls):
    lengths = np.array([len(ll) for ll in lls])
    logpx = np.log(ppl) * lengths * -1
    return logpx

def compute_auroc_all(id_msp, id_px, id_ppl, ood_msp, ood_px, ood_ppl, do_print=False):
    score_px = compute_auroc(-id_px, -ood_px)
    score_py = compute_auroc(-id_msp, -ood_msp)
    score_ppl = compute_auroc(id_ppl, ood_ppl)
    if do_print:
        print(f"P(x): {score_px:.3f}")
        print(f"P(y | x): {score_py:.3f}")
        print(f"Perplexity: {score_ppl:.3f}")
    scores = {
        'p_x': score_px,
        'p_y': score_py,
        'ppl': score_ppl
    }
    return scores

In [4]:
def compute_auroc(id_pps, ood_pps, normalize=False, return_curve=False):
    y = np.concatenate((np.ones_like(ood_pps), np.zeros_like(id_pps)))
    scores = np.concatenate((ood_pps, id_pps))
    if normalize:
        scores = (scores - scores.min()) / (scores.max() - scores.min())
    if return_curve:
        return roc_curve(y, scores)
    else:
        return 100*roc_auc_score(y, scores)

def compute_far(id_pps, ood_pps, rate=5, return_indices=False):
    if return_indices:
        cut_off = np.percentile(ood_pps, rate)
        id_indices = [i for i, pps in enumerate(id_pps) if pps > cut_off]
        ood_indices = [i for i, pps in enumerate(ood_pps) if pps > cut_off]
        return {'id': id_indices, 'ood': ood_indices}
    else:
        incorrect = len(id_pps[id_pps > np.percentile(ood_pps, rate)])
        return 100*incorrect / len(id_pps)

In [5]:
def compute_metric_all(id_msp, id_px, id_ppl, ood_msp, ood_px, ood_ppl, metric='auroc', do_print=False):
    if metric == 'auroc':
        score_px = compute_auroc(-id_px, -ood_px)
        score_py = compute_auroc(-id_msp, -ood_msp)
        score_ppl = compute_auroc(id_ppl, ood_ppl)
    elif metric == 'far':
        score_px = compute_far(-id_px, -ood_px)
        score_py = compute_far(-id_msp, -ood_msp)
        score_ppl = compute_far(id_ppl, ood_ppl)
    else:
        raise Exception('Invalid metric name')

    if do_print:
        print(f"Metric {metric}:")
        print(f"P(x): {score_px:.3f}")
        print(f"P(y | x): {score_py:.3f}")
        print(f"Perplexity: {score_ppl:.3f}\n")

    scores = {
        'p_x': score_px,
        'p_y': score_py,
        'ppl': score_ppl
    }
    return scores

In [6]:
def read_model_out(fname):
    ftype = fname.split('.')[1]
    
    if ftype == 'pkl':
        with open(fname, 'rb') as f:
            return pickle.load(f)
    elif ftype == 'npy':
        return np.load(fname)
    else:
        raise KeyError(f'{ftype} not supported')


# Summarize

## Presettings

In [7]:
verbose = False

In [8]:
repo = os.path.dirname(os.path.dirname(os.path.abspath('__file__')))
print(repo)

C:\Users\Willi\Documents\NYU\2020_Fall\nlp\project\ood-detection


In [9]:
output_dir = os.path.join(repo, 'output')
fig_dir = os.path.join(repo, 'figs')

In [10]:
train_sets = ['imdb', 'sst2']
eval_sets = ['imdb', 'sst2', 'snli', 'counterfactual-imdb', 'rte']
methods = ['msp', 'lls', 'pps']

signals = {}
for train_set in train_sets:
    for eval_set in eval_sets:
        signals[(train_set, eval_set)] = {method: None for method in methods}

## Import Signals

In [11]:
method2ftype={
    'msp': 'npy',
    'lls': 'pkl',
    'pps': 'npy',
}

### Subsampling Indices

In [12]:
def get_indices(fname):
    with open(fname, 'r') as f:
        return [int(x) for x in f.readlines()]

In [13]:
roberta_dir = os.path.join(repo, 'roberta')

subsample_indices = {
    data_name: get_indices(os.path.join(roberta_dir, f'{data_name}_indices.txt'))
    for data_name in train_sets
}

### GPT2

In [14]:
best_lr = {
    'imdb': '5e-5',
    'sst2': '5e-5',
}

methods = ['lls', 'pps']
not_readys = []

for (train_set, eval_set), signals_dict in signals.items():
    for method in methods:
        signal_fname = os.path.join(output_dir, 'gpt2', train_set, f'{eval_set}_{best_lr[train_set]}_{method}.{method2ftype[method]}')
        if not os.path.exists(signal_fname):
            not_readys.append((train_set, eval_set, method))
            continue
        
        signal = read_model_out(signal_fname)
        if train_set == eval_set:
            idxs = subsample_indices[train_set]
            signal = [signal[idx] for idx in idxs]
        
        signals_dict[method] = signal
        
for not_ready in not_readys:
    print(not_ready)

### RoBERTa

In [15]:
methods = ['msp']
not_readys = []

model_type = 'roberta-large'

for (train_set, eval_set), signals_dict in signals.items():
    for method in methods:
        signal_fname = os.path.join(output_dir, 'roberta', train_set, f'{model_type}_{eval_set}_{method}.{method2ftype[method]}')
        if not os.path.exists(signal_fname):
            not_readys.append((train_set, eval_set, method))
            continue

        signals_dict[method] = read_model_out(signal_fname)
        
for not_ready in not_readys:
    print(not_ready)

## Get Error Indices by FAR95

In [16]:
score2plot = {
    'p_x': r'GPT2: $p(x)$',
    'ppl': 'GPT2: PPL',
    'p_y': 'RoBERTa: MSP',
}

metric2plot = {
    'auroc': 'AUROC',
    'far': 'FAR95'
}

dataset2plot = {
    'imdb': 'IMDB',
    'sst2': 'SST-2',
    'snli': 'SNLI',
    'counterfactual-imdb': 'c-IMDB',
    'rte': 'RTE',
}

error_indices = {}
not_ready = []
for train_set in train_sets:
    for eval_set in eval_sets:
        if train_set == eval_set:
            continue
        
        ood_signal_dict = signals[(train_set, eval_set)]
        id_signal_dict = signals[(train_set, train_set)]
        
        skip=False
        for value in ood_signal_dict.values():
            if isinstance(value, type(None)):
                skip=True
                
        if skip:
            not_ready.append((train_set, eval_set))
            continue
        
        pps_errors = compute_far(id_signal_dict['pps'], ood_signal_dict['pps'], return_indices=True)
        msp_errors = compute_far(-id_signal_dict['msp'], -ood_signal_dict['msp'], return_indices=True)
        
        error_indices[(train_set, eval_set)] = {'pps': pps_errors, 'msp': msp_errors}

In [17]:
print(list(error_indices.keys()))

[('imdb', 'sst2'), ('imdb', 'snli'), ('imdb', 'counterfactual-imdb'), ('imdb', 'rte'), ('sst2', 'imdb'), ('sst2', 'snli'), ('sst2', 'counterfactual-imdb'), ('sst2', 'rte')]


# Partition Indices

In [18]:
indices_parts = {}
for ood_key, errors_dict in error_indices.items():
    pps_id = set(errors_dict['pps']['id'])
    pps_ood = set(errors_dict['pps']['ood'])
    msp_id = set(errors_dict['msp']['id'])
    msp_ood = set(errors_dict['msp']['ood'])
    
    union = {
        'id': pps_id.union(msp_id),
        'ood': pps_ood.union(msp_ood),
    }
    
    common = {
        'id': pps_id.intersection(msp_id),
        'ood': pps_ood.intersection(msp_ood),
    }
    
    pps_only = {
        'id': pps_id.difference(msp_id),
        'ood': pps_ood.difference(msp_ood),
    }
    
    msp_only = {
        'id': msp_id.difference(pps_id),
        'ood': msp_ood.difference(pps_ood),
    }
    
    indices_parts[ood_key] = {
        'union': union,
        'common': common,
        'pps': pps_only,
        'msp': msp_only,
    }

# Partition Stats

In [19]:
stats = []
for (indomain, ood), partitions in indices_parts.items():
    total_count = len(partitions['union']['id']) + len(partitions['union']['ood'])
    common_count = len(partitions['common']['id']) + len(partitions['common']['ood'])
    msp_count = len(partitions['msp']['id']) + len(partitions['msp']['ood']) + common_count
    pps_count = len(partitions['pps']['id']) + len(partitions['pps']['ood']) + common_count
    
    row = {
        'in domain': indomain,
        'ood': ood,
        'total count': total_count,
        'common count': common_count,
        'msp only count': msp_count - common_count,
        'pps only count': pps_count - common_count,
        'common ratio': common_count/total_count,
        'msp only ratio': (msp_count - common_count)/total_count,
        'pps only ratio': (pps_count - common_count)/total_count,
    }
    
    stats.append(row)
print(pd.DataFrame(stats))
pd.DataFrame(stats).to_csv(os.path.join('.', 'error_analysis', 'error_counts_ratios.csv'))

  in domain                  ood  total count  common count  msp only count  \
0      imdb                 sst2        19413          9482            7527   
1      imdb                 snli        28368         11665             634   
2      imdb  counterfactual-imdb        21877         17398            1169   
3      imdb                  rte        20014          4278              43   
4      sst2                 imdb        25437         22813            1428   
5      sst2                 snli        10492          9157             575   
6      sst2  counterfactual-imdb         2951          2344             421   
7      sst2                  rte          813           414             108   

   pps only count  common ratio  msp only ratio  pps only ratio  
0            2404      0.488436        0.387730        0.123835  
1           16069      0.411203        0.022349        0.566448  
2            3310      0.795264        0.053435        0.151300  
3           15693      0

# Sample Examples

In [20]:
with open(os.path.join('.', 'all_val_data.p'), 'rb') as f:
    datasets = pickle.load(f)

In [21]:
seed = 42
np.random.seed(seed)

In [22]:
nsample = 10
examples = []
ignore_ood = ['sst2', 'imdb', 'counterfactual-imdb']
parts = ['common', 'msp', 'pps']
domains = ['ood']

for (indomain, ood), partitions in indices_parts.items():
    if ood in ignore_ood:
        continue
        
    data = {
        'id': datasets[('id', 'val', indomain)]['text'],
        'ood': datasets[('ood', 'val', ood)]['text'],
    }
    
    for part in parts:
        for domain in domains:
            sample = np.random.choice(list(partitions[part][domain]), size=nsample, replace=False)
            print(indomain, ood, part, domain)
            for idx in sample:
                examples.append({
                    'in domain': indomain,
                    'ood': ood,
                    'text': data[domain][idx],
                    'domain': domain,
                    'dataset': indomain if domain == 'id' else ood,
                    'partition': part,
                })
pd.DataFrame(examples).to_csv(os.path.join('.', 'error_analysis', 'error_examples.csv'))
    
    

imdb snli common ood
imdb snli msp ood
imdb snli pps ood
imdb rte common ood
imdb rte msp ood
imdb rte pps ood
sst2 snli common ood
sst2 snli msp ood
sst2 snli pps ood
sst2 rte common ood
sst2 rte msp ood
sst2 rte pps ood


# Consistent OOD Examples

In [23]:
seed = 42
np.random.seed(seed)

In [24]:
ood_stats, ood_examples = [], []

tasks = ['snli', 'rte']
parts = ['common', 'msp', 'pps']
domains = ['ood']

nsample = 10

for (indomain, ood), partitions in indices_parts.items():
    if not ood in tasks:
        continue
    
    total_count = len(partitions['union']['ood'])
    common_count = len(partitions['common']['ood'])
    msp_count = len(partitions['msp']['ood']) + common_count
    pps_count = len(partitions['pps']['ood']) + common_count
    
    row = {
        'in domain': indomain,
        'ood': ood,
        'total count': total_count,
        'common count': common_count,
        'msp only count': msp_count - common_count,
        'pps only count': pps_count - common_count,
        'common ratio': common_count/total_count,
        'msp only ratio': (msp_count - common_count)/total_count,
        'pps only ratio': (pps_count - common_count)/total_count,
    }
    
    ood_stats.append(row)
    
    
    data = {'ood': datasets[('ood', 'val', ood)]['text']}
    for part in parts:
        for domain in domains:
            sample = np.random.choice(list(partitions[part][domain]), size=nsample, replace=False)
            print(indomain, ood, part, domain)
            for idx in sample:
                ood_examples.append({
                    'in domain': indomain,
                    'ood': ood,
                    'text': data[domain][idx],
                    'domain': domain,
                    'dataset': indomain if domain == 'id' else ood,
                    'partition': part,
                })
    
    
pd.DataFrame(ood_stats).to_csv(os.path.join('.', 'error_analysis', 'ood_error_counts_ratios.csv'))
pd.DataFrame(ood_examples).to_csv(os.path.join('.', 'error_analysis', 'ood_error_examples.csv'))

imdb snli common ood
imdb snli msp ood
imdb snli pps ood
imdb rte common ood
imdb rte msp ood
imdb rte pps ood
sst2 snli common ood
sst2 snli msp ood
sst2 snli pps ood
sst2 rte common ood
sst2 rte msp ood
sst2 rte pps ood


# Check Signals

In [25]:
print('sst2 in domain', 'msp', len(signals[('sst2', 'sst2')]['msp']))
print('sst2 out domain', 'msp',len(signals[('imdb', 'sst2')]['msp']))
print('sst2 in domain', 'pps',len(signals[('sst2', 'sst2')]['pps']))
print('sst2 out domain', 'pps',len(signals[('imdb', 'sst2')]['pps']))
print('='*45)
print('imdb in domain', 'msp', len(signals[('imdb', 'imdb')]['msp']))
print('imdb out domain', 'msp',len(signals[('sst2', 'imdb')]['msp']))
print('imdb in domain', 'pps',len(signals[('imdb', 'imdb')]['pps']))
print('imdb out domain', 'pps',len(signals[('sst2', 'imdb')]['pps']))

sst2 in domain msp 698
sst2 out domain msp 872
sst2 in domain pps 698
sst2 out domain pps 1821
imdb in domain msp 20000
imdb out domain msp 25000
imdb in domain pps 20000
imdb out domain pps 25000


In [26]:
print('snli', 'sst2 in domain', 'msp', len(signals[('sst2', 'snli')]['msp']))
print('snli', 'imdb in domain', 'msp',len(signals[('imdb', 'snli')]['msp']))
print('='*45)
print('rte', 'sst2 in domain', 'msp', len(signals[('sst2', 'rte')]['msp']))
print('rte', 'imdb in domain', 'msp',len(signals[('imdb', 'rte')]['msp']))

snli sst2 in domain msp 10000
snli imdb in domain msp 10000
rte sst2 in domain msp 277
rte imdb in domain msp 277


In [27]:
print('snli', 'sst2 in domain', 'pps', len(signals[('sst2', 'snli')]['pps']))
print('snli', 'imdb in domain', 'pps',len(signals[('imdb', 'snli')]['pps']))
print('='*45)
print('rte', 'sst2 in domain', 'pps', len(signals[('sst2', 'rte')]['pps']))
print('rte', 'imdb in domain', 'pps',len(signals[('imdb', 'rte')]['pps']))

snli sst2 in domain pps 10000
snli imdb in domain pps 10000
rte sst2 in domain pps 277
rte imdb in domain pps 277
