# Unified Evidential Training Example

To add training functionality to the probly package, we want to add a Unified Evidential Training function, that enables evidential models to be easily trained.

This notebook demonstrates how a Unified Evidential Training Function works.
It uses the `CIFAR10H-dataset` and the `EvidentialCELoss` loss-function, as introduced by _Sensoy et al. (2018)_.
The function `unified_evidential_train()` simulates, how the routine is going to look later on.

This notebook can be divided into 5 sections:

1. Imports & Setup
2. Data Preparation
3. Model Definition
4. Unified Evidential Training Function
5. Starting Training Loop

### 1. Imports & Setup
- **torch** → building neural networks
- **torchvision** → used to convert images to tensors and normalize them
- **CIFAR10H** → our dataset class from probly library
- **EvidentialCELoss** → evidential cross-entropy loss from Sensory et al. (2018)

This imports everything your Unified Evidential Training Function will also depend on:
datasets, losses, and the PyTorch core tools.
Right now we do it manually but the unified function will handle this internally later on.

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10

# --- Patch CIFAR10H to work without cifar10h-counts.npy ---
from probly.datasets.torch import CIFAR10H as OriginalCIFAR10H  # noqa: N811
from probly.train.evidential.torch import EvidentialCELoss


# class so that we don't download .npy files
class CIFAR10H(OriginalCIFAR10H):
    def __init__(self, root: str, transform=None, download: bool = False) -> None:  # noqa: ANN001, D107
        # use normal CIFAR10 as fallback
        CIFAR10.__init__(self, root=root, train=True, transform=transform, download=download)
        # targets are just normal label-integer
        self.targets = torch.tensor(self.targets, dtype=torch.long)

### 2. Data Preparation
This simulates how the function will handle datasets. It'll prepare them by e.g. converting images to tensors and normalizing pixels.

In [12]:
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ],
)

train_data = CIFAR10H(root="./data", transform=transform, download=True)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)

print(f"Loaded CIFAR10H dataset with {len(train_data)} samples.")

Loaded CIFAR10H dataset with 50000 samples.


### 3. Model Definition
This is an example of a small Convolutional Neural Network (CNN), that produces evidence values instead of softmax probabilities. Later on we can also use models from `probly.layers` if we want to.
Our unified function will be able to train such a model with the corresponding evidential loss.

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10) -> None:  # noqa: ANN001, D107
        super().__init__()
        self.fc1 = nn.Linear(32 * 32 * 3, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x) -> None:  # noqa: ANN001, D102
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        return F.softplus(self.fc2(x))  # use of softplus so that our output is always positive


model = SimpleCNN()

### 4. Unified Evidential Training Function
In this part, we create the heart of our notebook...the Unified Evidential Train Function.
Its takes a model and a bunch of other parameters in, that the user can costumize before running.
After starting, it creates a training loop for evidential deep learning in PyTorch based on our given parameters (e.g. model, loss-function, epochs...).

In [20]:
def unified_evidential_train(model, dataloader, loss_fn, epochs=5, lr=1e-3, device="cuda") -> None:  # noqa: ANN001
    """Demonstration of a unified evidential training function."""
    model = model.to(device)  # moves the model to the correct device (GPU or CPU)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # repeats the training function for a defined number of epochs
    for epoch in range(epochs):
        model.train()  # call of train important for models like dropout
        total_loss = 0.0  # track total_loss to calculate average loss per epoch

        for x, y in dataloader:
            # CIFAR10H returns distribution (y.shape = [batch, num_classes])
            # handle both cases: distributions (CIFAR10H original) or integer labels (fallback)
            x = x.to(device)  # noqa: PLW2901
            y = y.to(device)  # noqa: PLW2901
            if y.ndim > 1:
                y = y.argmax(dim=1)  # if its a distribution  # noqa: PLW2901

            optimizer.zero_grad()  # clears old gradients
            outputs = model(x)  # computes model-outputs
            loss = loss_fn(outputs, y)  # calculate the evidential loss
            loss.backward()  # backpropagation
            optimizer.step()  # updates model-parameters

            total_loss += loss.item()  # add-up the loss of this epoch ontop of our total loss till then

        avg_loss = total_loss / len(dataloader)  # calculate average loss per epoch across all batches
        print(f"Epoch [{epoch + 1}/{epochs}] - Loss: {avg_loss:.4f}")

### 5. Starting Training Loop
In this part, we are executing the training funtion and starting the training loop, after we initiated the model and the loss function.

In [21]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = SimpleCNN(num_classes=10)  # model that gets trained
loss_fn = EvidentialCELoss()  # loss function that will be customizable later on

unified_evidential_train(
    model=model,
    dataloader=train_loader,
    loss_fn=loss_fn,
    epochs=5,
    lr=1e-3,
    device=device,
)

Epoch [1/5] - Loss: 1.8176
Epoch [2/5] - Loss: 1.6396
Epoch [3/5] - Loss: 1.5612
Epoch [4/5] - Loss: 1.5054
Epoch [5/5] - Loss: 1.4614


And by that we created our own "Unified" Evidential Training Function.
Later on we will initiate our unified_evidential_train() function as a singledispatch_traverser and the code code we wrote would be in the `torch.py`, as it only reprentates the implementation for PyTorch-Tensors.