In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
from scipy.optimize import linprog
from tqdm import tqdm

In [2]:
transform = transforms.Compose([
    transforms.ToTensor()
])

train_dataset = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=transform
)

test_dataset = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=transform
)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader  = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

In [3]:
class LeNet5(nn.Module):
    def __init__(self, num_classes=10):
        super(LeNet5, self).__init__()

        # Original LeNet-5 uses tanh; we keep ReLU for better training stability.
        # If you need the exact classical version, I can give that too.

        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0)   # 28 → 24
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0)  # 12 → 8
    
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)

    def forward(self, x):
        # Feature extractor
        x = F.relu(self.conv1(x))   # (batch, 6, 24, 24)
        x = F.max_pool2d(x, 2)      # → (batch, 6, 12, 12)
        
        x = F.relu(self.conv2(x))   # (batch, 16, 8, 8)
        x = F.max_pool2d(x, 2)      # → (batch, 16, 4, 4)

        # Flatten
        x = x.view(x.size(0), -1)

        # Classifier
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))

        logits = self.fc3(x)        # Raw logits → perfect for temp scaling
        return logits


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = LeNet5(num_classes=10).to(device)

# Loss function
criterion = nn.CrossEntropyLoss()

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

Using device: cpu


In [5]:
num_epochs = 10  # you can change this
best_acc = 0.0
best_model_path = "best_model_fashion_pth"

for epoch in range(num_epochs):
    model.train()  # set model to training mode
    running_loss = 0.0
    correct = 0
    total = 0


    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()          # reset gradients
        logits = model(images)         # forward pass
        loss = criterion(logits, labels)
        loss.backward()                # backward pass
        optimizer.step()               # update weights

        running_loss += loss.item() * images.size(0)
        _, predicted = logits.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / len(train_loader.dataset)
    epoch_acc = correct / total
    
    print(f"Epoch [{epoch+1}/{num_epochs}]  Loss: {epoch_loss:.4f}  Acc: {epoch_acc:.4f}")
    if epoch_acc > best_acc: 
        best_acc = epoch_acc
        torch.save(model.state_dict(), best_model_path)

Epoch [1/10]  Loss: 0.7348  Acc: 0.7150
Epoch [2/10]  Loss: 0.4787  Acc: 0.8206
Epoch [3/10]  Loss: 0.4008  Acc: 0.8522
Epoch [4/10]  Loss: 0.3594  Acc: 0.8677
Epoch [5/10]  Loss: 0.3320  Acc: 0.8781
Epoch [6/10]  Loss: 0.3100  Acc: 0.8871
Epoch [7/10]  Loss: 0.2931  Acc: 0.8932
Epoch [8/10]  Loss: 0.2804  Acc: 0.8963
Epoch [9/10]  Loss: 0.2685  Acc: 0.9009
Epoch [10/10]  Loss: 0.2564  Acc: 0.9044


In [6]:
model.load_state_dict(torch.load(best_model_path, map_location=device))
model.to(device)
model.eval()

LeNet5(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=256, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

In [51]:
print(test_loader.dataset.classes)
for image, labels  in test_loader:
    print(labels)
    
    break

['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
tensor([9, 2, 1, 1, 6, 1, 4, 6, 5, 7, 4, 5, 7, 3, 4, 1, 2, 4, 8, 0, 2, 5, 7, 9,
        1, 4, 6, 0, 9, 3, 8, 8, 3, 3, 8, 0, 7, 5, 7, 9, 6, 1, 3, 7, 6, 7, 2, 1,
        2, 2, 4, 4, 5, 8, 2, 2, 8, 4, 8, 0, 7, 7, 8, 5])


In [52]:
import torch
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.eval()  # evaluation mode

for images, labels in test_loader:
    images = images.to(device)

    with torch.no_grad():  # no gradients needed
        logits = model(images)            # raw outputs from the model
        probs = F.softmax(logits, dim=1)  # convert logits to probabilities

    val_probs = probs
    print("Probabilities for first batch:\n", probs)
    break  # only first batch

Probabilities for first batch:
 tensor([[1.1039e-06, 3.6344e-09, 2.6412e-08, 1.6688e-08, 4.6046e-11, 6.2828e-04,
         2.1235e-10, 1.5064e-03, 1.1998e-08, 9.9786e-01],
        [8.8461e-06, 4.6992e-09, 9.9929e-01, 1.5312e-07, 2.8559e-04, 5.5941e-12,
         4.1592e-04, 2.9225e-13, 8.4764e-07, 2.1216e-10],
        [5.1711e-08, 1.0000e+00, 8.1623e-09, 1.6715e-07, 7.6110e-09, 5.6345e-12,
         5.2706e-07, 3.1448e-15, 1.2978e-07, 2.4247e-16],
        [1.0091e-07, 9.9999e-01, 1.1655e-08, 2.0174e-06, 8.5603e-07, 1.0720e-09,
         5.9012e-06, 5.4706e-12, 1.6368e-07, 5.1195e-13],
        [3.9880e-02, 3.8795e-07, 6.7004e-03, 7.1930e-04, 7.9000e-04, 2.6957e-07,
         9.5181e-01, 3.8908e-09, 9.7684e-05, 3.1690e-06],
        [3.3558e-07, 1.0000e+00, 1.6691e-07, 3.6650e-08, 6.1572e-08, 4.0042e-12,
         2.4279e-07, 2.4434e-15, 1.3446e-07, 1.4237e-15],
        [4.4643e-05, 9.8564e-07, 5.8992e-02, 1.4096e-05, 8.4181e-01, 1.8687e-07,
         9.9108e-02, 4.6034e-09, 2.9046e-05, 7.9662e-

In [64]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# iterate over test_loader (or take one batch)
for images, labels in test_loader:
    images, labels = images.to(device), labels.to(device)  # move to device
    
    with torch.no_grad():  # no need to compute gradients for inference
        outputs = model(images)              # forward pass
        _, predicted = torch.max(outputs, 1)  # get predicted class index

    print("Predicted labels (indices):", predicted)
    print("True labels (indices)     :", labels)
    break

Predicted labels (indices): tensor([9, 2, 1, 1, 6, 1, 4, 6, 5, 7, 4, 5, 5, 3, 4, 1, 2, 4, 8, 0, 2, 5, 7, 5,
        1, 4, 6, 0, 9, 4, 8, 8, 3, 3, 8, 0, 7, 5, 7, 9, 0, 1, 6, 9, 6, 7, 2, 1,
        2, 6, 4, 2, 5, 8, 2, 2, 8, 4, 8, 0, 7, 7, 8, 5])
True labels (indices)     : tensor([9, 2, 1, 1, 6, 1, 4, 6, 5, 7, 4, 5, 7, 3, 4, 1, 2, 4, 8, 0, 2, 5, 7, 9,
        1, 4, 6, 0, 9, 3, 8, 8, 3, 3, 8, 0, 7, 5, 7, 9, 6, 1, 3, 7, 6, 7, 2, 1,
        2, 2, 4, 4, 5, 8, 2, 2, 8, 4, 8, 0, 7, 7, 8, 5])


In [65]:
save_test = []
for images, labels in test_loader: 
    save_test.append((images, labels))

In [55]:
to_test_images = save_test[2][0]
to_test_labels = save_test[2][1]
with torch.no_grad():  # no gradients needed
    logits = model(to_test_images)            # raw outputs from the model
    neu_probs = F.softmax(logits, dim=1)  

In [56]:
neu_probs

tensor([[2.7172e-06, 9.9997e-01, 9.6320e-08, 1.0118e-05, 3.9442e-07, 6.6034e-10,
         2.1184e-05, 3.4099e-12, 2.7518e-07, 7.0176e-13],
        [6.8104e-06, 8.3936e-08, 1.1518e-07, 4.3987e-09, 5.4663e-10, 9.9975e-01,
         1.2813e-07, 2.3185e-04, 2.2901e-06, 5.0626e-06],
        [1.4861e-04, 3.2108e-05, 3.6833e-02, 3.7079e-05, 9.6044e-01, 3.5373e-07,
         2.1980e-03, 3.7740e-08, 3.0238e-04, 6.1710e-06],
        [3.9315e-08, 9.9999e-01, 1.9230e-08, 3.3926e-07, 1.3222e-06, 3.5385e-10,
         2.9974e-06, 5.8609e-12, 4.8709e-07, 1.0130e-13],
        [1.2231e-06, 5.6793e-09, 6.7388e-09, 2.0214e-08, 1.0369e-11, 2.0731e-03,
         3.8526e-10, 8.5274e-04, 8.9060e-10, 9.9707e-01],
        [1.3435e-05, 9.9996e-01, 4.2168e-07, 7.1040e-06, 1.0235e-06, 1.4769e-09,
         1.9753e-05, 1.0577e-12, 1.6836e-06, 8.6818e-13],
        [4.6155e-04, 3.7513e-05, 8.8781e-05, 2.9047e-06, 2.4899e-05, 1.6834e-06,
         3.7662e-04, 2.4130e-06, 9.9897e-01, 2.8674e-05],
        [1.3226e-04, 2.5597

In [57]:
"""Implementation of temperature scaling in torch."""

# Suppose val_probs, val_labels

'Implementation of temperature scaling in torch.'

In [76]:
#val_logits = torch.log(val_probs + 1e-12)  # pseudo logits
val_logits = logits
val_labels = labels


class TempScaler(nn.Module):
    def __init__(self):
        super().__init__()
        self.temperature = nn.Parameter(torch.ones(1))
        
    def forward(self, logits):
        return logits / self.temperature

def calibrate_temperature_grid(logits, labels, temp_min=0.5, temp_max=5.0, num_steps=100):
    """
    Find optimal temperature T using grid search
    """
    scaler = TempScaler()  # use original module
    best_temp = 1.0
    best_loss = float('inf')
    labels = labels.long()
    
    temperatures = torch.linspace(temp_min, temp_max, num_steps)
    
    for T in temperatures:
        # improve the scaler's temperature
        scaler.temperature.data.fill_(T)
        scaled_logits = scaler.forward(logits)
        loss = F.cross_entropy(scaled_logits, labels)
        if loss < best_loss:
            best_loss = loss
            best_temp = T.item()
    
    # set final optimal temperature
    scaler.temperature.data.fill_(best_temp)
    print(best_temp)
    return scaler


def calibrate_temperature(logits, labels, max_iter=50, lr=0.01):
    scaler = TempScaler()
    optimizer = torch.optim.LBFGS([scaler.temperature], lr=lr, max_iter=max_iter)

    def closure():
        optimizer.zero_grad()
        scaled_logits = scaler(logits)
        loss = nn.functional.cross_entropy(scaled_logits, labels)
        loss.backward()
        return loss

    optimizer.step(closure)
    return scaler


def apply_temperature_scaling_probs(probs, scaler):
    pseudo_logits = torch.log(probs + 1e-12)       # convert probabilities to pseudo logits
    scaled_logits = scaler(pseudo_logits)          # divide by learned temperature
    calibrated_probs = torch.softmax(scaled_logits, dim=1)
    return calibrated_probs

In [78]:
test_scaler = calibrate_temperature_grid(val_logits, val_labels)
calibrated_probs = apply_temperature_scaling_probs(neu_probs, test_scaler)

1.0909091234207153


In [79]:
test_scaler_2 = calibrate_temperature(val_logits, val_labels)
calibrated_probs_2 = apply_temperature_scaling_probs(neu_probs, test_scaler_2)

In [80]:
ROUND_DECIMALS = 3  # Number of decimals to round probabilities to when computing coverage, efficiency, etc.


def expected_calibration_error(probs: np.ndarray, labels: np.ndarray, num_bins: int = 10) -> float:
    """Compute the expected calibration error (ECE) of the predicted probabilities :cite:`guoOnCalibration2017`.

    Args:
        probs: The predicted probabilities as an array of shape (n_instances, n_classes).
        labels: The true labels as an array of shape (n_instances,).
        num_bins: The number of bins to use for the calibration error calculation.

    Returns:
        ece: The expected calibration error.
    """
    confs = np.max(probs, axis=1)
    preds = np.argmax(probs, axis=1)
    bins = np.linspace(0, 1, num_bins + 1, endpoint=True)
    bin_indices = np.digitize(confs, bins, right=True) - 1
    num_instances = probs.shape[0]
    ece = 0
    for i in range(num_bins):
        _bin = np.where(bin_indices == i)[0]
        # check if bin is empty
        if _bin.shape[0] == 0:
            continue
        acc_bin = np.mean(preds[_bin] == labels[_bin])
        conf_bin = np.mean(confs[_bin])
        weight = _bin.shape[0] / num_instances
        ece += weight * np.abs(acc_bin - conf_bin)
    return float(ece)

In [81]:
print(expected_calibration_error(neu_probs.cpu().numpy(), val_labels.cpu().numpy()))
print(expected_calibration_error(calibrated_probs.detach().cpu().numpy(), val_labels.cpu().numpy()))

0.7252609780989587
0.7330209827050567


In [82]:
print(expected_calibration_error(neu_probs.cpu().numpy(), val_labels.cpu().numpy()))
print(expected_calibration_error(calibrated_probs_2.detach().cpu().numpy(), val_labels.cpu().numpy()))

0.7252609780989587
0.7225449196994305
