In [None]:
import torch
import torch.nn as nn

In [None]:
from torch.nn.modules import conv
class Layer1(nn.Module):
  def __init__(self, input_channels, output_channels):
    super().__init__()
    # nn.Sequential doesn't work
    self.conv1 = nn.Conv3d(in_channels=input_channels, out_channels=output_channels, kernel_size=3, stride=1, padding=1)
    self.instnorm = nn.InstanceNorm3d(output_channels)
    self.relu = nn.ReLU()
    self.conv = nn.Conv3d(in_channels=output_channels, out_channels=output_channels, kernel_size=3, stride=1, padding=1)

  def forward(self, x):
    # block 1
    x1 = self.conv1(x)
    x2 = self.instnorm(x1)
    x3 = self.relu(x2) # concatenate this before 2nd ReLU in res block
    # print(x3.shape)

    # block 2
    x4 = self.conv(x3)
    x5 = self.instnorm(x4)
    x6 = self.relu(x5)
    x7 = self.conv(x6)
    x8 = self.instnorm(x7) # concatenate this before 2nd ReLU in res block
    # print(x8.shape)

    # element wise add
    x9 = torch.add(x3, x8)

    # relu
    x10 = self.relu(x9)
    return x10

In [None]:
class Layer2(nn.Module):
  def __init__(self, input_channels, output_channels):
    super().__init__()
    self.conv1 = nn.Conv3d(in_channels=input_channels, out_channels=output_channels, kernel_size=3, stride=2, padding=1)
    self.instnorm = nn.InstanceNorm3d(output_channels)
    self.relu = nn.ReLU()
    self.conv = nn.Conv3d(in_channels=output_channels, out_channels=output_channels, kernel_size=3, stride=1, padding=1)

  def forward(self, x):
    # block 1
    x1 = self.conv1(x)
    x2 = self.instnorm(x1)
    x3 = self.relu(x2)

    # block 2
    x4 = self.conv(x3)
    x5 = self.instnorm(x4)
    x6 = self.relu(x5)
    x7 = self.conv(x6)
    x8 = self.instnorm(x7) # concat
    x9 = torch.add(x8, x3)
    x10 = self.relu(x9)
    return x10

In [None]:
class Layer3(nn.Module):
  def __init__(self, input_channels, output_channels):
    super().__init__()
    self.conv1 = nn.Conv3d(in_channels=input_channels, out_channels=output_channels, kernel_size=3, stride=2, padding=1)
    self.instnorm = nn.InstanceNorm3d(output_channels)
    self.relu = nn.ReLU()
    self.conv = nn.Conv3d(in_channels=output_channels, out_channels=output_channels, kernel_size=3, stride=1, padding=1)

  def forward(self, x):
    # block 1
    x1 = self.conv1(x)
    x2 = self.instnorm(x1)
    x3 = self.relu(x2)

    # block 2 x1
    x4 = self.conv(x3)
    x5 = self.instnorm(x4)
    x6 = self.relu(x5)
    x7 = self.conv(x6)
    x8 = self.instnorm(x7) # concat
    x9 = torch.add(x8, x3)
    x10 = self.relu(x9)

    # block 2 x2
    x11 = self.conv(x10)
    x12 = self.instnorm(x11)
    x13 = self.relu(x12)
    x14 = self.conv(x13)
    x15 = self.instnorm(x14) # concat
    x16 = torch.add(x10, x15)
    x17 = self.relu(x16)
    return x17

In [None]:
class Layer4(nn.Module):
  def __init__(self, input_channels, output_channels):
    super().__init__()
    self.conv1 = nn.Conv3d(in_channels=input_channels, out_channels=output_channels, kernel_size=3, stride=2, padding=1)
    self.instnorm = nn.InstanceNorm3d(output_channels)
    self.relu = nn.ReLU()
    self.conv = nn.Conv3d(in_channels=output_channels, out_channels=output_channels, kernel_size=3, stride=1, padding=1)

  def forward(self, x):
    # block 1
    x1 = self.conv1(x)
    x2 = self.instnorm(x1)
    x3 = self.relu(x2)

    # block 2 x1
    x4 = self.conv(x3)
    x5 = self.instnorm(x4)
    x6 = self.relu(x5)
    x7 = self.conv(x6)
    x8 = self.instnorm(x7) 
    x9 = torch.add(x8, x3)
    x10 = self.relu(x9)

    # block 2 x2
    x11 = self.conv(x10)
    x12 = self.instnorm(x11)
    x13 = self.relu(x12)
    x14 = self.conv(x13)
    x15 = self.instnorm(x14)
    x16 = torch.add(x10, x15)
    x17 = self.relu(x16)

    # block 2 x3
    x18 = self.conv(x17)
    x19 = self.instnorm(x18)
    x20 = self.relu(x19)
    x21 = self.conv(x20)
    x22 = self.instnorm(x21)
    x23 = torch.add(x17, x22)
    x24 = self.relu(x23)
    return x24

In [None]:
class Layer5(nn.Module):
  def __init__(self, input_channels, output_channels):
    super().__init__()
    self.conv1 = nn.Conv3d(in_channels=input_channels, out_channels=output_channels, kernel_size=3, stride=2, padding=1)
    self.instnorm = nn.InstanceNorm3d(output_channels)
    self.relu = nn.ReLU()
    self.conv = nn.Conv3d(in_channels=output_channels, out_channels=output_channels, kernel_size=3, stride=1, padding=1)

  def forward(self, x):
    # block 1
    x1 = self.conv1(x)
    x2 = self.instnorm(x1)
    x3 = self.relu(x2)

    # block 2 x1
    x4 = self.conv(x3)
    x5 = self.instnorm(x4)
    x6 = self.relu(x5)
    x7 = self.conv(x6)
    x8 = self.instnorm(x7) 
    x9 = torch.add(x8, x3)
    x10 = self.relu(x9)

    # block 2 x2
    x11 = self.conv(x10)
    x12 = self.instnorm(x11)
    x13 = self.relu(x12)
    x14 = self.conv(x13)
    x15 = self.instnorm(x14)
    x16 = torch.add(x10, x15)
    x17 = self.relu(x16)

    # block 2 x3
    x18 = self.conv(x17)
    x19 = self.instnorm(x18)
    x20 = self.relu(x19)
    x21 = self.conv(x20)
    x22 = self.instnorm(x21)
    x23 = torch.add(x17, x22)
    x24 = self.relu(x23)

    # block 2 x4
    x25 = self.conv(x24)
    x26 = self.instnorm(x25)
    x27 = self.relu(x26)
    x28 = self.conv(x27)
    x29 = self.instnorm(x28)
    x30 = torch.add(x24, x29)
    x31 = self.relu(x30)
    return x31

In [None]:
class Layer6(nn.Module):
  def __init__(self, input_channels, output_channels):
    super().__init__()
    self.conv1 = nn.Conv3d(in_channels=input_channels, out_channels=output_channels, kernel_size=3, stride=2, padding=1)
    self.instnorm = nn.InstanceNorm3d(output_channels)
    self.relu = nn.ReLU()
    self.conv = nn.Conv3d(in_channels=output_channels, out_channels=output_channels, kernel_size=3, stride=1, padding=1)

  def forward(self, x):
    # block 1
    x1 = self.conv1(x)
    x2 = self.instnorm(x1)
    x3 = self.relu(x2)

    # block 2 x1
    x4 = self.conv(x3)
    x5 = self.instnorm(x4)
    x6 = self.relu(x5)
    x7 = self.conv(x6)
    x8 = self.instnorm(x7) 
    x9 = torch.add(x8, x3)
    x10 = self.relu(x9)

    # block 2 x2
    x11 = self.conv(x10)
    x12 = self.instnorm(x11)
    x13 = self.relu(x12)
    x14 = self.conv(x13)
    x15 = self.instnorm(x14)
    x16 = torch.add(x10, x15)
    x17 = self.relu(x16)

    # block 2 x3
    x18 = self.conv(x17)
    x19 = self.instnorm(x18)
    x20 = self.relu(x19)
    x21 = self.conv(x20)
    x22 = self.instnorm(x21)
    x23 = torch.add(x17, x22)
    x24 = self.relu(x23)

    # block 2 x4
    x25 = self.conv(x24)
    x26 = self.instnorm(x25)
    x27 = self.relu(x26)
    x28 = self.conv(x27)
    x29 = self.instnorm(x28)
    x30 = torch.add(x24, x29)
    x31 = self.relu(x30)

    # block 2 x5
    x32 = self.conv(x31)
    x33 = self.instnorm(x32)
    x34 = self.relu(x33)
    x35 = self.conv(x34)
    x36 = self.instnorm(x35)
    x37 = torch.add(x31, x36)
    x38 = self.relu(x37)
    return x38

In [None]:
class ConvBlockDecoder(nn.Module):
  def __init__(self, input_channels, output_channels):
    super().__init__()
    self.transcov = nn.ConvTranspose3d(in_channels=input_channels, out_channels=input_channels, kernel_size=2, stride=2)
    self.conv = nn.Conv3d(in_channels=input_channels, out_channels=output_channels, kernel_size=3, stride=1, padding=1)
    self.instnorm = nn.InstanceNorm3d(output_channels)
    self.relu = nn.ReLU()

  def forward(self, x, concat_tensor):
    x1 = self.transcov(x)
    x2 = torch.add(x1, concat_tensor)
    x3 = self.conv(x2)
    x4 = self.instnorm(x3)
    x5 = self.relu(x4)
    return x5

In [None]:
class ResidualUNet(nn.Module):
  def __init__(self, input_shape):
    super().__init__()
    self.layer1 = Layer1(4, 24)
    self.layer2 = Layer2(24, 48)
    self.layer3 = Layer3(48, 96)
    self.layer4 = Layer4(96, 192)
    self.layer5 = Layer5(192, 320)
    self.layer6 = Layer6(320, 320)
    self.layer7 = ConvBlockDecoder(320, 192)
    self.layer8 = ConvBlockDecoder(192, 96)
    self.layer9 = ConvBlockDecoder(96, 48)
    self.layer10 = ConvBlockDecoder(48, 24)
    self.layer11 = ConvBlockDecoder(24, 24)
    self.conv1x1x1 = nn.Conv3d(in_channels=24, out_channels=3, kernel_size=1, stride=1)
    self.softmax = nn.Softmax()

  def forward(self, x):
    print(x.shape) # [1, 4, 128, 128, 128]
    x1 = self.layer1(x)
    print(x1.shape) # [1, 24, 128, 128, 128]

    x2 = self.layer2(x1)
    print(x2.shape) # [1, 48, 64, 64, 64]

    x3 = self.layer3(x2)
    print(x3.shape) # [1, 96, 32, 32, 32]

    x4 = self.layer4(x3)
    print(x4.shape) # [1, 192, 16, 16, 16])

    x5 = self.layer5(x4)
    print(x5.shape) # [1, 320, 8, 8, 8]

    x6 = self.layer6(x5)
    print(x6.shape) # [1, 320, 4, 4, 4]

    x7 = self.layer7(x6, x5)
    print(x7.shape) # [1, 320, 4, 4, 4]

    x8 = self.layer8(x7, x4)
    print(x8.shape) # [1, 192, 8, 8, 8]

    x9 = self.layer9(x8, x3)
    print(x9.shape) # [1, 96, 16, 16, 16]

    x10 = self.layer10(x9, x2)
    print(x10.shape) # [1, 48, 32, 32, 32]

    x11 = self.layer11(x10, x1)
    print(x11.shape) # [1, 24, 64, 64, 64]

    x12 = self.conv1x1x1(x11)
    print(x12.shape) # [1, 24, 128, 128, 128]

    prob = self.softmax(x12) # [1, 3, 128, 128, 128]

    return x12, prob

In [None]:
x = torch.rand(size=(1, 4, 128, 128, 128), dtype=torch.float32)
# print(x.shape)

model = ResidualUNet(x.shape)
print(model)
print()

out = model(x)
# print(out.shape)

ResidualUNet(
  (layer1): Layer1(
    (conv1): Conv3d(4, 24, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (instnorm): InstanceNorm3d(24, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (relu): ReLU()
    (conv): Conv3d(24, 24, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  )
  (layer2): Layer2(
    (conv1): Conv3d(24, 48, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
    (instnorm): InstanceNorm3d(48, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (relu): ReLU()
    (conv): Conv3d(48, 48, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  )
  (layer3): Layer3(
    (conv1): Conv3d(48, 96, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
    (instnorm): InstanceNorm3d(96, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (relu): ReLU()
    (conv): Conv3d(96, 96, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  )
  (layer4): Layer4(
    (conv1



In [None]:
output, probability = out

In [None]:
print(output.shape)
print(output)

torch.Size([1, 3, 128, 128, 128])
tensor([[[[[ 0.3602,  0.1844,  0.3652,  ...,  0.1221,  0.5244,  0.3558],
           [ 0.9086,  0.1166,  0.2256,  ...,  0.0929,  0.2228,  0.1085],
           [ 0.4354,  0.1410,  0.3573,  ..., -0.1633,  0.5540,  0.2071],
           ...,
           [ 0.5004,  0.2470, -0.1720,  ...,  0.0800, -0.0228,  0.1403],
           [ 0.0982,  0.5464,  0.3332,  ...,  0.6819,  0.0423,  0.2113],
           [ 0.6684,  0.2299,  0.0670,  ...,  0.2828,  0.2991,  0.2259]],

          [[ 0.2891,  0.1468,  0.7275,  ...,  0.4471,  0.4107,  0.1801],
           [ 0.1449,  0.0174,  0.4402,  ...,  0.1534,  0.0289,  0.2174],
           [-0.1998,  0.4861,  0.2981,  ..., -0.1011,  0.3637,  0.1635],
           ...,
           [-0.0254,  0.3631, -0.0889,  ...,  0.1321,  0.0576,  0.3282],
           [-0.2441, -0.1029,  0.6604,  ...,  0.2950,  0.6670,  0.2430],
           [ 0.1377, -0.2047,  0.1267,  ...,  0.1465,  0.0680, -0.0531]],

          [[ 0.4591,  0.0387,  0.2690,  ...,  0.3396, 

In [None]:
print(probability.shape)
print(probability)

torch.Size([1, 3, 128, 128, 128])
tensor([[[[[0.4394, 0.3912, 0.4485,  ..., 0.3844, 0.4863, 0.4737],
           [0.5468, 0.4729, 0.3564,  ..., 0.3305, 0.2961, 0.4057],
           [0.4196, 0.4901, 0.4663,  ..., 0.2168, 0.5437, 0.5569],
           ...,
           [0.4045, 0.4576, 0.2618,  ..., 0.4171, 0.3136, 0.3674],
           [0.3783, 0.5672, 0.4751,  ..., 0.5965, 0.3587, 0.4216],
           [0.5406, 0.4556, 0.3690,  ..., 0.4326, 0.4048, 0.4093]],

          [[0.3998, 0.3577, 0.4258,  ..., 0.4478, 0.4550, 0.3660],
           [0.4009, 0.2996, 0.3647,  ..., 0.2950, 0.2119, 0.3964],
           [0.2926, 0.5585, 0.4070,  ..., 0.3573, 0.3465, 0.3780],
           ...,
           [0.2225, 0.4917, 0.2370,  ..., 0.4528, 0.3326, 0.4006],
           [0.2188, 0.2468, 0.4511,  ..., 0.3013, 0.5532, 0.3261],
           [0.3736, 0.3169, 0.2652,  ..., 0.3083, 0.4622, 0.2872]],

          [[0.4422, 0.4310, 0.5001,  ..., 0.5214, 0.3731, 0.3252],
           [0.4179, 0.4393, 0.4926,  ..., 0.4139, 0.2335, 0