In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary

class UpSampleBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UpSampleBlock, self).__init__()
        self.upsample = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2)
        self.conv = nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.gn = nn.GroupNorm(16, out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.upsample(x)
        x = self.conv(x)
        x = self.gn(x)
        x = self.relu(x)
        return x

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.gn1 = nn.GroupNorm(16, out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.gn2 = nn.GroupNorm(16, out_channels)

        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=stride),
                nn.GroupNorm(16, out_channels)
            )
        else:
            self.downsample = None

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.gn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.gn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)
        return out

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.initial_conv = nn.Conv3d(4, 64, kernel_size=3, stride=2, padding=1)
        self.initial_dropout = nn.Dropout3d(p=0.2)
        self.gn_initial = nn.GroupNorm(16, 64)
        self.relu_initial = nn.ReLU(inplace=True)
        self.resblock1 = ResidualBlock(64, 128, stride=2)
        self.resblock2 = ResidualBlock(128, 256, stride=2)
        self.resblock3 = ResidualBlock(256, 256, stride=2)

    def forward(self, x):
        x = self.initial_conv(x)
        x = self.initial_dropout(x)
        x = self.gn_initial(x)
        x = self.relu_initial(x)
        x = self.resblock1(x)
        x = self.resblock2(x)
        x = self.resblock3(x)
        return x

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.upblock1 = UpSampleBlock(256, 256)
        self.upblock2 = UpSampleBlock(256, 128)
        self.upblock3 = UpSampleBlock(128, 64)
        self.upblock4 = UpSampleBlock(64, 64)
        self.final_conv = nn.Conv3d(64, 3, kernel_size=1)  # Changed to output 3 channels for classes

    def forward(self, x):
        x = self.upblock1(x)
        x = self.upblock2(x)
        x = self.upblock3(x)
        x = self.upblock4(x)
        x = self.final_conv(x)
        return x

class SegModel(nn.Module):
    def __init__(self):
        super(SegModel, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# Initialize model
model = SegModel()


print(model)


SegModel(
  (encoder): Encoder(
    (initial_conv): Conv3d(4, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
    (initial_dropout): Dropout3d(p=0.2, inplace=False)
    (gn_initial): GroupNorm(16, 64, eps=1e-05, affine=True)
    (relu_initial): ReLU(inplace=True)
    (resblock1): ResidualBlock(
      (conv1): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
      (gn1): GroupNorm(16, 128, eps=1e-05, affine=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (gn2): GroupNorm(16, 128, eps=1e-05, affine=True)
      (downsample): Sequential(
        (0): Conv3d(64, 128, kernel_size=(1, 1, 1), stride=(2, 2, 2))
        (1): GroupNorm(16, 128, eps=1e-05, affine=True)
      )
    )
    (resblock2): ResidualBlock(
      (conv1): Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
      (gn1): GroupNorm(16, 256, eps=1e-05, affine=True)
      (r

In [None]:

device = torch.device("cuda:0")
model = SegModel().to(device)
model.cuda()