# U-Net Concepts (Self-teaching)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/USER/teaching-cnn/blob/main/book/self_teach/unet_concepts.ipynb)

We explore U-Net architecture, shapes, and skip connections.

In [None]:
# Minimal U-Net block shapes with torchsummary
import sys, subprocess
for p in ['torch', 'torchsummary']:
    try:
        __import__(p)
    except Exception:
        subprocess.check_call([sys.executable, '-m', 'pip', 'install', p, '-q'])

import torch
import torch.nn as nn
from torchsummary import summary

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.ReLU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.ReLU()
        )
    def forward(self, x):
        return self.net(x)

class MiniUNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.down1 = DoubleConv(1,16)
        self.pool1 = nn.MaxPool2d(2)
        self.down2 = DoubleConv(16,32)
        self.pool2 = nn.MaxPool2d(2)
        self.bottleneck = DoubleConv(32,64)
        self.up2 = nn.ConvTranspose2d(64,32,2,2)
        self.dec2 = DoubleConv(64,32)
        self.up1 = nn.ConvTranspose2d(32,16,2,2)
        self.dec1 = DoubleConv(32,16)
        self.out = nn.Conv2d(16,1,1)
    def forward(self, x):
        e1 = self.down1(x)
        e2 = self.down2(self.pool1(e1))
        b = self.bottleneck(self.pool2(e2))
        d2 = self.dec2(torch.cat([self.up2(b), e2], 1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], 1))
        return self.out(d1)

m = MiniUNet()
summary(m, (1,128,128))