In [None]:
#  Importing!

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.autograd import Variable

from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from PIL import Image
from torch.utils.data import Dataset, DataLoader


In [None]:
# custom weights initialization
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.01)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.01)
        nn.init.constant_(m.bias.data, 0)

Defining Depthwise Separable (DS) Convolution And Dilated Residual Dense Block (DRDB)

In [None]:
class DSconv(nn.Module):
  def __init__(in_channels, out_channels,DSkernel,DSstride, DSpadding, dilation = True):
    super().__init__()

    self.dsconv = nn.Sequential(
        nn.Conv3d(in_channels, in_channels, DSkernel, DSstride, DSpadding, dilation, groups = in_channels),
        nn.BatchNorm3d(in_channels),
        nn.ReLU(inplace=True),
        nn.Conv3d(in_channels, out_channels, 1 ,DSstride, DSpadding),
        nn.BatchNorm3d(out_channels),
        nn.ReLU(inplace=True)
    )

  def forward(self, inputx):
    return self.dsconv(inputx)


class DRDB(nn.Module):
      def __init__(self,in_channels, out_channels, DSkernel, DSstride ,DSpadding):
        super().__init__()
        
        self.r1 = DSconv(in_channels, in_channels, DSkernel, DSstride, DSpadding, dilation = 1)
        self.r2 = DSconv(in_channels, in_channels, DSkernel, DSstride, DSpadding, dilation = 2)
        self.r3 = DSconv(in_channels, in_channels, DSkernel, DSstride, DSpadding, dilation = 3)
        self.r4 = DSconv(in_channels, in_channels, DSkernel, DSstride, DSpadding, dilation = 4)
        self.g0 = DSconv(in_channels, in_channels, 1, DSstride, DSpadding, dilation = 4)

      def forward(self,x):
        x1 = self.r1(x)
        x11 = torch.cat([x1,x], dim=1) # concatting the input and output of the same convolution
        x2 = self.r2(x11)
        x22 = torch.cat([x2,x11], dim=1)
        x3 = self.r3(x22)
        x33 = torch.cat([x3,x22], dim=1)
        x4 = self.r4(x33)
        x44 = torch.cat([x4,x33], dim=1)
        xg0 = self.g0(x44)
        
        return xg0 + x


Model

In [None]:
class PLSnet(nn.Module):
  def __init__(self,in_channels, out_channels, DSkernel, DSstride ,DSpadding):
    super().__init__()
    self.TLupsample = nn.Upsample(scale_factor=2, mode='trilinear')
    self.TLdownsample = nn.functional.interpolate(scale_factor=1/2, mode='trilinear')
    self.oneconv = nn.Conv3d(in_channels, out_channels, 1 ,DSstride, DSpadding)
    self.softmaxAF = nn.Softmax()

  def forward(self, InputImage):
    # resolution = 1 in encoder
    g1 = DSconv(InputImage)
    InputImage = self.TLdownsample(InputImage)
    g1 = torch.cat([g1,InputImage],dim=1)
    g1 = DRDB(g1)
    # resolution = 2 in encoder
    g2 = DSconv(g1)
    InputImage = self.TLdownsample(InputImage)
    g2 = torch.cat([g2,InputImage],dim=1)
    g2 = DRDB(g2)
    g2 = DRDB(g2) # Don't know, but the paper's architecture has a DRDBx2! and I don't know what's that mean!
    # resolution = 3 in encoder
    g3 = DSconv(g2)
    InputImage = self.TLdownsample(InputImage)
    g3 = torch.cat([g3,InputImage],dim=1)
    g3 = DRDB(g3)
    g3 = DRDB(g3)
    g3 = DRDB(g3)
    g3 = DRDB(g3)
    g3 = DSconv(g3)
    g3d = self.TLupsample(g3) # g3 that got upsample and is placed in decoder
    # resolution = 2 in decoder
    g2d = torch.cat([g3d,g2],dim=1)
    g2d = DSconv(g2d)
    g2d = self.TLupsample(g2d)
    # resolution = 1 in decoder
    g1d = torch.cat([g2d,g1],dim=1)
    g1d = DSconv(g1d)
    g1d = self.TLupsample(g1d)
    # resolution = 0 in decoder
    g0d = self.oneconv(g1d)
    return self.Softmax(g0d) 