Enhanced version of Unet

In [None]:
'''
Monai Unet supports residual units and the hyperparameters
Unet has some constraints with the input sizes

For our use cases we define UNet structure

model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH).to(device)


-->16x80x80                                       up
      |                                         |-->
      -->32x40x40                             up
          |                                |-->
          -->64x20x20                     up
                |                      |-->
                -->128x10x10          up
                      |               |
                      -->256x5x5 --->

Residual Unit:  (Conv2D +Stride) + (Instance Norm) + (PReLU) +
                (Conv2D) + (Instance Norm) + (PReLU)
UpSample:       (ConvTrans2D +Stride) + (Concat) +
                (Cond2D) + (Instance Norm) + (PReLU)

'''

In [None]:
!python -c "import monai" || pip install -q "monai-weekly"

In [2]:
import monai
from monai.networks.nets import UNet
import torch
import torch.nn as nn
'''
Define Structure with res units 0 and no impact on input size
1st down layer - intermediate skip connection - final up layer
'''
unet_model_0 = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=1,
    channels=(8, 16, 32),
    strides=(2, 3),
    kernel_size=3,
    up_kernel_size=3,
    num_res_units=0,
)
unet_model_0

UNet(
  (model): Sequential(
    (0): Convolution(
      (conv): Conv3d(1, 8, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
      (adn): ADN(
        (N): InstanceNorm3d(8, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (D): Dropout(p=0.0, inplace=False)
        (A): PReLU(num_parameters=1)
      )
    )
    (1): SkipConnection(
      (submodule): Sequential(
        (0): Convolution(
          (conv): Conv3d(8, 16, kernel_size=(3, 3, 3), stride=(3, 3, 3), padding=(1, 1, 1))
          (adn): ADN(
            (N): InstanceNorm3d(16, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
            (D): Dropout(p=0.0, inplace=False)
            (A): PReLU(num_parameters=1)
          )
        )
        (1): SkipConnection(
          (submodule): Convolution(
            (conv): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
            (adn): ADN(
              (N): InstanceNorm3d(32, eps=1e-05, momentum=0.1

In [8]:
'''
for a deeper UNet the intermediate block is expanded
two modules are needed Convolutions and Skip Connections\
each with four layers: Activation (PReLU), Droput, Norm (InstanceNorm3d),
Convolution layers (Conv and ConvTranspose)

For Convolution layers, the output size depends on stride, kernel_size, dilation,
padding.
For our UNet, dilation = 1, and padding = (kernel_size -1)/2
The output size of Conv. is math.floor((input_size + stride -1) / stride)

Output size for ConvTranspose layer is input_size * stride

InUNet, SkipConnection is called via
nn.Sequential(down, SkipConnection(subblock), up) and line be called in forward
function, torch.cat([x, self.submodule(x)], dim=1)

Constraints of UNet
'''

In [3]:
''' Conv Layer '''
import math
def get_conv_output_size(input_tensor, stride):
  output_size = []
  input_size = list(input_tensor.shape)[2:]
  for size in input_size:
    out = math.floor((size + stride -1) / stride)
    output_size.append(out)
  return output_size
stride_value = 3
input_tensor = torch.rand([1, 3, 1, 15, 29])
get_conv_output_size(input_tensor, stride_value)

[1, 5, 10]

In [4]:
output = nn.Conv3d(in_channels=3, out_channels=1, stride=stride_value,
                   kernel_size=3, padding=1)(input_tensor)
output.shape[2:]

torch.Size([1, 5, 10])

In [5]:
''' ConvTranspose layer '''
stride_value = 3
[i* stride_value for i in input_tensor.shape[2:]]

[3, 45, 87]

In [6]:
output = nn.ConvTranspose3d(
    in_channels=3,
    out_channels=1,
    stride=stride_value,
    kernel_size=3,
    padding=1,
    output_padding=stride_value -1,
)(input_tensor)
output.shape[2:]

torch.Size([3, 45, 87])

In [7]:
''' Normalization layer '''
list(monai.networks.layers.factories.Norm)

[('INSTANCE',
  <function monai.networks.layers.factories.instance_factory(dim: 'int') -> 'type[nn.InstanceNorm1d | nn.InstanceNorm2d | nn.InstanceNorm3d]'>),
 ('BATCH',
  <function monai.networks.layers.factories.batch_factory(dim: 'int') -> 'type[nn.BatchNorm1d | nn.BatchNorm2d | nn.BatchNorm3d]'>),
 ('INSTANCE_NVFUSER',
  <function monai.networks.layers.factories.instance_nvfuser_factory(dim)>),
 ('GROUP',
  <function monai.networks.layers.factories.LayerFactory.add_factory_class.<locals>.<lambda>(x=None)>),
 ('LAYER',
  <function monai.networks.layers.factories.LayerFactory.add_factory_class.<locals>.<lambda>(x=None)>),
 ('LOCALRESPONSE',
  <function monai.networks.layers.factories.LayerFactory.add_factory_class.<locals>.<lambda>(x=None)>),
 ('SYNCBATCH',
  <function monai.networks.layers.factories.LayerFactory.add_factory_class.<locals>.<lambda>(x=None)>)]

In [35]:
''' Batch normalization '''
batch = nn.BatchNorm3d(num_features=3)
for size in [[1, 3, 2, 1, 1], [2, 3, 1, 1, 1]]:
  output = batch(torch.randn(size))
#output, batch(torch.randn([1, 3, 2, 1, 1]))

In [18]:
''' Instance normalization '''
instance = nn.InstanceNorm3d(num_features=3)
for size in [[1, 3, 2, 1, 1], [1, 3, 1, 2, 1]]:
  output = instance(torch.randn(size))
#output.shape

In [None]:
''' Skip Connection '''

In [None]:
'''
Constraints of UNet

1 down layer - 1 mode skip connection - 1 up layer
- if len(channels) ==2, strides are single values
- if using batch normalization B>1
- if using local response normalization, no constraint
- if using instance normalization, for d = max(H, W, D), then
math.floor((d +s -1) / s ) >= 2, which means d >= s +1
'''

In [20]:
''' len(channels) = 2, batch norm '''
network_model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=3,
    channels=(8, 16),
    strides=(3,),
    kernel_size=3,
    up_kernel_size=3,
    num_res_units=0,
    norm=('batch')
)
input_tensor = torch.rand([2, 1, 1, 1, 1])
network_model(input_tensor).shape

torch.Size([2, 3, 3, 3, 3])

In [10]:
''' len(channels) = 2, localresponse '''
network_model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=3,
    channels=(8, 16),
    strides=(3,),
    kernel_size=1,
    up_kernel_size=1,
    num_res_units=1,
    norm=('localresponse', {'size': 1})
)
input_tensor = torch.rand([1, 1, 1, 1, 1])
network_model(input_tensor).shape

torch.Size([1, 3, 3, 3, 3])

In [11]:
network_model

UNet(
  (model): Sequential(
    (0): ResidualUnit(
      (conv): Sequential(
        (unit0): Convolution(
          (conv): Conv3d(1, 8, kernel_size=(1, 1, 1), stride=(3, 3, 3))
          (adn): ADN(
            (N): LocalResponseNorm(1, alpha=0.0001, beta=0.75, k=1.0)
            (D): Dropout(p=0.0, inplace=False)
            (A): PReLU(num_parameters=1)
          )
        )
      )
      (residual): Conv3d(1, 8, kernel_size=(1, 1, 1), stride=(3, 3, 3))
    )
    (1): SkipConnection(
      (submodule): ResidualUnit(
        (conv): Sequential(
          (unit0): Convolution(
            (conv): Conv3d(8, 16, kernel_size=(1, 1, 1), stride=(1, 1, 1))
            (adn): ADN(
              (N): LocalResponseNorm(1, alpha=0.0001, beta=0.75, k=1.0)
              (D): Dropout(p=0.0, inplace=False)
              (A): PReLU(num_parameters=1)
            )
          )
        )
        (residual): Conv3d(8, 16, kernel_size=(1, 1, 1), stride=(1, 1, 1))
      )
    )
    (2): Sequential(
     

In [21]:
''' len(channels) = 2, instance norm '''
network_model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=3,
    channels=(8, 16),
    strides=(3,),
    kernel_size=3,
    up_kernel_size=5,
    num_res_units=2,
    norm='instance'
)
input_tensor = torch.rand([1, 1, 4, 1, 1])
network_model(input_tensor).shape

torch.Size([1, 3, 6, 3, 3])

In [None]:
'''
Constraints of UNet - continued

- if len(channels) >2, for input size [B, C, H, W, D]
- if using instance normalization
  size = math.floor((v + s[0] - 1) / s[0])
'''

In [23]:
''' strides=(3,5), batch norm, batch_size >1 '''
network_model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=3,
    channels=(8, 16, 32),
    strides=(3, 5),
    kernel_size=3,
    up_kernel_size=3,
    num_res_units=0,
    norm='batch'
)
input_tensor = torch.rand([2, 1, 13, 14, 15])
network_model(input_tensor).shape

torch.Size([2, 3, 15, 15, 15])

In [25]:
'''
strides=(3, 2, 4), localresponse,
math.floor((v+2) /3) should be 8*k, v in [22, 23, 2]
'''
network_model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=3,
    channels=(8, 16, 32, 16),
    strides=(3, 2, 4),
    kernel_size=1,
    up_kernel_size=3,
    num_res_units=10,
    norm=('localresponse', {'size': 1})
)
input_tensor = torch.rand([1, 1, 22, 23, 24])
network_model(input_tensor).shape

torch.Size([1, 3, 24, 24, 24])