In [5]:
"""
class convBlock (because you have 6 so its more feasible)
  - conv3d 3x3 (increase channels)
  - relu
  - conv3d 3x3 (same channels)
  - relu
  - max pool = conv with same channels, but stride is 2 and kernel is 2 (just like max pool layer)

  ~~~~~~~~~~~~~~~~~~

  - convtranspose3d = deconv increased channels, but stride is 2 and kernel is 3 (combining steps) - this is approach one
 
  ~~~~~~~~~~~~~~~~~~

  approach 2:
  - convtranspose3d = deconv with same channels, but stride is 2 and kernel is 2 (to increase spatial size)
  - conv 3x3 (less channels)

  ~~~~~~~~~~~~~~~~~~

  torch.add - for element wise summation
"""

'\nclass convBlock (because you have 6 so its more feasible)\n  - conv3d 3x3 (increase channels)\n  - relu\n  - conv3d 3x3 (same channels)\n  - relu\n  - max pool = conv with same channels, but stride is 2 and kernel is 2 (just like max pool layer)\n\n  ~~~~~~~~~~~~~~~~~~\n\n  - convtranspose3d = deconv increased channels, but stride is 2 and kernel is 3 (combining steps) - this is approach one\n \n  ~~~~~~~~~~~~~~~~~~\n\n  approach 2:\n  - convtranspose3d = deconv with same channels, but stride is 2 and kernel is 2 (to increase spatial size)\n  - conv 3x3 (less channels)\n\n  ~~~~~~~~~~~~~~~~~~\n\n  torch.add - for element wise summation\n'

In [6]:
import torch
import torch.nn as nn

In [9]:
class ConvBlockEncoder(nn.Module): # need to add nn.Module here
  def __init__(self, input_channels, output_channels): # need to add variables needed here
    super().__init__()
    self.convblock1 = nn.Conv3d(in_channels=input_channels, out_channels=output_channels, kernel_size=3, padding=1)
    self.relu = nn.ReLU()
    self.convblock2 = nn.Conv3d(in_channels=output_channels, out_channels=output_channels, kernel_size=3, padding=1)
    self.downsample = nn.Conv3d(in_channels=output_channels, out_channels=output_channels, kernel_size=2, stride=2)

  def forward(self, x): # x represents the input
    x = self.convblock1(x)
    x = self.relu(x)
    x = self.convblock2(x)
    out = self.relu(x)
    down_out = self.downsample(out)
    return out, down_out # out is before downsamping and down_out is after

class ConvBlockDecoder(nn.Module):
  def __init__(self, input_channels, output_channels):
    super().__init__()
    # self.deconv = nn.ConvTranspose3d(in_channels=input_channels, out_channels=output_channels, kernel_size=3, stride=2, padding=1, dilation=1)
    self.upsample = nn.ConvTranspose3d(in_channels=input_channels, out_channels=input_channels, kernel_size=2, stride=2)
    self.convblock = nn.Conv3d(in_channels=input_channels, out_channels=output_channels, kernel_size=3, padding=1)

  def forward(self, x):
    # x = self.deconv(x)
    x = self.upsample(x)
    x = self.convblock(x)
    return x


class FCN(nn.Module):
  def __init__(self, input_shape):
    super().__init__()
    # 4, 128, 128, 128
    print(input_shape)
    self.channels, self.w, self.h, self.d = input_shape
    self.block1 = ConvBlockEncoder(4, 64)
    self.block2 = ConvBlockEncoder(64, 128)
    self.block3 = ConvBlockEncoder(128, 256)
    self.block4 = ConvBlockEncoder(256, 512)
    self.block5 = ConvBlockEncoder(512, 1024)
    self.block6 = ConvBlockEncoder(1024, 2048)

    self.block7 = ConvBlockDecoder(2048, 1024)
    self.block8 = ConvBlockDecoder(1024, 512)
    self.block9 = ConvBlockDecoder(512, 256)
    self.block10 = ConvBlockDecoder(256, 128)
    self.block11 = ConvBlockDecoder(128, 64)

    self.conv1by1 = nn.Conv3d(in_channels=64, out_channels=1, kernel_size=1, stride=1)
  
  def forward(self, x):
    print("Before shape: ", str(x.shape))
    out1, down_out1 = self.block1(x)
    print(out1.shape, down_out1.shape)

    out2, down_out2 = self.block2(down_out1)
    print(out2.shape, down_out2.shape)

    out3, down_out3 = self.block3(down_out2)
    print(out3.shape, down_out3.shape)

    out4, down_out4 = self.block4(down_out3)
    print(out4.shape, down_out4.shape)

    out5, down_out5 = self.block5(down_out4)
    print(out5.shape, down_out5.shape)

    out6, ____ = self.block6(down_out5)
    print(out6.shape)

    out7 = self.block7(out6)
    print(out7.shape)

    # sum here with block 5 + block 7
    out57 = torch.add(out5, out7)
    print(out57.shape)

    out8 = self.block8(out57)
    print(out8.shape)
    
    # sum here with block 4 + block 8
    out48 = torch.add(out4, out8)
    print(out48.shape)

    out9 = self.block9(out48)
    print(out9.shape)

    out10 = self.block10(out9)
    print(out10.shape)

    out11 = self.block11(out10)
    print(out11.shape)

    out12 = self.conv1by1(out11)
    print(out12.shape)

In [10]:
x = torch.rand(size=(4, 128, 128, 128), dtype=torch.float32)
# print(x.shape)

model = FCN(x.shape)

out = model(x)
# print(out.shape)

torch.Size([4, 128, 128, 128])
Before shape:  torch.Size([4, 128, 128, 128])
torch.Size([64, 128, 128, 128]) torch.Size([64, 64, 64, 64])
torch.Size([128, 64, 64, 64]) torch.Size([128, 32, 32, 32])
torch.Size([256, 32, 32, 32]) torch.Size([256, 16, 16, 16])
torch.Size([512, 16, 16, 16]) torch.Size([512, 8, 8, 8])
torch.Size([1024, 8, 8, 8]) torch.Size([1024, 4, 4, 4])
torch.Size([2048, 4, 4, 4])
torch.Size([1024, 8, 8, 8])
torch.Size([1024, 8, 8, 8])
torch.Size([512, 16, 16, 16])
torch.Size([512, 16, 16, 16])
torch.Size([256, 32, 32, 32])
torch.Size([128, 64, 64, 64])
torch.Size([64, 128, 128, 128])
torch.Size([1, 128, 128, 128])
