Imports:

In [5]:
import torch
import torch.nn as nn

Basic Building Blocks of U-Net:
* Double convolution blocks
* Downsampling blocks
* Upsampling blocks
* Skip connections

Double Convolution Block:
- Two 3Ã—3 convolutions keep spatial size
- BatchNorm stabilizes training
- ReLU gives non-linearity
- Padding=1 keeps height/width unchanged


In [6]:
class DoubleConv(nn.Module):
    """
    Basic U-Net block:
    Conv -> BN -> ReLU -> Conv -> BN -> ReLU
    """
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.net(x)

Down Block (Encoder):
- Each down block halves height/width.

In [7]:
class Down(nn.Module):
    """
    Downsampling block:
    MaxPool2d -> DoubleConv
    """
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.net = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.net(x)

Up Block (Decoder):

- Decoder upsamples
- Skip connections preserve fine structure (edges, tumor boundaries)
- Padding handles non-divisible sizes (rare but important)

In [8]:
class Up(nn.Module):
    """
    Upsampling block:
    Upsample -> Conv -> Concatenate skip -> DoubleConv
    """
    def __init__(self, in_channels, out_channels):
        super().__init__()

        # Upsampling (H,W doubled)
        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)

        # Reduce channels after upsampling
        self.conv1x1 = nn.Conv2d(in_channels // 2, in_channels // 2, kernel_size=1)

        # Final double conv to fuse skip connections
        self.double_conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        """
        x1 = decoder input (upsampled)
        x2 = encoder skip connection
        """
        x1 = self.up(x1)
        x1 = self.conv1x1(x1)

        # Pad if shape mismatch (can occur after odd divisions)
        diffY = x2.size(2) - x1.size(2)
        diffX = x2.size(3) - x1.size(3)

        x1 = nn.functional.pad(x1, [diffX // 2, diffX - diffX//2,
                                   diffY // 2, diffY - diffY//2])

        # Concatenate skip connection along channel axis
        x = torch.cat([x2, x1], dim=1)

        return self.double_conv(x)

Full U-Net Architecture:
 	
- Output NOT passed through sigmoid here
- Because BCEWithLogitsLoss requires raw logits
- Sigmoid will be applied in inference & evaluation

In [9]:
class UNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=1):
        super().__init__()

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)  # bottleneck

        self.up1 = Up(512 + 512, 256)
        self.up2 = Up(256 + 256, 128)
        self.up3 = Up(128 + 128, 64)
        self.up4 = Up(64 + 64, 64)

        self.out = nn.Conv2d(64, n_classes, kernel_size=1)

    def forward(self, x):
        # Encoder
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        # Decoder with skip connections
        x = self.up1(x5, x4)
        x = self.up2(x,  x3)
        x = self.up3(x,  x2)
        x = self.up4(x,  x1)

        return self.out(x)

Instantiate Model + Summary Test: 

In [10]:
model = UNet(n_channels=3, n_classes=1)
x = torch.randn(1, 3, 640, 640)  # dummy input
y = model(x)
y.shape

torch.Size([1, 1, 640, 640])