# LGG Brain MRI Segmentation — Transfer Learning + Self-Attention (ResNet-U-Net)

This notebook trains a segmentation model on the LGG FLAIR abnormality dataset:
- Patient-level split to prevent leakage
- Transfer learning (ImageNet pretrained ResNet34 encoder)
- Bottleneck self-attention (multi-head attention on deep features)
- Dice + BCE loss, Dice/IoU metrics
- Paper-style plots (global matplotlib styling)

Outputs saved to `/kaggle/working/`.


In [None]:
# If Kaggle already has these, this cell is harmless.
!pip -q install albumentations==1.3.1 opencv-python-headless==4.9.0.80


In [None]:
import os, re, glob, random, math, time
from pathlib import Path

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision.models import resnet34, ResNet34_Weights

import albumentations as A
from albumentations.pytorch import ToTensorV2


In [None]:
# -------------------------------
# ✅ GLOBAL PAPER-STYLE SETTINGS (as you provided)
# -------------------------------
plt.rcParams.update({
    "font.family": "serif",
    "font.serif": ["Times New Roman", "Times", "DejaVu Serif"],
    "font.size": 12,

    "axes.labelsize": 13,
    "axes.titlesize": 13,
    "axes.linewidth": 1.2,

    "xtick.labelsize": 11,
    "ytick.labelsize": 11,
    "xtick.major.size": 6,
    "ytick.major.size": 6,
    "xtick.minor.size": 3,
    "ytick.minor.size": 3,

    "legend.fontsize": 11,
    "legend.frameon": True,
    "legend.edgecolor": "0.4",

    "grid.linestyle": ":",
    "grid.linewidth": 0.7,
    "grid.alpha": 0.85,
})

def paper_axes(ax):
    ax.minorticks_on()
    ax.grid(True, which="major", linestyle=":", linewidth=0.8)
    ax.grid(True, which="minor", linestyle=":", linewidth=0.5, alpha=0.7)

    for spine in ax.spines.values():
        spine.set_linewidth(1.2)

    ax.tick_params(which="both", direction="in", top=True, right=True)

# -------------------------------
# Reproducibility
# -------------------------------
SEED = 7
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

torch.backends.cudnn.deterministic = False  # faster; OK with fixed seed
torch.backends.cudnn.benchmark = True


In [None]:
CFG = {
    "img_size": 256,
    "batch_size": 16,
    "num_workers": 2,
    "epochs": 12,
    "lr": 3e-4,
    "weight_decay": 1e-4,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "use_attention": True,          # ablation knob
    "attn_heads": 8,                # 512 channels / 8 heads works cleanly
    "threshold": 0.5,
}

CFG


In [None]:
## Load dataset index (with safe fallback)

Preferred:
- Load `lgg_master_slices.csv` created by your EDA notebook.

Fallback:
- Rebuild the dataframe directly from `/kaggle/input/**/kaggle_3m`.


In [None]:
KAGGLE_INPUT = Path("/kaggle/input")
candidates = list(KAGGLE_INPUT.glob("**/kaggle_3m"))
print("Found candidates:", [str(p) for p in candidates[:10]])

if len(candidates) == 0:
    raise FileNotFoundError("Could not find 'kaggle_3m' under /kaggle/input. Attach the dataset to this notebook.")

DATA_ROOT = candidates[0]
DATA_ROOT


In [None]:
def to_key(path):
    base = Path(path).name
    if base.endswith("_mask.tif"):
        base = base.replace("_mask.tif", "")
    else:
        base = base.replace(".tif", "")
    return base

def patient_id_from_path(path):
    return Path(path).parent.name

def slice_index_from_key(k):
    m = re.search(r"_(\d+)$", k)
    return int(m.group(1)) if m else np.nan

EDA_CSV = Path("/kaggle/working/eda_outputs/lgg_master_slices.csv")

if EDA_CSV.exists():
    df = pd.read_csv(EDA_CSV)
    print("Loaded:", EDA_CSV, "| rows:", len(df), "| patients:", df["patient_id"].nunique())
else:
    all_tifs = sorted(glob.glob(str(DATA_ROOT / "*" / "*.tif")))
    mask_tifs = sorted([p for p in all_tifs if p.endswith("_mask.tif")])
    img_tifs  = sorted([p for p in all_tifs if not p.endswith("_mask.tif")])

    img_map = {to_key(p): p for p in img_tifs}
    msk_map = {to_key(p): p for p in mask_tifs}
    keys = sorted(set(img_map.keys()) & set(msk_map.keys()))

    df = pd.DataFrame({
        "key": keys,
        "image_path": [img_map[k] for k in keys],
        "mask_path":  [msk_map[k] for k in keys],
    })
    df["patient_id"] = df["image_path"].apply(patient_id_from_path)
    df["slice_idx"] = df["key"].apply(slice_index_from_key)

    print("Built df | rows:", len(df), "| patients:", df["patient_id"].nunique())

df.head()


In [None]:
if "split" not in df.columns or df["split"].isna().any():
    patients = df["patient_id"].drop_duplicates().tolist()
    rng = np.random.default_rng(SEED)
    rng.shuffle(patients)

    n = len(patients)
    train_pat = set(patients[: int(0.8*n)])
    val_pat   = set(patients[int(0.8*n): int(0.9*n)])
    test_pat  = set(patients[int(0.9*n):])

    def assign_split(pid):
        if pid in train_pat: return "train"
        if pid in val_pat: return "val"
        return "test"

    df["split"] = df["patient_id"].apply(assign_split)

df["split"].value_counts(), df.groupby("split")["patient_id"].nunique()
