# Uncertainty Evaluation (Resnet18 from scratch on SVHN)

In [None]:
# for Google Colab
# %pip install torchinfo wandb

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import torch
import torchvision
import wandb
from tqdm.autonotebook import trange

In [None]:
!wandb login

In [None]:
run = wandb.init(
    project="i2r-active-da-01_uncertainty",
    job_type="train",
    config={
        # parameters roughly similar to LIRR
        "optim": {
            "algorithm": "AdamW",
            "lr": 1e-3,
            "weight_decay": 5e-4,
        },
        "batch_size": 64,
        "num_epochs": 15,
        "cuda_device": 0,
        "architecture": "resnet34-cosc",
        "train_dataset": "svhn",
        "test_dataset": "mnist",
        "seed": 123456,
    },
)
config = run.config

run.tags += (config.architecture,)
if config.train_dataset == config.test_dataset:
    run.tags += (config.train_dataset,)
else:
    run.tags += (f"{config.train_dataset}-to-{config.test_dataset}",)

## SVHN-to-MNIST Dataset

In [None]:
from collections import Counter
import random

random.seed(config.seed)


def get_balanced_ids(dataset):
    class_nums = Counter((label for (_, label) in dataset))
    _, class_num = class_nums.most_common()[-1]

    ids_labels = [(i, label) for i, (_, label) in enumerate(dataset)]
    random.shuffle(ids_labels)

    class_nums = Counter()
    for i, label in ids_labels:
        if class_nums[label] < class_num:
            class_nums.update([label])
            yield i

In [None]:
from torch.utils.data import Subset
from torchvision.datasets import MNIST, SVHN
from torchvision.transforms import Compose, Grayscale, Resize, ToTensor

gs_to_tensor = Compose([Grayscale(3), Resize((32, 32)), ToTensor()])

match config.train_dataset:
    case "svhn":
        train_data = SVHN(
            "data/svhn", download=True, split="train", transform=ToTensor()
        )
    case "mnist":
        train_data = MNIST(
            "data/mnist", download=True, train=True, transform=gs_to_tensor
        )

train_loader = torch.utils.data.DataLoader(
    Subset(train_data, list(get_balanced_ids(train_data))),
    batch_size=config.batch_size,
    shuffle=True,
    pin_memory=True,
)

match config.test_dataset:
    case "svhn":
        test_data = SVHN("data/svhn", download=True, split="test", transform=ToTensor())
    case "mnist":
        test_data = MNIST(
            "data/mnist", download=True, train=False, transform=gs_to_tensor
        )

test_loader = torch.utils.data.DataLoader(
    Subset(test_data, list(get_balanced_ids(test_data))),
    batch_size=config.batch_size,
    pin_memory=True,
)

In [None]:
train_inputs, train_targets = next(iter(train_loader))
plt.imshow(torch.einsum("cwh->whc", train_inputs[0]))

print(train_inputs.size())
print(f"Label: {train_targets[0]}")

In [None]:
test_inputs, test_targets = next(iter(test_loader))
plt.imshow(torch.einsum("cwh->whc", test_inputs[0]))
assert train_inputs.size() == test_inputs.size()

print(test_inputs.size())
print(f"Label: {test_targets[0]}")

In [None]:
from collections import Counter

train_targets = Counter()
for _, targets in list(train_loader):
    train_targets.update(targets.tolist())

test_targets = Counter()
for _, targets in list(test_loader):
    test_targets.update(targets.tolist())

print(f"Train: {sorted(train_targets.items())}")
print(f"Test: {sorted(test_targets.items())}")

## Training Setup

In [None]:
torch.cuda.set_device(config.cuda_device)
device = torch.device(
    f"cuda:{config.cuda_device}" if torch.cuda.is_available() else "cpu"
)

In [None]:
from torch.nn import CosineSimilarity, Linear, Module, Sequential
from torchinfo import summary
from torchvision.models import resnet18, resnet34


class CosC(Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.fc = Linear(**kwargs, bias=False)
        self.cos = CosineSimilarity(dim=1)

    def forward(self, x):
        x = x.reshape(x.size() + (1,))
        w_t = self.fc.weight.T.reshape((1,) + self.fc.weight.T.size())
        return self.cos(x, w_t)


match config.architecture:
    case "resnet18":
        model = resnet18(weights=None).to(device)
        model.fc = Linear(in_features=model.fc.in_features, out_features=10, bias=True)
    case "resnet18-cosc":
        model = resnet18(weights=None).to(device)
        model.fc = CosC(in_features=model.fc.in_features, out_features=10)
    case "resnet18-cosc-dim_32":
        model = resnet18(weights=None).to(device)
        model.fc = Sequential(
            Linear(in_features=model.fc.in_features, out_features=32, bias=True),
            CosC(in_features=32, out_features=10),
        )
    case "resnet18-cosc-dim_10":
        model = resnet18(weights=None).to(device)
        model.fc = Sequential(
            Linear(in_features=model.fc.in_features, out_features=10, bias=True),
            CosC(in_features=10, out_features=10),
        )
    case "resnet34-cosc":
        model = resnet34(weights=None).to(device)
        model.fc = CosC(in_features=model.fc.in_features, out_features=10)

summary(model, input_size=(config.batch_size, 3, 32, 32), depth=1)

In [None]:
def save_model(epoch):
    if config.train_dataset == config.test_dataset:
        model_name = f"{config.architecture}-{config.train_dataset}"
    else:
        model_name = (
            f"{config.architecture}-{config.train_dataset}-to-{config.test_dataset}"
        )
    torch.save(model, f"{run.dir}/{model_name}-{epoch:03}.pt")
    wandb.save(f"{run.dir}/{model_name}-{epoch:03}.pt", base_path=run.dir)

In [None]:
from torch.nn import CrossEntropyLoss
from torch.optim import AdamW

criterion = CrossEntropyLoss().to(device)
optimizer = AdamW(
    model.parameters(), lr=config.optim["lr"], weight_decay=config.optim["weight_decay"]
)

In [None]:
from sklearn.metrics import top_k_accuracy_score
import time


def train_epoch():
    model.train()

    start = time.monotonic()

    losses = []
    outputs = []
    targets = []
    for input, target in train_loader:
        if device is not None:
            input = input.to(device)
            target = target.to(device)

        output = model(input)
        loss = criterion(output, target)

        losses.append(loss.item())
        outputs.extend(output.squeeze().tolist())
        targets.extend(target.tolist())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    end = time.monotonic()

    return {
        "time": end - start,
        "loss": sum(losses) / len(losses),
        "top1": top_k_accuracy_score(targets, outputs, k=1, labels=range(10)),
    }

In [None]:
def test_epoch():
    model.eval()

    start = time.monotonic()

    losses = []
    outputs = []
    targets = []
    for input, target in test_loader:
        if device is not None:
            input = input.to(device)
            target = target.to(device)

        output = model(input)
        loss = criterion(output, target)

        losses.append(loss.item())
        outputs.extend(output.squeeze().tolist())
        targets.extend(target.tolist())

    end = time.monotonic()

    return (
        outputs,
        targets,
        {
            "time": end - start,
            "loss": sum(losses) / len(losses),
            "top1": top_k_accuracy_score(targets, outputs, k=1, labels=range(10)),
        },
    )

### Uncertainty Metrics

In [None]:
def softmax(outputs):
    return np.exp(outputs) / np.sum(np.exp(outputs), axis=1).reshape(
        (outputs.shape[0], 1)
    )


def predictive_entropy(pss):
    return (-pss * np.log(pss)).sum(axis=1)


def margin(pss):
    inds = np.argsort(pss, axis=1)
    max2s = np.take_along_axis(pss, inds[:, -2:], axis=1)
    return max2s[:, 1] - max2s[:, 0]

In [None]:
pss = softmax(np.array([[1, 2, 3, 4], [2, 4, 3, 6]]))
print(pss)
print(predictive_entropy(pss))
print(margin(pss))

In [None]:
save_model(0)

df_epochs = []
for epoch in trange(1, config.num_epochs + 1):
    train_metrics = train_epoch()
    save_model(epoch)
    outputs, targets, test_metrics = test_epoch()

    targets, outputs = np.array(targets), np.array(outputs)
    pss = softmax(outputs)
    cs = np.argmax(pss, axis=1) == targets
    pes = predictive_entropy(pss)
    ms = margin(pss)

    df_epoch = pd.DataFrame(
        data=zip([epoch] * len(targets), targets, outputs, pss, cs, pes, ms),
        columns=[
            "epoch",
            "target",
            "output",
            "probs",
            "is_correct",
            "predictive_entropy",
            "margin",
        ],
    )
    df_epochs.append(df_epoch)

    run.log(
        {
            "epoch": epoch,
            "train": train_metrics,
            "test": test_metrics,
            "uncertainty_metrics": wandb.Table(data=df_epoch),
        }
    )

In [None]:
df = pd.concat(df_epochs)
df.to_pickle(f"{run.dir}/df_uncertainty_metrics.pickle")
wandb.save(f"{run.dir}/df_uncertainty_metrics.pickle", base_path=run.dir)
df

## Analysis

In [None]:
fig = px.histogram(
    df,
    x="predictive_entropy",
    color="is_correct",
    animation_frame="epoch",
    barmode="overlay",
    histnorm="probability",
    # range_x=(2.15, 2.3),
    range_y=(0, 1),
)
fig.update_traces(xbins_size=0.005)
fig.update_layout(width=600, height=600)
fig.write_html(f"{run.dir}/hist_predictive_entropy.html", auto_play=False)
fig.show()

In [None]:
fig = px.histogram(
    df,
    x="margin",
    color="is_correct",
    barmode="overlay",
    animation_frame="epoch",
    histnorm="probability",
    # range_x=(0, 0.2),
    range_y=(0, 1),
)
fig.update_traces(xbins_size=0.01)
fig.update_layout(width=600, height=600)
fig.write_html(f"{run.dir}/hist_margin.html", auto_play=False)
fig.show()

In [None]:
table = wandb.Table(columns=["predictive_entropy", "margin"])
table.add_data(
    wandb.Html(f"{run.dir}/hist_predictive_entropy.html"),
    wandb.Html(f"{run.dir}/hist_margin.html"),
)
run.log({"uncertainty_figs": table})

if input("Finish the WandB run? [Y/n] ").strip().lower() == "y":
    run.finish()