In [2]:
import torch
import torch.nn as nn

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv_block = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(negative_slope=0.01),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(negative_slope=0.01)
        )

    def forward(self, x):
        return self.conv_block(x)

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.enc_blocks = nn.ModuleList([
            ConvBlock(1, 16),
            ConvBlock(16, 32),
            ConvBlock(32, 64)
        ])
        self.pool = nn.MaxPool3d(kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        enc_features = []
        for block in self.enc_blocks:
            x = block(x)
            enc_features.append(x)
            x = self.pool(x)
        return enc_features

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.upconvs = nn.ModuleList([
            nn.Conv3d(64, 32, kernel_size=3, stride=1, padding=1),
            nn.Conv3d(32, 16, kernel_size=3, stride=1, padding=1)
        ])
        self.dec_blocks = nn.ModuleList([
            ConvBlock(64, 32),
            ConvBlock(32, 16)
        ])

    def forward(self, x, enc_features):
        for i in range(len(self.upconvs)):
            x = self.upconvs[i](x)
            x = torch.cat((x, enc_features[i]), dim=1)  # concatenate along the channel axis
            x = self.dec_blocks[i](x)
        return x

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.c1x1 = nn.Conv3d(16, 2, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        enc_features = self.encoder(x)
        x = self.decoder(enc_features[-1], enc_features[::-1][1:])
        x = self.c1x1(x)
        return x

# Example usage
model = UNet()
input_tensor = torch.randn(1, 1, 64, 64, 64)  # batch size of 1, 1 channel, 64x64x64 volume
output = model(input_tensor)
print(output.shape)  # should be (1, 2, 64, 64, 64)


torch.Size([1, 2, 64, 64, 64])
