In [1]:
"""

Notes on this model: https://shizacharania.notion.site/FCN-Fully-Convolutional-Network-bbc70d47ad92491cb534e2987ac7eb48

class convBlock (because you have 6 so its more feasible)
  - conv3d 3x3 (increase channels)
  - relu
  - conv3d 3x3 (same channels)
  - relu
  - max pool = conv with same channels, but stride is 2 and kernel is 2 (just like max pool layer)

  ~~~~~~~~~~~~~~~~~~

  - convtranspose3d = deconv increased channels, but stride is 2 and kernel is 3 (combining steps) - this is approach one
 
  ~~~~~~~~~~~~~~~~~~

  approach 2:
  - convtranspose3d = deconv with same channels, but stride is 2 and kernel is 2 (to increase spatial size)
  - conv 3x3 (less channels)

  ~~~~~~~~~~~~~~~~~~

  torch.add - for element wise summation
"""

'\n\nNotes on this model: https://shizacharania.notion.site/FCN-Fully-Convolutional-Network-bbc70d47ad92491cb534e2987ac7eb48\n\nclass convBlock (because you have 6 so its more feasible)\n  - conv3d 3x3 (increase channels)\n  - relu\n  - conv3d 3x3 (same channels)\n  - relu\n  - max pool = conv with same channels, but stride is 2 and kernel is 2 (just like max pool layer)\n\n  ~~~~~~~~~~~~~~~~~~\n\n  - convtranspose3d = deconv increased channels, but stride is 2 and kernel is 3 (combining steps) - this is approach one\n \n  ~~~~~~~~~~~~~~~~~~\n\n  approach 2:\n  - convtranspose3d = deconv with same channels, but stride is 2 and kernel is 2 (to increase spatial size)\n  - conv 3x3 (less channels)\n\n  ~~~~~~~~~~~~~~~~~~\n\n  torch.add - for element wise summation\n'

In [2]:
import torch
import torch.nn as nn

In [3]:
class ConvBlockEncoder(nn.Module): # need to add nn.Module here
  def __init__(self, input_channels, output_channels): # need to add variables needed here
    super().__init__()
    self.convblock1 = nn.Conv3d(in_channels=input_channels, out_channels=output_channels, kernel_size=3, padding=1)
    self.relu = nn.ReLU()
    self.convblock2 = nn.Conv3d(in_channels=output_channels, out_channels=output_channels, kernel_size=3, padding=1)
    self.downsample = nn.Conv3d(in_channels=output_channels, out_channels=output_channels, kernel_size=2, stride=2)

  def forward(self, x): # x represents the input
    x = self.convblock1(x)
    x = self.relu(x)
    x = self.convblock2(x)
    out = self.relu(x)
    down_out = self.downsample(out)
    return out, down_out # out is before downsamping and down_out is after

class ConvBlockDecoder(nn.Module):
  def __init__(self, input_channels, output_channels):
    super().__init__()
    # self.deconv = nn.ConvTranspose3d(in_channels=input_channels, out_channels=output_channels, kernel_size=3, stride=2, padding=1, dilation=1)
    self.upsample = nn.ConvTranspose3d(in_channels=input_channels, out_channels=input_channels, kernel_size=2, stride=2)
    self.convblock = nn.Conv3d(in_channels=input_channels, out_channels=output_channels, kernel_size=3, padding=1)

  def forward(self, x):
    # x = self.deconv(x) # didn't work because it was just outputing 7 or 9 as the amount of channels and I coulnd't change the padding because it would either have one extra or one less (and I wanted 8 as the spatial dimensions)
    x = self.upsample(x)
    x = self.convblock(x)
    return x


class FCN(nn.Module):
  def __init__(self, input_shape):
    super().__init__()
    # 4, 128, 128, 128
    print(input_shape)
    self.nimages, self.channels, self.w, self.h, self.d = input_shape
    self.block1 = ConvBlockEncoder(4, 64)
    self.block2 = ConvBlockEncoder(64, 128)
    self.block3 = ConvBlockEncoder(128, 256)
    self.block4 = ConvBlockEncoder(256, 512)
    self.block5 = ConvBlockEncoder(512, 1024)
    self.block6 = ConvBlockEncoder(1024, 2048)

    self.block7 = ConvBlockDecoder(2048, 1024)
    self.block8 = ConvBlockDecoder(1024, 512)
    self.block9 = ConvBlockDecoder(512, 256)
    self.block10 = ConvBlockDecoder(256, 128)
    self.block11 = ConvBlockDecoder(128, 64)

    self.conv1by1 = nn.Conv3d(in_channels=64, out_channels=3, kernel_size=1, stride=1) # output is 3 channels for WT, ET, TC
    self.softmax = nn.Softmax(dim=1)
  
  def forward(self, x):
    print("Before shape: ", str(x.shape))
    out1, down_out1 = self.block1(x)
    print(out1.shape, down_out1.shape)

    out2, down_out2 = self.block2(down_out1)
    print(out2.shape, down_out2.shape)

    out3, down_out3 = self.block3(down_out2)
    print(out3.shape, down_out3.shape)

    out4, down_out4 = self.block4(down_out3)
    print(out4.shape, down_out4.shape)

    out5, down_out5 = self.block5(down_out4)
    print(out5.shape, down_out5.shape)

    out6, ____ = self.block6(down_out5)
    print(out6.shape)

    out7 = self.block7(out6)
    print(out7.shape)

    # sum here with block 5 + block 7
    out57 = torch.add(out5, out7)
    print(out57.shape)

    out8 = self.block8(out57)
    print(out8.shape)
    
    # sum here with block 4 + block 8
    out48 = torch.add(out4, out8)
    print(out48.shape)

    out9 = self.block9(out48)
    print(out9.shape)

    out10 = self.block10(out9)
    print(out10.shape)

    out11 = self.block11(out10)
    print(out11.shape)

    out12 = self.conv1by1(out11)
    print(out12.shape)

    end = self.softmax(out12)
    print(end.shape) # i thought that the shape would change when using softmax, but only the values change to add up to 1

    return out12, end

In [4]:
x = torch.rand(size=(1, 4, 128, 128, 128), dtype=torch.float32)
# print(x.shape)

model = FCN(x.shape)
print(model)
print()

out = model(x)
# print(out.shape)

torch.Size([1, 4, 128, 128, 128])
FCN(
  (block1): ConvBlockEncoder(
    (convblock1): Conv3d(4, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (relu): ReLU()
    (convblock2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (downsample): Conv3d(64, 64, kernel_size=(2, 2, 2), stride=(2, 2, 2))
  )
  (block2): ConvBlockEncoder(
    (convblock1): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (relu): ReLU()
    (convblock2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (downsample): Conv3d(128, 128, kernel_size=(2, 2, 2), stride=(2, 2, 2))
  )
  (block3): ConvBlockEncoder(
    (convblock1): Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (relu): ReLU()
    (convblock2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (downsample): Conv3d(256, 256, kernel_size=(2, 2, 2), stride=(2, 2, 2))
  )
  (block4): 

In [5]:
output, probability = out

In [6]:
print(output.shape)
print(output)

torch.Size([1, 3, 128, 128, 128])
tensor([[[[[-0.1080, -0.1059, -0.1059,  ..., -0.1059, -0.1057, -0.1061],
           [-0.1092, -0.1057, -0.1061,  ..., -0.1058, -0.1062, -0.1047],
           [-0.1093, -0.1055, -0.1058,  ..., -0.1055, -0.1056, -0.1048],
           ...,
           [-0.1091, -0.1056, -0.1059,  ..., -0.1058, -0.1060, -0.1047],
           [-0.1093, -0.1056, -0.1058,  ..., -0.1057, -0.1058, -0.1049],
           [-0.1099, -0.1080, -0.1081,  ..., -0.1081, -0.1083, -0.1075]],

          [[-0.1051, -0.1035, -0.1029,  ..., -0.1034, -0.1027, -0.1049],
           [-0.1038, -0.0994, -0.1001,  ..., -0.0994, -0.0997, -0.1025],
           [-0.1040, -0.1008, -0.1001,  ..., -0.1007, -0.1000, -0.1030],
           ...,
           [-0.1040, -0.0992, -0.1004,  ..., -0.0994, -0.0998, -0.1028],
           [-0.1039, -0.1007, -0.1003,  ..., -0.1009, -0.1003, -0.1033],
           [-0.1057, -0.1011, -0.1019,  ..., -0.1012, -0.1017, -0.1048]],

          [[-0.1049, -0.1031, -0.1029,  ..., -0.1032, 

In [7]:
print(probability.shape)
print(probability)

torch.Size([1, 3, 128, 128, 128])
tensor([[[[[0.3173, 0.3176, 0.3177,  ..., 0.3176, 0.3177, 0.3173],
           [0.3172, 0.3176, 0.3177,  ..., 0.3176, 0.3177, 0.3175],
           [0.3172, 0.3177, 0.3178,  ..., 0.3178, 0.3178, 0.3175],
           ...,
           [0.3172, 0.3177, 0.3178,  ..., 0.3176, 0.3177, 0.3175],
           [0.3172, 0.3177, 0.3178,  ..., 0.3177, 0.3178, 0.3175],
           [0.3169, 0.3170, 0.3171,  ..., 0.3169, 0.3170, 0.3169]],

          [[0.3180, 0.3179, 0.3178,  ..., 0.3180, 0.3178, 0.3174],
           [0.3187, 0.3189, 0.3187,  ..., 0.3189, 0.3188, 0.3177],
           [0.3185, 0.3188, 0.3186,  ..., 0.3188, 0.3186, 0.3178],
           ...,
           [0.3186, 0.3189, 0.3186,  ..., 0.3189, 0.3189, 0.3176],
           [0.3185, 0.3188, 0.3185,  ..., 0.3188, 0.3185, 0.3178],
           [0.3177, 0.3182, 0.3180,  ..., 0.3182, 0.3181, 0.3171]],

          [[0.3181, 0.3178, 0.3180,  ..., 0.3177, 0.3181, 0.3173],
           [0.3184, 0.3187, 0.3186,  ..., 0.3187, 0.3185, 0