In [1]:
import torch
from torch.functional import F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
import copy
import numpy as np
from torchvision import datasets, transforms

from sad_nns.uncertainty import *
from neurops import *

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
def squaredErrorBayesRisk(evidence, target):

    # calculate class probabilities
    alpha = evidence + 1.
    strength = alpha.sum(dim=-1)
    p = alpha / strength[:, None]

    # calculate error and variance
    err = (target - p) ** 2
    var = p * (1 - p) / (strength[:, None] + 1)

    # calculate loss
    loss = (err + var).sum(dim=-1)

    # mean loss over the batch
    return loss.mean()

In [5]:
model = ModSequential(
        ModConv2d(in_channels=1, out_channels=8, kernel_size=7, masked=True, padding=1, learnable_mask=True),
        ModConv2d(in_channels=8, out_channels=16, kernel_size=7, masked=True, padding=1, prebatchnorm=True, learnable_mask=True),
        ModConv2d(in_channels=16, out_channels=16, kernel_size=5, masked=True, prebatchnorm=True, learnable_mask=True),
        ModLinear(64, 32, masked=True, prebatchnorm=True, learnable_mask=True),
        ModLinear(32, 10, masked=True, prebatchnorm=True, nonlinearity=""),
        track_activations=True,
        track_auxiliary_gradients=True,
        input_shape = (1, 14, 14)
    ).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = squaredErrorBayesRisk

print("This model has {} effective parameters.".format(model.parameter_count(masked = True)))
print("The conversion factor of this model is {} after layer {}.".format(model.conversion_factor, model.conversion_layer))

This model has 15634 effective parameters.
The conversion factor of this model is 4 after layer 2.


In [6]:
def kl_divergence(evidence, target):
    # derive alpha from evidence
    alpha = evidence + 1.
    # obtain number of classes
    n_classes = evidence.shape[-1]
    # remove non-misleading evidence
    alpha_tilde = target + (1 - target) * alpha
    strength_tilde = alpha_tilde.sum(dim=-1)
    # first term
    first = (torch.lgamma(alpha_tilde.sum(dim=-1))
                - torch.lgamma(alpha_tilde.new_tensor(float(n_classes)))
                - (torch.lgamma(alpha_tilde)).sum(dim=-1))
    # second terms
    second = (
            (alpha_tilde - 1) *
            (torch.digamma(alpha_tilde) - torch.digamma(strength_tilde)[:, None])
    ).sum(dim=-1)

    loss = first + second

    return loss.mean()

In [7]:
edl_kl_divergence = KLDivergenceLoss()

In [8]:
dataset = datasets.MNIST('../data', train=True, download=True,
                     transform=transforms.Compose([ 
                            transforms.ToTensor(),
                            transforms.Normalize((0.1307,), (0.3081,)),
                            transforms.Resize((14,14))
                        ]))
train_set, val_set = torch.utils.data.random_split(dataset, lengths=[int(0.9*len(dataset)), int(0.1*len(dataset))])
train_loader = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=128, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.1307,), (0.3081,)),
                            transforms.Resize((14,14))
                        ])),
    batch_size=128, shuffle=True)

def train(model, train_loader, optimizer, criterion, epochs=10, num_classes=10, val_loader=None, verbose=True):
    model.train()
    for epoch in range(epochs):
        for batch_idx, (data, target) in enumerate(train_loader):

            data, target = data.to(device), target.to(device)
            target = F.one_hot(target, num_classes=num_classes)
            optimizer.zero_grad()
            output = model(data)
            # loss = criterion(output, target)
            

            # calculate uncertainty
            evidence = F.relu(output)
            alpha = evidence + 1
            u = num_classes / torch.sum(alpha, dim=1, keepdim=True)

            loss = criterion(evidence, target)

            # calculate KL Divergence
            kl_div_loss = kl_divergence(evidence, target)
            annealing_step = 10
            annealing_coef = torch.min(
                torch.tensor(1.0, dtype=torch.float32),
                torch.tensor(epoch / annealing_step, dtype=torch.float32),
            )
            
            loss = loss + annealing_coef * kl_div_loss

            loss.backward()
            optimizer.step()
            if batch_idx % 100 == 0 and verbose:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss.item()))
        if val_loader is not None:
            print("Validation: ", end = "")
            test(model, val_loader, criterion)

def test(model, test_loader, criterion, num_classes=10):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            one_hot_target = F.one_hot(target, num_classes=num_classes)
            output = model(data)
            # test_loss += criterion(output, target).item() # sum up batch loss
            test_loss += criterion(output, one_hot_target).item()
            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    
    print('Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format(test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [9]:
train(model, train_loader, optimizer, criterion, epochs=10, num_classes=10, val_loader=val_loader)



Validation: Average loss: 0.0051, Accuracy: 4904/6000 (81.73%)


KeyboardInterrupt: 