In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F`

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1)
        self.shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity()

    def forward(self, x):
        residual = self.shortcut(x)
        x = F.relu(self.conv1(x))
        x = self.conv2(x)
        return F.relu(x + residual)

In [None]:
class VNetEncoder(nn.Module):
    def __init__(self, in_channels, out_channels, num_blocks):
        super(VNetEncoder, self).__init__()
        layers = []
        for i in range(num_blocks):
            layers.append(ResidualBlock(in_channels, out_channels))
            in_channels = out_channels
        self.encoder = nn.Sequential(*layers)

    def forward(self, x):
        return self.encoder(x)

In [None]:
class VNetDecoder(nn.Module):
    def __init__(self, in_channels, out_channels, num_blocks):
        super(VNetDecoder, self).__init__()
        layers = []
        for i in range(num_blocks):
            layers.append(ResidualBlock(in_channels, out_channels))
            in_channels = out_channels
        self.decoder = nn.Sequential(*layers)

    def forward(self, x):
        return self.decoder(x)

In [None]:
class VNetDownsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(VNetDownsample, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)

    def forward(self, x):
        return self.conv(x)

In [None]:
class VNetUpsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(VNetUpsample, self).__init__()
        self.conv = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2)

    def forward(self, x):
        return self.conv(x)

In [None]:
class VNet(nn.Module):
    def __init__(self, in_channels, out_channels, base_channels=16, num_blocks=2):
        super(VNet, self).__init__()

        # Encoder path
        self.encoder1 = VNetEncoder(in_channels, base_channels, num_blocks)
        self.down1 = VNetDownsample(base_channels, base_channels * 2)
        self.encoder2 = VNetEncoder(base_channels * 2, base_channels * 2, num_blocks)
        self.down2 = VNetDownsample(base_channels * 2, base_channels * 4)
        self.encoder3 = VNetEncoder(base_channels * 4, base_channels * 4, num_blocks)
        self.down3 = VNetDownsample(base_channels * 4, base_channels * 8)
        self.encoder4 = VNetEncoder(base_channels * 8, base_channels * 8, num_blocks)

        # Decoder path
        self.up3 = VNetUpsample(base_channels * 8, base_channels * 4)
        self.decoder3 = VNetDecoder(base_channels * 8, base_channels * 4, num_blocks)
        self.up2 = VNetUpsample(base_channels * 4, base_channels * 2)
        self.decoder2 = VNetDecoder(base_channels * 4, base_channels * 2, num_blocks)
        self.up1 = VNetUpsample(base_channels * 2, base_channels)
        self.decoder1 = VNetDecoder(base_channels * 2, base_channels, num_blocks)

        # Final output
        self.final_conv = nn.Conv3d(base_channels, out_channels, kernel_size=1)

    def forward(self, x):
        # Encoder path
        enc1 = self.encoder1(x)
        down1 = self.down1(enc1)
        enc2 = self.encoder2(down1)
        down2 = self.down2(enc2)
        enc3 = self.encoder3(down2)
        down3 = self.down3(enc3)
        enc4 = self.encoder4(down3)

        # Decoder path with skip connections
        up3 = self.up3(enc4)
        dec3 = self.decoder3(torch.cat([up3, enc3], dim=1))

        up2 = self.up2(dec3)
        dec2 = self.decoder2(torch.cat([up2, enc2], dim=1))

        up1 = self.up1(dec2)
        dec1 = self.decoder1(torch.cat([up1, enc1], dim=1))

        # Final segmentation output
        output = self.final_conv(dec1)
        return output