In [None]:
import torch 

import numpy as np

from model_helper import *
from data_helper import *
from ATC_helper import *
from predict_acc_helper import *
from calibration import *

# Directory with datasets. Refer to README.md to setup the data directory.
data_dir = "path/to/data/directory"
data = "CIFAR"
batch_size = 200 
net_type = "ResNet18"

# Directory with trained model. Refer to README.md to train a model.
model_checkpoint = "path/to/model/checkpoint"

device = "cuda" if torch.cuda.is_available() else "cpu"

## Load Dataset 
_, _, _ , testloaders = get_data(data_dir, data, batch_size, net_type)


net = get_net(net_type, data)
net = net.to(device)

net.load_state_dict(torch.load(model_checkpoint))

## Get ID validation data probs and labels 
val_probs, val_labels = save_probs(net, testloaders[0], device)

# Optional calibration 
calibrator = TempScaling()
calibrator.fit(val_probs, val_labels)
val_probs = calibrator.calibrate(val_probs)


## Get ID test data probs. Test labels to get true accuracy
test_probs, test_labels = save_probs(net, testloaders[1], device)
test_probs = calibrator.calibrate(test_probs)

## score function, e.g., negative entropy or argmax confidence 
val_scores = get_entropy(val_probs)
val_preds = np.argmax(val_probs ,axis=1)

test_scores = get_entropy(test_probs)

ATC_thres = find_ATC_threshold(val_scores, val_labels==val_preds)
ATC_accuracy = get_ATC_acc(ATC_thres, test_scores)

print(f"True Accuracy {100*np.mean(np.argmax(test_probs, axis=-1) == test_labels)}")
print(f"ATC predicted accuracy {ATC_accuracy}")


