In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F`

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1)
        self.shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity()

    def forward(self, x):
        residual = self.shortcut(x)
        x = F.relu(self.conv1(x))
        x = self.conv2(x)
        return F.relu(x + residual)

In [None]:
class VNetEncoder(nn.Module):
    def __init__(self, in_channels, out_channels, num_blocks):
        super(VNetEncoder, self).__init__()
        layers = []
        for i in range(num_blocks):
            layers.append(ResidualBlock(in_channels, out_channels))
            in_channels = out_channels
        self.encoder = nn.Sequential(*layers)

    def forward(self, x):
        return self.encoder(x)

In [None]:
class VNetDecoder(nn.Module):
    def __init__(self, in_channels, out_channels, num_blocks):
        super(VNetDecoder, self).__init__()
        layers = []
        for i in range(num_blocks):
            layers.append(ResidualBlock(in_channels, out_channels))
            in_channels = out_channels
        self.decoder = nn.Sequential(*layers)

    def forward(self, x):
        return self.decoder(x)

In [None]:
class VNetDownsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(VNetDownsample, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)

    def forward(self, x):
        return self.conv(x)