# Uncertainty Evaluation (Resnet18 from scratch on SVHN)

In [None]:
# %pip install torchinfo wandb # for Google Colab

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import torch
import torchvision
import wandb
from tqdm.autonotebook import trange

In [None]:
!wandb login

In [None]:
run = wandb.init(
    project="i2r-active-da",
    job_type="train",
    config={
        # parameters roughly similar to LIRR
        "optim": {
            "algorithm": "AdamW",
            "lr": 1e-3,
            "weight_decay": 5e-4,
        },
        "batch_size": 64,
        "num_epochs": 20,
        "num_workers": 1,
        "cuda_device": 0,
    },
)
config = run.config

## SVHN Dataset

In [None]:
from torchvision.datasets import SVHN
from torchvision.transforms import ToTensor

train_data = SVHN("data/svhn", download=True, split="train", transform=ToTensor())
train_loader = torch.utils.data.DataLoader(
    train_data,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=config.num_workers,
    pin_memory=True,
)

test_data = SVHN("data/svhn", download=True, split="test", transform=ToTensor())
test_loader = torch.utils.data.DataLoader(
    test_data,
    batch_size=config.batch_size,
    num_workers=config.num_workers,
    pin_memory=True,
)

In [None]:
inputs, targets = next(iter(train_loader))
plt.imshow(torch.einsum("cwh->whc", inputs[0]))
print(f"Label: {targets[0]}")

In [None]:
train_targets = []
for _, targets in list(train_loader):
    train_targets.extend(targets.tolist())

test_targets = []
for _, targets in list(test_loader):
    test_targets.extend(targets.tolist())

bins = range(10)
plt.xticks(bins)
plt.hist(train_targets, bins, alpha=0.5, label="train")
plt.hist(test_targets, bins, alpha=0.5, label="test")
plt.legend()
plt.show()

## Training Setup

In [None]:
torch.cuda.set_device(config.cuda_device)
device = torch.device(
    f"cuda:{config.cuda_device}" if torch.cuda.is_available() else "cpu"
)

In [None]:
from torch.nn import Linear
from torchinfo import summary
from torchvision.models import resnet18

model = resnet18(weights=None).to(device)
model.fc = Linear(in_features=model.fc.in_features, out_features=10, bias=True)

summary(model, input_size=(config.batch_size, 3, 32, 32), depth=1)

In [None]:
from pathlib import Path


def save_model(epoch):
    path = Path(f"results/{run.id}/")
    path.mkdir(parents=True, exist_ok=True)

    model_name = f"resnet18-svhn-{run.id}"
    torch.save(model, path / f"{model_name}-{epoch:03}.pt")

    artifact_model = wandb.Artifact(model_name, type="model")
    artifact_model.add_file(path / f"{model_name}-{epoch:03}.pt")
    run.log_artifact(artifact_model, aliases=["latest"])

In [None]:
from torch.nn import CrossEntropyLoss
from torch.optim import AdamW

criterion = CrossEntropyLoss().to(device)
optimizer = AdamW(
    model.parameters(), lr=config.optim["lr"], weight_decay=config.optim["weight_decay"]
)

In [None]:
from sklearn.metrics import top_k_accuracy_score
import time


def train_epoch(epoch):
    model.train()

    start = time.monotonic()

    losses = []
    outputs = []
    targets = []
    for input, target in train_loader:
        if device is not None:
            input = input.to(device)
            target = target.to(device)

        output = model(input)
        loss = criterion(output, target)

        losses.append(loss.item())
        outputs.extend(output.squeeze().tolist())
        targets.extend(target.tolist())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    end = time.monotonic()

    save_model(epoch)
    return {
        "time": end - start,
        "loss": sum(losses) / len(losses),
        "top1": top_k_accuracy_score(targets, outputs, k=1, labels=range(10)),
    }

In [None]:
def test_epoch():
    model.eval()

    start = time.monotonic()

    losses = []
    outputs = []
    targets = []
    for input, target in test_loader:
        if device is not None:
            input = input.to(device)
            target = target.to(device)

        output = model(input)
        loss = criterion(output, target)

        losses.append(loss.item())
        outputs.extend(output.squeeze().tolist())
        targets.extend(target.tolist())

    end = time.monotonic()

    return (
        outputs,
        targets,
        {
            "time": end - start,
            "loss": sum(losses) / len(losses),
            "top1": top_k_accuracy_score(targets, outputs, k=1, labels=range(10)),
        },
    )

### Uncertainty Metrics

In [None]:
def softmax(outputs):
    return np.exp(outputs) / np.sum(np.exp(outputs), axis=1).reshape(
        (outputs.shape[0], 1)
    )


def predictive_entropy(pss):
    return (-pss * np.log(pss)).sum(axis=1)


def margin(pss):
    inds = np.argsort(pss, axis=1)
    max2s = np.take_along_axis(pss, inds[:, -2:], axis=1)
    return max2s[:, 1] - max2s[:, 0]

In [None]:
pss = softmax(np.array([[1, 2, 3, 4], [2, 4, 3, 6]]))
print(pss)
print(predictive_entropy(pss))
print(margin(pss))

In [None]:
df_uncertainty = pd.DataFrame(
    columns=[
        "targets",
        "outputs",
        "probs",
        "is_correct",
        "predictive_entropy",
        "margin",
    ]
)

for epoch in trange(1, config.num_epochs + 1):
    train_metrics = train_epoch(epoch)

    outputs, targets, test_metrics = test_epoch()
    run.log(
        {
            "epoch": epoch,
            "train": train_metrics,
            "test": test_metrics,
        }
    )

    targets, outputs = np.array(targets), np.array(outputs)
    pss = softmax(outputs)
    cs = np.argmax(pss, axis=1) == targets
    pes = predictive_entropy(pss)
    ms = margin(pss)

    df_uncertainty.loc[epoch] = [targets, outputs, pss, cs, pes, ms]

In [None]:
df_uncertainty.loc[0]

In [None]:
path = Path(f"results/{run.id}/")
path.mkdir(parents=True, exist_ok=True)
df_uncertainty.to_pickle(path / "uncertainty.pickle")

In [None]:
artifact_df_uncertainty = wandb.Artifact("uncertainty_dataframe", type="dataframe")
artifact_df_uncertainty.add_file(path / f"uncertainty.pickle")
run.log_artifact(artifact_df_uncertainty)

In [None]:
run.finish()

## Analysis

In [None]:
df = df_uncertainty.drop(columns=["outputs", "probs"]).explode(
    ["targets", "is_correct", "predictive_entropy", "margin"]
)
df["epoch"] = df.index

In [None]:
fig = px.histogram(
    df,
    x="predictive_entropy",
    color="is_correct",
    opacity=0.5,
    animation_frame="epoch",
    histnorm="density",
)
fig.write_html(path / "predictive_entropy.html")
fig.show()

In [None]:
fig = px.histogram(
    df,
    x="margin",
    color="is_correct",
    opacity=0.5,
    animation_frame="epoch",
    histnorm="density",
)
fig.write_html(path / "margin.html")
fig.show()