In [45]:
import math

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.distributions import Normal
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchsummary import summary
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

## Evidential Convolutional Layer 
https://github.com/Dariusrussellkish/evidential-deep-learning/blob/pytorch_implementation/evidential_deep_learning/pytorch/layers/conv2d.py

In [115]:
class Conv2DNormalGamma(nn.Module):
    def __init__(self, in_channels, out_tasks=1, kernel_size=(1, 1), **kwargs):
        super(Conv2DNormalGamma, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_tasks
        self.conv = nn.Conv2d(in_channels, 4 * out_tasks, kernel_size, **kwargs)

    def forward(self, x):
        output = self.conv(x)
        print(output.size())
        if len(x.shape) == 3:
            gamma, lognu, logalpha, logbeta = torch.split(output, self.out_channels, dim=0)
        else:
            gamma, lognu, logalpha, logbeta = torch.split(output, self.out_channels, dim=1)

        nu = F.softplus(lognu)
        alpha = F.softplus(logalpha) + 1.
        beta = F.softplus(logbeta)
        return torch.stack([gamma, nu, alpha, beta], dim=1).squeeze().to(x.device)
    
#layer = Conv2DNormalGamma(1)
#x = torch.rand(64, 1, 28, 28)
#layer(x).shape

## Net

In [16]:
def double_conv(in_c, out_c, kernel_size, padding):
    conv = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=kernel_size, padding=padding),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, kernel_size=kernel_size, padding=padding),
            nn.ReLU(inplace=True)
            )
    return conv

In [61]:
def crop_tensor(tensor, target_tensor):
    target_size = target_tensor.size()[2]
    tensor_size = tensor.size()[2]
    delta = tensor_size - target_size
    delta1 = math.ceil(delta / 2)
    delta2 = math.floor(delta / 2)
    return tensor[:, :, delta1:tensor_size-delta2, delta1:tensor_size-delta2]

In [87]:
def zero_pad_to_match(tensor, target_tensor):
    target_size = target_tensor.size()[2]
    tensor_size = tensor.size()[2]
    delta = target_size - tensor_size
    delta1 = math.ceil(delta / 2)
    delta2 = math.floor(delta / 2)
    return F.pad(tensor, (delta1, delta2, delta1, delta2))

In [118]:
class EvidentialRegression(nn.Module):
    def __init__(self, num_class=1):
        super(EvidentialRegression, self).__init__()
        
        kernel_size = 3
        padding = 1 # To maintain size using (3, 3) filters
        
        self.max_pool_2x2 = nn.MaxPool2d((2, 2))
        self.upsample_2x2 = nn.Upsample(scale_factor=2)
        self.down_conv_1 = double_conv(1, 32, kernel_size, padding)
        self.down_conv_2 = double_conv(32, 64, kernel_size, padding)
        self.down_conv_3 = double_conv(64, 128, kernel_size, padding)
        self.down_conv_4 = double_conv(128, 256, kernel_size, padding)
        self.down_conv_5 = double_conv(256, 512, kernel_size, padding)
        
        self.up_conv_1 = double_conv(768, 256, kernel_size, padding)
        self.up_conv_2 = double_conv(384, 128, kernel_size, padding)
        self.up_conv_3 = double_conv(192, 64, kernel_size, padding)
        self.up_conv_4 = double_conv(96, 32, kernel_size, padding)
        self.up_conv_5 = nn.Conv2d(32, 4*num_class, kernel_size=1)
        
        self.evidential_layer = Conv2DNormalGamma(4, out_tasks=num_class)
        
    def forward(self, x0):
        # Shape of x should be (Batch, Channels, H, W)
        x1 = self.down_conv_1(x0)
        x2 = self.max_pool_2x2(x1)
        x3 = self.down_conv_2(x2)
        x4 = self.max_pool_2x2(x3)
        x5 = self.down_conv_3(x4)
        x6 = self.max_pool_2x2(x5)
        x7 = self.down_conv_4(x6)
        x8 = self.max_pool_2x2(x7)
        x9 = self.down_conv_5(x8)
        
        x = self.upsample_2x2(x9)
        y = crop_tensor(x7, x)
        x = self.up_conv_1(torch.cat([x, y], 1))
        
        x = self.upsample_2x2(x)
        y = crop_tensor(x5, x)
        x = self.up_conv_2(torch.cat([x, y], 1))
        
        x = self.upsample_2x2(x)
        y = crop_tensor(x3, x)
        x = self.up_conv_3(torch.cat([x, y], 1))
        
        x = self.upsample_2x2(x)
        y = crop_tensor(x1, x)
        x = self.up_conv_4(torch.cat([x, y], 1))

        x = zero_pad_to_match(x, x0)
        x = F.relu(self.up_conv_5(x))
        x = self.evidential_layer(x)
        return x
        
model = EvidentialRegression()

In [117]:
summary(model, (1, 28, 28))

torch.Size([2, 4, 28, 28])
torch.Size([2, 1, 28, 28])
torch.Size([2, 1, 28, 28])
torch.Size([2, 1, 28, 28])
torch.Size([2, 1, 28, 28])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 28, 28]             320
              ReLU-2           [-1, 32, 28, 28]               0
            Conv2d-3           [-1, 32, 28, 28]           9,248
              ReLU-4           [-1, 32, 28, 28]               0
         MaxPool2d-5           [-1, 32, 14, 14]               0
            Conv2d-6           [-1, 64, 14, 14]          18,496
              ReLU-7           [-1, 64, 14, 14]               0
            Conv2d-8           [-1, 64, 14, 14]          36,928
              ReLU-9           [-1, 64, 14, 14]               0
        MaxPool2d-10             [-1, 64, 7, 7]               0
           Conv2d-11            [-1, 128, 7, 7]          73,856
             ReLU-12            

In [12]:
conv = nn.Conv2d(3, 2, kernel_size=(3, 3), padding=1)
x = torch.rand(32, 3, 28, 28)
conv(x).shape

torch.Size([32, 2, 28, 28])