In [None]:
import torch

from utils import (
    seed_everything,
    binary_ood_id_repr,
    process_partially,
    get_best_temperature,
    Loader,
)

from metrics import (
    ECELoss,
    TACELoss,
    get_auroc_auprc,
    fpr_at_k_recall,
)

from gmm import fit_gmm, evaluate_gmm

import sentencepiece as spm
from transformer import TransformerEncoder

In [None]:
seed_everything(1234)

device = "cuda:0" if torch.cuda.is_available() else "cpu"

tokenizer = spm.SentencePieceProcessor('spm_wiki.model')

# In Domain Dataset: 20News 

In [None]:
BATCH_SIZE = 64
MAX_LEN = 512
NUM_WORKERS = 16

loader = Loader(batch_size=BATCH_SIZE, max_len=MAX_LEN, num_workers=NUM_WORKERS)
train_loader = loader.load('20news', 'train')
val_loader = loader.load('20news', 'val')
id_test_loader = loader.load('20news', 'test')
ood_snli_loader = loader.load('snli', 'test')
ood_imdb_loader = loader.load('imdb', 'test')
ood_m30k_loader = loader.load('m30k', 'test')
ood_wmt16_loader = loader.load('wmt16', 'test')
ood_yelp_loader = loader.load('yelp', 'test')

ood_loaders = [ood_snli_loader, ood_imdb_loader, ood_m30k_loader, ood_wmt16_loader, ood_yelp_loader]
ood_names = ['snli', 'imdb', 'm30k', 'wmt16', 'yelp']

### Normal Softmax Model (No GMM, No Spectral Norm)

In [None]:
params_20ng = {
    'vocab_size': len(tokenizer),
    'emb_dim': 1024,
    'n_layers': 1,
    'n_heads': 8,
    'forward_dim': 1024,
    'dropout': 0.3,
    'max_len': MAX_LEN,
    'pad_idx': tokenizer.pad_id(),
    'kind': 'sto_dual',
    'tau1': 1,
    'tau2': 40,
    'k_centroid': 3,
    'spectral': False,
    'n_classes': 20,
    'device': device
}

path = '../../../params/udl/sto_dual_20news/'
model_path = path + 'model.pt'

model = TransformerEncoder(**params_20ng).to(device)
model.load_state_dict(torch.load(model_path))
model.eval();

In [None]:
ece_criterion = ECELoss(num_bins=15)
tace_criterion = TACELoss(num_bins=15)

In [None]:
id_test_logits, _ = process_partially(model, id_test_loader, only_logits=True) # shape: (test_set_size, num_classes)

In [None]:
# get best temperature
best_ce_temp, best_ece_temp, best_tace_temp = get_best_temperature(model, val_loader, num_tries=100, num_bins=15, log=True)
print('Best CE temperature:', best_ce_temp)
print('Best ECE temperature:', best_ece_temp)
print('Best TACE temperature:', best_tace_temp)

In [None]:
t_ce_test_logits = id_test_logits / best_ce_temp     # shape: (test_set_size, num_classes)
t_ece_test_logits = id_test_logits / best_ece_temp   # shape: (test_set_size, num_classes)
t_tace_test_logits = id_test_logits / best_tace_temp   # shape: (test_set_size, num_classes)

In [None]:
for ood_loader, name in zip(ood_loaders, ood_names):
    print('CURRENT OOD SET:', name)
    ood_logits, _ = process_partially(model, ood_loader, only_logits=True)
    
    
    # id_test_logits, no temperature
    auroc, auprc = get_auroc_auprc(id_test_logits, ood_logits, measure='logsumexp', confidence=True)
    _logits, _lbls = binary_ood_id_repr(ood_logits, id_test_logits, return_1d=True)
    fpr90 = fpr_at_k_recall(_lbls.cpu().numpy(), _logits.cpu().numpy(), recall_level=0.9)
    print(f'no temperature, logsumexp, confidence=True, AUROC: {auroc*100:.2f}%')
    print(f'no temperature, logsumexp, confidence=True, AUPRC: {auprc*100:.2f}%')
    
    auroc, auprc = get_auroc_auprc(id_test_logits, ood_logits, measure='entropy')
    print(f'no temperature, entropy, confidence=False, AUROC: {auroc*100:.2f}%')
    print(f'no temperature, entropy, confidence=False, AUPRC: {auprc*100:.2f}%')
    print(f'FPR90: {fpr90*100:.2f}%')
    
    # id_test_logits with CE temperature
    auroc, auprc = get_auroc_auprc(t_ce_test_logits, ood_logits, measure='logsumexp', confidence=True)
    _logits, _lbls = binary_ood_id_repr(ood_logits, t_ce_test_logits, return_1d=True)
    fpr90 = fpr_at_k_recall(_lbls.cpu().numpy(), _logits.cpu().numpy(), recall_level=0.9)
    print(f'CE temperature, logsumexp, confidence=True, AUROC: {auroc*100:.2f}%')
    print(f'CE temperature, logsumexp, confidence=True, AUPRC: {auprc*100:.2f}%')
    
    auroc, auprc = get_auroc_auprc(t_ce_test_logits, ood_logits, measure='entropy')
    print(f'CE temperature, entropy, confidence=False, AUROC: {auroc*100:.2f}%')
    print(f'CE temperature, entropy, confidence=False, AUPRC: {auprc*100:.2f}%')
    print(f'FPR90: {fpr90*100:.2f}%')
    
    # id_test_logits with ECE temperature
    auroc, auprc = get_auroc_auprc(t_ece_test_logits, ood_logits, measure='logsumexp', confidence=True)
    _logits, _lbls = binary_ood_id_repr(ood_logits, t_ece_test_logits, return_1d=True)
    fpr90 = fpr_at_k_recall(_lbls.cpu().numpy(), _logits.cpu().numpy(), recall_level=0.9)
    print(f'ECE temperature, logsumexp, confidence=True, AUROC: {auroc*100:.2f}%')
    print(f'ECE temperature, logsumexp, confidence=True, AUPRC: {auprc*100:.2f}%')
    
    auroc, auprc = get_auroc_auprc(t_ece_test_logits, ood_logits, measure='entropy')
    print(f'ECE temperature, entropy, confidence=False, AUROC: {auroc*100:.2f}%')
    print(f'ECE temperature, entropy, confidence=False, AUPRC: {auprc*100:.2f}%')
    print(f'FPR90: {fpr90*100:.2f}%')
    
    # id_test_logits with TACE temperature
    auroc, auprc = get_auroc_auprc(t_tace_test_logits, ood_logits, measure='logsumexp', confidence=True)
    _logits, _lbls = binary_ood_id_repr(ood_logits, t_tace_test_logits, return_1d=True)
    fpr90 = fpr_at_k_recall(_lbls.cpu().numpy(), _logits.cpu().numpy(), recall_level=0.9)
    print(f'TACE temperature, logsumexp, confidence=True, AUROC: {auroc*100:.2f}%')
    print(f'TACE temperature, logsumexp, confidence=True, AUPRC: {auprc*100:.2f}%')
    
    auroc, auprc = get_auroc_auprc(t_tace_test_logits, ood_logits, measure='entropy')
    print(f'TACE temperature, entropy, confidence=False, AUROC: {auroc*100:.2f}%')
    print(f'TACE temperature, entropy, confidence=False, AUPRC: {auprc*100:.2f}%')
    print(f'FPR90: {fpr90*100:.2f}%')
    
    cat_logits, bin_labels = binary_ood_id_repr(ood_logits, id_test_logits)
    t_ce_cat_logits, _ = binary_ood_id_repr(ood_logits, t_ce_test_logits)
    t_ece_cat_logits, _ = binary_ood_id_repr(ood_logits, t_ece_test_logits)
    t_tace_cat_logits, _ = binary_ood_id_repr(ood_logits, t_tace_test_logits)
    
    # ECE for ood dataset
    ece = ece_criterion(cat_logits, bin_labels)
    ce_ece = ece_criterion(t_ce_cat_logits, bin_labels)
    ece_ece = ece_criterion(t_ece_cat_logits, bin_labels)
    tace_ece = ece_criterion(t_tace_cat_logits, bin_labels)

    print(f'ECE no temperature scaling: {ece:.4f}')
    print(f'ECE with CE temperature scaling: {ce_ece:.4f}')
    print(f'ECE with ECE temperature scaling: {ece_ece:.4f}')
    print(f'ECE with TACE temperature scaling: {tace_ece:.4f}')
    
    # TACE for ood dataset
    tace = tace_criterion(cat_logits, bin_labels)
    ce_tace = tace_criterion(t_ce_cat_logits, bin_labels)
    ece_tace = tace_criterion(t_ece_cat_logits, bin_labels)
    tace_tace = tace_criterion(t_tace_cat_logits, bin_labels)

    print(f'TACE no temperature scaling: {tace:.4f}')
    print(f'TACE with CE temperature scaling: {ce_tace:.4f}')
    print(f'TACE with ECE temperature scaling: {ece_tace:.4f}')
    print(f'TACE with TACE temperature scaling: {tace_tace:.4f}')
    print('------------------------------------------------')

### Normal Softmax Model (WITH Spectral Normalisation, NO GMM)

In [None]:
params_20ng_sn = {
    'vocab_size': len(tokenizer),
    'emb_dim': 1024,
    'n_layers': 1,
    'n_heads': 8,
    'forward_dim': 1024,
    'dropout': 0.4,
    'max_len': MAX_LEN,
    'pad_idx': tokenizer.pad_id(),
    'kind': 'sto_dual',
    'tau1': 1,
    'tau2': 40,
    'k_centroid': 3,
    'spectral': True,
    'n_classes': 20,
    'device': device
}

path = '../../../params/udl/sto_dual_20news_sn/'
model_path = path + 'model.pt'

model = TransformerEncoder(**params_20ng_sn).to(device)
model.load_state_dict(torch.load(model_path))
model.eval();

In [None]:
ece_criterion = ECELoss(num_bins=15)
tace_criterion = TACELoss(num_bins=15)

In [None]:
id_test_logits, _ = process_partially(model, id_test_loader, only_logits=True) # shape: (test_set_size, num_classes)

In [None]:
# get best temperature
best_ce_temp, best_ece_temp, best_tace_temp = get_best_temperature(model, val_loader, num_tries=100, num_bins=15, log=True)
print('Best CE temperature:', best_ce_temp)
print('Best ECE temperature:', best_ece_temp)
print('Best TACE temperature:', best_tace_temp)

In [None]:
t_ce_test_logits = id_test_logits / best_ce_temp     # shape: (test_set_size, num_classes)
t_ece_test_logits = id_test_logits / best_ece_temp   # shape: (test_set_size, num_classes)
t_tace_test_logits = id_test_logits / best_tace_temp   # shape: (test_set_size, num_classes)

In [None]:
for ood_loader, name in zip(ood_loaders, ood_names):
    print('CURRENT OOD SET:', name)
    ood_logits, _ = process_partially(model, ood_loader, only_logits=True)
    
    
    # id_test_logits, no temperature
    auroc, auprc = get_auroc_auprc(id_test_logits, ood_logits, measure='logsumexp', confidence=True)
    _logits, _lbls = binary_ood_id_repr(ood_logits, id_test_logits, return_1d=True)
    fpr90 = fpr_at_k_recall(_lbls.cpu().numpy(), _logits.cpu().numpy(), recall_level=0.9)
    print(f'no temperature, logsumexp, confidence=True, AUROC: {auroc*100:.2f}%')
    print(f'no temperature, logsumexp, confidence=True, AUPRC: {auprc*100:.2f}%')
    
    auroc, auprc = get_auroc_auprc(id_test_logits, ood_logits, measure='entropy')
    print(f'no temperature, entropy, confidence=False, AUROC: {auroc*100:.2f}%')
    print(f'no temperature, entropy, confidence=False, AUPRC: {auprc*100:.2f}%')
    print(f'FPR90: {fpr90*100:.2f}%')
    
    # id_test_logits with CE temperature
    auroc, auprc = get_auroc_auprc(t_ce_test_logits, ood_logits, measure='logsumexp', confidence=True)
    _logits, _lbls = binary_ood_id_repr(ood_logits, t_ce_test_logits, return_1d=True)
    fpr90 = fpr_at_k_recall(_lbls.cpu().numpy(), _logits.cpu().numpy(), recall_level=0.9)
    print(f'CE temperature, logsumexp, confidence=True, AUROC: {auroc*100:.2f}%')
    print(f'CE temperature, logsumexp, confidence=True, AUPRC: {auprc*100:.2f}%')
    
    auroc, auprc = get_auroc_auprc(t_ce_test_logits, ood_logits, measure='entropy')
    print(f'CE temperature, entropy, confidence=False, AUROC: {auroc*100:.2f}%')
    print(f'CE temperature, entropy, confidence=False, AUPRC: {auprc*100:.2f}%')
    print(f'FPR90: {fpr90*100:.2f}%')
    
    # id_test_logits with ECE temperature
    auroc, auprc = get_auroc_auprc(t_ece_test_logits, ood_logits, measure='logsumexp', confidence=True)
    _logits, _lbls = binary_ood_id_repr(ood_logits, t_ece_test_logits, return_1d=True)
    fpr90 = fpr_at_k_recall(_lbls.cpu().numpy(), _logits.cpu().numpy(), recall_level=0.9)
    print(f'ECE temperature, logsumexp, confidence=True, AUROC: {auroc*100:.2f}%')
    print(f'ECE temperature, logsumexp, confidence=True, AUPRC: {auprc*100:.2f}%')
    
    auroc, auprc = get_auroc_auprc(t_ece_test_logits, ood_logits, measure='entropy')
    print(f'ECE temperature, entropy, confidence=False, AUROC: {auroc*100:.2f}%')
    print(f'ECE temperature, entropy, confidence=False, AUPRC: {auprc*100:.2f}%')
    print(f'FPR90: {fpr90*100:.2f}%')
    
    # id_test_logits with TACE temperature
    auroc, auprc = get_auroc_auprc(t_tace_test_logits, ood_logits, measure='logsumexp', confidence=True)
    _logits, _lbls = binary_ood_id_repr(ood_logits, t_tace_test_logits, return_1d=True)
    fpr90 = fpr_at_k_recall(_lbls.cpu().numpy(), _logits.cpu().numpy(), recall_level=0.9)
    print(f'TACE temperature, logsumexp, confidence=True, AUROC: {auroc*100:.2f}%')
    print(f'TACE temperature, logsumexp, confidence=True, AUPRC: {auprc*100:.2f}%')
    
    auroc, auprc = get_auroc_auprc(t_tace_test_logits, ood_logits, measure='entropy')
    print(f'TACE temperature, entropy, confidence=False, AUROC: {auroc*100:.2f}%')
    print(f'TACE temperature, entropy, confidence=False, AUPRC: {auprc*100:.2f}%')
    print(f'FPR90: {fpr90*100:.2f}%')
    
    cat_logits, bin_labels = binary_ood_id_repr(ood_logits, id_test_logits)
    t_ce_cat_logits, _ = binary_ood_id_repr(ood_logits, t_ce_test_logits)
    t_ece_cat_logits, _ = binary_ood_id_repr(ood_logits, t_ece_test_logits)
    t_tace_cat_logits, _ = binary_ood_id_repr(ood_logits, t_tace_test_logits)
    
    # ECE for ood dataset
    ece = ece_criterion(cat_logits, bin_labels)
    ce_ece = ece_criterion(t_ce_cat_logits, bin_labels)
    ece_ece = ece_criterion(t_ece_cat_logits, bin_labels)
    tace_ece = ece_criterion(t_tace_cat_logits, bin_labels)

    print(f'ECE no temperature scaling: {ece:.4f}')
    print(f'ECE with CE temperature scaling: {ce_ece:.4f}')
    print(f'ECE with ECE temperature scaling: {ece_ece:.4f}')
    print(f'ECE with TACE temperature scaling: {tace_ece:.4f}')
    
    # TACE for ood dataset
    tace = tace_criterion(cat_logits, bin_labels)
    ce_tace = tace_criterion(t_ce_cat_logits, bin_labels)
    ece_tace = tace_criterion(t_ece_cat_logits, bin_labels)
    tace_tace = tace_criterion(t_tace_cat_logits, bin_labels)

    print(f'TACE no temperature scaling: {tace:.4f}')
    print(f'TACE with CE temperature scaling: {ce_tace:.4f}')
    print(f'TACE with ECE temperature scaling: {ece_tace:.4f}')
    print(f'TACE with TACE temperature scaling: {tace_tace:.4f}')
    print('------------------------------------------------')

### With Spectral Normalisation, With GMM

In [None]:
params_20ng_sn = {
    'vocab_size': len(tokenizer),
    'emb_dim': 1024,
    'n_layers': 1,
    'n_heads': 8,
    'forward_dim': 1024,
    'dropout': 0.4,
    'max_len': MAX_LEN,
    'pad_idx': tokenizer.pad_id(),
    'kind': 'sto_dual',
    'tau1': 1,
    'tau2': 40,
    'k_centroid': 3,
    'spectral': True,
    'n_classes': 20,
    'device': device
}

path = '../../../params/udl/sto_dual_20news_sn/'
model_path = path + 'model.pt'

model = TransformerEncoder(**params_20ng_sn).to(device)
model.load_state_dict(torch.load(model_path))
model.eval();

In [None]:
ece_criterion = ECELoss(num_bins=15)
tace_criterion = TACELoss(num_bins=15)

In [None]:
train_embeds, train_lbls = process_partially(model, train_loader, return_embeddings=True)
gmm = fit_gmm(train_embeds, train_lbls, params_20ng_sn['n_classes'])
id_gmm_logits, _ = evaluate_gmm(model, gmm, id_test_loader) 

# logits (log_probs) have extremely bad scaling / are very large

# maybe it's not meant for text? Has only been tested on images

In [None]:
for ood_loader, name in zip(ood_loaders, ood_names):
    print('CURRENT OOD SET:', name)
    ood_gmm_logits, _ = evaluate_gmm(model, gmm, ood_loader)
    
    # that's what they do in the paper with the GMM outputs
    # https://github.com/omegafragger/DDU/blob/main/evaluate.py#L201
    
    # id_test_logits, no temperature
    auroc, auprc = get_auroc_auprc(id_gmm_logits, ood_gmm_logits, measure='logsumexp', confidence=True)
    _logits, _lbls = binary_ood_id_repr(ood_gmm_logits, id_gmm_logits, return_1d=True)
    fpr90 = fpr_at_k_recall(_lbls.cpu().numpy(), _logits.cpu().numpy(), recall_level=0.9)
    print(f'no temperature, logsumexp, confidence=True, AUROC: {auroc*100:.2f}%')
    print(f'no temperature, logsumexp, confidence=True, AUPRC: {auprc*100:.2f}%')
    
    auroc, auprc = get_auroc_auprc(id_gmm_logits, ood_gmm_logits, measure='entropy')
    print(f'no temperature, entropy, confidence=False, AUROC: {auroc*100:.2f}%')
    print(f'no temperature, entropy, confidence=False, AUPRC: {auprc*100:.2f}%')
    print(f'FPR90 (meaningless due to only log_probs available after GMM): {fpr90*100:.2f}%')
    
    cat_logits, bin_labels = binary_ood_id_repr(ood_gmm_logits, id_gmm_logits)
    
    # ECE for ood dataset
    ece = ece_criterion(cat_logits, bin_labels)
    print(f'ECE no temperature scaling: {ece:.4f}')
    
    # TACE for ood dataset
    tace = tace_criterion(cat_logits, bin_labels)
    print(f'TACE no temperature scaling: {tace:.4f}')
    print('------------------------------------------------')

# In Domain Dataset: TREC

In [None]:
BATCH_SIZE = 256
MAX_LEN = 512
NUM_WORKERS = 16

loader = Loader(batch_size=BATCH_SIZE, max_len=MAX_LEN, num_workers=NUM_WORKERS)
train_loader = loader.load('trec', 'train')
val_loader = loader.load('trec', 'val')
id_test_loader = loader.load('trec', 'test')
ood_snli_loader = loader.load('snli', 'test')
ood_imdb_loader = loader.load('imdb', 'test')
ood_m30k_loader = loader.load('m30k', 'test')
ood_wmt16_loader = loader.load('wmt16', 'test')
ood_yelp_loader = loader.load('yelp', 'test')

ood_loaders = [ood_snli_loader, ood_imdb_loader, ood_m30k_loader, ood_wmt16_loader, ood_yelp_loader]
ood_names = ['snli', 'imdb', 'm30k', 'wmt16', 'yelp']

### NO Spectral Normalisation, NO GMM 

In [None]:
params_trec = {
    'vocab_size': len(tokenizer),
    'emb_dim': 1024,
    'n_layers': 1,
    'n_heads': 8,
    'forward_dim': 2048,
    'dropout': 0.5,
    'max_len': MAX_LEN,
    'pad_idx': tokenizer.pad_id(),
    'kind': 'sto_dual',
    'tau1': 1,
    'tau2': 20,
    'k_centroid': 4,
    'spectral': False,
    'n_classes': 50,
    'device': device
}

path = '../../../params/udl/sto_dual_trec/'
model_path = path + 'model.pt'

model = TransformerEncoder(**params_trec).to(device)
model.load_state_dict(torch.load(model_path))
model.eval();

In [None]:
ece_criterion = ECELoss(num_bins=15)
tace_criterion = TACELoss(num_bins=15)

In [None]:
id_test_logits, _ = process_partially(model, id_test_loader, only_logits=True) # shape: (test_set_size, num_classes)

In [None]:
# get best temperature
best_ce_temp, best_ece_temp, best_tace_temp = get_best_temperature(model, val_loader, num_tries=100, num_bins=15, log=True)
print('Best CE temperature:', best_ce_temp)
print('Best ECE temperature:', best_ece_temp)
print('Best TACE temperature:', best_tace_temp)

In [None]:
t_ce_test_logits = id_test_logits / best_ce_temp     # shape: (test_set_size, num_classes)
t_ece_test_logits = id_test_logits / best_ece_temp   # shape: (test_set_size, num_classes)
t_tace_test_logits = id_test_logits / best_tace_temp   # shape: (test_set_size, num_classes)

In [None]:
for ood_loader, name in zip(ood_loaders, ood_names):
    print('CURRENT OOD SET:', name)
    ood_logits, _ = process_partially(model, ood_loader, only_logits=True)
    
    
    # id_test_logits, no temperature
    auroc, auprc = get_auroc_auprc(id_test_logits, ood_logits, measure='logsumexp', confidence=True)
    _logits, _lbls = binary_ood_id_repr(ood_logits, id_test_logits, return_1d=True)
    fpr90 = fpr_at_k_recall(_lbls.cpu().numpy(), _logits.cpu().numpy(), recall_level=0.9)
    print(f'no temperature, logsumexp, confidence=True, AUROC: {auroc*100:.2f}%')
    print(f'no temperature, logsumexp, confidence=True, AUPRC: {auprc*100:.2f}%')
    
    auroc, auprc = get_auroc_auprc(id_test_logits, ood_logits, measure='entropy')
    print(f'no temperature, entropy, confidence=False, AUROC: {auroc*100:.2f}%')
    print(f'no temperature, entropy, confidence=False, AUPRC: {auprc*100:.2f}%')
    print(f'FPR90: {fpr90*100:.2f}%')
    
    # id_test_logits with CE temperature
    auroc, auprc = get_auroc_auprc(t_ce_test_logits, ood_logits, measure='logsumexp', confidence=True)
    _logits, _lbls = binary_ood_id_repr(ood_logits, t_ce_test_logits, return_1d=True)
    fpr90 = fpr_at_k_recall(_lbls.cpu().numpy(), _logits.cpu().numpy(), recall_level=0.9)
    print(f'CE temperature, logsumexp, confidence=True, AUROC: {auroc*100:.2f}%')
    print(f'CE temperature, logsumexp, confidence=True, AUPRC: {auprc*100:.2f}%')
    
    auroc, auprc = get_auroc_auprc(t_ce_test_logits, ood_logits, measure='entropy')
    print(f'CE temperature, entropy, confidence=False, AUROC: {auroc*100:.2f}%')
    print(f'CE temperature, entropy, confidence=False, AUPRC: {auprc*100:.2f}%')
    print(f'FPR90: {fpr90*100:.2f}%')
    
    # id_test_logits with ECE temperature
    auroc, auprc = get_auroc_auprc(t_ece_test_logits, ood_logits, measure='logsumexp', confidence=True)
    _logits, _lbls = binary_ood_id_repr(ood_logits, t_ece_test_logits, return_1d=True)
    fpr90 = fpr_at_k_recall(_lbls.cpu().numpy(), _logits.cpu().numpy(), recall_level=0.9)
    print(f'ECE temperature, logsumexp, confidence=True, AUROC: {auroc*100:.2f}%')
    print(f'ECE temperature, logsumexp, confidence=True, AUPRC: {auprc*100:.2f}%')
    
    auroc, auprc = get_auroc_auprc(t_ece_test_logits, ood_logits, measure='entropy')
    print(f'ECE temperature, entropy, confidence=False, AUROC: {auroc*100:.2f}%')
    print(f'ECE temperature, entropy, confidence=False, AUPRC: {auprc*100:.2f}%')
    print(f'FPR90: {fpr90*100:.2f}%')
    
    # id_test_logits with TACE temperature
    auroc, auprc = get_auroc_auprc(t_tace_test_logits, ood_logits, measure='logsumexp', confidence=True)
    _logits, _lbls = binary_ood_id_repr(ood_logits, t_tace_test_logits, return_1d=True)
    fpr90 = fpr_at_k_recall(_lbls.cpu().numpy(), _logits.cpu().numpy(), recall_level=0.9)
    print(f'TACE temperature, logsumexp, confidence=True, AUROC: {auroc*100:.2f}%')
    print(f'TACE temperature, logsumexp, confidence=True, AUPRC: {auprc*100:.2f}%')
    
    auroc, auprc = get_auroc_auprc(t_tace_test_logits, ood_logits, measure='entropy')
    print(f'TACE temperature, entropy, confidence=False, AUROC: {auroc*100:.2f}%')
    print(f'TACE temperature, entropy, confidence=False, AUPRC: {auprc*100:.2f}%')
    print(f'FPR90: {fpr90*100:.2f}%')
    
    cat_logits, bin_labels = binary_ood_id_repr(ood_logits, id_test_logits)
    t_ce_cat_logits, _ = binary_ood_id_repr(ood_logits, t_ce_test_logits)
    t_ece_cat_logits, _ = binary_ood_id_repr(ood_logits, t_ece_test_logits)
    t_tace_cat_logits, _ = binary_ood_id_repr(ood_logits, t_tace_test_logits)
    
    # ECE for ood dataset
    ece = ece_criterion(cat_logits, bin_labels)
    ce_ece = ece_criterion(t_ce_cat_logits, bin_labels)
    ece_ece = ece_criterion(t_ece_cat_logits, bin_labels)
    tace_ece = ece_criterion(t_tace_cat_logits, bin_labels)

    print(f'ECE no temperature scaling: {ece:.4f}')
    print(f'ECE with CE temperature scaling: {ce_ece:.4f}')
    print(f'ECE with ECE temperature scaling: {ece_ece:.4f}')
    print(f'ECE with TACE temperature scaling: {tace_ece:.4f}')
    
    # TACE for ood dataset
    tace = tace_criterion(cat_logits, bin_labels)
    ce_tace = tace_criterion(t_ce_cat_logits, bin_labels)
    ece_tace = tace_criterion(t_ece_cat_logits, bin_labels)
    tace_tace = tace_criterion(t_tace_cat_logits, bin_labels)

    print(f'TACE no temperature scaling: {tace:.4f}')
    print(f'TACE with CE temperature scaling: {ce_tace:.4f}')
    print(f'TACE with ECE temperature scaling: {ece_tace:.4f}')
    print(f'TACE with TACE temperature scaling: {tace_tace:.4f}')
    print('------------------------------------------------')

### With Spectral Normalisation, NO GMM

In [None]:
params_trec_sn = {
    'vocab_size': len(tokenizer),
    'emb_dim': 1024,
    'n_layers': 1,
    'n_heads': 8,
    'forward_dim': 2048,
    'dropout': 0.5,
    'max_len': MAX_LEN,
    'pad_idx': tokenizer.pad_id(),
    'kind': 'sto_dual',
    'tau1': 1,
    'tau2': 20,
    'k_centroid': 2,
    'spectral': True,
    'n_classes': 50,
    'device': device
}

path = '../../../params/udl/sto_dual_trec_sn/'
model_path = path + 'model.pt'

model = TransformerEncoder(**params_trec_sn).to(device)
model.load_state_dict(torch.load(model_path))
model.eval();

In [None]:
ece_criterion = ECELoss(num_bins=15)
tace_criterion = TACELoss(num_bins=15)

In [None]:
id_test_logits, _ = process_partially(model, id_test_loader, only_logits=True) # shape: (test_set_size, num_classes)

In [None]:
# get best temperature
best_ce_temp, best_ece_temp, best_tace_temp = get_best_temperature(model, val_loader, num_tries=100, num_bins=15, log=True)
print('Best CE temperature:', best_ce_temp)
print('Best ECE temperature:', best_ece_temp)
print('Best TACE temperature:', best_tace_temp)

In [None]:
t_ce_test_logits = id_test_logits / best_ce_temp     # shape: (test_set_size, num_classes)
t_ece_test_logits = id_test_logits / best_ece_temp   # shape: (test_set_size, num_classes)
t_tace_test_logits = id_test_logits / best_tace_temp   # shape: (test_set_size, num_classes)

In [None]:
for ood_loader, name in zip(ood_loaders, ood_names):
    print('CURRENT OOD SET:', name)
    ood_logits, _ = process_partially(model, ood_loader, only_logits=True)
    
    
    # id_test_logits, no temperature
    auroc, auprc = get_auroc_auprc(id_test_logits, ood_logits, measure='logsumexp', confidence=True)
    _logits, _lbls = binary_ood_id_repr(ood_logits, id_test_logits, return_1d=True)
    fpr90 = fpr_at_k_recall(_lbls.cpu().numpy(), _logits.cpu().numpy(), recall_level=0.9)
    print(f'no temperature, logsumexp, confidence=True, AUROC: {auroc*100:.2f}%')
    print(f'no temperature, logsumexp, confidence=True, AUPRC: {auprc*100:.2f}%')
    
    auroc, auprc = get_auroc_auprc(id_test_logits, ood_logits, measure='entropy')
    print(f'no temperature, entropy, confidence=False, AUROC: {auroc*100:.2f}%')
    print(f'no temperature, entropy, confidence=False, AUPRC: {auprc*100:.2f}%')
    print(f'FPR90: {fpr90*100:.2f}%')
    
    # id_test_logits with CE temperature
    auroc, auprc = get_auroc_auprc(t_ce_test_logits, ood_logits, measure='logsumexp', confidence=True)
    _logits, _lbls = binary_ood_id_repr(ood_logits, t_ce_test_logits, return_1d=True)
    fpr90 = fpr_at_k_recall(_lbls.cpu().numpy(), _logits.cpu().numpy(), recall_level=0.9)
    print(f'CE temperature, logsumexp, confidence=True, AUROC: {auroc*100:.2f}%')
    print(f'CE temperature, logsumexp, confidence=True, AUPRC: {auprc*100:.2f}%')
    
    auroc, auprc = get_auroc_auprc(t_ce_test_logits, ood_logits, measure='entropy')
    print(f'CE temperature, entropy, confidence=False, AUROC: {auroc*100:.2f}%')
    print(f'CE temperature, entropy, confidence=False, AUPRC: {auprc*100:.2f}%')
    print(f'FPR90: {fpr90*100:.2f}%')
    
    # id_test_logits with ECE temperature
    auroc, auprc = get_auroc_auprc(t_ece_test_logits, ood_logits, measure='logsumexp', confidence=True)
    _logits, _lbls = binary_ood_id_repr(ood_logits, t_ece_test_logits, return_1d=True)
    fpr90 = fpr_at_k_recall(_lbls.cpu().numpy(), _logits.cpu().numpy(), recall_level=0.9)
    print(f'ECE temperature, logsumexp, confidence=True, AUROC: {auroc*100:.2f}%')
    print(f'ECE temperature, logsumexp, confidence=True, AUPRC: {auprc*100:.2f}%')
    
    auroc, auprc = get_auroc_auprc(t_ece_test_logits, ood_logits, measure='entropy')
    print(f'ECE temperature, entropy, confidence=False, AUROC: {auroc*100:.2f}%')
    print(f'ECE temperature, entropy, confidence=False, AUPRC: {auprc*100:.2f}%')
    print(f'FPR90: {fpr90*100:.2f}%')
    
    # id_test_logits with TACE temperature
    auroc, auprc = get_auroc_auprc(t_tace_test_logits, ood_logits, measure='logsumexp', confidence=True)
    _logits, _lbls = binary_ood_id_repr(ood_logits, t_tace_test_logits, return_1d=True)
    fpr90 = fpr_at_k_recall(_lbls.cpu().numpy(), _logits.cpu().numpy(), recall_level=0.9)
    print(f'TACE temperature, logsumexp, confidence=True, AUROC: {auroc*100:.2f}%')
    print(f'TACE temperature, logsumexp, confidence=True, AUPRC: {auprc*100:.2f}%')
    
    auroc, auprc = get_auroc_auprc(t_tace_test_logits, ood_logits, measure='entropy')
    print(f'TACE temperature, entropy, confidence=False, AUROC: {auroc*100:.2f}%')
    print(f'TACE temperature, entropy, confidence=False, AUPRC: {auprc*100:.2f}%')
    print(f'FPR90: {fpr90*100:.2f}%')
    
    cat_logits, bin_labels = binary_ood_id_repr(ood_logits, id_test_logits)
    t_ce_cat_logits, _ = binary_ood_id_repr(ood_logits, t_ce_test_logits)
    t_ece_cat_logits, _ = binary_ood_id_repr(ood_logits, t_ece_test_logits)
    t_tace_cat_logits, _ = binary_ood_id_repr(ood_logits, t_tace_test_logits)
    
    # ECE for ood dataset
    ece = ece_criterion(cat_logits, bin_labels)
    ce_ece = ece_criterion(t_ce_cat_logits, bin_labels)
    ece_ece = ece_criterion(t_ece_cat_logits, bin_labels)
    tace_ece = ece_criterion(t_tace_cat_logits, bin_labels)

    print(f'ECE no temperature scaling: {ece:.4f}')
    print(f'ECE with CE temperature scaling: {ce_ece:.4f}')
    print(f'ECE with ECE temperature scaling: {ece_ece:.4f}')
    print(f'ECE with TACE temperature scaling: {tace_ece:.4f}')
    
    # TACE for ood dataset
    tace = tace_criterion(cat_logits, bin_labels)
    ce_tace = tace_criterion(t_ce_cat_logits, bin_labels)
    ece_tace = tace_criterion(t_ece_cat_logits, bin_labels)
    tace_tace = tace_criterion(t_tace_cat_logits, bin_labels)

    print(f'TACE no temperature scaling: {tace:.4f}')
    print(f'TACE with CE temperature scaling: {ce_tace:.4f}')
    print(f'TACE with ECE temperature scaling: {ece_tace:.4f}')
    print(f'TACE with TACE temperature scaling: {tace_tace:.4f}')
    print('------------------------------------------------')

### With Spectral Normalisation, with GMM

In [None]:
params_trec_sn = {
    'vocab_size': len(tokenizer),
    'emb_dim': 1024,
    'n_layers': 1,
    'n_heads': 8,
    'forward_dim': 2048,
    'dropout': 0.5,
    'max_len': MAX_LEN,
    'pad_idx': tokenizer.pad_id(),
    'kind': 'sto_dual',
    'tau1': 1,
    'tau2': 20,
    'k_centroid': 2,
    'spectral': True,
    'n_classes': 50,
    'device': device
}

path = '../../../params/udl/sto_dual_trec_sn/'
model_path = path + 'model.pt'

model = TransformerEncoder(**params_trec_sn).to(device)
model.load_state_dict(torch.load(model_path))
model.eval();

In [None]:
ece_criterion = ECELoss(num_bins=15)
tace_criterion = TACELoss(num_bins=15)

In [None]:
train_embeds, train_lbls = process_partially(model, train_loader, return_embeddings=True)
gmm = fit_gmm(train_embeds, train_lbls, params_trec_sn['n_classes'])
id_gmm_logits, _ = evaluate_gmm(model, gmm, id_test_loader) 

# maybe GMM not meant for text? Has only been tested on images

In [None]:
for ood_loader, name in zip(ood_loaders, ood_names):
    print('CURRENT OOD SET:', name)
    ood_gmm_logits, _ = evaluate_gmm(model, gmm, ood_loader)
    
    # that's what they do in the paper with the GMM outputs
    # https://github.com/omegafragger/DDU/blob/main/evaluate.py#L201
    
    # id_test_logits, no temperature
    auroc, auprc = get_auroc_auprc(id_gmm_logits, ood_gmm_logits, measure='logsumexp', confidence=True)
    _logits, _lbls = binary_ood_id_repr(ood_gmm_logits, id_gmm_logits, return_1d=True)
    fpr90 = fpr_at_k_recall(_lbls.cpu().numpy(), _logits.cpu().numpy(), recall_level=0.9)
    print(f'no temperature, logsumexp, confidence=True, AUROC: {auroc*100:.2f}%')
    print(f'no temperature, logsumexp, confidence=True, AUPRC: {auprc*100:.2f}%')
    
    auroc, auprc = get_auroc_auprc(id_gmm_logits, ood_gmm_logits, measure='entropy')
    print(f'no temperature, entropy, confidence=False, AUROC: {auroc*100:.2f}%')
    print(f'no temperature, entropy, confidence=False, AUPRC: {auprc*100:.2f}%')
    print(f'FPR90 (meaningless due to only log_probs available after GMM): {fpr90*100:.2f}%')
    
    cat_logits, bin_labels = binary_ood_id_repr(ood_gmm_logits, id_gmm_logits)
    
    # ECE for ood dataset
    ece = ece_criterion(cat_logits, bin_labels)
    print(f'ECE no temperature scaling: {ece:.4f}')
    
    # TACE for ood dataset
    tace = tace_criterion(cat_logits, bin_labels)
    print(f'TACE no temperature scaling: {tace:.4f}')
    print('------------------------------------------------')

# In Domain Dataset: SST

In [None]:
BATCH_SIZE = 256
MAX_LEN = 512
NUM_WORKERS = 16

loader = Loader(batch_size=BATCH_SIZE, max_len=MAX_LEN, num_workers=NUM_WORKERS)
train_loader = loader.load('sst', 'train')
val_loader = loader.load('sst', 'val')
id_test_loader = loader.load('sst', 'test')
ood_snli_loader = loader.load('snli', 'test')
ood_imdb_loader = loader.load('imdb', 'test')
ood_m30k_loader = loader.load('m30k', 'test')
ood_wmt16_loader = loader.load('wmt16', 'test')
ood_yelp_loader = loader.load('yelp', 'test')

ood_loaders = [ood_snli_loader, ood_imdb_loader, ood_m30k_loader, ood_wmt16_loader, ood_yelp_loader]
ood_names = ['snli', 'imdb', 'm30k', 'wmt16', 'yelp']

### NO Spectral Normalisation, NO GMM 

In [None]:
params_sst = {
    'vocab_size': len(tokenizer),
    'emb_dim': 1024,
    'n_layers': 1,
    'n_heads': 8,
    'forward_dim': 2048,
    'dropout': 0.4,
    'max_len': MAX_LEN,
    'pad_idx': tokenizer.pad_id(),
    'kind': 'sto_dual',
    'tau1': 1,
    'tau2': 20,
    'k_centroid': 2,
    'spectral': False,
    'n_classes': 2,
    'device': device
}

path = '../../../params/udl/sto_dual_sst/'
model_path = path + 'model.pt'

model = TransformerEncoder(**params_sst).to(device)
model.load_state_dict(torch.load(model_path))
model.eval();

In [None]:
ece_criterion = ECELoss(num_bins=15)
tace_criterion = TACELoss(num_bins=15)

In [None]:
id_test_logits, _ = process_partially(model, id_test_loader, only_logits=True) # shape: (test_set_size, num_classes)

In [None]:
# get best temperature
best_ce_temp, best_ece_temp, best_tace_temp = get_best_temperature(model, val_loader, num_tries=100, num_bins=15, log=True)
print('Best CE temperature:', best_ce_temp)
print('Best ECE temperature:', best_ece_temp)
print('Best TACE temperature:', best_tace_temp)

In [None]:
t_ce_test_logits = id_test_logits / best_ce_temp     # shape: (test_set_size, num_classes)
t_ece_test_logits = id_test_logits / best_ece_temp   # shape: (test_set_size, num_classes)
t_tace_test_logits = id_test_logits / best_tace_temp   # shape: (test_set_size, num_classes)

In [None]:
for ood_loader, name in zip(ood_loaders, ood_names):
    print('CURRENT OOD SET:', name)
    ood_logits, _ = process_partially(model, ood_loader, only_logits=True)
    
    
    # id_test_logits, no temperature
    auroc, auprc = get_auroc_auprc(id_test_logits, ood_logits, measure='logsumexp', confidence=True)
    _logits, _lbls = binary_ood_id_repr(ood_logits, id_test_logits, return_1d=True)
    fpr90 = fpr_at_k_recall(_lbls.cpu().numpy(), _logits.cpu().numpy(), recall_level=0.9)
    print(f'no temperature, logsumexp, confidence=True, AUROC: {auroc*100:.2f}%')
    print(f'no temperature, logsumexp, confidence=True, AUPRC: {auprc*100:.2f}%')
    
    auroc, auprc = get_auroc_auprc(id_test_logits, ood_logits, measure='entropy')
    print(f'no temperature, entropy, confidence=False, AUROC: {auroc*100:.2f}%')
    print(f'no temperature, entropy, confidence=False, AUPRC: {auprc*100:.2f}%')
    print(f'FPR90: {fpr90*100:.2f}%')
    
    # id_test_logits with CE temperature
    auroc, auprc = get_auroc_auprc(t_ce_test_logits, ood_logits, measure='logsumexp', confidence=True)
    _logits, _lbls = binary_ood_id_repr(ood_logits, t_ce_test_logits, return_1d=True)
    fpr90 = fpr_at_k_recall(_lbls.cpu().numpy(), _logits.cpu().numpy(), recall_level=0.9)
    print(f'CE temperature, logsumexp, confidence=True, AUROC: {auroc*100:.2f}%')
    print(f'CE temperature, logsumexp, confidence=True, AUPRC: {auprc*100:.2f}%')
    
    auroc, auprc = get_auroc_auprc(t_ce_test_logits, ood_logits, measure='entropy')
    print(f'CE temperature, entropy, confidence=False, AUROC: {auroc*100:.2f}%')
    print(f'CE temperature, entropy, confidence=False, AUPRC: {auprc*100:.2f}%')
    print(f'FPR90: {fpr90*100:.2f}%')
    
    # id_test_logits with ECE temperature
    auroc, auprc = get_auroc_auprc(t_ece_test_logits, ood_logits, measure='logsumexp', confidence=True)
    _logits, _lbls = binary_ood_id_repr(ood_logits, t_ece_test_logits, return_1d=True)
    fpr90 = fpr_at_k_recall(_lbls.cpu().numpy(), _logits.cpu().numpy(), recall_level=0.9)
    print(f'ECE temperature, logsumexp, confidence=True, AUROC: {auroc*100:.2f}%')
    print(f'ECE temperature, logsumexp, confidence=True, AUPRC: {auprc*100:.2f}%')
    
    auroc, auprc = get_auroc_auprc(t_ece_test_logits, ood_logits, measure='entropy')
    print(f'ECE temperature, entropy, confidence=False, AUROC: {auroc*100:.2f}%')
    print(f'ECE temperature, entropy, confidence=False, AUPRC: {auprc*100:.2f}%')
    print(f'FPR90: {fpr90*100:.2f}%')
    
    # id_test_logits with TACE temperature
    auroc, auprc = get_auroc_auprc(t_tace_test_logits, ood_logits, measure='logsumexp', confidence=True)
    _logits, _lbls = binary_ood_id_repr(ood_logits, t_tace_test_logits, return_1d=True)
    fpr90 = fpr_at_k_recall(_lbls.cpu().numpy(), _logits.cpu().numpy(), recall_level=0.9)
    print(f'TACE temperature, logsumexp, confidence=True, AUROC: {auroc*100:.2f}%')
    print(f'TACE temperature, logsumexp, confidence=True, AUPRC: {auprc*100:.2f}%')
    
    auroc, auprc = get_auroc_auprc(t_tace_test_logits, ood_logits, measure='entropy')
    print(f'TACE temperature, entropy, confidence=False, AUROC: {auroc*100:.2f}%')
    print(f'TACE temperature, entropy, confidence=False, AUPRC: {auprc*100:.2f}%')
    print(f'FPR90: {fpr90*100:.2f}%')
    
    cat_logits, bin_labels = binary_ood_id_repr(ood_logits, id_test_logits)
    t_ce_cat_logits, _ = binary_ood_id_repr(ood_logits, t_ce_test_logits)
    t_ece_cat_logits, _ = binary_ood_id_repr(ood_logits, t_ece_test_logits)
    t_tace_cat_logits, _ = binary_ood_id_repr(ood_logits, t_tace_test_logits)
    
    # ECE for ood dataset
    ece = ece_criterion(cat_logits, bin_labels)
    ce_ece = ece_criterion(t_ce_cat_logits, bin_labels)
    ece_ece = ece_criterion(t_ece_cat_logits, bin_labels)
    tace_ece = ece_criterion(t_tace_cat_logits, bin_labels)

    print(f'ECE no temperature scaling: {ece:.4f}')
    print(f'ECE with CE temperature scaling: {ce_ece:.4f}')
    print(f'ECE with ECE temperature scaling: {ece_ece:.4f}')
    print(f'ECE with TACE temperature scaling: {tace_ece:.4f}')
    
    # TACE for ood dataset
    tace = tace_criterion(cat_logits, bin_labels)
    ce_tace = tace_criterion(t_ce_cat_logits, bin_labels)
    ece_tace = tace_criterion(t_ece_cat_logits, bin_labels)
    tace_tace = tace_criterion(t_tace_cat_logits, bin_labels)

    print(f'TACE no temperature scaling: {tace:.4f}')
    print(f'TACE with CE temperature scaling: {ce_tace:.4f}')
    print(f'TACE with ECE temperature scaling: {ece_tace:.4f}')
    print(f'TACE with TACE temperature scaling: {tace_tace:.4f}')
    print('------------------------------------------------')

### With Spectral Normalisation, NO GMM

In [None]:
params_sst_sn = {
    'vocab_size': len(tokenizer),
    'emb_dim': 1024,
    'n_layers': 1,
    'n_heads': 8,
    'forward_dim': 2048,
    'dropout': 0.3,
    'max_len': MAX_LEN,
    'pad_idx': tokenizer.pad_id(),
    'kind': 'sto_dual',
    'tau1': 1,
    'tau2': 20,
    'k_centroid': 3, 
    'spectral': True,
    'n_classes': 2,
    'device': device
}

path = '../../../params/udl/sto_dual_sst_sn/'
model_path = path + 'model.pt'

model = TransformerEncoder(**params_sst_sn).to(device)
model.load_state_dict(torch.load(model_path))
model.eval();

In [None]:
ece_criterion = ECELoss(num_bins=15)
tace_criterion = TACELoss(num_bins=15)

In [None]:
id_test_logits, _ = process_partially(model, id_test_loader, only_logits=True) # shape: (test_set_size, num_classes)

In [None]:
# get best temperature
best_ce_temp, best_ece_temp, best_tace_temp = get_best_temperature(model, val_loader, num_tries=100, num_bins=15, log=True)
print('Best CE temperature:', best_ce_temp)
print('Best ECE temperature:', best_ece_temp)
print('Best TACE temperature:', best_tace_temp)

In [None]:
t_ce_test_logits = id_test_logits / best_ce_temp     # shape: (test_set_size, num_classes)
t_ece_test_logits = id_test_logits / best_ece_temp   # shape: (test_set_size, num_classes)
t_tace_test_logits = id_test_logits / best_tace_temp   # shape: (test_set_size, num_classes)

In [None]:
for ood_loader, name in zip(ood_loaders, ood_names):
    print('CURRENT OOD SET:', name)
    ood_logits, _ = process_partially(model, ood_loader, only_logits=True)
    
    
    # id_test_logits, no temperature
    auroc, auprc = get_auroc_auprc(id_test_logits, ood_logits, measure='logsumexp', confidence=True)
    _logits, _lbls = binary_ood_id_repr(ood_logits, id_test_logits, return_1d=True)
    fpr90 = fpr_at_k_recall(_lbls.cpu().numpy(), _logits.cpu().numpy(), recall_level=0.9)
    print(f'no temperature, logsumexp, confidence=True, AUROC: {auroc*100:.2f}%')
    print(f'no temperature, logsumexp, confidence=True, AUPRC: {auprc*100:.2f}%')
    
    auroc, auprc = get_auroc_auprc(id_test_logits, ood_logits, measure='entropy')
    print(f'no temperature, entropy, confidence=False, AUROC: {auroc*100:.2f}%')
    print(f'no temperature, entropy, confidence=False, AUPRC: {auprc*100:.2f}%')
    print(f'FPR90: {fpr90*100:.2f}%')
    
    # id_test_logits with CE temperature
    auroc, auprc = get_auroc_auprc(t_ce_test_logits, ood_logits, measure='logsumexp', confidence=True)
    _logits, _lbls = binary_ood_id_repr(ood_logits, t_ce_test_logits, return_1d=True)
    fpr90 = fpr_at_k_recall(_lbls.cpu().numpy(), _logits.cpu().numpy(), recall_level=0.9)
    print(f'CE temperature, logsumexp, confidence=True, AUROC: {auroc*100:.2f}%')
    print(f'CE temperature, logsumexp, confidence=True, AUPRC: {auprc*100:.2f}%')
    
    auroc, auprc = get_auroc_auprc(t_ce_test_logits, ood_logits, measure='entropy')
    print(f'CE temperature, entropy, confidence=False, AUROC: {auroc*100:.2f}%')
    print(f'CE temperature, entropy, confidence=False, AUPRC: {auprc*100:.2f}%')
    print(f'FPR90: {fpr90*100:.2f}%')
    
    # id_test_logits with ECE temperature
    auroc, auprc = get_auroc_auprc(t_ece_test_logits, ood_logits, measure='logsumexp', confidence=True)
    _logits, _lbls = binary_ood_id_repr(ood_logits, t_ece_test_logits, return_1d=True)
    fpr90 = fpr_at_k_recall(_lbls.cpu().numpy(), _logits.cpu().numpy(), recall_level=0.9)
    print(f'ECE temperature, logsumexp, confidence=True, AUROC: {auroc*100:.2f}%')
    print(f'ECE temperature, logsumexp, confidence=True, AUPRC: {auprc*100:.2f}%')
    
    auroc, auprc = get_auroc_auprc(t_ece_test_logits, ood_logits, measure='entropy')
    print(f'ECE temperature, entropy, confidence=False, AUROC: {auroc*100:.2f}%')
    print(f'ECE temperature, entropy, confidence=False, AUPRC: {auprc*100:.2f}%')
    print(f'FPR90: {fpr90*100:.2f}%')
    
    # id_test_logits with TACE temperature
    auroc, auprc = get_auroc_auprc(t_tace_test_logits, ood_logits, measure='logsumexp', confidence=True)
    _logits, _lbls = binary_ood_id_repr(ood_logits, t_tace_test_logits, return_1d=True)
    fpr90 = fpr_at_k_recall(_lbls.cpu().numpy(), _logits.cpu().numpy(), recall_level=0.9)
    print(f'TACE temperature, logsumexp, confidence=True, AUROC: {auroc*100:.2f}%')
    print(f'TACE temperature, logsumexp, confidence=True, AUPRC: {auprc*100:.2f}%')
    
    auroc, auprc = get_auroc_auprc(t_tace_test_logits, ood_logits, measure='entropy')
    print(f'TACE temperature, entropy, confidence=False, AUROC: {auroc*100:.2f}%')
    print(f'TACE temperature, entropy, confidence=False, AUPRC: {auprc*100:.2f}%')
    print(f'FPR90: {fpr90*100:.2f}%')
    
    cat_logits, bin_labels = binary_ood_id_repr(ood_logits, id_test_logits)
    t_ce_cat_logits, _ = binary_ood_id_repr(ood_logits, t_ce_test_logits)
    t_ece_cat_logits, _ = binary_ood_id_repr(ood_logits, t_ece_test_logits)
    t_tace_cat_logits, _ = binary_ood_id_repr(ood_logits, t_tace_test_logits)
    
    # ECE for ood dataset
    ece = ece_criterion(cat_logits, bin_labels)
    ce_ece = ece_criterion(t_ce_cat_logits, bin_labels)
    ece_ece = ece_criterion(t_ece_cat_logits, bin_labels)
    tace_ece = ece_criterion(t_tace_cat_logits, bin_labels)

    print(f'ECE no temperature scaling: {ece:.4f}')
    print(f'ECE with CE temperature scaling: {ce_ece:.4f}')
    print(f'ECE with ECE temperature scaling: {ece_ece:.4f}')
    print(f'ECE with TACE temperature scaling: {tace_ece:.4f}')
    
    # TACE for ood dataset
    tace = tace_criterion(cat_logits, bin_labels)
    ce_tace = tace_criterion(t_ce_cat_logits, bin_labels)
    ece_tace = tace_criterion(t_ece_cat_logits, bin_labels)
    tace_tace = tace_criterion(t_tace_cat_logits, bin_labels)

    print(f'TACE no temperature scaling: {tace:.4f}')
    print(f'TACE with CE temperature scaling: {ce_tace:.4f}')
    print(f'TACE with ECE temperature scaling: {ece_tace:.4f}')
    print(f'TACE with TACE temperature scaling: {tace_tace:.4f}')
    print('------------------------------------------------')

### With Spectral Normalisation, with GMM

In [None]:
params_sst_sn = {
    'vocab_size': len(tokenizer),
    'emb_dim': 1024,
    'n_layers': 1,
    'n_heads': 8,
    'forward_dim': 2048,
    'dropout': 0.3,
    'max_len': MAX_LEN,
    'pad_idx': tokenizer.pad_id(),
    'kind': 'sto_dual',
    'tau1': 1,
    'tau2': 20,
    'k_centroid': 3, 
    'spectral': True,
    'n_classes': 2,
    'device': device
}

path = '../../../params/udl/sto_dual_sst_sn/'
model_path = path + 'model.pt'

model = TransformerEncoder(**params_sst_sn).to(device)
model.load_state_dict(torch.load(model_path))
model.eval();

In [None]:
ece_criterion = ECELoss(num_bins=15)
tace_criterion = TACELoss(num_bins=15)

In [None]:
train_embeds, train_lbls = process_partially(model, train_loader, return_embeddings=True)
gmm = fit_gmm(train_embeds, train_lbls, params_sst_sn['n_classes'])
id_gmm_logits, _ = evaluate_gmm(model, gmm, id_test_loader) 

# maybe GMM not meant for text? Has only been tested on images

In [None]:
for ood_loader, name in zip(ood_loaders, ood_names):
    print('CURRENT OOD SET:', name)
    ood_gmm_logits, _ = evaluate_gmm(model, gmm, ood_loader)
    
    # that's what they do in the paper with the GMM outputs
    # https://github.com/omegafragger/DDU/blob/main/evaluate.py#L201
    
    # id_test_logits, no temperature
    auroc, auprc = get_auroc_auprc(id_gmm_logits, ood_gmm_logits, measure='logsumexp', confidence=True)
    _logits, _lbls = binary_ood_id_repr(ood_gmm_logits, id_gmm_logits, return_1d=True)
    fpr90 = fpr_at_k_recall(_lbls.cpu().numpy(), _logits.cpu().numpy(), recall_level=0.9)
    print(f'no temperature, logsumexp, confidence=True, AUROC: {auroc*100:.2f}%')
    print(f'no temperature, logsumexp, confidence=True, AUPRC: {auprc*100:.2f}%')
    
    auroc, auprc = get_auroc_auprc(id_gmm_logits, ood_gmm_logits, measure='entropy')
    print(f'no temperature, entropy, confidence=False, AUROC: {auroc*100:.2f}%')
    print(f'no temperature, entropy, confidence=False, AUPRC: {auprc*100:.2f}%')
    print(f'FPR90 (meaningless due to only log_probs available after GMM): {fpr90*100:.2f}%')
    
    cat_logits, bin_labels = binary_ood_id_repr(ood_gmm_logits, id_gmm_logits)
    
    # ECE for ood dataset
    ece = ece_criterion(cat_logits, bin_labels)
    print(f'ECE no temperature scaling: {ece:.4f}')
    
    # TACE for ood dataset
    tace = tace_criterion(cat_logits, bin_labels)
    print(f'TACE no temperature scaling: {tace:.4f}')
    print('------------------------------------------------')