In [1]:
import torch
import torch.nn as nn

In [17]:
def doubleConv(in_channels, out_channels, kernels=3, pad=1):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=kernels, padding=pad),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=kernels, padding=pad),
        nn.ReLU(inplace=True)
    )

In [18]:
class UNet(nn.Module):
    
    def __init__(self, input_channels, n_classes):
        super().__init__()
        
        self.downConv1 = doubleConv(input_channels, 64)
        self.downConv2 = doubleConv(64, 128)
        self.downConv3 = doubleConv(128, 256)
        self.downConv4 = doubleConv(256, 512)
        
        self.maxPoolLayer = nn.MaxPool2d(2) # kernel size will reduce both dimension by half
        self.upsampleLayer = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        
        self.upConv3 = doubleConv(256+512, 256)
        self.upConv2 = doubleConv(128+256, 128)
        self.upConv1 = doubleConv(64+128, 64)
        
        self.finalConv = doubleConv(64, n_classes, kernels=1, pad=0) # final 1x1 kernel conv layer with depth=n_classes
        
    def forward(self, x):
        
        conv_r1 = self.downConv1(x)      #Conv Layer 1,2
        x = self.maxPoolLayer(conv_r1)   #Pool Layer 3
        
        conv_r2 = self.downConv2(x)      #Conv Layer 4,5
        x = self.maxPoolLayer(conv_r2)   #Pool Layer 6
        
        conv_r3 = self.downConv3(x)      #Conv Layer 7,8
        x = self.maxPoolLayer(conv_r3)   #Conv Layer 9
        
        x = self.downConv4(x)            #Conv Layer 10,11
        
        x = self.upsampleLayer(x)        #Upsampling Layer 12
        
        x = torch.cat([conv_r3, x], dim=1) # concatenate
        x = self.upConv3(x)              # Conv Layer 13,14
        
        x = self.upsampleLayer(x)        # Upsampling Layer 15
        
        x = torch.cat([conv_r2, x], dim=1) 
        x = self.upConv2(x)              # Conv Layer 16,17
        
        x = self.upsampleLayer(x)        # Upsampling Layer 18
        
        x = torch.cat([conv_r1, x], dim=1) 
        x = self.upConv1(x)              # Conv Layer 19,20
        
        out = self.finalConv(x)
        return out

In [21]:
class CityscapesUnet(nn.Module): #Wrapper Unet class for dataset classes & softmax
    
    def __init__(self, input_channels):
        super.__init__()
        
        n_classes=34          #Fixed for cityscapes dataset
        self.Unet = UNet(input_channels, 34)
        self.softmaxLayer = nn.Softmax(dim=1)  # softmax on depth of tensor
        
    def forward(self, X):
        
        X = self.softmaxLayer(self.Unet(X))
        return X

In [5]:
test_tensor = torch.randn((1, 3, 128, 256))
test_tensor.shape

torch.Size([1, 3, 128, 256])

In [19]:
# input channels, num_classes
learner = UNet(3, 34)
# learner = learner.to(device)

In [5]:
learner.parameters()

<generator object Module.parameters at 0x0000016A50736148>

In [20]:
op = learner(test_tensor)
op.shape

torch.Size([1, 34, 128, 256])

In [9]:
test_tensor2 = torch.randn((1, 3, 4, 4))*1087
test_tensor2

tensor([[[[  4.2098,   9.6302,   4.9260,  11.6340],
          [ -5.3136,  13.0683,  -7.0615,  -9.5064],
          [  4.6390,   9.8135,   2.2820,  -5.2707],
          [ -5.9274,  -2.8761,   1.1233,  -1.4763]],

         [[  9.5644,  10.8344,  -2.6058,  14.3671],
          [-11.2367,  -3.7074,   2.1452,   4.0693],
          [-11.9275,  32.0497,   2.4363,  12.6367],
          [ -7.2436,   5.8231, -15.9325,   6.6507]],

         [[ -5.8009,  -0.8185,  -3.3179,   6.3987],
          [  7.2884,  -0.1063,   9.9496,  11.0945],
          [ -7.6305,   4.9629, -10.1745, -10.4062],
          [  4.2417,  -4.9165,   0.9865,  -0.9399]]]])

In [12]:
m = nn.Softmax(dim=1)
op = m(test_tensor2)
op

tensor([[[[4.7041e-03, 2.3073e-01, 9.9920e-01, 6.1031e-02],
          [3.3650e-06, 1.0000e+00, 4.0925e-08, 1.1291e-09],
          [1.0000e+00, 2.2026e-10, 4.6150e-01, 1.6708e-08],
          [3.8335e-05, 1.6669e-04, 5.3415e-01, 2.9521e-04]],

         [[9.9530e-01, 7.6926e-01, 5.3532e-04, 9.3864e-01],
          [9.0084e-09, 5.1807e-08, 4.0779e-04, 8.8841e-04],
          [6.3860e-08, 1.0000e+00, 5.3850e-01, 1.0000e+00],
          [1.0280e-05, 9.9981e-01, 2.0913e-08, 9.9920e-01]],

         [[2.1128e-07, 6.6878e-06, 2.6266e-04, 3.2500e-04],
          [1.0000e+00, 1.8981e-06, 9.9959e-01, 9.9911e-01],
          [4.6925e-06, 1.7231e-12, 1.7964e-06, 9.8306e-11],
          [9.9995e-01, 2.1665e-05, 4.6585e-01, 5.0476e-04]]]])

In [16]:
_, op_idx = torch.max(op, dim=1)
op_idx

tensor([[[1, 1, 0, 1],
         [2, 0, 2, 2],
         [0, 1, 1, 1],
         [2, 1, 0, 1]]])