# UNet from Scratch in PyTorch

A self-contained implementation of [UNet](https://arxiv.org/abs/1505.04597) (Ronneberger et al., 2015) for semantic segmentation.

**What you will build**:
1. `DoubleConv` — two 3x3 convolutions with BatchNorm + ReLU
2. `Down` — max-pool then DoubleConv (encoder)
3. `Up` — bilinear upsample + skip connection + DoubleConv (decoder)
4. `OutConv` — 1x1 convolution to output class logits
5. `UNet` — full encoder-decoder with skip connections matching the original paper
6. Training loop on FoodSeg103 (streamed from HuggingFace) with Dice + CE loss
7. Predicted mask visualisation vs ground truth

Dataset: [EduardoPacheco/FoodSeg103](https://huggingface.co/datasets/EduardoPacheco/FoodSeg103) streamed from Hugging Face — no local download required.

In [None]:
# Install dependencies (run once in container)
!pip install datasets --quiet

In [None]:
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from datasets import load_dataset
from torch.utils.data import IterableDataset, DataLoader

IMG_SIZE    = 256          # UNet input resolution (power of 2 for clean pooling)
NUM_CLASSES = 103          # FoodSeg103: 103 food semantic categories
DEVICE      = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
MEAN        = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
STD         = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)

print(f'Device: {DEVICE}')
if DEVICE.type == 'cuda':
    print(f'GPU : {torch.cuda.get_device_name(0)}')
    print(f'VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB')

## 1. Dataset

We stream [FoodSeg103](https://huggingface.co/datasets/EduardoPacheco/FoodSeg103) directly from the Hugging Face Hub.
Each sample contains:
- `image` — RGB photograph of a food dish
- `label` — per-pixel class mask in [0, 102] (103 food categories)

We resize both image and mask to `IMG_SIZE x IMG_SIZE` using nearest-neighbour interpolation for the mask.

In [None]:
class FoodSeg103StreamDataset(IterableDataset):
    """Stream FoodSeg103 from HuggingFace and resize to IMG_SIZE x IMG_SIZE."""

    def __init__(self, split: str = 'train', max_samples: Optional[int] = None):
        super().__init__()
        self.ds = load_dataset('EduardoPacheco/FoodSeg103', split=split, streaming=True)
        self.max_samples = max_samples

    def __iter__(self):
        count = 0
        for sample in self.ds:
            if self.max_samples and count >= self.max_samples:
                break
            img  = sample['image'].convert('RGB').resize((IMG_SIZE, IMG_SIZE))
            mask = sample['label'].resize((IMG_SIZE, IMG_SIZE), Image.NEAREST)
            img_t  = TF.to_tensor(img)                                    # (3, H, W)
            img_t  = (img_t - MEAN) / STD                                # ImageNet normalisation
            mask_t = torch.from_numpy(np.array(mask)).long()             # (H, W) in [0, 102]
            mask_t = mask_t.clamp(0, NUM_CLASSES - 1)
            count += 1
            yield img_t, mask_t


# Sanity check
ds_check = FoodSeg103StreamDataset(split='train', max_samples=2)
imgs_c, masks_c = [], []
for img, mask in ds_check:
    imgs_c.append(img); masks_c.append(mask)

print(f'Image shape : {imgs_c[0].shape}')
print(f'Mask shape  : {masks_c[0].shape}')
print(f'Unique mask values (sample): {masks_c[0].unique().tolist()[:10]} ...')

## 2. UNet architecture

The original UNet paper (Ronneberger et al., 2015) uses:
- **Encoder**: 4 x (DoubleConv -> MaxPool), doubling channels each level: 64 -> 128 -> 256 -> 512
- **Bottleneck**: DoubleConv at 1024 channels
- **Decoder**: 4 x (Upsample + skip concat + DoubleConv), halving channels each level
- **Output**: 1x1 convolution to `NUM_CLASSES` channels

We use bilinear upsampling (rather than transposed convolutions) to avoid checkerboard artefacts.

In [None]:
class DoubleConv(nn.Module):
    """Two sequential 3x3 Conv -> BatchNorm -> ReLU blocks."""

    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels,  out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.block(x)


class Down(nn.Module):
    """Encoder step: 2x2 MaxPool then DoubleConv."""

    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        self.pool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.pool_conv(x)


class Up(nn.Module):
    """
    Decoder step: bilinear upsample, concatenate skip, then DoubleConv.

    in_channels = upsampled_channels + skip_channels (channels after cat).
    """

    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        self.up   = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
        x = self.up(x)
        # Pad if spatial dims are odd
        dh = skip.size(2) - x.size(2)
        dw = skip.size(3) - x.size(3)
        if dh > 0 or dw > 0:
            x = F.pad(x, [dw // 2, dw - dw // 2, dh // 2, dh - dh // 2])
        x = torch.cat([skip, x], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    """1x1 convolution to produce per-pixel class logits."""

    def __init__(self, in_channels: int, num_classes: int):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, num_classes, kernel_size=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.conv(x)


class UNet(nn.Module):
    """
    UNet encoder-decoder with skip connections (bilinear upsample variant).

    Channel flow:
        Input (3) -> enc1(64) -> enc2(128) -> enc3(256) -> enc4(512) -> bottleneck(1024)
        Decoder (channels after cat = upsampled + skip):
            dec4: 1024+512=1536 -> 512
            dec3:  512+256= 768 -> 256
            dec2:  256+128= 384 -> 128
            dec1:  128+ 64= 192 ->  64
        out: 64 -> num_classes
    """

    def __init__(self, in_channels: int = 3, num_classes: int = NUM_CLASSES):
        super().__init__()
        self.enc1       = DoubleConv(in_channels, 64)
        self.enc2       = Down(64,   128)
        self.enc3       = Down(128,  256)
        self.enc4       = Down(256,  512)
        self.bottleneck = Down(512, 1024)
        self.dec4       = Up(1536,  512)  # 1024 + 512
        self.dec3       = Up(768,   256)  #  512 + 256
        self.dec2       = Up(384,   128)  #  256 + 128
        self.dec1       = Up(192,    64)  #  128 +  64
        self.out        = OutConv(64, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Encoder
        s1 = self.enc1(x)          # (B,  64, H,    W   )
        s2 = self.enc2(s1)         # (B, 128, H/2,  W/2 )
        s3 = self.enc3(s2)         # (B, 256, H/4,  W/4 )
        s4 = self.enc4(s3)         # (B, 512, H/8,  W/8 )
        x  = self.bottleneck(s4)   # (B,1024, H/16, W/16)
        # Decoder
        x  = self.dec4(x, s4)     # cat(1024,512)->1536 -> 512
        x  = self.dec3(x, s3)     # cat(512, 256)-> 768 -> 256
        x  = self.dec2(x, s2)     # cat(256, 128)-> 384 -> 128
        x  = self.dec1(x, s1)     # cat(128,  64)-> 192 ->  64
        return self.out(x)         # (B, num_classes, H, W)


# Parameter count and shape check
model   = UNet(in_channels=3, num_classes=NUM_CLASSES).to(DEVICE)
total   = sum(p.numel() for p in model.parameters())
print(f'UNet parameters: {total:,}  (~{total/1e6:.1f}M)')

x_dummy = torch.randn(2, 3, IMG_SIZE, IMG_SIZE, device=DEVICE)
logits  = model(x_dummy)
print(f'Input  : {tuple(x_dummy.shape)}')
print(f'Output : {tuple(logits.shape)}  (B, C, H, W)')
assert logits.shape == (2, NUM_CLASSES, IMG_SIZE, IMG_SIZE), 'Shape mismatch!'
print('Shape check passed.')

## 3. Loss function

We combine **cross-entropy loss** (pixel-wise classification) with **Dice loss** (overlap metric that handles class imbalance):

$$\mathcal{L} = \mathcal{L}_{CE} + \mathcal{L}_{\text{Dice}}$$

where

$$\mathcal{L}_{\text{Dice}} = 1 - \frac{2 \sum p_i \, g_i + \varepsilon}{\sum p_i + \sum g_i + \varepsilon}$$

and $p_i$ are the softmax probabilities for the true class, $g_i$ are the one-hot ground truth labels.

In [None]:
def dice_loss(logits: torch.Tensor, targets: torch.Tensor,
              num_classes: int = NUM_CLASSES, eps: float = 1e-6) -> torch.Tensor:
    """
    Soft multiclass Dice loss.

    Args:
        logits  : (B, C, H, W) raw logits
        targets : (B, H, W)    long class indices in [0, C)
    Returns:
        Scalar mean Dice loss over classes.
    """
    probs      = logits.softmax(dim=1)                        # (B, C, H, W)
    targets_oh = F.one_hot(targets, num_classes)              # (B, H, W, C)
    targets_oh = targets_oh.permute(0, 3, 1, 2).float()      # (B, C, H, W)
    inter      = (probs * targets_oh).sum(dim=(2, 3))         # (B, C)
    union      = probs.sum(dim=(2, 3)) + targets_oh.sum(dim=(2, 3))
    dice       = (2.0 * inter + eps) / (union + eps)          # (B, C)
    return 1.0 - dice.mean()


def combined_loss(logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    return F.cross_entropy(logits, targets) + dice_loss(logits, targets)


# Sanity check
dummy_logits  = torch.randn(2, NUM_CLASSES, IMG_SIZE, IMG_SIZE)
dummy_targets = torch.randint(0, NUM_CLASSES, (2, IMG_SIZE, IMG_SIZE))
print(f'Combined loss (random init): {combined_loss(dummy_logits, dummy_targets).item():.4f}')

## 4. Training loop

We run a short training demo (10 gradient steps) to verify the full forward + backward pass.
A real training run would use hundreds of epochs; this demo confirms shapes and that loss decreases.

In [None]:
BATCH_SIZE  = 4
TRAIN_STEPS = 10   # Demo: 10 steps to verify forward + backward pass
LR          = 1e-3

train_ds     = FoodSeg103StreamDataset(split='train', max_samples=BATCH_SIZE * TRAIN_STEPS)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE)

model     = UNet(in_channels=3, num_classes=NUM_CLASSES).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
amp_on    = DEVICE.type == 'cuda'
scaler    = torch.cuda.amp.GradScaler(enabled=amp_on)

model.train()
history = []

for step, (imgs, masks) in enumerate(train_loader):
    imgs  = imgs.to(DEVICE)
    masks = masks.to(DEVICE)
    optimizer.zero_grad()
    with torch.cuda.amp.autocast(enabled=amp_on):
        logits = model(imgs)
        loss   = combined_loss(logits, masks)
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    history.append(loss.item())
    print(f'Step {step+1:2d}/{TRAIN_STEPS}  loss={loss.item():.4f}')

plt.figure(figsize=(6, 3))
plt.plot(range(1, len(history) + 1), history, marker='o')
plt.xlabel('Step'); plt.ylabel('Loss (CE + Dice)')
plt.title('Training loss — UNet demo (FoodSeg103)'); plt.grid(True)
plt.tight_layout(); plt.savefig('training_loss.png', dpi=120); plt.show()
print('Saved: training_loss.png')

## 5. Visualisation

We run inference on a small validation split and overlay the predicted segmentation masks on the input images.

In [None]:
VAL_SAMPLES = 4

val_ds = FoodSeg103StreamDataset(split='validation', max_samples=VAL_SAMPLES)
val_imgs, val_masks = [], []
for img, mask in val_ds:
    val_imgs.append(img); val_masks.append(mask)

val_batch = torch.stack(val_imgs).to(DEVICE)
val_gt    = torch.stack(val_masks).to(DEVICE)

model.eval()
with torch.no_grad():
    val_logits = model(val_batch)
val_preds = val_logits.argmax(dim=1).cpu().numpy()
val_gt_np = val_gt.cpu().numpy()

pixel_acc = (val_preds == val_gt_np).mean()
print(f'Pixel accuracy ({VAL_SAMPLES} val samples, 10-step model): {pixel_acc:.3f}')


def unnorm(t: torch.Tensor) -> np.ndarray:
    return (t.cpu() * STD + MEAN).permute(1, 2, 0).numpy().clip(0, 1)


CMAP = plt.cm.get_cmap('tab20', NUM_CLASSES)
fig, axes = plt.subplots(VAL_SAMPLES, 3, figsize=(10, VAL_SAMPLES * 3))
for i in range(VAL_SAMPLES):
    axes[i, 0].imshow(unnorm(val_imgs[i]));                                    axes[i, 0].set_title('Input')
    axes[i, 1].imshow(val_gt_np[i],  cmap=CMAP, vmin=0, vmax=NUM_CLASSES-1);  axes[i, 1].set_title('Ground truth')
    axes[i, 2].imshow(val_preds[i],  cmap=CMAP, vmin=0, vmax=NUM_CLASSES-1);  axes[i, 2].set_title('Prediction')
    for ax in axes[i]: ax.axis('off')

plt.suptitle('UNet — input / ground truth / predicted mask (FoodSeg103)', fontsize=13)
plt.tight_layout(); plt.savefig('unet_predictions.png', dpi=120); plt.show()
print('Saved: unet_predictions.png')