<a href="https://colab.research.google.com/github/uma-mahesh-24/CS-254-Lab/blob/main/CNN_FEA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ====== 0) Setup
!pip -q install tqdm

import os, json, random, math, glob
import numpy as np
from google.colab import drive
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import r2_score

# ====== 1) Mount Google Drive & paths
drive.mount("/content/drive", force_remount=True)

root_dir = "/content/drive/MyDrive"
data_dir = os.path.join(root_dir, "corrected_dataset")   # adjust if needed
inputs_dir = os.path.join(data_dir, "inputs")
targets_dir = os.path.join(data_dir, "targets")

out_dir = os.path.join(root_dir, "cnn_unet_baseline")
os.makedirs(out_dir, exist_ok=True)

# ====== 2) Repro
seed = 42
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)

# ====== 3) Dataset
class NpyPairDataset(Dataset):
    def __init__(self, inputs_dir, targets_dir, files):
        self.inputs_dir = inputs_dir
        self.targets_dir = targets_dir
        self.files = files

    def __len__(self): return len(self.files)

    def __getitem__(self, idx):
        fname = self.files[idx]
        x = np.load(os.path.join(self.inputs_dir, fname)).astype(np.float32)  # (H,W)
        y = np.load(os.path.join(self.targets_dir, fname)).astype(np.float32) # (H,W)

        x = x[None, ...]   # (1,H,W)
        y = y[None, ...]   # (1,H,W)

        return {
            "pixel_values": torch.from_numpy(x),
            "labels": torch.from_numpy(y),
            "file_name": fname,
        }

# ====== 4) Split (fixed)
all_files = sorted([f for f in os.listdir(inputs_dir) if f.endswith(".npy")])
random.Random(seed).shuffle(all_files)
split_idx = int(0.9 * len(all_files))
train_files = all_files[:split_idx]
val_files   = all_files[split_idx:]

train_ds = NpyPairDataset(inputs_dir, targets_dir, train_files)
val_ds   = NpyPairDataset(inputs_dir, targets_dir, val_files)

# ====== 5) Dataloaders
batch_size = 4
num_workers = 2
pin_memory = True

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                          num_workers=num_workers, pin_memory=pin_memory, drop_last=False)
val_loader   = DataLoader(val_ds, batch_size=batch_size, shuffle=False,
                          num_workers=num_workers, pin_memory=pin_memory, drop_last=False)


Mounted at /content/drive


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# ---------- Utility ----------
def pad_to_multiple(x, multiple=4):
    """Pad image tensor on right/bottom to be divisible by `multiple`."""
    _, _, H, W = x.shape
    pad_h = (multiple - (H % multiple)) % multiple
    pad_w = (multiple - (W % multiple)) % multiple
    return F.pad(x, (0, pad_w, 0, pad_h)), pad_h, pad_w


# ---------- Building Blocks ----------
class DoubleConv(nn.Module):
    """(Conv → BN → LeakyReLU → Conv → BN → LeakyReLU → Dropout)"""
    def __init__(self, c_in, c_out, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(c_in, c_out, 3, padding=1, bias=False),
            nn.BatchNorm2d(c_out),
            nn.LeakyReLU(inplace=True),
            nn.Conv2d(c_out, c_out, 3, padding=1, bias=False),
            nn.BatchNorm2d(c_out),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)


class Down(nn.Module):
    """Downscale with MaxPool then DoubleConv"""
    def __init__(self, c_in, c_out, dropout=0.1):
        super().__init__()
        self.pool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(c_in, c_out, dropout)
        )

    def forward(self, x):
        return self.pool_conv(x)


class Up(nn.Module):
    """Upscale then DoubleConv with skip connection cropping."""
    def __init__(self, c_in, c_out, bilinear=True, dropout=0.1):
        super().__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
            self.conv = DoubleConv(c_in, c_out, dropout)
        else:
            self.up = nn.ConvTranspose2d(c_in // 2, c_in // 2, 2, stride=2)
            self.conv = DoubleConv(c_in, c_out, dropout)

    def crop(self, enc_feat, target):
        _, _, H, W = target.shape
        return enc_feat[:, :, :H, :W]

    def forward(self, x1, x2):
        x1 = self.up(x1)
        x2 = self.crop(x2, x1)
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


# ---------- Full U-Net ----------
class UNetImproved(nn.Module):
    def __init__(self, c_in=1, c_out=1, base=64, dropout=0.1):
        super().__init__()
        self.inc = DoubleConv(c_in, base, dropout)
        self.down1 = Down(base, base * 2, dropout)
        self.down2 = Down(base * 2, base * 4, dropout)
        self.down3 = Down(base * 4, base * 8, dropout)  # deeper bottleneck

        self.up1 = Up(base * 8 + base * 4, base * 4, dropout=dropout)
        self.up2 = Up(base * 4 + base * 2, base * 2, dropout=dropout)
        self.up3 = Up(base * 2 + base, base, dropout=dropout)
        self.outc = nn.Conv2d(base, c_out, 1)

    def forward(self, x):
        H, W = x.shape[-2:]
        x, pad_h, pad_w = pad_to_multiple(x, 8)

        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)

        x = self.up1(x4, x3)
        x = self.up2(x, x2)
        x = self.up3(x, x1)

        out = self.outc(x)
        return out[:, :, :H, :W]  # crop back to original size




device = "cuda" if torch.cuda.is_available() else "cpu"
model = UNetImproved(base=64, dropout=0.1).to(device)

# ====== 7) Losses & metrics
def masked_mse_mean(pred, target):
    mask = (target >= 0)
    if mask.sum() == 0:
        return torch.tensor(0.0, device=pred.device)
    return ((pred - target)**2)[mask].mean()

@torch.no_grad()
def eval_epoch():
    model.eval()
    mae_sum, rmse_sum, r2_numer, r2_denom, count = 0.0, 0.0, 0.0, 0.0, 0
    for batch in val_loader:
        x = batch["pixel_values"].to(device, non_blocking=True)
        y = batch["labels"].to(device, non_blocking=True)
        pred = model(x)

        # gather masked vectors for metrics
        mask = (y >= 0)
        if mask.sum() == 0:
            continue
        e = (pred - y)[mask]
        mae_sum  += e.abs().mean().item()
        rmse_sum += torch.sqrt((e**2).mean()).item()

        # R2 (pixel-wise)
        y_masked = y[mask]
        ybar = y_masked.mean()
        r2_numer += ((e)**2).sum().item()
        r2_denom += ((y_masked - ybar)**2).sum().item()
        count += 1

    mae  = mae_sum / max(count,1)
    rmse = rmse_sum / max(count,1)
    r2   = 1.0 - (r2_numer / (r2_denom + 1e-12)) if r2_denom > 0 else float("nan")
    return {"mae": mae, "rmse": rmse, "r2": r2}

# ====== 8) Optimizer & training loop
epochs = 50
lr = 5e-4
weight_decay = 1e-2
scaler = torch.cuda.amp.GradScaler(enabled=(device=="cuda"))
opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

best_rmse = float("inf")
best_r2 = float("-inf") # Initialize best_r2
best_path = os.path.join(out_dir, "unet_tiny_best.pt")
checkpoint_path = os.path.join(out_dir, "checkpoint.pt")

# Load checkpoint if exists
start_epoch = 1
if os.path.exists(checkpoint_path):
    print(f"Loading checkpoint from {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path)

    # Load state_dict with strict=False to ignore mismatched keys
    model_state_dict = model.state_dict()
    pretrained_state_dict = checkpoint['state_dict']

    # Filter out unnecessary keys
    filtered_state_dict = {k: v for k, v in pretrained_state_dict.items() if k in model_state_dict and v.shape == model_state_dict[k].shape}

    model_state_dict.update(filtered_state_dict)
    model.load_state_dict(model_state_dict)

    # Load optimizer and scaler state, handling potential key mismatches
    try:
        opt.load_state_dict(checkpoint['optimizer'])
    except ValueError as e:
        print(f"Could not load optimizer state dict: {e}. Starting with a new optimizer state.")

    try:
        scaler.load_state_dict(checkpoint['scaler'])
    except ValueError as e:
        print(f"Could not load scaler state dict: {e}. Starting with a new scaler state.")

    start_epoch = checkpoint['epoch'] + 1
    best_rmse = checkpoint.get('best_rmse', float('inf')) # Use .get with a default value
    best_r2 = checkpoint.get('best_r2', float('-inf')) # Use .get with a default value

    print(f"Resumed training from epoch {start_epoch}")


for epoch in range(start_epoch, epochs+1):
    model.train()
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}", leave=False)
    running = 0.0
    for batch in pbar:
        x = batch["pixel_values"].to(device, non_blocking=True)
        y = batch["labels"].to(device, non_blocking=True)

        with torch.cuda.amp.autocast(enabled=(device=="cuda")):
            pred = model(x)
            loss = masked_mse_mean(pred, y)

        opt.zero_grad(set_to_none=True)
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()

        running += loss.item()
        pbar.set_postfix(loss=f"{running/ (pbar.n or 1):.4f}")

    metrics = eval_epoch()
    print(f"Epoch {epoch}: val_mae={metrics['mae']:.4f} "
          f"val_rmse={metrics['rmse']:.4f} val_r2={metrics['r2']:.4f}")

    # save best by RMSE and record R2
    if metrics["rmse"] < best_rmse:
        best_rmse = metrics["rmse"]
        best_r2 = metrics["r2"]  # Record the corresponding R2
        torch.save({"state_dict": model.state_dict(),
                    "metrics": metrics,
                    "epoch": epoch,
                    'best_rmse': best_rmse, # Save best_rmse and best_r2 in best_path checkpoint
                    'best_r2': best_r2}, best_path)
        print(f"  ↳ saved best to {best_path}")

    # Save checkpoint
    torch.save({
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': opt.state_dict(),
        'scaler': scaler.state_dict(),
        'best_rmse': best_rmse,
        'best_r2': best_r2
    }, checkpoint_path)
    print(f"  ↳ saved checkpoint to {checkpoint_path}")


print("Best val RMSE:", best_rmse)
print("Corresponding val R2:", best_r2) # Print the best R2

# # ====== 9) Inference helper (save predictions for the val set)
# @torch.no_grad()
# def save_val_predictions(save_dir):
#     os.makedirs(save_dir, exist_ok=True)
#     model.eval()
#     for batch in tqdm(val_loader, desc="Saving val preds"):
#         x = batch["pixel_values"].to(device, non_blocking=True)
#         fnames = batch["file_name"]
#         pred = model(x).cpu().numpy()  # [B,1,H,W]
#         for i, fname in enumerate(fnames):
#             np.save(os.path.join(save_dir, fname), pred[i,0])

# pred_dir = os.path.join(out_dir, "preds_val")
# save_val_predictions(pred_dir)
# print("Predictions saved to:", pred_dir)

  scaler = torch.cuda.amp.GradScaler(enabled=(device=="cuda"))


Loading checkpoint from /content/drive/MyDrive/cnn_unet_baseline/checkpoint.pt
Could not load optimizer state dict: loaded state dict contains a parameter group that doesn't match the size of optimizer's group. Starting with a new optimizer state.
Resumed training from epoch 19


  with torch.cuda.amp.autocast(enabled=(device=="cuda")):


Epoch 19: val_mae=0.1578 val_rmse=0.2622 val_r2=0.0799
  ↳ saved best to /content/drive/MyDrive/cnn_unet_baseline/unet_tiny_best.pt
  ↳ saved checkpoint to /content/drive/MyDrive/cnn_unet_baseline/checkpoint.pt




Epoch 20: val_mae=0.1573 val_rmse=0.2615 val_r2=0.0847
  ↳ saved best to /content/drive/MyDrive/cnn_unet_baseline/unet_tiny_best.pt
  ↳ saved checkpoint to /content/drive/MyDrive/cnn_unet_baseline/checkpoint.pt




Epoch 21: val_mae=0.1571 val_rmse=0.2591 val_r2=0.1032
  ↳ saved best to /content/drive/MyDrive/cnn_unet_baseline/unet_tiny_best.pt
  ↳ saved checkpoint to /content/drive/MyDrive/cnn_unet_baseline/checkpoint.pt




Epoch 22: val_mae=0.1741 val_rmse=0.2646 val_r2=0.0761
  ↳ saved checkpoint to /content/drive/MyDrive/cnn_unet_baseline/checkpoint.pt




Epoch 23: val_mae=0.1507 val_rmse=0.2607 val_r2=0.0880
  ↳ saved checkpoint to /content/drive/MyDrive/cnn_unet_baseline/checkpoint.pt




Epoch 24: val_mae=0.1692 val_rmse=0.2636 val_r2=0.0798
  ↳ saved checkpoint to /content/drive/MyDrive/cnn_unet_baseline/checkpoint.pt




Epoch 25: val_mae=0.1583 val_rmse=0.2577 val_r2=0.1153
  ↳ saved best to /content/drive/MyDrive/cnn_unet_baseline/unet_tiny_best.pt
  ↳ saved checkpoint to /content/drive/MyDrive/cnn_unet_baseline/checkpoint.pt




Epoch 26: val_mae=0.1569 val_rmse=0.2563 val_r2=0.1220
  ↳ saved best to /content/drive/MyDrive/cnn_unet_baseline/unet_tiny_best.pt
  ↳ saved checkpoint to /content/drive/MyDrive/cnn_unet_baseline/checkpoint.pt




Epoch 27: val_mae=0.1684 val_rmse=0.2594 val_r2=0.1102
  ↳ saved checkpoint to /content/drive/MyDrive/cnn_unet_baseline/checkpoint.pt




Epoch 28: val_mae=0.1500 val_rmse=0.2524 val_r2=0.1477
  ↳ saved best to /content/drive/MyDrive/cnn_unet_baseline/unet_tiny_best.pt
  ↳ saved checkpoint to /content/drive/MyDrive/cnn_unet_baseline/checkpoint.pt




Epoch 29: val_mae=0.1470 val_rmse=0.2480 val_r2=0.1766
  ↳ saved best to /content/drive/MyDrive/cnn_unet_baseline/unet_tiny_best.pt
  ↳ saved checkpoint to /content/drive/MyDrive/cnn_unet_baseline/checkpoint.pt




Epoch 30: val_mae=0.1476 val_rmse=0.2468 val_r2=0.1860
  ↳ saved best to /content/drive/MyDrive/cnn_unet_baseline/unet_tiny_best.pt
  ↳ saved checkpoint to /content/drive/MyDrive/cnn_unet_baseline/checkpoint.pt




Epoch 31: val_mae=0.1470 val_rmse=0.2456 val_r2=0.1933
  ↳ saved best to /content/drive/MyDrive/cnn_unet_baseline/unet_tiny_best.pt
  ↳ saved checkpoint to /content/drive/MyDrive/cnn_unet_baseline/checkpoint.pt




Epoch 32: val_mae=0.1548 val_rmse=0.2480 val_r2=0.1847
  ↳ saved checkpoint to /content/drive/MyDrive/cnn_unet_baseline/checkpoint.pt




Epoch 33: val_mae=0.1666 val_rmse=0.2548 val_r2=0.1492
  ↳ saved checkpoint to /content/drive/MyDrive/cnn_unet_baseline/checkpoint.pt




Epoch 34: val_mae=0.1473 val_rmse=0.2444 val_r2=0.2036
  ↳ saved best to /content/drive/MyDrive/cnn_unet_baseline/unet_tiny_best.pt
  ↳ saved checkpoint to /content/drive/MyDrive/cnn_unet_baseline/checkpoint.pt




Epoch 35: val_mae=0.1471 val_rmse=0.2430 val_r2=0.2125
  ↳ saved best to /content/drive/MyDrive/cnn_unet_baseline/unet_tiny_best.pt
  ↳ saved checkpoint to /content/drive/MyDrive/cnn_unet_baseline/checkpoint.pt




Epoch 36: val_mae=0.1431 val_rmse=0.2414 val_r2=0.2195
  ↳ saved best to /content/drive/MyDrive/cnn_unet_baseline/unet_tiny_best.pt
  ↳ saved checkpoint to /content/drive/MyDrive/cnn_unet_baseline/checkpoint.pt




Epoch 37: val_mae=0.1469 val_rmse=0.2480 val_r2=0.1777
  ↳ saved checkpoint to /content/drive/MyDrive/cnn_unet_baseline/checkpoint.pt




Epoch 38: val_mae=0.1540 val_rmse=0.2447 val_r2=0.2075
  ↳ saved checkpoint to /content/drive/MyDrive/cnn_unet_baseline/checkpoint.pt




Epoch 39: val_mae=0.1413 val_rmse=0.2397 val_r2=0.2309
  ↳ saved best to /content/drive/MyDrive/cnn_unet_baseline/unet_tiny_best.pt
  ↳ saved checkpoint to /content/drive/MyDrive/cnn_unet_baseline/checkpoint.pt




Epoch 40: val_mae=0.1416 val_rmse=0.2388 val_r2=0.2366
  ↳ saved best to /content/drive/MyDrive/cnn_unet_baseline/unet_tiny_best.pt
  ↳ saved checkpoint to /content/drive/MyDrive/cnn_unet_baseline/checkpoint.pt




Epoch 41: val_mae=0.1690 val_rmse=0.2551 val_r2=0.1506
  ↳ saved checkpoint to /content/drive/MyDrive/cnn_unet_baseline/checkpoint.pt




Epoch 42: val_mae=0.1419 val_rmse=0.2387 val_r2=0.2371
  ↳ saved best to /content/drive/MyDrive/cnn_unet_baseline/unet_tiny_best.pt
  ↳ saved checkpoint to /content/drive/MyDrive/cnn_unet_baseline/checkpoint.pt




Epoch 43: val_mae=0.1427 val_rmse=0.2393 val_r2=0.2332
  ↳ saved checkpoint to /content/drive/MyDrive/cnn_unet_baseline/checkpoint.pt




Epoch 44: val_mae=0.1415 val_rmse=0.2372 val_r2=0.2473
  ↳ saved best to /content/drive/MyDrive/cnn_unet_baseline/unet_tiny_best.pt
  ↳ saved checkpoint to /content/drive/MyDrive/cnn_unet_baseline/checkpoint.pt




Epoch 45: val_mae=0.1433 val_rmse=0.2381 val_r2=0.2422
  ↳ saved checkpoint to /content/drive/MyDrive/cnn_unet_baseline/checkpoint.pt




Epoch 46: val_mae=0.1427 val_rmse=0.2395 val_r2=0.2313
  ↳ saved checkpoint to /content/drive/MyDrive/cnn_unet_baseline/checkpoint.pt




Epoch 47: val_mae=0.1455 val_rmse=0.2384 val_r2=0.2435
  ↳ saved checkpoint to /content/drive/MyDrive/cnn_unet_baseline/checkpoint.pt




Epoch 48: val_mae=0.1418 val_rmse=0.2374 val_r2=0.2467
  ↳ saved checkpoint to /content/drive/MyDrive/cnn_unet_baseline/checkpoint.pt




Epoch 49: val_mae=0.1525 val_rmse=0.2429 val_r2=0.2191
  ↳ saved checkpoint to /content/drive/MyDrive/cnn_unet_baseline/checkpoint.pt




Epoch 50: val_mae=0.1439 val_rmse=0.2381 val_r2=0.2434
  ↳ saved checkpoint to /content/drive/MyDrive/cnn_unet_baseline/checkpoint.pt
Best val RMSE: 0.2372476402670145
Corresponding val R2: 0.2473417485400904
