In [1]:
# import libraries
from dataloader import get_dataloader_vae
from dataloader import get_dataloader_OOD
from models import get_trained_model
from energy import ELOOD

# Get model, dataloaders, and OOD detectors

In [2]:
# get model
model = get_trained_model('CIFAR10')

# get dataloaders
train_dl_c10, _ = get_dataloader_vae('CIFAR10')
train_dl_c100, _ = get_dataloader_vae('CIFAR100')
svhn_dl = get_dataloader_OOD('SVHN')
lsun_dl = get_dataloader_OOD('LSUN')

# get OOD detector
c10_ood = ELOOD(model)
c100_ood = ELOOD(model)

Files already downloaded and verified
Files already downloaded and verified
Using downloaded and verified file: data\test_32x32.mat


# Train OOD detector for CIFAR 10

In [3]:
c10_ood.train_ood(train_dl_c10)

# Get predictions for SVHN and LSUN

In [4]:
svhn_c10_ood = c10_ood.predict_ood(svhn_dl)
lsun_c10_ood = c10_ood.predict_ood(lsun_dl)

# Train OOD detector for CIFAR 100

In [5]:
c100_ood.train_ood(train_dl_c100)

# Get predictions for SVHN and LSUN

In [6]:
svhn_c100_ood = c100_ood.predict_ood(svhn_dl)
lsun_c100_ood = c100_ood.predict_ood(lsun_dl)

# Get metrics for all predictions

In [7]:
print("For VAE trained on CIFAR 10 =>")
print("For SVHN =>")
false_positive_rate_95, aupr, auroc = c10_ood.get_metrics(svhn_c10_ood)
print("FPR@TPR95 = ", false_positive_rate_95," AUPR = ", aupr," AUROC = ", auroc)
print("For LSUN =>")
false_positive_rate_95, aupr, auroc = c10_ood.get_metrics(lsun_c10_ood)
print("FPR@TPR95 = ", false_positive_rate_95," AUPR = ", aupr," AUROC = ", auroc)
print()
print("For VAE trained on CIFAR 100 =>")
print("For SVHN =>")
false_positive_rate_95, aupr, auroc = c100_ood.get_metrics(svhn_c100_ood)
print("FPR@TPR95 = ", false_positive_rate_95," AUPR = ", aupr," AUROC = ", auroc)
print("For LSUN =>")
false_positive_rate_95, aupr, auroc = c100_ood.get_metrics(lsun_c100_ood)
print("FPR@TPR95 = ", false_positive_rate_95," AUPR = ", aupr," AUROC = ", auroc)

For VAE trained on CIFAR 10 =>
For SVHN =>
FPR@TPR95 =  1.0  AUPR =  0.7306610288106328  AUROC =  0.6155885064535955
For LSUN =>
FPR@TPR95 =  0.5016  AUPR =  0.8320384589726484  AUROC =  0.7479

For VAE trained on CIFAR 100 =>
For SVHN =>
FPR@TPR95 =  1.0  AUPR =  0.7401213753785955  AUROC =  0.6271896127842655
For LSUN =>
FPR@TPR95 =  0.5018  AUPR =  0.8286607382550335  AUROC =  0.7432
