In [None]:
import torch
import torch.nn as nn

In [39]:
class Block1(nn.Module):
  def __init__(self, input_channels, output_channels):
    super().__init__()
    self.conv1 = nn.Conv3d(in_channels=input_channels, out_channels=output_channels, kernel_size=3, stride=1, padding=1)
    self.instnorm = nn.InstanceNorm3d(output_channels)
    self.lrelu = nn.LeakyReLU()
    self.conv2 = nn.Conv3d(in_channels=output_channels, out_channels=output_channels, kernel_size=3, stride=1, padding=1)
  def forward(self, x):
    x = self.conv1(x)
    x = self.instnorm(x)
    x = self.lrelu(x)
    x = self.conv2(x)
    x = self.instnorm(x)
    x = self.lrelu(x)
    return x

class ConvBlockEncoder(nn.Module): # need to add nn.Module here
  def __init__(self, input_channels, output_channels):
    super().__init__()
    # self.conv = nn.Conv3d(in_channels=input_channels, out_channels=output_channels, kernel_size=3, stride=1)
    self.conv1 = nn.Conv3d(in_channels=input_channels, out_channels=output_channels, kernel_size=3, stride=2, padding=1) # before padding it was 64 x 63 x 63 x 63. by adding 1, you add one to each dimension needed
    self.instnorm = nn.InstanceNorm3d(output_channels)
    self.lrelu = nn.LeakyReLU()
    self.conv2 = nn.Conv3d(in_channels=output_channels, out_channels=output_channels, kernel_size=3, stride=1, padding=1)
  def forward(self, x):
    x = self.conv1(x)
    x = self.instnorm(x)
    x = self.lrelu(x)
    x = self.conv2(x)
    x = self.instnorm(x)
    x = self.lrelu(x)
    return x

class ConvBlockDecoder(nn.Module):
  def __init__(self, input_channels, output_channels):
    super().__init__()
    self.transconv = nn.ConvTranspose3d(in_channels=input_channels, out_channels=input_channels, kernel_size=2, stride=2)
    self.conv1 = nn.Conv3d(in_channels=input_channels, out_channels=output_channels, kernel_size=3, stride=1)
    self.instnorm = nn.InstanceNorm3d(output_channels)
    self.lrelu = nn.LeakyReLU()
    self.conv2 = nn.Conv3d(in_channels=output_channels, out_channels=output_channels, kernel_size=3, stride=1)
  def forward(self, x):
    x = self.transconv(x)
    x = self.conv1(x)
    x = self.instnorm(x)
    x = self.lrelu(x)
    x = self.conv2(x)
    x = self.instnorm(x)
    x = self.lrelu(x)
    return x

class nnUNet(nn.Module):
  def __init__(self, input_shape):
    super().__init__()
    self.convblock1 = Block1(4, 32)
    self.convblock2 = ConvBlockEncoder(32, 64)
    self.convblock3 = ConvBlockEncoder(64, 128)
    self.convblock4 = ConvBlockEncoder(128, 256)
    self.convblock5 = ConvBlockEncoder(256, 320)

  def forward(self, x):
    print(x.shape)

    x1 = self.convblock1(x)
    print(x1.shape)

    x2 = self.convblock2(x1)
    print(x2.shape)

    x3 = self.convblock3(x2)
    print(x3.shape)

    x4 = self.convblock4(x3)
    print(x4.shape)

    x5 = self.convblock5(x4)
    print(x5.shape)

In [40]:
x = torch.rand(size=(4, 128, 128, 128), dtype=torch.float32)
# print(x.shape)

model = nnUNet(x.shape)
print(model)
print()

out = model(x)
# print(out.shape)

nnUNet(
  (convblock1): Block1(
    (conv1): Conv3d(4, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (instnorm): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (lrelu): LeakyReLU(negative_slope=0.01)
    (conv2): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  )
  (convblock2): ConvBlockEncoder(
    (conv1): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
    (instnorm): InstanceNorm3d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (lrelu): LeakyReLU(negative_slope=0.01)
    (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  )
  (convblock3): ConvBlockEncoder(
    (conv1): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
    (instnorm): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (lrelu): LeakyReLU(negative_slope=0.01)
    (conv2): Con