In [46]:
import torch
import torchvision
import matplotlib.pyplot as plt
from tqdm import tqdm
from torchvision import datasets, transforms
import general_models
import scipy
import numpy as np

In [2]:
transform = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.1307,), (0.3081,)),
                              ])

train_dataset = datasets.MNIST('./mnist/MNIST_data/', download=True, train=True, transform=transform)
test_dataset = datasets.MNIST('./mnist/MNIST_data/', download=True, train=False, transform=transform)

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1024, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=True)

In [3]:
def loss(y_hat, y):
    # y_hat has shape (N, 10) for 10 classes, y has shape (N,)
    f = torch.nn.CrossEntropyLoss()
    return f(y_hat, y)

In [4]:
def probabilties_from_scores(y):
    f = torch.nn.Softmax(dim=1)
    p = f(y)
    return p

In [12]:
scipy.special.erfinv(1-1/5)

0.9061938024368232

In [34]:
def evaluate(dataloader, model, device, log=False):
    correct_count, total_count = 0, 0
    for images_batch, labels_batch in dataloader:
        images = images_batch.reshape(images_batch.shape[0], 28*28)
        alphas, betas = model(images.to(device))
        probabilities = alphas/(alphas+betas)
        # probabilities = probabilties_from_scores(scores)

        pred_label = probabilities.max(1, keepdim=True)[1]
        correct_count += pred_label.eq(labels_batch.to(device).view_as(pred_label)).sum().item()
        total_count += labels_batch.shape[0]

    if log:
        print("Number Of Images Tested =", total_count)
        print("Model Accuracy =", (correct_count/total_count))

    return (correct_count/total_count)

In [27]:
class Net(torch.nn.Module):
    def __init__(self, layer_sizes, inference_type='regression'):
        super(Net, self).__init__()
        self.layers = torch.nn.ModuleList([torch.nn.Linear(layer_sizes[i], layer_sizes[i+1]) for i in range(len(layer_sizes)-2)])
        self.output_layer = torch.nn.Linear(layer_sizes[-2], layer_sizes[-1])
        self.inference_type = inference_type

    def forward(self, x):
        # x = self.input_instance_norm(x)
        for i in range(len(self.layers)):
            x = self.layers[i](x)
            x = torch.nn.functional.hardtanh_(x) + x
        x = self.output_layer(x)
        if self.inference_type == 'classification':
            x = torch.nn.functional.softmax(x,dim=-1)
        return x

class EvidentialNet(torch.nn.Module):
    def __init__(self, layer_sizes, inference_type='classification'):
        super(EvidentialNet, self).__init__()
        self.layers = torch.nn.ModuleList([torch.nn.Linear(layer_sizes[i], layer_sizes[i+1]) for i in range(len(layer_sizes)-2)])
        self.output_alpha_layer = torch.nn.Linear(layer_sizes[-2], layer_sizes[-1])
        self.output_beta_layer = torch.nn.Linear(layer_sizes[-2], layer_sizes[-1])
        self.inference_type = inference_type

    def forward(self, x):
        # x = self.input_instance_norm(x)
        for i in range(len(self.layers)):
            x = self.layers[i](x)
            x = torch.nn.functional.hardtanh_(x) + x
        alphas = torch.exp(self.output_alpha_layer(x))
        betas = torch.exp(self.output_beta_layer(x))
        return alphas, betas

In [28]:
def evidential_classification_loss(evidential_output, target):
    alphas, betas = evidential_output
    return loss(alphas/(alphas+betas), target)

In [6]:
device = torch.device("cuda:0" if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else "cpu"))

In [38]:
# model = general_models.FFNetwork([784, 100, 10], 'classification').to(device)
model = EvidentialNet([784, 200, 100, 10], 'classification').to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [39]:
epochs = 100
for epoch in range(epochs):
    running_loss = 0

    for images_batch, labels_batch in tqdm(train_dataloader):
        images_batch = images_batch.reshape(images_batch.shape[0], 28*28)   # Flatten MNIST images into a 784 long vector

        # forward pass
        y_hat = model(images_batch.to(device))
        y = labels_batch.to(device)
        L = evidential_classification_loss(y_hat, y)

        # backward pass
        optimizer.zero_grad()
        L.backward()

        # update parameters
        optimizer.step()

        running_loss += L.item()

    training_loss = running_loss/len(train_dataloader)
    #train_accuracy = evaluate(train_dataloader, model)
    test_accuracy = evaluate(test_dataloader, model, device)

    if epoch % 1 == 0:
        print("Epoch {} - Training loss: {}  Test Accuracy: {}".format(epoch, training_loss, test_accuracy))

    if test_accuracy > .98:
        break

100%|██████████| 59/59 [00:02<00:00, 23.06it/s]


Epoch 0 - Training loss: 1.6420034998554294  Test Accuracy: 0.9278


100%|██████████| 59/59 [00:02<00:00, 24.24it/s]


Epoch 1 - Training loss: 1.5281803001791745  Test Accuracy: 0.9457


100%|██████████| 59/59 [00:02<00:00, 23.50it/s]


Epoch 2 - Training loss: 1.5077944711103277  Test Accuracy: 0.9578


100%|██████████| 59/59 [00:02<00:00, 24.54it/s]


Epoch 3 - Training loss: 1.4964462700536696  Test Accuracy: 0.9642


100%|██████████| 59/59 [00:02<00:00, 24.34it/s]


Epoch 4 - Training loss: 1.4891036405401714  Test Accuracy: 0.969


100%|██████████| 59/59 [00:02<00:00, 22.94it/s]


Epoch 5 - Training loss: 1.483880826982401  Test Accuracy: 0.9707


100%|██████████| 59/59 [00:02<00:00, 24.85it/s]


Epoch 6 - Training loss: 1.4801020945532848  Test Accuracy: 0.9722


100%|██████████| 59/59 [00:02<00:00, 24.79it/s]


Epoch 7 - Training loss: 1.477302545208042  Test Accuracy: 0.9751


100%|██████████| 59/59 [00:02<00:00, 23.52it/s]


Epoch 8 - Training loss: 1.4747352963787015  Test Accuracy: 0.9744


100%|██████████| 59/59 [00:02<00:00, 25.40it/s]


Epoch 9 - Training loss: 1.4729705402406597  Test Accuracy: 0.9765


100%|██████████| 59/59 [00:02<00:00, 25.24it/s]


Epoch 10 - Training loss: 1.4714028936321453  Test Accuracy: 0.977


100%|██████████| 59/59 [00:02<00:00, 24.08it/s]


Epoch 11 - Training loss: 1.4699572668237202  Test Accuracy: 0.9781


100%|██████████| 59/59 [00:02<00:00, 25.53it/s]


Epoch 12 - Training loss: 1.469052157159579  Test Accuracy: 0.9792


100%|██████████| 59/59 [00:02<00:00, 25.53it/s]


Epoch 13 - Training loss: 1.4680548785096508  Test Accuracy: 0.9782


100%|██████████| 59/59 [00:02<00:00, 25.24it/s]


Epoch 14 - Training loss: 1.4673969179896984  Test Accuracy: 0.9796


100%|██████████| 59/59 [00:02<00:00, 25.26it/s]


Epoch 15 - Training loss: 1.4668430190975383  Test Accuracy: 0.9793


100%|██████████| 59/59 [00:02<00:00, 24.83it/s]


Epoch 16 - Training loss: 1.4663685438996654  Test Accuracy: 0.9807


In [40]:
for images_batch, labels_batch in test_dataloader:
    images = images_batch.reshape(images_batch.shape[0], 28*28)
    alphas, betas = model(images.to(device))

In [42]:
labels_batch

tensor([1, 3, 6,  ..., 7, 1, 6])

In [43]:
alphas[labels_batch]

tensor([[3.8439e-03, 9.2269e-04, 7.8416e-03,  ..., 1.4908e-03, 5.9892e-04,
         6.5439e-02],
        [1.2096e-02, 4.2533e-03, 3.3414e-04,  ..., 3.4290e-02, 2.5534e-03,
         9.6897e+01],
        [8.5815e+02, 5.2137e-03, 2.5098e-03,  ..., 6.1372e-02, 1.0391e-03,
         3.8280e-02],
        ...,
        [3.7767e-03, 3.2389e-03, 1.9796e-03,  ..., 3.1685e-03, 9.6762e-03,
         6.3980e-03],
        [3.8439e-03, 9.2269e-04, 7.8416e-03,  ..., 1.4908e-03, 5.9892e-04,
         6.5439e-02],
        [8.5815e+02, 5.2137e-03, 2.5098e-03,  ..., 6.1372e-02, 1.0391e-03,
         3.8280e-02]], device='mps:0', grad_fn=<IndexBackward0>)

In [67]:
alphas.gather(1,torch.tensor(labels_batch.view(-1,1)).to(device))

  alphas.gather(1,torch.tensor(labels_batch.view(-1,1)).to(device))


tensor([[ 547.4056],
        [ 942.0099],
        [ 924.6964],
        ...,
        [5358.8926],
        [ 138.7685],
        [ 568.7893]], device='mps:0', grad_fn=<GatherBackward0>)