In [1]:
import torch
from torch.functional import F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
import copy
import numpy as np
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

from sad_nns.uncertainty import *
from neurops import *

In [3]:
# set seed
torch.manual_seed(0)

<torch._C.Generator at 0x7efe63f81730>

In [15]:
# model = ModSequential(
#         ModConv2d(in_channels=3, out_channels=8, kernel_size=7, masked=True, padding=1, learnable_mask=True),
#         ModConv2d(in_channels=8, out_channels=26, kernel_size=7, masked=True, padding=1, prebatchnorm=True, learnable_mask=True),
#         ModConv2d(in_channels=26, out_channels=20, kernel_size=7, masked=True, prebatchnorm=True, learnable_mask=True),
#         ModConv2d(in_channels=20, out_channels=10, kernel_size=7, masked=True, prebatchnorm=True, learnable_mask=True),
#         ModLinear(1440, 256, masked=True, prebatchnorm=True, learnable_mask=True),
#         ModLinear(256, 10, masked=True, prebatchnorm=True, nonlinearity=""),
#         track_activations=True,
#         track_auxiliary_gradients=True,
#         input_shape=(3, 32, 32)  # Adjusted for CIFAR10
# ).to(device)
# torch.compile(model)
# optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

model = ModSequential(
        ModConv2d(in_channels=3, out_channels=16, kernel_size=3, masked=True, padding=1, learnable_mask=True),
        ModConv2d(in_channels=16, out_channels=32, kernel_size=3, masked=True, padding=1, prebatchnorm=True, learnable_mask=True),
        ModConv2d(in_channels=32, out_channels=64, kernel_size=3, masked=True, padding=1, prebatchnorm=True, learnable_mask=True),
        ModLinear(1024, 120, masked=True, prebatchnorm=True, learnable_mask=True),
        ModLinear(120, 60, masked=True, prebatchnorm=True, nonlinearity=""),
        ModLinear(60, 10, masked=True, prebatchnorm=True, nonlinearity=""),
        track_activations=True,
        track_auxiliary_gradients=True,
        input_shape = (3, 32, 32)
    ).to(device)
torch.compile(model)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# Add EDL Loss Function
# KLDivergenceLoss, MaximumLikelihoodLoss, CrossEntropyBayesRisk, SquaredErrorBayesRisk
criterion = SquaredErrorBayesRisk()
kl_divergence = KLDivergenceLoss()

In [5]:
# dataset = datasets.MNIST('../data', train=True, download=True,
#                      transform=transforms.Compose([ 
#                             transforms.ToTensor(),
#                             transforms.Normalize((0.1307,), (0.3081,)),
#                             transforms.Resize((14,14))
#                         ]))

dataset = datasets.CIFAR10('../data', train=True, download=True,
                     transform=transforms.Compose([ 
                            transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                            transforms.Resize((32,32))
                        ]))

train_set, val_set = torch.utils.data.random_split(dataset, lengths=[int(0.9*len(dataset)), int(0.1*len(dataset))])
train_loader = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=128, shuffle=True)

# test_loader = torch.utils.data.DataLoader(
#     datasets.MNIST('../data', train=False, transform=transforms.Compose([
#                             transforms.ToTensor(),
#                             transforms.Normalize((0.1307,), (0.3081,)),
#                             transforms.Resize((14,14))
#                         ])),
#     batch_size=128, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('../data', train=False, transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                            transforms.Resize((32,32))
                        ])),
    batch_size=128, shuffle=True)

Files already downloaded and verified


In [6]:
def train(model, train_loader, optimizer, criterion, epochs=10, num_classes=10, val_loader=None, verbose=True, plot=False, return_vals=False):
    model.train()

    train_acc_vals = []
    train_u_vals = []
    test_acc_vals = []
    test_u_vals = []

    for epoch in range(epochs):
        correct = 0
        for batch_idx, (data, target) in enumerate(train_loader):

            data, target = data.to(device), target.to(device)
            one_hot_target = F.one_hot(target, num_classes=num_classes)
            optimizer.zero_grad()
            output = model(data)
            
            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

            # calculate uncertainty
            evidence = F.relu(output)
            alpha = evidence + 1
            u = num_classes / torch.sum(alpha, dim=1, keepdim=True)

            loss = criterion(evidence, one_hot_target)

            # calculate KL Divergence
            kl_div_loss = kl_divergence(evidence, one_hot_target)
            annealing_step = 10
            annealing_coef = torch.min(
                torch.tensor(1.0, dtype=torch.float32),
                torch.tensor(epoch / annealing_step, dtype=torch.float32),
            )
            
            loss = loss + annealing_coef * kl_div_loss

            loss.backward()
            optimizer.step()
            if batch_idx % 100 == 0 and verbose:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tUncertainty: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss.item(), u.mean()))
                
        train_acc, train_u = correct / len(train_loader.dataset), u.mean().detach().numpy().item()
        train_acc_vals.append(train_acc)
        train_u_vals.append(train_u)
        
        if val_loader is not None:
            print("Validation: ", end = "")
            test_acc, test_u = test(model, val_loader, criterion, return_vals=True)
            test_acc_vals.append(test_acc)
            test_u_vals.append(test_u)

    if plot:
        # Plotting the lines
        plt.plot(np.arange(epochs), train_acc_vals, label='Training Accuracy', color='blue')
        plt.plot(np.arange(epochs), train_u_vals, label='Training Uncertainty', color='purple')
        plt.plot(np.arange(epochs), test_acc_vals, label='Testing Accuracy', color='orange')
        plt.plot(np.arange(epochs), test_u_vals, label='Testing Uncertainty', color='red')
        plt.xticks(np.arange(0, epochs, 1))

        # Adding titles and labels
        plt.title('Accuracy vs. Uncertainty')
        plt.xlabel('Epoch')
        plt.ylabel('Value')
        plt.legend()

        # Show the plot
        plt.show()
    elif return_vals:
        return train_acc, train_u, test_acc, test_u

def test(model, test_loader, criterion, num_classes=10, return_vals=False):
    model.eval()
    test_loss = 0
    correct = 0
    uncertainties = []

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            one_hot_target = F.one_hot(target, num_classes=num_classes)
            output = model(data)

            evidence = F.relu(output)
            alpha = evidence + 1
            u = num_classes / torch.sum(alpha, dim=1, keepdim=True)
            uncertainties.append(u.mean())

            # test_loss += criterion(output, target).item() # sum up batch loss
            test_loss += criterion(output, one_hot_target).item()
            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    avg_u = np.mean(uncertainties)
    
    print('Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%), Average Uncertainty: {:.4f}'.format(test_loss, correct, len(test_loader.dataset),
        accuracy, avg_u))
    
    if return_vals:
        return accuracy / 100, avg_u

In [16]:
train(model, train_loader, optimizer, criterion, epochs=10, val_loader=val_loader)

RuntimeError: running_mean should contain 65536 elements not 1024