In [1]:
import torch
import torch.nn as nn

In [2]:
class Block1(nn.Module):
  def __init__(self, input_channels, output_channels):
    super().__init__()
    self.conv1 = nn.Conv3d(in_channels=input_channels, out_channels=output_channels, kernel_size=3, stride=1, padding=1)
    self.instnorm = nn.InstanceNorm3d(output_channels)
    self.lrelu = nn.LeakyReLU()
    self.conv2 = nn.Conv3d(in_channels=output_channels, out_channels=output_channels, kernel_size=3, stride=1, padding=1)
  def forward(self, x):
    x = self.conv1(x)
    x = self.instnorm(x)
    x = self.lrelu(x)
    x = self.conv2(x)
    x = self.instnorm(x)
    x = self.lrelu(x)
    return x

class ConvBlockEncoder(nn.Module): # need to add nn.Module here
  def __init__(self, input_channels, output_channels):
    super().__init__()
    # self.conv = nn.Conv3d(in_channels=input_channels, out_channels=output_channels, kernel_size=3, stride=1)
    self.conv1 = nn.Conv3d(in_channels=input_channels, out_channels=output_channels, kernel_size=3, stride=2, padding=1) # before padding it was 64 x 63 x 63 x 63. by adding 1, you add one to each dimension needed
    self.instnorm = nn.InstanceNorm3d(output_channels)
    self.lrelu = nn.LeakyReLU()
    self.conv2 = nn.Conv3d(in_channels=output_channels, out_channels=output_channels, kernel_size=3, stride=1, padding=1)
  def forward(self, x):
    x = self.conv1(x)
    x = self.instnorm(x)
    x = self.lrelu(x)
    x = self.conv2(x)
    x = self.instnorm(x)
    x = self.lrelu(x)
    return x

class ConvBlockDecoder(nn.Module):
  def __init__(self, input_channels, output_channels):
    super().__init__()
    self.transconv = nn.ConvTranspose3d(in_channels=input_channels, out_channels=input_channels, kernel_size=2, stride=2)
    self.conv1 = nn.Conv3d(in_channels=input_channels+input_channels, out_channels=output_channels, kernel_size=3, stride=1, padding=1)
    self.instnorm = nn.InstanceNorm3d(output_channels)
    self.lrelu = nn.LeakyReLU()
    self.conv2 = nn.Conv3d(in_channels=output_channels, out_channels=output_channels, kernel_size=3, stride=1, padding=1)
  def forward(self, x, concat_tensor):
    x = self.transconv(x)
    # print("after transpose: " + str(x.shape))
    x = torch.cat((x, concat_tensor), dim=0)
    # print("after concat: " + str(x.shape))
    x = self.conv1(x)
    # print("after conv1: " + str(x.shape))
    x = self.instnorm(x)
    x = self.lrelu(x)
    x = self.conv2(x)
    x = self.instnorm(x)
    x = self.lrelu(x)
    return x

class nnUNet(nn.Module):
  def __init__(self, input_shape):
    super().__init__()
    self.convblock1 = Block1(4, 32)
    self.convblock2 = ConvBlockEncoder(32, 64)
    self.convblock3 = ConvBlockEncoder(64, 128)
    self.convblock4 = ConvBlockEncoder(128, 256)
    self.convblock5 = ConvBlockEncoder(256, 320)
    self.convblock6 = ConvBlockEncoder(320, 320)

    self.convblock7 = ConvBlockDecoder(320, 256)
    self.convblock8 = ConvBlockDecoder(256, 128)
    self.convblock9 = ConvBlockDecoder(128, 64)
    self.convblock10 = ConvBlockDecoder(64, 32)
    self.convblock11 = ConvBlockDecoder(32, 16)
    self.convblock11b = nn.Conv3d(in_channels=16, out_channels=3, kernel_size=1, stride=1) # output is 3 channels for WT, ET, TC

    self.softmax = nn.Softmax(dim=1)

  def forward(self, x, ):
    print(x.shape) # [4, 128, 128, 128]

    # ENCODER
    x1 = self.convblock1(x)
    print(x1.shape) # [32, 128, 128, 128]

    x2 = self.convblock2(x1)
    print(x2.shape) # [64, 64, 64, 64]

    x3 = self.convblock3(x2)
    print(x3.shape) # [128, 32, 32, 32]

    x4 = self.convblock4(x3)
    print(x4.shape) # [256, 16, 16, 16]

    x5 = self.convblock5(x4)
    print(x5.shape) # [320, 8, 8, 8]

    x6 = self.convblock6(x5)
    print(x6.shape) # [320, 4, 4, 4]

    x7 = self.convblock7(x6, x5)
    print(x7.shape) # [256, 8, 8, 8]

    x8 = self.convblock8(x7, x4)
    print(x8.shape) # [128, 16, 16, 16]

    x9 = self.convblock9(x8, x3)
    print(x9.shape) # [64, 32, 32, 32]

    x10 = self.convblock10(x9, x2)
    print(x10.shape) # [32, 64, 64, 64]

    x11 = self.convblock11(x10, x1)
    print(x11.shape) # [16, 128, 128, 128]

    x12 = self.convblock11b(x11)
    print(x12.shape) # [3, 128, 128, 128]

    x13 = self.softmax(x12)
    print(x13.shape) # [3, 128, 128, 128] - i thought that the shape would change when using softmax, but only the values change to add up to 1

    return x13

In [None]:
"""
when dim = 1
>>> torch.cat((x, x, x), 1)
tensor([[ 0.6580, -1.0969, -0.4614,  0.6580, -1.0969, -0.4614,  0.6580,
         -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497, -0.1034, -0.5790,  0.1497, -0.1034,
         -0.5790,  0.1497]])
"""

In [4]:
x = torch.rand(size=(4, 128, 128, 128), dtype=torch.float32)
# print(x.shape)

model = nnUNet(x.shape)
print(model)
print()

out = model(x)
# print(out.shape)

nnUNet(
  (convblock1): Block1(
    (conv1): Conv3d(4, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (instnorm): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (lrelu): LeakyReLU(negative_slope=0.01)
    (conv2): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  )
  (convblock2): ConvBlockEncoder(
    (conv1): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
    (instnorm): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (lrelu): LeakyReLU(negative_slope=0.01)
    (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  )
  (convblock3): ConvBlockEncoder(
    (conv1): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
    (instnorm): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (lrelu): LeakyReLU(negative_slope=0.01)
    (conv2): Con

In [5]:
print(out.shape)
print(out)

torch.Size([3, 128, 128, 128])
tensor([[[[0.0114, 0.0073, 0.0071,  ..., 0.0091, 0.0054, 0.0066],
          [0.0065, 0.0085, 0.0094,  ..., 0.0079, 0.0119, 0.0092],
          [0.0077, 0.0083, 0.0112,  ..., 0.0074, 0.0071, 0.0087],
          ...,
          [0.0069, 0.0101, 0.0121,  ..., 0.0088, 0.0091, 0.0083],
          [0.0067, 0.0073, 0.0085,  ..., 0.0074, 0.0087, 0.0081],
          [0.0089, 0.0079, 0.0110,  ..., 0.0071, 0.0064, 0.0054]],

         [[0.0081, 0.0089, 0.0038,  ..., 0.0062, 0.0048, 0.0063],
          [0.0089, 0.0060, 0.0062,  ..., 0.0042, 0.0058, 0.0075],
          [0.0074, 0.0042, 0.0054,  ..., 0.0055, 0.0050, 0.0082],
          ...,
          [0.0079, 0.0058, 0.0054,  ..., 0.0153, 0.0074, 0.0101],
          [0.0053, 0.0071, 0.0053,  ..., 0.0074, 0.0062, 0.0078],
          [0.0109, 0.0070, 0.0090,  ..., 0.0098, 0.0104, 0.0074]],

         [[0.0119, 0.0068, 0.0087,  ..., 0.0061, 0.0090, 0.0081],
          [0.0084, 0.0068, 0.0065,  ..., 0.0096, 0.0095, 0.0075],
          [