# Pair Programming: U-Net for Nuclei Segmentation (BBBC039)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/thawn/ttt-workshop-cnn/blob/main/book/exercises/unet_bbbc039_pair_programming.ipynb)

We train a small U-Net on a tiny subset of BBBC039 for speed.

In [None]:
# Deps
import sys, subprocess
for p in ['torch', 'torchvision', 'opencv-python', 'numpy', 'matplotlib', 'scikit-image', 'tqdm']:
    try:
        __import__(p if p!='opencv-python' else 'cv2')
    except Exception:
        subprocess.check_call([sys.executable, '-m', 'pip', 'install', p, '-q'])

In [None]:
import os, zipfile, urllib.request
import numpy as np
import matplotlib.pyplot as plt
import torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from skimage.io import imread
from skimage.transform import resize
from tqdm import tqdm

URL = 'https://raw.githubusercontent.com/aleju/cnn-datasets/master/BBBC039_small.zip'
zip_path = 'BBBC039_small.zip'
if not os.path.exists(zip_path):
    with urllib.request.urlopen(URL) as r:
        open(zip_path, 'wb').write(r.read())
if not os.path.exists('BBBC039_small'):
    with zipfile.ZipFile(zip_path) as z: z.extractall('.')

class Nuclei(Dataset):
    def __init__(self, root='BBBC039_small', size=128):
        imgs = sorted([p for p in os.listdir(root) if p.endswith('_img.png')])
        self.pairs = [(os.path.join(root,p), os.path.join(root,p.replace('_img','_mask'))) for p in imgs]
        self.size = size
    def __len__(self): return len(self.pairs)
    def __getitem__(self, i):
        x = imread(self.pairs[i][0], as_gray=True).astype(np.float32)
        y = imread(self.pairs[i][1], as_gray=True).astype(np.float32)
        x = resize(x, (self.size,self.size), preserve_range=True)
        y = resize(y, (self.size,self.size), preserve_range=True)
        x = (x - x.mean()) / (x.std() + 1e-6)
        return torch.from_numpy(x)[None], torch.from_numpy((y>0.5).astype(np.float32))[None]

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.ReLU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.ReLU()
        )
    def forward(self, x): return self.net(x)

class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.d1 = DoubleConv(1,16); self.p1 = nn.MaxPool2d(2)
        self.d2 = DoubleConv(16,32); self.p2 = nn.MaxPool2d(2)
        self.b = DoubleConv(32,64)
        self.u2 = nn.ConvTranspose2d(64,32,2,2); self.c2 = DoubleConv(64,32)
        self.u1 = nn.ConvTranspose2d(32,16,2,2); self.c1 = DoubleConv(32,16)
        self.out = nn.Conv2d(16,1,1)
    def forward(self, x):
        e1 = self.d1(x)
        e2 = self.d2(self.p1(e1))
        b = self.b(self.p2(e2))
        d2 = self.c2(torch.cat([self.u2(b), e2], 1))
        d1 = self.c1(torch.cat([self.u1(d2), e1], 1))
        return self.out(d1)

train_ds = Nuclei(); val_ds = Nuclei()
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=4)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = UNet().to(device)
opt = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.BCEWithLogitsLoss()

for epoch in range(3):
    model.train()
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}')
    for X,y in pbar:
        X,y = X.to(device), y.to(device)
        opt.zero_grad()
        logits = model(X)
        loss = criterion(logits, y)
        loss.backward(); opt.step()
        pbar.set_postfix(loss=float(loss))

model.eval()
with torch.no_grad():
    X,y = next(iter(val_loader))
    X = X.to(device)
    pred = torch.sigmoid(model(X)).cpu().numpy()

fig,ax=plt.subplots(1,3,figsize=(7,3))
ax[0].imshow(X[0,0].cpu(), cmap='gray'); ax[0].set_title('Input')
ax[1].imshow(y[0,0], cmap='gray'); ax[1].set_title('GT')
ax[2].imshow(pred[0,0]>0.5, cmap='gray'); ax[2].set_title('Pred')
for a in ax: a.axis('off')
plt.show()