<a href="https://colab.research.google.com/github/rubykumari1/project_964/blob/main/Brat_2020_AI_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#Access to datasets uploaded to google Drive

from google.colab import drive
drive.mount('/content/drive', force_remount=True)
import pathlib


Mounted at /content/drive


In [2]:
#Environment setup - nibabel for reading NIFTI and tqdm for progress
!pip install -q nibabel tqdm

In [3]:
#Importing core libraries
import pathlib, torch, numpy as np, albumentations as A
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import Dataset, DataLoader
import time, json, pathlib, shutil, numpy as np, nibabel as nib
from tqdm import tqdm
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("GPU:", torch.cuda.get_device_name(0))

GPU: NVIDIA A100-SXM4-40GB


In [4]:
#Project path constants

DATA_ROOT = pathlib.Path(
    "/content/drive/MyDrive/964_project/segmentation/"
    "BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData"
)
PROJ_ROOT = pathlib.Path("/content/drive/MyDrive/964_project/segmentation")
SLICE_DIR = PROJ_ROOT / "4modslices" #This path save the mod slices



KEEP_ONLY_TUMOR = True tells the script to discard every axial slice whose mask is entirely background, which cuts down disk usage and training time while preventing the model from being flooded with “all-zero” examples that encourage a trivial background-only prediction.
TUMOR_PIX_THRESH = 0.01 keeps a slice only if at least 1 % of its pixels are labeled as tumour; this keeps peripheral slices that still contain useful tumour context, but drops slices where stray mislabeled pixels would add noise.
RESIZE_TO = None is a placeholder that lets you optionally downsample slices to a fixed resolution (for example 128×128) later on, so you can trade a little spatial detail for lower GPU memory use and more consistent input shapes without hard-coding it upfront.

In [None]:
KEEP_ONLY_TUMOR = True # no blank slces skip it i.e to discard every axial slice whose mask is entirely background,
TUMOR_PIX_THRESH = 0.01 #keeps a slice only if at least 1 % of its pixels are labeled as tumour
RESIZE_TO = None #downsample slices to a fixed resolution (for example 128×128)

The loop creates five sub-directories inside SLICE_DIR—one each for the four MRI modalities (t1, t1ce, t2, flair) and one for the segmentation mask, making sure they exist before any files are written.
Next, the code defines a path called processed_flag that points to processed.json, a bookkeeping file stored in the same directory. If that JSON file is already present, it is read and parsed into a Python set named done_patients; each entry records a patient that has been sliced and saved earlier. Finally, the script prints how many patient folders will be skipped this run, so it won’t redo work that’s already finished.

In [None]:
#cereating 5 subdirectories inside SLICE_DIR for 4 modalities (t1, t1ce, t2, flair) and 1 for the segmenation mask
for sub in ["t1","t1ce","t2","flair","mask"]:
    (SLICE_DIR / sub).mkdir(parents=True, exist_ok=True)

processed_flag = SLICE_DIR / "processed.json"

#a set done_patients to keep record of patients
done_patients  = set()
if processed_flag.exists():
    done_patients = set(json.loads(processed_flag.read_text()))
    #will skip previously processed patients
print("Will skip", len(done_patients), "previously-processed patients") #each entry records a patient that has been sliced and saved earlier

In [None]:
#check the raw Niftis shold be 369
patient_folders = sorted(DATA_ROOT.glob("BraTS20_Training_*" ))
print("Total patients found :", len(patient_folders))

#the first three patients
#it then builds a dictionary whose keys are the four MRI modalities t1, t1ce, t2, flair plus the segmentation label
#each mapping to the matching .nii file path inside that patients folder.
#Using nibabel it loads each volume, converts it to a NumPy array, records the array’s shape, and finally prints shapes

for folder in patient_folders[:3]:

    vols = {m:list(folder.glob(f"*_{m}.nii*")) for m in
            ["t1","t1ce","t2","flair","seg"]}

    shapes = []
    for m, paths in vols.items():
        img = nib.load(str(paths[0])); arr = img.get_fdata()
        shapes.append(arr.shape)
        print(m, arr.shape)
    print(folder.name, "shapes ", shapes[0])

In [None]:
#zscore standard-normalises a 3-D volume
def zscore(volume):
    mu, sigma = volume.mean(), volume.std()
    return (volume - mu) / (sigma + 1e-8) #1e-8 to prevent division by zero

#writes a given 2-D NumPy array to disk as a .npy file | float32
def save_slice(arr, out_path):
    np.save(out_path.with_suffix(".npy"), arr.astype(np.float32))

def slice_indices(mask_3d):
    """Return axialindices to keep ->

    based on tumor-pixel %"""

    if not KEEP_ONLY_TUMOR:

        return range(mask_3d.shape[2])

    nz_per_slice = (mask_3d > 0).sum(axis=(0,1)) #TUMOR_PIX_THRESH >0.1

    keep = np.where(nz_per_slice > TUMOR_PIX_THRESH*mask_3d[:,:,0].size)[0]
    return keep

In [None]:

# check one patient/one slice
"""
It loads all four MRI modalities and the segmentation mask for a chosen patient,
converts them to NumPy arrays, and counts how many tumour voxels appear in each axial slice.
If no slice contains tumour it simply picks the middle slice!
otherwise it selects the slice with the highest tumour-pixel count.
It then renders a six-panel figure: the four raw modalities, the mask alone, and an overlay of the mask on T1-CE,
so to confirm that the modalities align spatially and that the mask labels match the anatomy.
Finally it prints which slice index was displayed.

"""

import matplotlib.pyplot as plt

def show_sample(patient_folder):

    vols = {m:nib.load(str(list(patient_folder.glob(f"*_{m}.nii*"))[0]))
            for m in ["t1","t1ce","t2","flair","seg"]}

    vols = {k: v.get_fdata().astype(np.float32) for k,v in vols.items()}


    nz_per_slice = (vols["seg"] > 0).sum(axis=(0,1))
    if nz_per_slice.max() == 0:
        z = vols["seg"].shape[2] // 2
    else:
        z = int(np.argmax(nz_per_slice))

    titles = ["T1", "T1-CE", "T2", "FLAIR", "Mask", "T1-CE + Mask"]
    imgs   = [vols["t1"][:,:,z], vols["t1ce"][:,:,z], vols["t2"][:,:,z],
              vols["flair"][:,:,z], vols["seg"][:,:,z]]

    plt.figure(figsize=(15,3))
    for i, img in enumerate(imgs):
        plt.subplot(1,6,i+1
        plt.imshow(img.T, cmap='gray', origin='lower')
        plt.axis('off'); plt.title(titles[i])
    # overlay
    plt.subplot(1,6,6)
    plt.imshow(vols["t1ce"][:,:,z].T, cmap='gray', origin='lower')
    plt.imshow(vols["seg"][:,:,z].T, alpha=0.35, cmap='Reds', origin='lower')
    plt.axis('off'); plt.title(titles[5])
    plt.tight_layout()
    plt.show()
    print(f"Displayed slice z={z} from {patient_folder.name}")

#check it on the first patient
show_sample(patient_folders[0])


In [None]:

#load → normalise → slice → save

"""
The loop walks through every BraTS patient folder, skips any already listed in processed.json,
and only processes cases that contain all five required NIfTI files t1, t1ce, t2, flair, seg.
For each complete case it z-scores the four modalities, keeps axial slices whose masks contain at least 1 % tumour pixels,
saves those slices (and their mask) as .npy files in modality-specific sub-folders,
and logs the patient as finished so the run can resume safely if interrupted.

t1 1
t1ce 1
t2 1
flair 1
seg 0

"""        |

start_all = time.time()
for folder in tqdm(patient_folders, desc="Patients"):
    pid = folder.name
    if pid in done_patients: continue

    #skip if any missing Brats 355 ----------
    paths = {}
    missing = []
    for m in ["t1","t1ce","t2","flair","seg"]:
        files = list(folder.glob(f"*_{m}.nii*"))
        if files:
            paths[m] = files[0]
        else:
            missing.append(m)
    if missing:
        print(f"[WARN] {pid}: missing {missing} → skipped")
        done_patients.add(pid)
        processed_flag.write_text(json.dumps(sorted(done_patients)))
        continue

    t0 = time.time()
    paths = {m:list(folder.glob(f"*_{m}.nii*"))[0] for m in
             ["t1","t1ce","t2","flair","seg"]}
    vols  = {m:nib.load(str(p)).get_fdata().astype(np.float32)
             for m,p in paths.items()}

    # z-score the four imaging modalities (leave mask unchanged)
    for m in ["t1","t1ce","t2","flair"]:
        vols[m] = zscore(vols[m])

    # select slice indices
    keep_z = slice_indices(vols["seg"])
    if len(keep_z) == 0:                # should never happen
        print("No tumor slices in", pid)
        continue

      """
      The data become 2-D at the save_slice(..., vols["t1"][:,:,z]) here each 3-D volume is indexed [:,:,z] to grab a single axial slice before saving

      """

    # iterate slices
    for z in keep_z:
        tag = f"{pid}_z{z:03d}"
        save_slice(vols["t1"]   [:,:,z], SLICE_DIR / "t1"   / tag)
        save_slice(vols["t1ce"] [:,:,z], SLICE_DIR / "t1ce" / tag)
        save_slice(vols["t2"]   [:,:,z], SLICE_DIR / "t2"   / tag)
        save_slice(vols["flair"][:,:,z], SLICE_DIR / "flair"/ tag)
        np.save((SLICE_DIR / "mask" / tag).with_suffix(".npy"),
                vols["seg"][:,:,z].astype(np.uint8))

    #record
    done_patients.add(pid)
    processed_flag.write_text(json.dumps(sorted(done_patients)))

    print(f"{pid}: kept {len(keep_z)} slices | "
          f"time {time.time()-t0:.1f} s")

print("=== Finished all patients in",
      time.time()-start_all, "seconds ===")


In [None]:
#cached count files / modality

# how many .npy slice files it contains and prints the counts

from collections import Counter
import glob, os

mod_counts = {
    mod: len(glob.glob(str(SLICE_DIR / mod / "*.npy")))

    for mod in ["t1", "t1ce", "t2", "flair", "mask"]
}
print("Cached slice counts:")

for k, v in mod_counts.items():

    print(f"{k:>5}: {v:6,d}")

total_pairs = mod_counts["mask"]

assert all(v == total_pairs for k, v in mod_counts.items() if k != "mask"), \
    "Mismatch between  the image and mask counts!"

print(f"\n All modalities aligned — total tumour-containing slices: {total_pairs:,}")

"""
Cached slice counts:
   t1: 15,273
 t1ce: 15,273
   t2: 15,273
flair: 15,273
 mask: 15,272
"""

In [None]:
# a randome slice plot across all 5 arrfays

import matplotlib.pyplot as plt, random, numpy as np, pathlib

def show_cached_sample():
    # pick one slice
    tag_path = random.choice(list((SLICE_DIR / "mask").glob("*.npy")))

    tag = tag_path.stem.replace("_mask", "") if tag_path.stem.endswith("_mask") else tag_path.stem
    pid, z = tag.split("_z")

    imgs = {}
    for mod in ["t1", "t1ce", "t2", "flair", "mask"]:

        imgs[mod] = np.load(SLICE_DIR / mod / f"{tag}.npy")

    titles = ["T1", "T1-CE", "T2", "FLAIR", "Mask", "Overlay"]

    # plot
    plt.figure(figsize=(15,3))
    for i, mod in enumerate(["t1","t1ce","t2","flair","mask"]):
        plt.subplot(1,6,i+1)
        plt.imshow(imgs[mod].T, cmap='gray' if mod!="mask" else 'viridis', origin='lower')
        plt.axis('off'); plt.title(titles[i])

    # overlay
    plt.subplot(1,6,6)
    plt.imshow(imgs["t1ce"].T, cmap='gray', origin='lower')
    plt.imshow(imgs["mask"].T, alpha=0.35, cmap='Reds', origin='lower')
    plt.axis('off'); plt.title(titles[5])
    plt.suptitle(f"{pid}  |  axial slice z={int(z):03d}")
    plt.tight_layout(); plt.show()

show_cached_sample()

In [None]:
#Build slice list from intersection of all modalities
import pathlib, random

"""
gathers the filenames in each modality folder..
keeps only the slice IDs that appear in all five modalities
groups those IDs by patient, shuffles the patients
 and then splits them patient-wise into an 80 % training set and a 20 % validation set
 so that no slices from the same patient leak across splits.

"""

DATA_DIR = pathlib.Path("/content/drive/MyDrive/964_project/segmentation/4modslices")

mods = ["t1", "t1ce", "t2", "flair", "mask"]

# Collect IDs for each folder
id_sets = {}
for m in mods:
    files = sorted((DATA_DIR / m).glob("*.npy"))
    id_sets[m] = {f.stem for f in files}

# Intersection: only keep IDs that exist in every modality **and** mask

common_ids = set.intersection(*id_sets.values())

print(f"Intersection slice count: {len(common_ids):,}  "
      f"(dropped {len(id_sets['mask']) - len(common_ids)} orphan masks)")

# Group IDs by patient for leakage-free split

def patient_id_from_slice(sid: str) -> str:
    # Example sid: "BraTS20_Training_033_z106"  ->  "BraTS20_Training_033"
    return "_".join(sid.split("_")[:3])   # adjust if your naming differs

patient_to_slices = {}
for sid in common_ids:
    pid = patient_id_from_slice(sid)
    patient_to_slices.setdefault(pid, []).append(sid)

patients = sorted(patient_to_slices.keys())
random.seed(42)
random.shuffle(patients)

val_frac = 0.20 #20%
n_val   = int(len(patients) * val_frac)
val_patients   = set(patients[:n_val])
train_patients = set(patients[n_val:])

train_slice_ids = [sid for p in train_patients for sid in patient_to_slices[p]]
val_slice_ids   = [sid for p in val_patients   for sid in patient_to_slices[p]]

print(f"\nPatients → train: {len(train_patients)}, val: {len(val_patients)}")
print(f"Slices   → train: {len(train_slice_ids):,}, val: {len(val_slice_ids):,}")

"""
Intersection slice count: 15,272  (dropped 0 orphan masks)

Patients → train: 258, val: 64
Slices   → train: 12,128, val: 3,144
"""


In [None]:
#an Albumentations augmentation pipeline (random flips, 90-degree rotations,
#brightness/contrast jitter) that is applied jointly to images and their masks

train_transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.RandomBrightnessContrast(p=0.5)
], additional_targets={"mask":"mask"})



In [None]:
class BraTSSliceDS(Dataset):
    mods = ["t1","t1ce","t2","flair"]; mask_folder="mask"

    def __init__(self, root, ids, tf=None):
        self.root, self.ids, self.tf = pathlib.Path(root), ids, tf

    def __len__(self): return len(self.ids)

    def __getitem__(self, idx):
        sid = self.ids[idx]
        img  = np.stack([np.load(self.root/m/f"{sid}.npy", mmap_mode='r')
                         for m in self.mods], -1)
        mask = (np.load(self.root/self.mask_folder/f"{sid}.npy", mmap_mode='r')>0
                ).astype(np.float32)[...,None]
        if self.tf:
            aug = self.tf(image=img, mask=mask); img, mask = aug["image"], aug["mask"]
        img  = torch.from_numpy(img.transpose(2,0,1)).float()
        mask = torch.from_numpy(mask.transpose(2,0,1)).float()
        return img, mask

batch = 8 #batch size=8

train_ds = BraTSSliceDS(DATA_DIR, train_slice_ids, train_transform)
val_ds   = BraTSSliceDS(DATA_DIR, val_slice_ids)

train_loader = DataLoader(train_ds, batch, shuffle=True,  num_workers=0, pin_memory=False)
val_loader   = DataLoader(val_ds,   batch, shuffle=False, num_workers=0, pin_memory=False)

# sanity fetch
x,y = next(iter(train_loader))

print(" print batch :", x.shape, y.shape, torch.unique(y))

"""
torch.Size([8, 4, 240, 240]) torch.Size([8, 1, 240, 240]) tensor([0., 1.])
"""

In [None]:
 #Model (U-Net + ResNet50 encoder)
 """
builds a lightweight 2-D U-Net whose encoder re-uses an ImageNet-pretrained ResNet-50
the first convolution is widened from 3 to 4 channels so it can ingest the four MRI modalities,
and the original RGB weights are copied while the new fourth-channel kernel is initialised by the RGB mean.
The decoder upsamples with bilinear interpolation and concatenates skip-features from each encoder stage,
then a 1×1 conv maps the final 64-channel feature map to a single-channel probability mask followed by a sigmoid
 """

"""
ResNet-50- its first conv expects 3-channel RGB,
but our MRI input has four modalities, so we replace that layer with a 4-channel Conv2d.
To retain as much pretrained signal as possible, we copy the original RGB kernels into the first three input planes and
initialise the new fourth plane with the mean of those kernels; this lets the network start with sensible weights instead of random
values while still accommodating the extra modality.

"""

import torchvision
import torch.nn as nn
import torch.nn.functional as F

class UNetResNet50(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        enc = torchvision.models.resnet50(weights='IMAGENET1K_V1' if pretrained else None)
        #adapting first conv to 4-channels

        old_conv = enc.conv1
        new_conv = nn.Conv2d(4, old_conv.out_channels, 7, 2, 3, bias=False)
        with torch.no_grad():
            new_conv.weight[:, :3] = old_conv.weight
            new_conv.weight[:, 3:] = old_conv.weight.mean(dim=1, keepdim=True)
        enc.conv1 = new_conv
        self.enc1 = nn.Sequential(enc.conv1, enc.bn1, enc.relu)
        self.enc2 = nn.Sequential(enc.maxpool, enc.layer1)
        self.enc3 = enc.layer2
        self.enc4 = enc.layer3
        self.enc5 = enc.layer4
        def block(ic, oc):
            return nn.Sequential(
                nn.Conv2d(ic, oc, 3, padding=1), nn.BatchNorm2d(oc), nn.ReLU(inplace=True),
                nn.Conv2d(oc, oc, 3, padding=1), nn.BatchNorm2d(oc), nn.ReLU(inplace=True)
            )
        self.dec4 = block(2048+1024, 1024)
        self.dec3 = block(1024+512,  512)
        self.dec2 = block(512 +256,  256)
        self.dec1 = block(256 +64,    64)
        self.final = nn.Conv2d(64,1,1)

    def forward(self,x):
        x1 = self.enc1(x)          #64 ×120×120
        x2 = self.enc2(x1)         #256×60×60
        x3 = self.enc3(x2)         #512×30×30
        x4 = self.enc4(x3)         #1024×15×15
        x5 = self.enc5(x4)         #2048×8×8
        d4 = self.dec4(torch.cat([F.interpolate(x5, x4.shape[2:], mode='bilinear', align_corners=False), x4],1))
        d3 = self.dec3(torch.cat([F.interpolate(d4, x3.shape[2:], mode='bilinear', align_corners=False), x3],1))
        d2 = self.dec2(torch.cat([F.interpolate(d3, x2.shape[2:], mode='bilinear', align_corners=False), x2],1))
        d1 = self.dec1(torch.cat([F.interpolate(d2, x1.shape[2:], mode='bilinear', align_corners=False), x1],1))
        out= self.final(F.interpolate(d1, scale_factor=2, mode='bilinear', align_corners=False))
        return torch.sigmoid(out)



In [None]:
#Loss, metric, optimiser
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNetResNet50(pretrained=True).to(device)

bce = nn.BCELoss()
def dice_loss(p,t,eps=1e-6):
    p=t.view(p.size(0),-1); t=t.view(t.size(0),-1)
    inter=(p*t).sum(1); union=p.sum(1)+t.sum(1)
    return 1-((2*inter+eps)/(union+eps)).mean()
def dice_score(p,t,eps=1e-6):
    p=t.view(p.size(0),-1); t=t.view(t.size(0),-1)
    inter=(p*t).sum(1); union=p.sum(1)+t.sum(1)
    return ((2*inter+eps)/(union+eps)).mean().item()

opt = torch.optim.AdamW(model.parameters(), lr=3e-4)


In [None]:
import torch, platform, os
print("torch.cuda.is_available():", torch.cuda.is_available())
if torch.cuda.is_available():
    print("CUDA device:", torch.cuda.get_device_name(0))

In [None]:
imgs, masks = next(iter(train_loader))
print("Unique mask values:", torch.unique(masks))


In [None]:
# DATA_DIR stays on Drive
DATA_DIR = "/content/drive/MyDrive/964_project/segmentation/4modslices"

# --- rebuild datasets (no change) ---
train_ds = BraTSSliceDS(DATA_DIR, train_slice_ids, train_transform)
val_ds   = BraTSSliceDS(DATA_DIR, val_slice_ids)          # no aug

# --- new loaders: 2 workers + pin_memory ---
batch_size = 8
train_loader = DataLoader(
    train_ds, batch_size,
    shuffle=True,
    num_workers=2,          # <— 2 prefetch threads
    pin_memory=True,        # speeds GPU transfer
    drop_last=True
)

val_loader = DataLoader(
    val_ds, batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print("DataLoaders rebuilt with 2 workers.")


In [None]:
%%bash
SRC="/content/drive/MyDrive/964_project/segmentation/4modslices"
DST="/content/slice_cache"

echo "Copying slices to fast local storage …"
time rsync -ah --info=progress2 "$SRC"/ "$DST"/
echo "Done — data now in $DST"

In [None]:
!du -sh /content/slice_cache           # total size copied so far
!find /content/slice_cache -type f | wc -l   # number of files copied


In [None]:
#copying slices to content folder to speed up the tfraingng process
%%bash
SRC="/content/drive/MyDrive/964_project/segmentation/4modslices"
DST="/content/slice_cache"

echo "Resuming copy …"
rsync -ah --info=progress2 "$SRC"/ "$DST"/
echo "Copy complete"

In [None]:
import pathlib
DATA_DIR = pathlib.Path("/content/slice_cache")

In [None]:
import pathlib, torch
from torch.utils.data import DataLoader

#local
DATA_DIR = pathlib.Path("/content/slice_cache")

#now recreate dagtset
train_ds = BraTSSliceDS(DATA_DIR, train_slice_ids, train_transform)
val_ds   = BraTSSliceDS(DATA_DIR, val_slice_ids)      # no aug for val

#fast datalaoders
batch_size = 8
train_loader = DataLoader(train_ds, batch_size,
                          shuffle=True,  num_workers=2,
                          pin_memory=True, drop_last=True)
val_loader   = DataLoader(val_ds,   batch_size,
                          shuffle=False, num_workers=2,
                          pin_memory=True)

print("Local-SSD DataLoaders ready!!!")


In [None]:
#training
import torch, time, datetime, glob, os
import pathlib

ckpt_dir = pathlib.Path(DATA_DIR).parent / "checkpoints"
ckpt_dir.mkdir(exist_ok=True)

num_epochs = 50

# -------- resume if a checkpoint exists --------
ckpts = sorted(glob.glob(str(ckpt_dir/"ep*.pt")))
start_ep, best_dice = 1, 0.
if ckpts:
    last = torch.load(ckpts[-1], map_location=device)
    model.load_state_dict(last["model_state_dict"])
    opt.load_state_dict(last["optimizer_state_dict"])
    start_ep = last["epoch"] + 1
    best_dice = last["dice"]
    print(f"Resumed from epoch {last['epoch']} (val dice {best_dice:.3f})")

def hms(sec): return str(datetime.timedelta(seconds=int(sec)))
print(f"{'Epoch':>5} | {'Time':>8} | {'TrainL':>7} | {'ValL':>7} | "
      f"{'TrainD':>6} | {'ValD':>5}")

for epoch in range(start_ep, num_epochs+1):
    t0=time.time(); model.train()
    tl, td = 0., 0.
    for b,(img,msk) in enumerate(train_loader,1):
        img, msk = img.to(device), msk.to(device)
        opt.zero_grad(); pred = model(img)
        loss = bce(pred, msk) + dice_loss(pred, msk)
        loss.backward(); opt.step()
        tl += loss.item(); td += dice_score((pred>=0.5).float(), msk)
        if b % 100 == 0:
            print(f"   batch {b}/{len(train_loader)}")   # heartbeat
    tl /= len(train_loader); td /= len(train_loader)

    # ---- validation ----
    model.eval(); vl, vd = 0., 0.
    with torch.no_grad():
        for img,msk in val_loader:
            img, msk = img.to(device), msk.to(device)
            pred = model(img)
            vl += (bce(pred, msk) + dice_loss(pred, msk)).item()
            vd += dice_score((pred>=0.5).float(), msk)
    vl /= len(val_loader); vd /= len(val_loader)

    print(f"{epoch:5d} | {hms(time.time()-t0):>8} | {tl:7.4f} | {vl:7.4f} | "
          f"{td:6.4f} | {vd:5.4f}")

    torch.save({'epoch':epoch, 'model_state_dict':model.state_dict(),
                'optimizer_state_dict':opt.state_dict(), 'dice':vd},
               ckpt_dir/f"ep{epoch:02d}_dice{vd:.3f}.pt")
    if vd > best_dice:
        best_dice = vd
        torch.save(model.state_dict(),
                   ckpt_dir/f"best_dice{best_dice:.3f}_ep{epoch:02d}.pt")
        print(f"   ✔︎ new best saved (dice={best_dice:.3f})")
"""

batch 1400/1516
   batch 1500/1516
   50 |  0:01:47 |  0.1988 |  0.1056 | 0.8423 | 0.9188
"""