In [None]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        self.activation = nn.Mish()  # Activation function
        self.pool = nn.MaxPool3d(2, stride=2)  # Pooling

        # Encoder layers

        self.e11 = nn.Conv3d(1, 16, kernel_size=3, padding=1)
        self.e12 = nn.Conv3d(16, 16, kernel_size=3, padding=1)

        self.e21 = nn.Conv3d(16, 64, kernel_size=3, padding=1)
        self.e22 = nn.Conv3d(64, 64, kernel_size=3, padding=1)

        # Middle layers

        self.m1 = nn.Conv3d(64, 128, kernel_size=3, padding=1)
        self.m2 = nn.Conv3d(128, 128, kernel_size=3, padding=1)

        # Decoder layers

        self.upconv1 = nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2, output_padding=(1, 0, 1))  # Upscaling
        self.d11 = nn.Conv3d(128, 64, kernel_size=3, padding=1)
        self.d12 = nn.Conv3d(64, 64, kernel_size=3, padding=1)

        self.upconv2 = nn.ConvTranspose3d(64, 48, kernel_size=2, stride=2)  # Upscaling
        self.d21 = nn.Conv3d(64, 16, kernel_size=3, padding=1)
        self.d22 = nn.Conv3d(16, 16, kernel_size=3, padding=1)

        self.output = nn.Conv3d(16, 1, kernel_size=3, padding=1)

    def forward(self, x):
        x = x.unsqueeze(1)

        x_e11 = self.activation(self.e11(x))
        x_e12 = self.activation(self.e12(x_e11))
        x_pool1 = self.pool(x_e12)

        x_e21 = self.activation(self.e12(x_pool1))
        x_e22 = self.activation(self.e21(x_e21))
        x_pool2 = self.pool(x_e22)

        x_m1 = self.activation(self.m1(x_pool2))
        x_m2 = self.activation(self.m2(x_m1))

        x_upconv1 = self.upconv1(x_m2)
        x_upconv1 = torch.cat([x_upconv1, x_e22], dim=1)  # Concatenating
        x_d11 = self.activation(self.d11(x_upconv1))
        x_d12 = self.activation(self.d12(x_d11))

        x_upconv2 = self.upconv2(x_d12)
        x_upconv2 = torch.cat([x_upconv2, x_e12], dim=1)  # Concatenating
        x_d21 = self.activation(self.d21(x_upconv2))
        x_d22 = self.activation(self.d22(x_d21))

        x_out = self.output(x_d22)

        return x_out.squeeze(1)


# Create the model
model = UNet().to(device)

In [None]:
class UNetV2(nn.Module):
    def __init__(self):
        super(UNetV2, self).__init__()

        self.activation = nn.Mish()  # Activation function
        self.pool = nn.MaxPool3d(2, stride=2)  # Pooling
        self.dropout = nn.Dropout3d(0.2)  # Dropout

        # Encoder layers

        self.e11 = nn.Conv3d(1, 16, kernel_size=3, padding=1)
        self.e12 = nn.Conv3d(16, 16, kernel_size=3, padding=1)

        self.e21 = nn.Conv3d(16, 64, kernel_size=3, padding=1)
        self.e22 = nn.Conv3d(64, 64, kernel_size=3, padding=1)

        # Middle layers

        self.m1 = nn.Conv3d(64, 128, kernel_size=3, padding=1)
        self.m2 = nn.Conv3d(128, 128, kernel_size=3, padding=1)

        # Decoder layers

        self.upconv1 = nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2, output_padding=(1, 0, 1))  # Upscaling
        self.d11 = nn.Conv3d(128, 64, kernel_size=3, padding=1)
        self.d12 = nn.Conv3d(64, 64, kernel_size=3, padding=1)

        self.upconv2 = nn.ConvTranspose3d(64, 48, kernel_size=2, stride=2)  # Upscaling
        self.d21 = nn.Conv3d(64, 16, kernel_size=3, padding=1)
        self.d22 = nn.Conv3d(16, 16, kernel_size=3, padding=1)

        self.output = nn.Conv3d(16, 1, kernel_size=3, padding=1)

    def forward(self, x):
        x = x.unsqueeze(1)

        x_e11 = self.activation(self.e11(x))
        x_e12 = self.activation(self.e12(x_e11))
        x_pool1 = self.pool(x_e12)

        x_e21 = self.dropout(self.activation(self.e12(x_pool1)))
        x_e22 = self.dropout(self.activation(self.e21(x_e21)))
        x_pool2 = self.pool(x_e22)

        x_m1 = self.activation(self.m1(x_pool2))
        x_m2 = self.activation(self.m2(x_m1))

        x_upconv1 = self.upconv1(x_m2)
        x_upconv1 = torch.cat([x_upconv1, x_e22], dim=1)  # Concatenating
        x_d11 = self.activation(self.d11(x_upconv1))
        x_d12 = self.activation(self.d12(x_d11))

        x_upconv2 = self.upconv2(x_d12)
        x_upconv2 = torch.cat([x_upconv2, x_e12], dim=1)  # Concatenating
        x_d21 = self.activation(self.d21(x_upconv2))
        x_d22 = self.activation(self.d22(x_d21))

        x_out = self.output(x_d22)

        return x_out.squeeze(1)


In [None]:
class UNetV3(nn.Module):
    def __init__(self):
        super(UNetV3, self).__init__()

        self.activation = nn.Mish()  # Activation function
        self.pool = nn.MaxPool3d(2, stride=2)  # Pooling
        self.dropout = nn.Dropout3d(0.2)  # Dropout

        # Encoder layers

        self.e11 = nn.Conv3d(1, 16, kernel_size=3, padding=1)
        self.e12 = nn.Conv3d(16, 16, kernel_size=3, padding=1)

        self.e21 = nn.Conv3d(16, 32, kernel_size=3, padding=1)
        self.e22 = nn.Conv3d(32, 32, kernel_size=3, padding=1)

        self.e31 = nn.Conv3d(32, 64, kernel_size=3, padding=1)
        self.e32 = nn.Conv3d(64, 64, kernel_size=3, padding=1)

        # Middle layers

        self.m1 = nn.Conv3d(64, 128, kernel_size=3, padding=1)
        self.m2 = nn.Conv3d(128, 128, kernel_size=3, padding=1)

        # Decoder layers

        self.upconv1 = nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2, output_padding=1)  # Upscaling
        self.d11 = nn.Conv3d(128, 64, kernel_size=3, padding=1)
        self.d12 = nn.Conv3d(64, 64, kernel_size=3, padding=1)

        self.upconv2 = nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2, output_padding=(1, 0, 1))  # Upscaling
        self.d21 = nn.Conv3d(64, 32, kernel_size=3, padding=1)
        self.d22 = nn.Conv3d(32, 32, kernel_size=3, padding=1)

        self.upconv3 = nn.ConvTranspose3d(32, 16, kernel_size=2, stride=2)  # Upscaling
        self.d31 = nn.Conv3d(32, 16, kernel_size=3, padding=1)
        self.d32 = nn.Conv3d(16, 16, kernel_size=3, padding=1)


        self.output = nn.Conv3d(16, 1, kernel_size=3, padding=1)

    def forward(self, x):
        x = x.unsqueeze(1)

        x_e11 = self.activation(self.e11(x))
        x_e12 = self.activation(self.e12(x_e11))
        x_pool1 = self.pool(x_e12)

        x_e21 = self.activation(self.e21(x_pool1))
        x_e22 = self.activation(self.e22(x_e21))
        x_pool2 = self.pool(x_e22)

        x_e31 = self.activation(self.e31(x_pool2))
        x_e32 = self.activation(self.e32(x_e31))
        x_pool3 = self.pool(x_e32)

        x_m1 = self.activation(self.m1(x_pool3))
        x_m2 = self.activation(self.m2(x_m1))

        x_upconv1 = self.upconv1(x_m2)
        x_upconv1 = torch.cat([x_upconv1, x_e32], dim=1)  # Concatenating
        x_d11 = self.activation(self.d11(x_upconv1))
        x_d12 = self.activation(self.d12(x_d11))

        x_upconv2 = self.upconv2(x_d12)
        x_upconv2 = torch.cat([x_upconv2, x_e22], dim=1)  # Concatenating
        x_d21 = self.activation(self.d21(x_upconv2))
        x_d22 = self.activation(self.d22(x_d21))

        x_upconv3 = self.upconv3(x_d22)
        x_upconv3 = torch.cat([x_upconv3, x_e12], dim=1)  # Concatenating
        x_d31 = self.activation(self.d31(x_upconv3))
        x_d32 = self.activation(self.d32(x_d31))

        x_out = self.output(x_d32)

        return x_out.squeeze(1)


# Create the model
model = UNetV3().to(device)

In [None]:
class UNetV4(nn.Module):
    def __init__(self):
        super(UNetV5, self).__init__()

        self.activation = nn.Mish()  # Activation function
        self.pool = nn.MaxPool3d(2, stride=2)  # Pooling
        self.dropout = nn.Dropout3d(0.3)  # Dropout

        # Encoder layers

        self.e11 = nn.Conv3d(1, 16, kernel_size=3, padding=1)
        self.e12 = nn.Conv3d(16, 16, kernel_size=3, padding=1)

        self.e21 = nn.Conv3d(16, 32, kernel_size=3, padding=1)
        self.e22 = nn.Conv3d(32, 32, kernel_size=3, padding=1)

        self.e31 = nn.Conv3d(32, 64, kernel_size=3, padding=1)
        self.e32 = nn.Conv3d(64, 64, kernel_size=3, padding=1)

        # Middle layers

        self.m1 = nn.Conv3d(64, 128, kernel_size=3, padding=1)
        self.m2 = nn.Conv3d(128, 128, kernel_size=3, padding=1)

        # Decoder layers

        self.upconv1 = nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2, output_padding=1)  # Upscaling
        self.d11 = nn.Conv3d(128, 64, kernel_size=3, padding=1)
        self.d12 = nn.Conv3d(64, 64, kernel_size=3, padding=1)

        self.upconv2 = nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2, output_padding=(1, 0, 1))  # Upscaling
        self.d21 = nn.Conv3d(64, 32, kernel_size=3, padding=1)
        self.d22 = nn.Conv3d(32, 32, kernel_size=3, padding=1)

        self.upconv3 = nn.ConvTranspose3d(32, 16, kernel_size=2, stride=2)  # Upscaling
        self.d31 = nn.Conv3d(32, 16, kernel_size=3, padding=1)
        self.d32 = nn.Conv3d(16, 16, kernel_size=3, padding=1)


        self.output = nn.Conv3d(16, 1, kernel_size=3, padding=1)

    def forward(self, x):
        x = x.unsqueeze(1)

        x_e11 = self.activation(self.e11(x))
        x_e12 = self.activation(self.e12(x_e11))
        x_pool1 = self.pool(x_e12)

        x_e21 = self.activation(self.e21(x_pool1))
        x_e22 = self.dropout(self.activation(self.e22(x_e21)))
        x_pool2 = self.pool(x_e22)

        x_e31 = self.dropout(self.activation(self.e31(x_pool2)))
        x_e32 = self.dropout(self.activation(self.e32(x_e31)))
        x_pool3 = self.pool(x_e32)

        x_m1 = self.activation(self.m1(x_pool3))
        x_m2 = self.activation(self.m2(x_m1))

        x_upconv1 = self.upconv1(x_m2)
        x_upconv1 = torch.cat([x_upconv1, x_e32], dim=1)  # Concatenating
        x_d11 = self.activation(self.d11(x_upconv1))
        x_d12 = self.activation(self.d12(x_d11))

        x_upconv2 = self.upconv2(x_d12)
        x_upconv2 = torch.cat([x_upconv2, x_e22], dim=1)  # Concatenating
        x_d21 = self.activation(self.d21(x_upconv2))
        x_d22 = self.activation(self.d22(x_d21))

        x_upconv3 = self.upconv3(x_d22)
        x_upconv3 = torch.cat([x_upconv3, x_e12], dim=1)  # Concatenating
        x_d31 = self.activation(self.d31(x_upconv3))
        x_d32 = self.activation(self.d32(x_d31))

        x_out = self.output(x_d32)

        return x_out.squeeze(1)


# Create the model
model = UNetV4().to(device)

In [None]:
class UNetV5(nn.Module):
    def __init__(self):
        super(UNetV5, self).__init__()

        self.activation = nn.Mish()  # Activation function
        self.pool = nn.MaxPool3d(2, stride=2)  # Pooling

        # Encoder layers

        self.e11 = nn.Conv3d(1, 16, kernel_size=3, padding=1)
        self.e12 = nn.Conv3d(16, 16, kernel_size=3, padding=1)

        self.e21 = nn.Conv3d(16, 32, kernel_size=3, padding=1)
        self.e22 = nn.Conv3d(32, 32, kernel_size=3, padding=1)

        self.e31 = nn.Conv3d(32, 64, kernel_size=3, padding=1)
        self.e32 = nn.Conv3d(64, 64, kernel_size=3, padding=1)

        # Middle layers

        self.m1 = nn.Conv3d(64, 128, kernel_size=3, padding=1)
        self.m2 = nn.Conv3d(128, 128, kernel_size=3, padding=1)

        # Decoder layers

        self.upconv1 = nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2, output_padding=1)  # Upscaling
        self.d11 = nn.Conv3d(128, 64, kernel_size=3, padding=1)
        self.d12 = nn.Conv3d(64, 64, kernel_size=3, padding=1)

        self.upconv2 = nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2, output_padding=(1, 0, 1))  # Upscaling
        self.d21 = nn.Conv3d(64, 32, kernel_size=3, padding=1)
        self.d22 = nn.Conv3d(32, 32, kernel_size=3, padding=1)

        self.upconv3 = nn.ConvTranspose3d(32, 16, kernel_size=2, stride=2)  # Upscaling
        self.d31 = nn.Conv3d(32, 16, kernel_size=3, padding=1)
        self.d32 = nn.Conv3d(16, 16, kernel_size=3, padding=1)


        self.output = nn.Conv3d(16, 1, kernel_size=3, padding=1)

    def forward(self, x):
        x = x.unsqueeze(1)

        x_e11 = self.activation(self.e11(x))
        x_e12 = self.activation(self.e12(x_e11))
        x_pool1 = self.pool(x_e12)

        x_e21 = self.activation(self.e21(x_pool1))
        x_e22 = self.activation(self.e22(x_e21))
        x_pool2 = self.pool(x_e22)

        x_e31 = self.activation(self.e31(x_pool2))
        x_e32 = self.activation(self.e32(x_e31))
        x_pool3 = self.pool(x_e32)

        x_m1 = self.activation(self.m1(x_pool3))
        x_m2 = self.activation(self.m2(x_m1))

        x_upconv1 = self.upconv1(x_m2)
        x_upconv1 = torch.cat([x_upconv1, x_e32], dim=1)  # Concatenating
        x_d11 = self.activation(self.d11(x_upconv1))
        x_d12 = self.activation(self.d12(x_d11))

        x_upconv2 = self.upconv2(x_d12)
        x_upconv2 = torch.cat([x_upconv2, x_e22], dim=1)  # Concatenating
        x_d21 = self.activation(self.d21(x_upconv2))
        x_d22 = self.activation(self.d22(x_d21))

        x_upconv3 = self.upconv3(x_d22)
        x_upconv3 = torch.cat([x_upconv3, x_e12], dim=1)  # Concatenating
        x_d31 = self.activation(self.d31(x_upconv3))
        x_d32 = self.activation(self.d32(x_d31))

        x_out = self.output(x_d32)

        return x_out.squeeze(1)


# Create the model
model = UNetV5().to(device)

In [None]:
class UNetV6(nn.Module):
    def __init__(self):
        super(UNetV6, self).__init__()

        self.activation = nn.Mish()  # Activation function
        self.pool = nn.MaxPool3d(2, stride=2)  # Pooling

        # Encoder layers

        self.e11 = nn.Conv3d(1, 16, kernel_size=3, padding=1)
        self.e12 = nn.Conv3d(16, 16, kernel_size=3, padding=1)

        self.e21 = nn.Conv3d(16, 32, kernel_size=3, padding=1)
        self.e22 = nn.Conv3d(32, 32, kernel_size=3, padding=1)

        self.e31 = nn.Conv3d(32, 64, kernel_size=3, padding=1)
        self.e32 = nn.Conv3d(64, 64, kernel_size=3, padding=1)

        self.e41 = nn.Conv3d(64, 128, kernel_size=3, padding=1)
        self.e42 = nn.Conv3d(128, 128, kernel_size=3, padding=1)

        # Middle layers

        self.m1 = nn.Conv3d(128, 256, kernel_size=3, padding=1)
        self.m2 = nn.Conv3d(256, 256, kernel_size=3, padding=1)

        # Decoder layers

        self.upconv1 = nn.ConvTranspose3d(256, 128, kernel_size=2, stride=2, output_padding=(0, 1, 0))  # Upscaling
        self.d11 = nn.Conv3d(256, 128, kernel_size=3, padding=1)
        self.d12 = nn.Conv3d(128, 128, kernel_size=3, padding=1)

        self.upconv2 = nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2, output_padding=1)  # Upscaling
        self.d21 = nn.Conv3d(128, 64, kernel_size=3, padding=1)
        self.d22 = nn.Conv3d(64, 64, kernel_size=3, padding=1)

        self.upconv3 = nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2, output_padding=(1, 0, 1))  # Upscaling
        self.d31 = nn.Conv3d(64, 32, kernel_size=3, padding=1)
        self.d32 = nn.Conv3d(32, 32, kernel_size=3, padding=1)

        self.upconv4 = nn.ConvTranspose3d(32, 16, kernel_size=2, stride=2)  # Upscaling
        self.d41 = nn.Conv3d(32, 16, kernel_size=3, padding=1)
        self.d42 = nn.Conv3d(16, 16, kernel_size=3, padding=1)


        self.output = nn.Conv3d(16, 1, kernel_size=3, padding=1)

    def forward(self, x):
        x = x.unsqueeze(1)

        x_e11 = self.activation(self.e11(x))
        x_e12 = self.activation(self.e12(x_e11))
        x_pool1 = self.pool(x_e12)

        x_e21 = self.activation(self.e21(x_pool1))
        x_e22 = self.activation(self.e22(x_e21))
        x_pool2 = self.pool(x_e22)

        x_e31 = self.activation(self.e31(x_pool2))
        x_e32 = self.activation(self.e32(x_e31))
        x_pool3 = self.pool(x_e32)

        x_e41 = self.activation(self.e41(x_pool3))
        x_e42 = self.activation(self.e42(x_e41))
        x_pool4 = self.pool(x_e42)

        x_m1 = self.activation(self.m1(x_pool4))
        x_m2 = self.activation(self.m2(x_m1))

        x_upconv1 = self.upconv1(x_m2)
        x_upconv1 = torch.cat([x_upconv1, x_e42], dim=1)  # Concatenating
        x_d11 = self.activation(self.d11(x_upconv1))
        x_d12 = self.activation(self.d12(x_d11))

        x_upconv2 = self.upconv2(x_d12)
        x_upconv2 = torch.cat([x_upconv2, x_e32], dim=1)  # Concatenating
        x_d21 = self.activation(self.d21(x_upconv2))
        x_d22 = self.activation(self.d22(x_d21))

        x_upconv3 = self.upconv3(x_d22)
        x_upconv3 = torch.cat([x_upconv3, x_e22], dim=1)  # Concatenating
        x_d31 = self.activation(self.d31(x_upconv3))
        x_d32 = self.activation(self.d32(x_d31))

        x_upconv4 = self.upconv4(x_d32)
        x_upconv4 = torch.cat([x_upconv4, x_e12], dim=1)  # Concatenating
        x_d41 = self.activation(self.d41(x_upconv4))
        x_d42 = self.activation(self.d42(x_d41))

        x_out = self.output(x_d42)

        return x_out.squeeze(1)


# Create the model
model = UNetV6().to(device)

In [None]:
class UNetV7(nn.Module):
    def __init__(self):
        super(UNetV7, self).__init__()

        self.activation = nn.Mish()  # Activation function
        self.pool = nn.MaxPool3d(2, stride=2)  # Pooling

        # Encoder layers

        self.e11 = nn.Conv3d(1, 32, kernel_size=3, padding=1)
        self.e12 = nn.Conv3d(32, 32, kernel_size=3, padding=1)

        self.e21 = nn.Conv3d(32, 64, kernel_size=3, padding=1)
        self.e22 = nn.Conv3d(64, 64, kernel_size=3, padding=1)

        self.e31 = nn.Conv3d(64, 128, kernel_size=3, padding=1)
        self.e32 = nn.Conv3d(128, 128, kernel_size=3, padding=1)

        self.e41 = nn.Conv3d(128, 256, kernel_size=3, padding=1)
        self.e42 = nn.Conv3d(256, 256, kernel_size=3, padding=1)

        # Middle layers

        self.m1 = nn.Conv3d(256, 512, kernel_size=3, padding=1)
        self.m2 = nn.Conv3d(512, 512, kernel_size=3, padding=1)

        # Decoder layers

        self.upconv1 = nn.ConvTranspose3d(512, 256, kernel_size=2, stride=2, output_padding=(0, 1, 0))  # Upscaling
        self.d11 = nn.Conv3d(512, 256, kernel_size=3, padding=1)
        self.d12 = nn.Conv3d(256, 256, kernel_size=3, padding=1)

        self.upconv2 = nn.ConvTranspose3d(256, 128, kernel_size=2, stride=2, output_padding=1)  # Upscaling
        self.d21 = nn.Conv3d(256, 128, kernel_size=3, padding=1)
        self.d22 = nn.Conv3d(128, 128, kernel_size=3, padding=1)

        self.upconv3 = nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2, output_padding=(1, 0, 1))  # Upscaling
        self.d31 = nn.Conv3d(128, 64, kernel_size=3, padding=1)
        self.d32 = nn.Conv3d(64, 64, kernel_size=3, padding=1)

        self.upconv4 = nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2)  # Upscaling
        self.d41 = nn.Conv3d(64, 32, kernel_size=3, padding=1)
        self.d42 = nn.Conv3d(32, 32, kernel_size=3, padding=1)


        self.output = nn.Conv3d(32, 1, kernel_size=3, padding=1)

    def forward(self, x):
        x = x.unsqueeze(1)

        x_e11 = self.activation(self.e11(x))
        x_e12 = self.activation(self.e12(x_e11))
        x_pool1 = self.pool(x_e12)

        x_e21 = self.activation(self.e21(x_pool1))
        x_e22 = self.activation(self.e22(x_e21))
        x_pool2 = self.pool(x_e22)

        x_e31 = self.activation(self.e31(x_pool2))
        x_e32 = self.activation(self.e32(x_e31))
        x_pool3 = self.pool(x_e32)

        x_e41 = self.activation(self.e41(x_pool3))
        x_e42 = self.activation(self.e42(x_e41))
        x_pool4 = self.pool(x_e42)

        x_m1 = self.activation(self.m1(x_pool4))
        x_m2 = self.activation(self.m2(x_m1))

        x_upconv1 = self.upconv1(x_m2)
        x_upconv1 = torch.cat([x_upconv1, x_e42], dim=1)  # Concatenating
        x_d11 = self.activation(self.d11(x_upconv1))
        x_d12 = self.activation(self.d12(x_d11))

        x_upconv2 = self.upconv2(x_d12)
        x_upconv2 = torch.cat([x_upconv2, x_e32], dim=1)  # Concatenating
        x_d21 = self.activation(self.d21(x_upconv2))
        x_d22 = self.activation(self.d22(x_d21))

        x_upconv3 = self.upconv3(x_d22)
        x_upconv3 = torch.cat([x_upconv3, x_e22], dim=1)  # Concatenating
        x_d31 = self.activation(self.d31(x_upconv3))
        x_d32 = self.activation(self.d32(x_d31))

        x_upconv4 = self.upconv4(x_d32)
        x_upconv4 = torch.cat([x_upconv4, x_e12], dim=1)  # Concatenating
        x_d41 = self.activation(self.d41(x_upconv4))
        x_d42 = self.activation(self.d42(x_d41))

        x_out = self.output(x_d42)

        return x_out.squeeze(1)


# Create the model
model = UNetV7().to(device)

In [None]:
class UNetV9(nn.Module):
    def __init__(self):
        super(UNetV9, self).__init__()

        self.activation = nn.Mish()  # Activation function
        self.pool = nn.MaxPool3d(2, stride=2)  # Pooling

        # Encoder layers

        self.e11 = nn.Conv3d(1, 16, kernel_size=3, padding=1)
        self.e_bn11 = nn.BatchNorm3d(16)
        self.e12 = nn.Conv3d(16, 16, kernel_size=3, padding=1)
        self.e_bn12 = nn.BatchNorm3d(16)

        self.e21 = nn.Conv3d(16, 32, kernel_size=3, padding=1)
        self.e_bn21 = nn.BatchNorm3d(32)
        self.e22 = nn.Conv3d(32, 32, kernel_size=3, padding=1)
        self.e_bn22 = nn.BatchNorm3d(32)

        self.e31 = nn.Conv3d(32, 64, kernel_size=3, padding=1)
        self.e_bn31 = nn.BatchNorm3d(64)
        self.e32 = nn.Conv3d(64, 64, kernel_size=3, padding=1)
        self.e_bn32 = nn.BatchNorm3d(64)

        self.e41 = nn.Conv3d(64, 128, kernel_size=3, padding=1)
        self.e_bn41 = nn.BatchNorm3d(128)
        self.e42 = nn.Conv3d(128, 128, kernel_size=3, padding=1)
        self.e_bn42 = nn.BatchNorm3d(128)

        # Middle layers

        self.m1 = nn.Conv3d(128, 256, kernel_size=3, padding=1)
        self.m_bn1 = nn.BatchNorm3d(256)
        self.m2 = nn.Conv3d(256, 256, kernel_size=3, padding=1)
        self.m_bn2 = nn.BatchNorm3d(256)

        # Decoder layers

        self.upconv1 = nn.ConvTranspose3d(256, 128, kernel_size=2, stride=2, output_padding=(0, 1, 0))  # Upscaling
        self.d11 = nn.Conv3d(256, 128, kernel_size=3, padding=1)
        self.d_bn11 = nn.BatchNorm3d(128)
        self.d12 = nn.Conv3d(128, 128, kernel_size=3, padding=1)
        self.d_bn12 = nn.BatchNorm3d(128)

        self.upconv2 = nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2, output_padding=1)  # Upscaling
        self.d21 = nn.Conv3d(128, 64, kernel_size=3, padding=1)
        self.d_bn21 = nn.BatchNorm3d(64)
        self.d22 = nn.Conv3d(64, 64, kernel_size=3, padding=1)
        self.d_bn22 = nn.BatchNorm3d(64)

        self.upconv3 = nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2, output_padding=(1, 0, 1))  # Upscaling
        self.d31 = nn.Conv3d(64, 32, kernel_size=3, padding=1)
        self.d_bn31 = nn.BatchNorm3d(32)
        self.d32 = nn.Conv3d(32, 32, kernel_size=3, padding=1)
        self.d_bn32 = nn.BatchNorm3d(32)

        self.upconv4 = nn.ConvTranspose3d(32, 16, kernel_size=2, stride=2)  # Upscaling
        self.d41 = nn.Conv3d(32, 16, kernel_size=3, padding=1)
        self.d_bn41 = nn.BatchNorm3d(16)
        self.d42 = nn.Conv3d(16, 16, kernel_size=3, padding=1)
        self.d_bn42 = nn.BatchNorm3d(16)


        self.output = nn.Conv3d(16, 1, kernel_size=3, padding=1)

    def forward(self, x):
        x = x.unsqueeze(1)

        x_e11 = self.activation(self.e_bn11(self.e11(x)))
        x_e12 = self.activation(self.e_bn12(self.e12(x_e11)))
        x_pool1 = self.pool(x_e12)

        x_e21 = self.activation(self.e_bn21(self.e21(x_pool1)))
        x_e22 = self.activation(self.e_bn22(self.e22(x_e21)))
        x_pool2 = self.pool(x_e22)

        x_e31 = self.activation(self.e_bn31(self.e31(x_pool2)))
        x_e32 = self.activation(self.e_bn32(self.e32(x_e31)))
        x_pool3 = self.pool(x_e32)

        x_e41 = self.activation(self.e_bn41(self.e41(x_pool3)))
        x_e42 = self.activation(self.e_bn42(self.e42(x_e41)))
        x_pool4 = self.pool(x_e42)

        x_m1 = self.activation(self.m_bn1(self.m1(x_pool4)))
        x_m2 = self.activation(self.m_bn2(self.m2(x_m1)))

        x_upconv1 = self.upconv1(x_m2)
        x_upconv1 = torch.cat([x_upconv1, x_e42], dim=1)  # Concatenating
        x_d11 = self.activation(self.d_bn11(self.d11(x_upconv1)))
        x_d12 = self.activation(self.d_bn12(self.d12(x_d11)))

        x_upconv2 = self.upconv2(x_d12)
        x_upconv2 = torch.cat([x_upconv2, x_e32], dim=1)  # Concatenating
        x_d21 = self.activation(self.d_bn21(self.d21(x_upconv2)))
        x_d22 = self.activation(self.d_bn22(self.d22(x_d21)))

        x_upconv3 = self.upconv3(x_d22)
        x_upconv3 = torch.cat([x_upconv3, x_e22], dim=1)  # Concatenating
        x_d31 = self.activation(self.d_bn31(self.d31(x_upconv3)))
        x_d32 = self.activation(self.d_bn32(self.d32(x_d31)))

        x_upconv4 = self.upconv4(x_d32)
        x_upconv4 = torch.cat([x_upconv4, x_e12], dim=1)  # Concatenating
        x_d41 = self.activation(self.d_bn41((self.d41(x_upconv4))))
        x_d42 = self.activation(self.d_bn42(self.d42(x_d41)))

        x_out = self.output(x_d42)

        return x_out.squeeze(1)


# Create the model
model = UNetV9().to(device)