In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class DenseBlock(nn.Module):
    def __init__(self, in_channels, growth_rate, num_layers):
        super(DenseBlock, self).__init__()
        self.num_layers = num_layers
        self.layers = nn.ModuleList()

        for _ in range(num_layers):
            self.layers.append(
                nn.Conv2d(in_channels, growth_rate, kernel_size=3, padding=1)
            )
            in_channels += growth_rate  # Increase the input channels after each layer
        
    def forward(self, x):
        for layer in self.layers:
            new_features = F.relu(layer(x))
            x = torch.cat([x, new_features], dim=1)  # Concatenate new features with existing features
        return x

In [None]:
class TransitionLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(TransitionLayer, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.pool = nn.MaxPool2d(2)
    
    def forward(self, x):
        x = self.conv(x)
        x = self.pool(x)
        return x

In [None]:
class UpConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UpConvBlock, self).__init__()
        self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
    
    def forward(self, x, skip):
        x = self.upconv(x)
        x = torch.cat([x, skip], dim=1)  # Concatenate the skip connection
        x = F.relu(self.conv(x))
        return x

In [None]:
class DenseUNet(nn.Module):
    def __init__(self, in_channels, out_channels, growth_rate=32, num_layers_per_block=4):
        super(DenseUNet, self).__init__()

        # Encoder
        self.enc1 = DenseBlock(in_channels, growth_rate, num_layers_per_block)
        self.trans1 = TransitionLayer(in_channels + growth_rate * num_layers_per_block, growth_rate * 2)

        self.enc2 = DenseBlock(growth_rate * 2, growth_rate, num_layers_per_block)
        self.trans2 = TransitionLayer(growth_rate * 2 + growth_rate * num_layers_per_block, growth_rate * 4)

        self.enc3 = DenseBlock(growth_rate * 4, growth_rate, num_layers_per_block)
        self.trans3 = TransitionLayer(growth_rate * 4 + growth_rate * num_layers_per_block, growth_rate * 8)

        self.enc4 = DenseBlock(growth_rate * 8, growth_rate, num_layers_per_block)
        self.trans4 = TransitionLayer(growth_rate * 8 + growth_rate * num_layers_per_block, growth_rate * 16)

        # Bottleneck
        self.bottleneck = DenseBlock(growth_rate * 16, growth_rate, num_layers_per_block)

        # Decoder
        self.upconv4 = UpConvBlock(growth_rate * 16, growth_rate * 8)
        self.upconv3 = UpConvBlock(growth_rate * 8, growth_rate * 4)
        self.upconv2 = UpConvBlock(growth_rate * 4, growth_rate * 2)
        self.upconv1 = UpConvBlock(growth_rate * 2, growth_rate)

        # Final output layer
        self.final_conv = nn.Conv2d(growth_rate, out_channels, kernel_size=1)

    def forward(self, x):
        # Encoder path
        enc1 = self.enc1(x)
        trans1 = self.trans1(enc1)

        enc2 = self.enc2(trans1)
        trans2 = self.trans2(enc2)

        enc3 = self.enc3(trans2)
        trans3 = self.trans3(enc3)

        enc4 = self.enc4(trans3)
        trans4 = self.trans4(enc4)

        # Bottleneck
        bottleneck = self.bottleneck(trans4)

        # Decoder path with skip connections
        dec4 = self.upconv4(bottleneck, enc4)
        dec3 = self.upconv3(dec4, enc3)
        dec2 = self.upconv2(dec3, enc2)
        dec1 = self.upconv1(dec2, enc1)

        # Final segmentation output
        output = self.final_conv(dec1)
        
        return output
