# Calibration using Temperature Scaling

In [40]:
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T
from tqdm import tqdm

from probly.calibration import Temperature
from probly.evaluation.metrics import expected_calibration_error

device = torch.device("cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

Using device: mps


## Load data

In [41]:
transforms = T.Compose([T.ToTensor()])
train = torchvision.datasets.CIFAR10(root="~/datasets", train=True, download=True, transform=transforms)
train, cal = torch.utils.data.random_split(train, [0.8, 0.2])
test = torchvision.datasets.CIFAR10(root="~/datasets", train=False, download=True, transform=transforms)
train_loader = DataLoader(train, batch_size=256, shuffle=True)
cal_loader = DataLoader(cal, batch_size=256, shuffle=True)
test_loader = DataLoader(test, batch_size=256, shuffle=False)

## Load neural network

In [42]:
net = torchvision.models.resnet18(pretrained=True)
net.fc = nn.Linear(512, 10, device=device)
net = net.to(device)

## Train neural network

In [43]:
epochs = 5
optimizer = optim.Adam(net.parameters())
criterion = nn.CrossEntropyLoss()
for epoch in tqdm(range(epochs)):
    net.train()
    running_loss = 0.0
    for inputs, targets in train_loader:
        optimizer.zero_grad()
        outputs = net(inputs.to(device))
        loss = criterion(outputs, targets.to(device))
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch + 1}, Running loss: {running_loss / len(train_loader)}")

# compute accuracy and expected calibration error on test set
net.eval()
with torch.no_grad():
    outputs = torch.empty(0, device=device)
    targets = torch.empty(0, device=device)
    for inpt, target in tqdm(test_loader):
        outputs = torch.cat((outputs, net(inpt.to(device))), dim=0)
        targets = torch.cat((targets, target.to(device)), dim=0)
outputs = F.softmax(outputs, dim=1)
correct = torch.sum(torch.argmax(outputs, dim=1) == targets).item()
total = targets.size(0)
ece = expected_calibration_error(outputs.cpu().numpy(), targets.cpu().numpy(), num_bins=10)
print(f"Accuracy: {correct / total}")
print(f"Expected Calibration Error: {ece}")

 20%|██        | 1/5 [00:15<01:00, 15.04s/it]

Epoch 1, Running loss: 0.9242285938019965


 40%|████      | 2/5 [00:30<00:45, 15.03s/it]

Epoch 2, Running loss: 0.562462279561219


 60%|██████    | 3/5 [00:45<00:30, 15.05s/it]

Epoch 3, Running loss: 0.4154979696699009


 80%|████████  | 4/5 [01:00<00:15, 15.03s/it]

Epoch 4, Running loss: 0.3231563491236632


100%|██████████| 5/5 [01:15<00:00, 15.03s/it]


Epoch 5, Running loss: 0.25008974437880666


100%|██████████| 40/40 [00:01<00:00, 26.32it/s]

Accuracy: 0.7667
Expected Calibration Error: 0.11227427605837584





## Use the temperature scaling class and fit temperature using the calibration set

In [44]:
model = Temperature(net).to(device)
model.train()
model.fit(cal_loader, learning_rate=0.01, max_iter=100)

100%|██████████| 40/40 [00:01<00:00, 24.88it/s]


In [45]:
# compute accuracy and expected calibration error on test set after temperature scaling
model.eval()
with torch.no_grad():
    outputs = torch.empty(0, device=device)
    targets = torch.empty(0, device=device)
    for inpt, target in tqdm(test_loader):
        outputs = torch.cat((outputs, model.predict_pointwise(inpt.to(device))), dim=0)
        targets = torch.cat((targets, target.to(device)), dim=0)
correct = torch.sum(torch.argmax(outputs, dim=1) == targets).item()
total = targets.size(0)
ece = expected_calibration_error(outputs.cpu().numpy(), targets.cpu().numpy(), num_bins=10)
print(f"Softmax temperature: {model.temperature.item()}")
print(f"Accuracy: {correct / total}")
print(f"Expected Calibration Error: {ece}")

100%|██████████| 40/40 [00:01<00:00, 25.39it/s]

Softmax temperature: 1.1454050540924072
Accuracy: 0.8066
Expected Calibration Error: 0.06147321143448353



