In [3]:
import torch
import torch.nn as nn

In [4]:
class ConvBlock(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()
    self.conv1 = nn.Conv2d(in_channels=in_channels,
                           out_channels=out_channels,
                           kernel_size=3,
                           padding=1,
                           bias=False)
    self.bn1 = nn.BatchNorm2d(num_features=out_channels)
    self.conv2 = nn.Conv2d(in_channels=out_channels,
                           out_channels=out_channels,
                           kernel_size=3,
                           padding=1,
                           bias=False)
    self.bn2 = nn.BatchNorm2d(num_features=out_channels)
    self.relu = nn.ReLU(inplace=True)

  def forward(self, x):
    x = self.bn1(self.conv1(x))
    x = self.relu(x)
    x = self.bn2(self.conv2(x))
    return self.relu(x)

In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
conv = ConvBlock(in_channels=1, out_channels=64).to(device)

In [6]:
input = torch.randint(low=0, high=255, size=(1, 1, 256, 256), dtype=torch.float32).to(device)
input.shape

torch.Size([1, 1, 256, 256])

In [7]:
conv(input).shape

torch.Size([1, 64, 256, 256])

In [8]:
class ConvBlock(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()
    self.conv_block = nn.Sequential(
        nn.Conv2d(in_channels=in_channels,
                           out_channels=out_channels,
                           kernel_size=3,
                           padding=1,
                           bias=False),
        nn.BatchNorm2d(num_features=out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_channels=out_channels,
                           out_channels=out_channels,
                           kernel_size=3,
                           padding=1,
                           bias=False),
        nn.BatchNorm2d(num_features=out_channels),
        nn.ReLU(inplace=True)
    )

  def forward(self, x):
    return self.conv_block(x)

In [9]:
class Encoder(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()
    self.encoder = nn.Sequential(
        nn.MaxPool2d(kernel_size=2),
        ConvBlock(in_channels=in_channels, out_channels=out_channels),
    )

  def forward(self, x):
    return self.encoder(x)

In [10]:
enc = Encoder(in_channels=64, out_channels=128).to(device)

In [11]:
input = torch.randint(low=0, high=255, size=(1, 64, 128, 128), dtype=torch.float32).to(device)
input.shape

torch.Size([1, 64, 128, 128])

In [12]:
enc(input).shape

torch.Size([1, 128, 64, 64])

In [44]:
class Decoder(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()
    self.trans_conv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels,
                                         kernel_size=4, stride=2, padding=1)
    self.conv_block = ConvBlock(in_channels=in_channels, out_channels=out_channels)

  def forward(self, x1, x2):
    # print(x1.shape, x2.shape)
    x = self.trans_conv(x1)
    # print("After transpose: ", x.shape)
    x = torch.cat([x2, x], dim=1)
    # print("after concat: ", x.shape)
    return self.conv_block(x)

In [24]:
dec = Decoder(in_channels=128, out_channels=64).to(device)

In [17]:
x1 = torch.randint(high=8, size=(1, 128, 128, 128), dtype=torch.float32, device=device)
x2 = torch.randint(high=8, size=(1, 64, 256, 256), dtype=torch.float32, device=device)

In [25]:
dec(x1, x2).shape

torch.Size([1, 128, 128, 128]) torch.Size([1, 64, 256, 256])
After transpose:  torch.Size([1, 64, 256, 256])
after concat:  torch.Size([1, 128, 256, 256])


torch.Size([1, 64, 256, 256])

In [45]:
class UNET(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()
    self.in_conv = ConvBlock(in_channels=in_channels, out_channels=64)
    self.enc_1 = Encoder(in_channels=64, out_channels=128)
    self.enc_2 = Encoder(in_channels=128, out_channels=256)
    self.enc_3 = Encoder(in_channels=256, out_channels=512)
    self.enc_4 = Encoder(in_channels=512, out_channels=1024)

    self.dec_1 = Decoder(in_channels=1024, out_channels=512)
    self.dec_2 = Decoder(in_channels=512, out_channels=256)
    self.dec_3 = Decoder(in_channels=256, out_channels=128)
    self.dec_4 = Decoder(in_channels=128, out_channels=64)

    self.out_conv = nn.Conv2d(in_channels=64, out_channels=out_channels,
                              kernel_size=1)

  def forward(self, x):
    x1 = self.in_conv(x)
    x2 = self.enc_1(x1)
    x3 = self.enc_2(x2)
    x4 = self.enc_3(x3)
    x5 = self.enc_4(x4)

    x = self.dec_1(x5, x4)
    x = self.dec_2(x, x3)
    x = self.dec_3(x, x2)
    x = self.dec_4(x, x1)
    return self.out_conv(x)



In [46]:
model = UNET(in_channels=1, out_channels=2).to(device)

In [28]:
input = torch.randint(low=0, high=255, size=(1, 1, 256, 256), dtype=torch.float32).to(device)
input.shape

torch.Size([1, 1, 256, 256])

In [42]:
model(input).shape

torch.Size([1, 1024, 16, 16]) torch.Size([1, 512, 32, 32])
After transpose:  torch.Size([1, 512, 32, 32])
after concat:  torch.Size([1, 1024, 32, 32])
torch.Size([1, 512, 32, 32]) torch.Size([1, 256, 64, 64])
After transpose:  torch.Size([1, 256, 64, 64])
after concat:  torch.Size([1, 512, 64, 64])
torch.Size([1, 256, 64, 64]) torch.Size([1, 128, 128, 128])
After transpose:  torch.Size([1, 128, 128, 128])
after concat:  torch.Size([1, 256, 128, 128])
torch.Size([1, 128, 128, 128]) torch.Size([1, 64, 256, 256])
After transpose:  torch.Size([1, 64, 256, 256])
after concat:  torch.Size([1, 128, 256, 256])


torch.Size([1, 2, 256, 256])

In [47]:
from torchsummary import summary
summary(model, (1, 256, 256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 256, 256]             576
       BatchNorm2d-2         [-1, 64, 256, 256]             128
              ReLU-3         [-1, 64, 256, 256]               0
            Conv2d-4         [-1, 64, 256, 256]          36,864
       BatchNorm2d-5         [-1, 64, 256, 256]             128
              ReLU-6         [-1, 64, 256, 256]               0
         ConvBlock-7         [-1, 64, 256, 256]               0
         MaxPool2d-8         [-1, 64, 128, 128]               0
            Conv2d-9        [-1, 128, 128, 128]          73,728
      BatchNorm2d-10        [-1, 128, 128, 128]             256
             ReLU-11        [-1, 128, 128, 128]               0
           Conv2d-12        [-1, 128, 128, 128]         147,456
      BatchNorm2d-13        [-1, 128, 128, 128]             256
             ReLU-14        [-1, 128, 1