In [1]:
# ================== ONE-CELL DOWNLOADER: FIVES DATASET ONLY ==================

import os
import sys
import json
import shutil
import subprocess
from pathlib import Path

# ============================ ROOT PATH ============================
BASE = Path("/content/data")
BASE.mkdir(parents=True, exist_ok=True)

FIVES_SLUG = "nitishsingla0/fives-dataset"   # Kaggle dataset slug
FIVES_NAME = "FIVES"                        # Folder name under BASE

# ----------------- Utils -----------------
def sh(cmd, check=True, echo=True, capture=True):
    """Run a shell command with optional echo and capture."""
    if echo:
        print("$", cmd)
    if capture:
        p = subprocess.run(cmd, shell=True, text=True, capture_output=True)
        if p.stdout:
            print(p.stdout)
        if p.stderr:
            print(p.stderr, file=sys.stderr)
    else:
        p = subprocess.run(cmd, shell=True)
    if check and p.returncode != 0:
        raise subprocess.CalledProcessError(p.returncode, cmd)
    return p

def move_into(src_dir: Path, out_dir: Path):
    """Move/merge everything from src_dir into out_dir, then try to remove src_dir."""
    out_dir.mkdir(parents=True, exist_ok=True)
    for item in src_dir.iterdir():
        dst = out_dir / item.name
        if item.is_dir():
            dst.mkdir(parents=True, exist_ok=True)
            for s in item.rglob("*"):
                if s.is_file():
                    d = dst / s.relative_to(item)
                    d.parent.mkdir(parents=True, exist_ok=True)
                    shutil.copy2(s, d)
        else:
            shutil.copy2(item, dst)

    try:
        shutil.rmtree(src_dir)
    except Exception:
        pass

# ----------------- Kaggle auth -----------------
def ensure_kaggle_auth():
    """Ensure Kaggle API is installed and kaggle.json is in place."""
    kaggle_dir = Path("/root/.kaggle")
    kaggle_json = kaggle_dir / "kaggle.json"
    if not kaggle_json.exists():
        # Colab-style upload
        try:
            from google.colab import files
        except Exception:
            print("‚úó google.colab not available. If you are not in Colab,"
                  " place kaggle.json at /root/.kaggle/kaggle.json manually.")
            raise
        print("‚ñ∂ Kaggle API token not found. Upload kaggle.json (Kaggle ‚Üí Account ‚Üí Create New API Token).")
        kaggle_dir.mkdir(parents=True, exist_ok=True)
        uploaded = files.upload()
        fname = next(iter(uploaded))
        if fname != "kaggle.json":
            Path(f"/content/{fname}").rename("/content/kaggle.json")
            fname = "kaggle.json"
        shutil.move(f"/content/{fname}", kaggle_json)
        kaggle_json.chmod(0o600)

    # export env vars (for CLI)
    try:
        creds = json.loads(kaggle_json.read_text())
        os.environ["KAGGLE_USERNAME"] = creds.get("username", "")
        os.environ["KAGGLE_KEY"] = creds.get("key", "")
    except Exception:
        pass

    try:
        import kaggle  # noqa: F401
    except Exception:
        sh("pip -q install kaggle", check=True)

def kaggle_probe(slug: str) -> bool:
    p = sh(f'kaggle datasets files -d "{slug}"', check=False)
    return p.returncode == 0

def kaggle_download(slug: str, out_dir: Path):
    """Download a Kaggle dataset and place contents into out_dir."""
    out_dir.mkdir(parents=True, exist_ok=True)
    if not kaggle_probe(slug):
        print(f"‚úó Probe failed for {slug}. Skipping.")
        return False

    # Snapshot zips/dirs before
    before_zips = {p.name for p in Path("/content").glob("*.zip")}
    before_dirs = {p.name for p in Path("/content").iterdir() if p.is_dir()}

    # 1) Bulk download ZIP to /content (no --unzip)
    p = sh(f'kaggle datasets download -d "{slug}" -p /content', check=False)
    if p.returncode == 0:
        # Unzip any new zips into out_dir
        new_zips = [p for p in Path("/content").glob("*.zip") if p.name not in before_zips]
        if new_zips:
            for z in new_zips:
                sh(f'unzip -q -o "{z}" -d "{out_dir}"', check=False)
                z.unlink(missing_ok=True)
            print(f"‚úì {slug} ‚Üí {out_dir}")
            return True
        else:
            # Some datasets might appear as folders
            after_dirs = {p.name for p in Path("/content").iterdir() if p.is_dir()}
            created = sorted(list(after_dirs - before_dirs))
            moved_any = False
            for dname in created:
                src = Path("/content") / dname
                if any(src.iterdir()):
                    move_into(src, out_dir)
                    moved_any = True
            if moved_any:
                print(f"‚úì {slug} ‚Üí {out_dir}")
                return True

    # 2) Fallback: per-file download (rarely needed)
    lst = sh(f'kaggle datasets files -d "{slug}"', check=False)
    if lst.returncode != 0:
        return False
    names = []
    for line in lst.stdout.splitlines():
        s = line.strip()
        if (not s) or s.startswith("name") or s.startswith("---") or s.startswith("Next Page Token"):
            continue
        names.append(s.split()[0])
    success_any = False
    for fname in names:
        print(f"  ‚Üì {fname}")
        q = sh(f'kaggle datasets download -d "{slug}" -f "{fname}" -p /content --force', check=False)
        if q.returncode != 0:
            print(f"  ‚úó Failed: {fname}")
            continue
        z = Path("/content") / (Path(fname).name + ".zip")
        if z.exists():
            sh(f'unzip -q -o "{z}" -d "{out_dir}"', check=False)
            z.unlink(missing_ok=True)
            success_any = True
        else:
            # direct file case
            src = Path("/content") / Path(fname).name
            if src.exists():
                dst = out_dir / src.name
                dst.parent.mkdir(parents=True, exist_ok=True)
                shutil.move(str(src), str(dst))
                success_any = True
    if success_any:
        print(f"‚úì {slug} ‚Üí {out_dir}")
    return success_any

# ----------------- Run FIVES download -----------------
print("=== Ensuring Kaggle auth ===")
ensure_kaggle_auth()

target = BASE / FIVES_NAME
if target.exists() and any(target.iterdir()):
    print(f"‚úì Skip (exists): {FIVES_NAME} at {target}")
else:
    print(f"\n=== Kaggle: {FIVES_SLUG} ‚Üí {target} ===")
    ok = kaggle_download(FIVES_SLUG, target)
    if not ok:
        print("‚úó FIVES download failed; please check Kaggle access/slug.")
    else:
        print(f"‚úî FIVES dataset ready at: {target}")


=== Ensuring Kaggle auth ===
‚ñ∂ Kaggle API token not found. Upload kaggle.json (Kaggle ‚Üí Account ‚Üí Create New API Token).


Saving kaggle.json to kaggle.json

=== Kaggle: nitishsingla0/fives-dataset ‚Üí /content/data/FIVES ===
$ kaggle datasets files -d "nitishsingla0/fives-dataset"
Next Page Token = CfDJ8Ksq__M8KNdOsrtGDpOZ52VWSQ2JCG9_AYN1LLmR5s8EM1pEuBQ_BKwaYu9IirWH8C_ku9rdqO55BpwW3JJPPfH50tMC3keT-CfSLW1UciyhcenBMHRUhwA4u3holRdjm7GkG2Ye3h6dUxdreTZBhR7AIc-P3swnWmv3MZ8RHd-VGhqt3H5uT9H2njtyowLN2eYlWijueKqCVMUzuO2v6EgCVahSoEOe-8OLEL7Vbsfl55rmvntqBFilxWYjoUP-0TOhqp4NZsKcwZCwPCu9npFIN1QTkKcNviM4I1u9kBPQE0ghLOczCybi2ToEfnD4
name                                                                                                                                                      size  creationDate                
-------------------------------------------------------------------------------------------------------------------------------------------------------  -----  --------------------------  
FIVES A Fundus Image Dataset for AI-based Vessel Segmentation/FIVES A Fundus Image Dataset for AI-based Vessel Segmenta


  0%|          | 0.00/1.64G [00:00<?, ?B/s]
  8%|‚ñä         | 137M/1.64G [00:00<00:01, 1.43GB/s]
 16%|‚ñà‚ñã        | 274M/1.64G [00:00<00:02, 578MB/s] 
 21%|‚ñà‚ñà        | 353M/1.64G [00:00<00:03, 449MB/s]
 24%|‚ñà‚ñà‚ñç       | 410M/1.64G [00:00<00:03, 428MB/s]
 27%|‚ñà‚ñà‚ñã       | 459M/1.64G [00:01<00:03, 407MB/s]
 30%|‚ñà‚ñà‚ñâ       | 502M/1.64G [00:01<00:03, 393MB/s]
 32%|‚ñà‚ñà‚ñà‚ñè      | 542M/1.64G [00:01<00:03, 343MB/s]
 34%|‚ñà‚ñà‚ñà‚ñç      | 577M/1.64G [00:01<00:04, 284MB/s]
 36%|‚ñà‚ñà‚ñà‚ñå      | 606M/1.64G [00:01<00:04, 265MB/s]
 38%|‚ñà‚ñà‚ñà‚ñä      | 633M/1.64G [00:01<00:04, 249MB/s]
 39%|‚ñà‚ñà‚ñà‚ñâ      | 658M/1.64G [00:01<00:04, 238MB/s]
 40%|‚ñà‚ñà‚ñà‚ñà      | 681M/1.64G [00:02<00:04, 229MB/s]
 42%|‚ñà‚ñà‚ñà‚ñà‚ñè     | 703M/1.64G [00:02<00:04, 223MB/s]
 43%|‚ñà‚ñà‚ñà‚ñà‚ñé     | 725M/1.64G [00:02<00:04, 218MB/s]
 44%|‚ñà‚ñà‚ñà‚ñà‚ñç     | 746M/1.64G [00:02<00:04, 214MB/s]
 47%|‚ñà‚ñà‚ñà‚ñà‚ñã     | 783M/1.64G [00:02<00:03, 259MB/s]
 48%|‚ñà‚ñà‚ñà‚ñà‚ñä 

‚úì nitishsingla0/fives-dataset ‚Üí /content/data/FIVES
‚úî FIVES dataset ready at: /content/data/FIVES


In [2]:
import os
import glob
import sys
import numpy as np
from pathlib import Path
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.semi_supervised import LabelPropagation
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, jaccard_score, recall_score, confusion_matrix, matthews_corrcoef
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from skimage.morphology import opening, closing, disk
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# ==================== CONFIGURATION ====================
CHECKPOINT_DIR = Path("/content/checkpoints/vessel_segmentation")
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

BEST_MODEL_PATH = CHECKPOINT_DIR / "best_vessel_unet.pth"
EARLY_STOPPING_PATIENCE = 5

# -------- Dataset for FIVES --------
class FIVESDataset(Dataset):
    def __init__(self, img_paths, mask_paths, transform=None):
        self.img_paths = img_paths
        self.mask_paths = mask_paths
        self.transform = transform

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img = Image.open(self.img_paths[idx]).convert('RGB')
        mask = Image.open(self.mask_paths[idx]).convert('L')
        if self.transform:
            img = self.transform(img)
            mask = self.transform(mask)
        mask = (mask > 0).float()
        return img, mask

# -------- U-Net Model --------
class UNet(nn.Module):
    def __init__(self, in_ch=3, out_ch=1):
        super().__init__()
        def CBR(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            )
        self.enc1 = CBR(in_ch, 64)
        self.enc2 = CBR(64, 128)
        self.enc3 = CBR(128, 256)
        self.pool = nn.MaxPool2d(2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dec3 = CBR(256+128, 128)
        self.dec2 = CBR(128+64, 64)
        self.final = nn.Conv2d(64, out_ch, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        d3 = self.up(e3)
        d3 = self.dec3(torch.cat([d3, e2], dim=1))
        d2 = self.up(d3)
        d2 = self.dec2(torch.cat([d2, e1], dim=1))
        return self.final(d2)

# -------- Validation Function --------
def validate(model, loader, loss_fn, device='cuda'):
    """Validate model and return loss and IoU metrics"""
    model.eval()
    total_loss = 0.0
    all_ious = []

    with torch.no_grad():
        for imgs, masks in loader:
            imgs, masks = imgs.to(device), masks.to(device)
            preds = model(imgs)
            loss = loss_fn(preds, masks)
            total_loss += loss.item()

            # Calculate IoU for this batch
            pred_masks = (torch.sigmoid(preds) > 0.5).float()
            intersection = (pred_masks * masks).sum(dim=(1, 2, 3))
            union = (pred_masks + masks).clamp(0, 1).sum(dim=(1, 2, 3))
            iou = (intersection / (union + 1e-8)).cpu().numpy()
            all_ious.extend(iou)

    avg_loss = total_loss / len(loader)
    avg_iou = np.mean(all_ious)

    return avg_loss, avg_iou

# -------- Training with Checkpointing --------
def train_unet(model, train_loader, val_loader, epochs=20, lr=1e-3, device='cuda'):
    """Train U-Net with validation and checkpoint saving"""
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.BCEWithLogitsLoss()

    best_val_iou = 0.0
    best_epoch = 0
    patience_counter = 0

    history = {
        'train_loss': [],
        'val_loss': [],
        'val_iou': []
    }

    print(f"\n{'='*60}")
    print("Starting Training with Model Checkpointing")
    print(f"{'='*60}")
    print(f"Checkpoint directory: {CHECKPOINT_DIR}")
    print(f"Early stopping patience: {EARLY_STOPPING_PATIENCE}")
    print(f"{'='*60}\n")

    for ep in range(epochs):
        # ==================== TRAINING ====================
        model.train()
        epoch_loss = 0.0
        pbar = tqdm(train_loader, desc=f"Epoch {ep+1}/{epochs} [Train]")

        for imgs, masks in pbar:
            imgs, masks = imgs.to(device), masks.to(device)
            preds = model(imgs)
            loss = loss_fn(preds, masks)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})

        avg_train_loss = epoch_loss / len(train_loader)
        history['train_loss'].append(avg_train_loss)

        # ==================== VALIDATION ====================
        val_loss, val_iou = validate(model, val_loader, loss_fn, device)
        history['val_loss'].append(val_loss)
        history['val_iou'].append(val_iou)

        print(f"\nEpoch {ep+1}/{epochs}")
        print(f"  Train Loss: {avg_train_loss:.4f}")
        print(f"  Val Loss:   {val_loss:.4f}")
        print(f"  Val IoU:    {val_iou:.4f}")

        # ==================== SAVE BEST MODEL ====================
        if val_iou > best_val_iou:
            best_val_iou = val_iou
            best_epoch = ep + 1
            patience_counter = 0

            # Save checkpoint
            checkpoint = {
                'epoch': ep + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_iou': best_val_iou,
                'val_loss': val_loss,
                'train_loss': avg_train_loss,
                'history': history
            }
            torch.save(checkpoint, BEST_MODEL_PATH)
            print(f"  ‚úÖ New best Val IoU: {best_val_iou:.4f} (saved checkpoint)")
        else:
            patience_counter += 1
            print(f"  ‚è≥ No improvement ({patience_counter}/{EARLY_STOPPING_PATIENCE})")

        # ==================== EARLY STOPPING ====================
        if patience_counter >= EARLY_STOPPING_PATIENCE:
            print(f"\n‚ö†Ô∏è Early stopping triggered at epoch {ep+1}")
            print(f"Best model was at epoch {best_epoch} with Val IoU: {best_val_iou:.4f}")
            break

        sys.stdout.flush()

    print(f"\n{'='*60}")
    print("Training Complete!")
    print(f"{'='*60}")
    print(f"Best Val IoU: {best_val_iou:.4f} at epoch {best_epoch}")
    print(f"Best model saved at: {BEST_MODEL_PATH}")
    print(f"{'='*60}\n")

    return model, history

# -------- Load Best Model --------
def load_best_model(model, device='cuda'):
    """Load the best saved model checkpoint"""
    if not BEST_MODEL_PATH.exists():
        print(f"‚ö†Ô∏è No checkpoint found at {BEST_MODEL_PATH}")
        return model, None

    print(f"Loading best model from {BEST_MODEL_PATH}")
    checkpoint = torch.load(BEST_MODEL_PATH, map_location=device, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])

    print(f"‚úÖ Loaded checkpoint from epoch {checkpoint['epoch']}")
    print(f"   Best Val IoU: {checkpoint['best_val_iou']:.4f}")

    return model, checkpoint

# ==================== MAIN PIPELINE ====================
if __name__ == "__main__":
    # 1. Load FIVES paths
    base_path = '/content/data/FIVES/FIVES A Fundus Image Dataset for AI-based Vessel Segmentation/FIVES A Fundus Image Dataset for AI-based Vessel Segmentation'
    img_files = sorted(glob.glob(f'{base_path}/train/Original/*.png'))
    mask_files = sorted(glob.glob(f'{base_path}/train/Ground truth/*.png'))

    print(f"Found {len(img_files)} images and {len(mask_files)} masks")

    # 2. Train/Val/Test split (70/15/15)
    train_imgs, temp_imgs, train_masks, temp_masks = train_test_split(
        img_files, mask_files, test_size=0.3, random_state=42
    )
    val_imgs, test_imgs, val_masks, test_masks = train_test_split(
        temp_imgs, temp_masks, test_size=0.5, random_state=42
    )

    print(f"Train: {len(train_imgs)}, Val: {len(val_imgs)}, Test: {len(test_imgs)}")

    # 3. Dataset & Loaders with resizing
    input_size = (512, 512)
    tfm = T.Compose([
        T.Resize(input_size),
        T.ToTensor()
    ])

    train_ds = FIVESDataset(train_imgs, train_masks, transform=tfm)
    val_ds = FIVESDataset(val_imgs, val_masks, transform=tfm)
    test_ds = FIVESDataset(test_imgs, test_masks, transform=tfm)

    train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_ds, batch_size=2, shuffle=False, num_workers=0)
    test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=0)

    # 4. Train U-Net with validation and checkpointing
    unet = UNet()
    unet, history = train_unet(unet, train_loader, val_loader, epochs=20, lr=1e-3)

    # 5. Load best model for testing
    print("\n" + "="*60)
    print("LOADING BEST MODEL FOR TESTING")
    print("="*60)
    unet, checkpoint = load_best_model(unet)

    # 6. Test: get probability maps
    print("\nRunning inference on test set...")
    unet.eval()
    probs_list, gts_list, pred_maps_list = [], [], []

    with torch.no_grad():
        for img, mask in tqdm(test_loader, desc="Testing"):
            img = img.to('cuda')
            logits = unet(img)
            prob = torch.sigmoid(logits).cpu().numpy().squeeze()
            gt = mask.numpy().squeeze()
            probs_list.append(prob)
            gts_list.append(gt)

    # 7. K-NN Graph + Label Propagation refinement
    print("\nApplying Label Propagation refinement...")
    metrics = {'AUC':[], 'F1':[], 'Acc':[], 'mIoU':[], 'Sens':[], 'Spec':[], 'MCC':[]}

    for prob, gt in tqdm(zip(probs_list, gts_list), total=len(probs_list), desc="Refining"):
        # seed selection
        thresh = 0.3
        fg_idx = np.where(prob > thresh)
        bg_idx = np.where(prob < (1 - thresh))
        pos = np.column_stack([fg_idx[0], fg_idx[1]])
        neg = np.column_stack([bg_idx[0], bg_idx[1]])

        if len(neg) > len(pos):
            neg = neg[np.random.choice(len(neg), size=len(pos), replace=False)]

        all_coords = np.vstack([pos, neg])
        labels = np.hstack([np.ones(len(pos)), np.zeros(len(neg))])
        feats = prob[all_coords[:,0], all_coords[:,1]][:, None]

        # Label Propagation
        lp = LabelPropagation(kernel='knn', n_neighbors=12, max_iter=50)
        lp.fit(feats, labels)
        preds = lp.predict(feats)

        # reconstruct map
        pred_map = np.zeros_like(prob, dtype=int)
        for (y,x), p in zip(all_coords, preds):
            pred_map[y, x] = p

        # morphology cleanup
        se = disk(1)
        pred_map = closing(opening(pred_map, se), se)
        pred_maps_list.append(pred_map)

        # metrics
        flat_gt = gt.flatten()
        flat_pred = pred_map.flatten()
        metrics['AUC'].append(roc_auc_score(flat_gt, prob.flatten()))
        metrics['F1'].append(f1_score(flat_gt, flat_pred))
        metrics['Acc'].append(accuracy_score(flat_gt, flat_pred))
        metrics['mIoU'].append(jaccard_score(flat_gt, flat_pred))
        metrics['Sens'].append(recall_score(flat_gt, flat_pred))
        tn, fp, fn, tp = confusion_matrix(flat_gt, flat_pred).ravel()
        metrics['Spec'].append(tn / (tn + fp))
        metrics['MCC'].append(matthews_corrcoef(flat_gt, flat_pred))

    # 8. Print final results
    print("\n" + "="*60)
    print("FINAL TEST RESULTS (with Label Propagation)")
    print("="*60)
    for metric, values in metrics.items():
        print(f"{metric:>6}: {np.mean(values):.4f} ¬± {np.std(values):.4f}")
    print("="*60)

    # 9. Save final results
    results_file = CHECKPOINT_DIR / "test_results.txt"
    with open(results_file, 'w') as f:
        f.write("="*60 + "\n")
        f.write("FINAL TEST RESULTS\n")
        f.write("="*60 + "\n")
        f.write(f"Best model from epoch: {checkpoint['epoch']}\n")
        f.write(f"Best Val IoU: {checkpoint['best_val_iou']:.4f}\n")
        f.write("\n" + "="*60 + "\n")
        f.write("Test Metrics (with Label Propagation):\n")
        f.write("="*60 + "\n")
        for metric, values in metrics.items():
            f.write(f"{metric:>6}: {np.mean(values):.4f} ¬± {np.std(values):.4f}\n")

    print(f"\n‚úÖ Results saved to {results_file}")

Found 600 images and 600 masks
Train: 420, Val: 90, Test: 90

Starting Training with Model Checkpointing
Checkpoint directory: /content/checkpoints/vessel_segmentation
Early stopping patience: 5



Epoch 1/20 [Train]:   0%|          | 0/210 [00:00<?, ?it/s]


Epoch 1/20
  Train Loss: 0.2463
  Val Loss:   0.2256
  Val IoU:    0.2915
  ‚úÖ New best Val IoU: 0.2915 (saved checkpoint)


Epoch 2/20 [Train]:   0%|          | 0/210 [00:00<?, ?it/s]


Epoch 2/20
  Train Loss: 0.1405
  Val Loss:   0.1852
  Val IoU:    0.4187
  ‚úÖ New best Val IoU: 0.4187 (saved checkpoint)


Epoch 3/20 [Train]:   0%|          | 0/210 [00:00<?, ?it/s]


Epoch 3/20
  Train Loss: 0.1187
  Val Loss:   0.2008
  Val IoU:    0.3639
  ‚è≥ No improvement (1/5)


Epoch 4/20 [Train]:   0%|          | 0/210 [00:00<?, ?it/s]


Epoch 4/20
  Train Loss: 0.1070
  Val Loss:   0.1193
  Val IoU:    0.6395
  ‚úÖ New best Val IoU: 0.6395 (saved checkpoint)


Epoch 5/20 [Train]:   0%|          | 0/210 [00:00<?, ?it/s]


Epoch 5/20
  Train Loss: 0.0973
  Val Loss:   0.1108
  Val IoU:    0.6648
  ‚úÖ New best Val IoU: 0.6648 (saved checkpoint)


Epoch 6/20 [Train]:   0%|          | 0/210 [00:00<?, ?it/s]


Epoch 6/20
  Train Loss: 0.0921
  Val Loss:   0.0976
  Val IoU:    0.6987
  ‚úÖ New best Val IoU: 0.6987 (saved checkpoint)


Epoch 7/20 [Train]:   0%|          | 0/210 [00:00<?, ?it/s]


Epoch 7/20
  Train Loss: 0.0858
  Val Loss:   0.0945
  Val IoU:    0.7054
  ‚úÖ New best Val IoU: 0.7054 (saved checkpoint)


Epoch 8/20 [Train]:   0%|          | 0/210 [00:00<?, ?it/s]


Epoch 8/20
  Train Loss: 0.0816
  Val Loss:   0.0959
  Val IoU:    0.6905
  ‚è≥ No improvement (1/5)


Epoch 9/20 [Train]:   0%|          | 0/210 [00:00<?, ?it/s]


Epoch 9/20
  Train Loss: 0.0770
  Val Loss:   0.1275
  Val IoU:    0.6016
  ‚è≥ No improvement (2/5)


Epoch 10/20 [Train]:   0%|          | 0/210 [00:00<?, ?it/s]


Epoch 10/20
  Train Loss: 0.0746
  Val Loss:   0.0826
  Val IoU:    0.7502
  ‚úÖ New best Val IoU: 0.7502 (saved checkpoint)


Epoch 11/20 [Train]:   0%|          | 0/210 [00:00<?, ?it/s]


Epoch 11/20
  Train Loss: 0.0709
  Val Loss:   0.0786
  Val IoU:    0.7634
  ‚úÖ New best Val IoU: 0.7634 (saved checkpoint)


Epoch 12/20 [Train]:   0%|          | 0/210 [00:00<?, ?it/s]


Epoch 12/20
  Train Loss: 0.0693
  Val Loss:   0.0824
  Val IoU:    0.7462
  ‚è≥ No improvement (1/5)


Epoch 13/20 [Train]:   0%|          | 0/210 [00:00<?, ?it/s]


Epoch 13/20
  Train Loss: 0.0689
  Val Loss:   0.0845
  Val IoU:    0.7464
  ‚è≥ No improvement (2/5)


Epoch 14/20 [Train]:   0%|          | 0/210 [00:00<?, ?it/s]


Epoch 14/20
  Train Loss: 0.0666
  Val Loss:   0.0824
  Val IoU:    0.7663
  ‚úÖ New best Val IoU: 0.7663 (saved checkpoint)


Epoch 15/20 [Train]:   0%|          | 0/210 [00:00<?, ?it/s]


Epoch 15/20
  Train Loss: 0.0658
  Val Loss:   0.0719
  Val IoU:    0.7826
  ‚úÖ New best Val IoU: 0.7826 (saved checkpoint)


Epoch 16/20 [Train]:   0%|          | 0/210 [00:00<?, ?it/s]


Epoch 16/20
  Train Loss: 0.0650
  Val Loss:   0.0825
  Val IoU:    0.7572
  ‚è≥ No improvement (1/5)


Epoch 17/20 [Train]:   0%|          | 0/210 [00:00<?, ?it/s]


Epoch 17/20
  Train Loss: 0.0638
  Val Loss:   0.0765
  Val IoU:    0.7657
  ‚è≥ No improvement (2/5)


Epoch 18/20 [Train]:   0%|          | 0/210 [00:00<?, ?it/s]


Epoch 18/20
  Train Loss: 0.0618
  Val Loss:   0.0820
  Val IoU:    0.7493
  ‚è≥ No improvement (3/5)


Epoch 19/20 [Train]:   0%|          | 0/210 [00:00<?, ?it/s]


Epoch 19/20
  Train Loss: 0.0617
  Val Loss:   0.0743
  Val IoU:    0.7689
  ‚è≥ No improvement (4/5)


Epoch 20/20 [Train]:   0%|          | 0/210 [00:00<?, ?it/s]


Epoch 20/20
  Train Loss: 0.0614
  Val Loss:   0.1111
  Val IoU:    0.6715
  ‚è≥ No improvement (5/5)

‚ö†Ô∏è Early stopping triggered at epoch 20
Best model was at epoch 15 with Val IoU: 0.7826

Training Complete!
Best Val IoU: 0.7826 at epoch 15
Best model saved at: /content/checkpoints/vessel_segmentation/best_vessel_unet.pth


LOADING BEST MODEL FOR TESTING
Loading best model from /content/checkpoints/vessel_segmentation/best_vessel_unet.pth
‚úÖ Loaded checkpoint from epoch 15
   Best Val IoU: 0.7826

Running inference on test set...


Testing:   0%|          | 0/90 [00:00<?, ?it/s]


Applying Label Propagation refinement...


Refining:   0%|          | 0/90 [00:00<?, ?it/s]


FINAL TEST RESULTS (with Label Propagation)
   AUC: 0.9841 ¬± 0.0383
    F1: 0.8538 ¬± 0.1032
   Acc: 0.9679 ¬± 0.0211
  mIoU: 0.7560 ¬± 0.1248
  Sens: 0.8833 ¬± 0.1113
  Spec: 0.9771 ¬± 0.0149
   MCC: 0.8378 ¬± 0.1124

‚úÖ Results saved to /content/checkpoints/vessel_segmentation/test_results.txt


In [3]:
# ==================== FINAL CHECKPOINT DOWNLOAD ====================
# Run this cell AFTER training is 100% complete!

from google.colab import files
import shutil
from pathlib import Path

print("\n" + "="*70)
print("DOWNLOADING FINAL TRAINED MODEL")
print("="*70)

# The best model checkpoint path
checkpoint_path = Path('/content/checkpoints/vessel_segmentation/best_vessel_unet.pth')

if not checkpoint_path.exists():
    print("\n‚ö†Ô∏è  WARNING: Checkpoint not found!")
    print(f"   Expected: {checkpoint_path}")
    print("\n   Make sure training completed successfully!")
    print("   Check the training logs above.")
    raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

# Load and verify
import torch
ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False)

print(f"\n‚úì Checkpoint found!")
print(f"\nüìä Training Summary:")

if 'epoch' in ckpt:
    print(f"   Saved at Epoch: {ckpt['epoch']}")

    # Check if trained enough
    if ckpt['epoch'] < 10:
        print(f"\n‚ö†Ô∏è  WARNING: Model only trained for {ckpt['epoch']} epochs!")
        print(f"   Recommended: Train for at least 15-20 epochs for good results.")
        response = input("\n   Download anyway? (yes/no): ")
        if response.lower() != 'yes':
            print("   Download cancelled. Let training continue.")
            raise SystemExit()

if 'best_val_iou' in ckpt:
    print(f"   Best Validation IoU: {ckpt['best_val_iou']:.4f}")
if 'val_loss' in ckpt:
    print(f"   Validation Loss: {ckpt['val_loss']:.4f}")

# Show what's in the checkpoint
print(f"\n   Checkpoint keys: {list(ckpt.keys())}")

# Download
print(f"\nüíæ Downloading checkpoint...")

download_name = 'vessel.pth'
shutil.copy(checkpoint_path, f'/content/{download_name}')

files.download(f'/content/{download_name}')

print(f"\n‚úÖ Downloaded: {download_name}")
print(f"\nüìù Model Details:")
print(f"   Architecture: UNet")
print(f"   Task: Vessel segmentation (binary)")
print(f"   Input: Fundus image (512x512)")
print(f"   Output: Vessel probability map (512x512)")
print(f"   Dataset: FIVES")
print("="*70)



DOWNLOADING FINAL TRAINED MODEL

‚úì Checkpoint found!

üìä Training Summary:
   Saved at Epoch: 15
   Best Validation IoU: 0.7826
   Validation Loss: 0.0719

   Checkpoint keys: ['epoch', 'model_state_dict', 'optimizer_state_dict', 'best_val_iou', 'val_loss', 'train_loss', 'history']

üíæ Downloading checkpoint...


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>


‚úÖ Downloaded: vessel.pth

üìù Model Details:
   Architecture: UNet
   Task: Vessel segmentation (binary)
   Input: Fundus image (512x512)
   Output: Vessel probability map (512x512)
   Dataset: FIVES
