# 02 â€” U-Net Segmentation Starter (fastai, ResNet-34)
**Goal:** Train a U-Net with a ResNet-34 encoder on paired `images/` + `masks/`.

### Expected dataset layout
```
/kaggle/input/your-seg-dataset/
  images/
    img_001.jpg
    img_002.jpg
    ...
  masks/
    img_001.png   # same filename as image (optional suffix supported)
    img_002.png
    ...
```
Update the config cell to your Kaggle input path and (optionally) a mask suffix like `_mask`.

In [None]:
# Imports
from pathlib import Path
import random, json, numpy as np, yaml
from fastai.vision.all import *

# Reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

## Config

In [None]:
# Loads ../configs/config.yaml for shared defaults
# You can override segmentation-specific options below.
cfg_path = Path('../configs/config.yaml').resolve()
with open(cfg_path) as f:
    CFG = yaml.safe_load(f)

# --- Segmentation-specific fields (override here if needed) ---
SEG = {
    'dataset_root': CFG.get('dataset_root', '/kaggle/input/YOUR_SEG_DATASET'),
    'images_dir': 'images',
    'masks_dir': 'masks',
    'mask_suffix': '',           # e.g., '_mask' if your masks are named foo_mask.png
    'image_size': CFG.get('image_size', 224),
    'valid_pct': CFG.get('valid_pct', 0.2),
    'bs': CFG.get('bs', 8),
    'epochs': max(5, CFG.get('epochs', 5)),
    'random_seed': CFG.get('random_seed', 42)
}

SEG

## Data & Dataloaders

In [None]:
set_seed(SEG['random_seed'])

root = Path(SEG['dataset_root'])
images = root/SEG['images_dir']
masks  = root/SEG['masks_dir']
assert images.exists() and masks.exists(), f"Couldn't find images/masks at: {images} / {masks}"

# Utility: map image path -> mask path
def get_mask_path(img_path:Path):
    name = img_path.stem + SEG['mask_suffix'] + '.png'  # expect PNG masks by default
    return masks/name

def label_func(p): 
    return get_mask_path(p)

# Collect codes automatically (unique values from one mask) or define manually here.
# For binary masks with values {0,255}, we'll use IntToFloatTensor(div_mask=255) to normalize.
sample_mask = first(get_image_files(masks))
assert sample_mask is not None, "No mask files found."
sample = PILMask.create(sample_mask)
unique_vals = np.unique(np.array(sample))
print("Unique values in sample mask:", unique_vals)

# Build DataBlock
item_tfms = [Resize(SEG['image_size']*2)]  # bigger resize first
batch_tfms = [Normalize.from_stats(*imagenet_stats), IntToFloatTensor(div_mask=255)]

seg_block = DataBlock(
    blocks=(ImageBlock, MaskBlock(codes=np.array(unique_vals))),
    get_items=get_image_files,
    get_y=label_func,
    splitter=RandomSplitter(valid_pct=SEG['valid_pct'], seed=SEG['random_seed']),
    item_tfms=item_tfms,
    batch_tfms=batch_tfms
)

dls = seg_block.dataloaders(images, bs=SEG['bs'])
dls.show_batch(max_n=6)

## Model: U-Net with ResNet-34 encoder

In [None]:
# Use resnet34 backbone explicitly
learn = unet_learner(dls, resnet34, metrics=DiceMulti())

# Training schedule: quick fit to sanity-check, then fine_tune
learn.fine_tune(SEG['epochs'])

## Evaluation: Mean Dice and visual examples

In [None]:
reports = Path('../reports').resolve()
reports.mkdir(parents=True, exist_ok=True)

# Validate to get mean Dice
val_loss, mean_dice = learn.validate()
metrics = {'mean_dice': float(mean_dice)}
with open(reports/'seg_metrics.json', 'w') as f:
    json.dump(metrics, f, indent=2)
print("Saved:", reports/'seg_metrics.json')

# Show & save a grid of predictions
import matplotlib.pyplot as plt
learn.show_results(max_n=8, figsize=(8,8))
ex_path = reports/'seg_examples.png'
plt.savefig(ex_path, bbox_inches='tight')
print("Saved:", ex_path)

## Export model

In [None]:
learn.export(reports/'seg_unet_resnet34.pkl')
print("Saved:", reports/'seg_unet_resnet34.pkl')