In [None]:


import os
import random
from glob import glob
from pathlib import Path
import shutil
import zipfile
import cv2
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T, models as tv_models
from PIL import Image

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback
from pytorch_lightning.loggers import CSVLogger

from torchmetrics.functional import structural_similarity_index_measure as ssim_fn

In [None]:
WORK_DIR = Path("/kaggle/working/")
MODELS_DIR = WORK_DIR / "models"
CHECKPOINTS_DIR = WORK_DIR / "checkpoints"
OUTPUT_DIR = WORK_DIR / "outputs"
for p in (MODELS_DIR, CHECKPOINTS_DIR, OUTPUT_DIR):
    p.mkdir(parents=True, exist_ok=True)

DATASET_DIR = "/kaggle/input/trident/DATA"
ICAM_DIR = os.path.join(DATASET_DIR, "Icam")
ICLEAN_DIR = os.path.join(DATASET_DIR, "Iclean")

IMG_SIZE = 256
BATCH_SIZE = 8
NUM_WORKERS = 4
LR = 1e-4
WEIGHT_DECAY = 1e-6
MAX_EPOCHS = 80
STAGE1_EPOCHS = 10
SAVE_PTH_PERIOD = 10
SEED = 42
NUM_SAMPLES = 100

torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

In [None]:
class PairedCLAHEImageDataset(Dataset):
    def __init__(self, in_files, gt_files, img_size=256, apply_clahe=True):
        assert len(in_files) == len(gt_files)
        self.in_files = in_files
        self.gt_files = gt_files
        self.img_size = img_size
        self.apply_clahe = apply_clahe
        self.to_tensor = T.ToTensor()

    def __len__(self):
        return len(self.in_files)

    def apply_clahe_rgb(self, pil_img):
        img = np.array(pil_img)
        lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
        l, a, b = cv2.split(lab)
        cl = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)).apply(l)
        lab = cv2.merge((cl, a, b))
        rgb = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
        return Image.fromarray(rgb)

    def __getitem__(self, idx):
        inp_p = Image.open(self.in_files[idx]).convert("RGB")
        tgt_p = Image.open(self.gt_files[idx]).convert("RGB")

        inp_p = inp_p.resize((self.img_size, self.img_size), Image.BICUBIC)
        tgt_p = tgt_p.resize((self.img_size, self.img_size), Image.BICUBIC)

        if self.apply_clahe:
            try:
                inp_p = self.apply_clahe_rgb(inp_p)
            except:
                pass

        return self.to_tensor(inp_p), self.to_tensor(tgt_p)


def get_paired_file_lists(icam_dir, iclean_dir):
    icam = sorted(glob(os.path.join(icam_dir, "*")))
    iclean = sorted(glob(os.path.join(iclean_dir, "*")))
    icam_map = {os.path.basename(p): p for p in icam}
    iclean_map = {os.path.basename(p): p for p in iclean}
    common = sorted(list(set(icam_map.keys()) & set(iclean_map.keys())))
    return [icam_map[k] for k in common], [iclean_map[k] for k in common]


icam_list, iclean_list = get_paired_file_lists(ICAM_DIR, ICLEAN_DIR)
total_pairs = len(icam_list)
print("Found total paired images:", total_pairs)

if NUM_SAMPLES is None or NUM_SAMPLES <= 0 or NUM_SAMPLES > total_pairs:
    NUM_SAMPLES = total_pairs

indices = list(range(total_pairs))
random.shuffle(indices)

selected = indices[:NUM_SAMPLES]
leftover = indices[NUM_SAMPLES:]  # for inference/export demo

icam_sel = [icam_list[i] for i in selected]
iclean_sel = [iclean_list[i] for i in selected]

infer_icam = [icam_list[i] for i in leftover]
infer_iclean = [iclean_list[i] for i in leftover]

# Train / Val / Test Split
n_train = int(0.8 * NUM_SAMPLES)
n_val = int(0.1 * NUM_SAMPLES)
n_test = NUM_SAMPLES - n_train - n_val

train_in = icam_sel[:n_train]
train_gt = iclean_sel[:n_train]
val_in = icam_sel[n_train:n_train + n_val]
val_gt = iclean_sel[n_train:n_train + n_val]
test_in = icam_sel[n_train + n_val:]
test_gt = iclean_sel[n_train + n_val:]

print("Train / Val / Test =", len(train_in), len(val_in), len(test_in))
print("Leftover for inference:", len(infer_icam))

In [None]:
class UnderWaterDataModule(pl.LightningDataModule):
    def __init__(self, train_pairs, val_pairs, test_pairs, infer_pairs,
                 img_size=256, batch_size=8, num_workers=4):
        super().__init__()
        self.train_pairs = train_pairs
        self.val_pairs = val_pairs
        self.test_pairs = test_pairs
        self.infer_pairs = infer_pairs
        self.img_size = img_size
        self.batch_size = batch_size
        self.num_workers = num_workers

    def setup(self, stage=None):
        self.train_ds = PairedCLAHEImageDataset(*self.train_pairs, self.img_size, apply_clahe=True)
        self.val_ds = PairedCLAHEImageDataset(*self.val_pairs, self.img_size, apply_clahe=False)
        self.test_ds = PairedCLAHEImageDataset(*self.test_pairs, self.img_size, apply_clahe=False)
        self.infer_ds = PairedCLAHEImageDataset(*self.infer_pairs, self.img_size, apply_clahe=False)

    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_size=self.batch_size, shuffle=True,
                          num_workers=self.num_workers, pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.val_ds, batch_size=self.batch_size,
                          num_workers=self.num_workers, pin_memory=True)

    def test_dataloader(self):
        return DataLoader(self.test_ds, batch_size=self.batch_size,
                          num_workers=self.num_workers, pin_memory=True)

    def predict_dataloader(self):
        return DataLoader(self.infer_ds, batch_size=self.batch_size,
                          num_workers=self.num_workers, pin_memory=True)


dm = UnderWaterDataModule(
    train_pairs=(train_in, train_gt),
    val_pairs=(val_in, val_gt),
    test_pairs=(test_in, test_gt),
    infer_pairs=(infer_icam, infer_iclean),
    img_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS
)

dm.setup()
print("DataModule ready!")

In [None]:
class ChannelAttention(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(channels, channels // reduction, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(channels // reduction, channels, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg = self.fc(self.avg_pool(x))
        mx = self.fc(self.max_pool(x))
        out = self.sigmoid(avg + mx)
        return x * out


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super().__init__()
        assert kernel_size in (3, 7)
        padding = 3 if kernel_size == 7 else 1
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # channel-wise avg and max
        avg = torch.mean(x, dim=1, keepdim=True)
        mx, _ = torch.max(x, dim=1, keepdim=True)
        cat = torch.cat([avg, mx], dim=1)
        out = self.sigmoid(self.conv(cat))
        return x * out


class CBAM(nn.Module):
    def __init__(self, channels, reduction=16, kernel_size=7):
        super().__init__()
        self.ca = ChannelAttention(channels, reduction)
        self.sa = SpatialAttention(kernel_size)

    def forward(self, x):
        x = self.ca(x)
        x = self.sa(x)
        return x


print("CBAM modules defined!")

In [None]:
class ResNet34_UNet_CBAM(nn.Module):
    def __init__(self, pretrained=True, cbam_reduction=16, dropout=0.1):
        super().__init__()
        resnet = tv_models.resnet34(weights=tv_models.ResNet34_Weights.IMAGENET1K_V1 if pretrained else None)

        # Encoder layers (we'll keep references for skip connections)
        self.conv1 = resnet.conv1  # out: 64, /2 (after conv)
        self.bn1 = resnet.bn1
        self.relu = resnet.relu
        self.maxpool = resnet.maxpool

        self.layer1 = resnet.layer1  # 64  (after maxpool)
        self.layer2 = resnet.layer2  # 128
        self.layer3 = resnet.layer3  # 256
        self.layer4 = resnet.layer4  # 512

        # CBAM on deepest features
        self.cbam = CBAM(channels=512, reduction=cbam_reduction, kernel_size=7)

        # Decoder conv blocks helper
        def conv_block(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
                nn.BatchNorm2d(out_c),
                nn.ReLU(inplace=True),
                nn.Dropout2d(p=dropout)
            )

        # Decoder (upsample & concat with corresponding skip)
        self.up3 = conv_block(512 + 256, 256)
        self.up2 = conv_block(256 + 128, 128)
        self.up1 = conv_block(128 + 64, 64)
        self.up0 = conv_block(64 + 64, 64)
        self.up_final = conv_block(64, 64)

        self.final_conv = nn.Conv2d(64, 3, kernel_size=1)

    def forward(self, x):
        # Encoder
        x0 = self.relu(self.bn1(self.conv1(x)))  # [B,64,H/2,W/2]
        x1 = self.layer1(self.maxpool(x0))  # [B,64,H/4,W/4]
        x2 = self.layer2(x1)  # [B,128,H/8,W/8]
        x3 = self.layer3(x2)  # [B,256,H/16,W/16]
        x4 = self.layer4(x3)  # [B,512,H/32,W/32]

        # CBAM applied to deep feature
        z = self.cbam(x4)

        # Decoder path - upsample + concat + conv block
        u3 = F.interpolate(z, scale_factor=2, mode='bilinear', align_corners=False)
        u3 = self.up3(torch.cat([u3, x3], dim=1))

        u2 = F.interpolate(u3, scale_factor=2, mode='bilinear', align_corners=False)
        u2 = self.up2(torch.cat([u2, x2], dim=1))

        u1 = F.interpolate(u2, scale_factor=2, mode='bilinear', align_corners=False)
        u1 = self.up1(torch.cat([u1, x1], dim=1))

        u0 = F.interpolate(u1, scale_factor=2, mode='bilinear', align_corners=False)
        u0 = self.up0(torch.cat([u0, x0], dim=1))

        uF = F.interpolate(u0, scale_factor=2, mode='bilinear', align_corners=False)
        uF = self.up_final(uF)

        out = torch.sigmoid(self.final_conv(uF))  # ensure 0-1
        return out


print("Model architecture defined!")

In [None]:
class SSIMLoss(nn.Module):
    """✅ FIXED: Correct class name in super() call"""
    def __init__(self, window_size=11, channel=3):
        super(SSIMLoss, self).__init__()  # ✅ Fixed from super(SSIMLoss, self)
        self.window_size = window_size
        self.channel = channel

        sigma = 1.5
        coords = torch.arange(window_size).float() - window_size // 2
        g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
        g = g / g.sum()

        window = g[:, None] * g[None, :]
        self.register_buffer("window", window.expand(channel, 1, window_size, window_size).clone())

    def forward(self, img1, img2):
        window = self.window.to(img1.device)

        mu1 = F.conv2d(img1, window, padding=self.window_size // 2, groups=self.channel)
        mu2 = F.conv2d(img2, window, padding=self.window_size // 2, groups=self.channel)

        mu1_sq = mu1 ** 2
        mu2_sq = mu2 ** 2
        mu12 = mu1 * mu2

        sigma1_sq = F.conv2d(img1 * img1, window, padding=self.window_size // 2, groups=self.channel) - mu1_sq
        sigma2_sq = F.conv2d(img2 * img2, window, padding=self.window_size // 2, groups=self.channel) - mu2_sq
        sigma12 = F.conv2d(img1 * img2, window, padding=self.window_size // 2, groups=self.channel) - mu12

        C1 = 0.01 ** 2
        C2 = 0.03 ** 2

        ssim = ((2 * mu12 + C1) * (2 * sigma12 + C2)) / \
               ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))

        return torch.clamp((1 - ssim) / 2, 0, 1).mean()


print("SSIM Loss defined!")

In [None]:
class VGGPerceptualLoss(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = tv_models.vgg19(weights=tv_models.VGG19_Weights.IMAGENET1K_V1).features
        self.slice = nn.Sequential(*list(vgg[:16])).eval()
        for p in self.slice.parameters():
            p.requires_grad = False

    def forward(self, x, y):
        x_vgg = self.slice(x)
        y_vgg = self.slice(y)
        return F.l1_loss(x_vgg, y_vgg)


print("VGG Perceptual Loss defined!")

In [None]:
class RestorationLitModel(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.save_hyperparameters(hparams)

        self.model = ResNet34_UNet_CBAM(pretrained=True, cbam_reduction=16, dropout=0.1)
        self.perc_loss = VGGPerceptualLoss()
        self.ssim_loss = SSIMLoss()  # ✅ FIXED: Instantiate the class

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        opt = torch.optim.AdamW(
            self.parameters(),
            lr=self.hparams.get("lr", LR),
            weight_decay=self.hparams.get("weight_decay", WEIGHT_DECAY)
        )
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            opt, mode='min', factor=0.5, patience=4, verbose=True
        )
        return {"optimizer": opt, "lr_scheduler": {"scheduler": scheduler, "monitor": "val/total"}}

    def training_step(self, batch, batch_idx):
        inp, tgt = batch
        pred = self(inp)

        l1 = F.l1_loss(pred, tgt)
        perc = self.perc_loss(pred, tgt)
        ssim = self.ssim_loss(pred, tgt)  # ✅ Already returns loss (1 - SSIM)
        total = l1 + 0.1 * perc + 0.1 * ssim

        self.log("train/l1", l1, on_epoch=True, prog_bar=False)
        self.log("train/perc", perc, on_epoch=True, prog_bar=False)
        self.log("train/ssim", ssim, on_epoch=True, prog_bar=False)
        self.log("train/total", total, on_epoch=True, prog_bar=True)

        return total

    def validation_step(self, batch, batch_idx):
        inp, tgt = batch
        pred = self(inp)

        l1 = F.l1_loss(pred, tgt)
        perc = self.perc_loss(pred, tgt)
        ssim = self.ssim_loss(pred, tgt)
        total = l1 + 0.1 * perc + 0.1 * ssim

        self.log("val/l1", l1, on_epoch=True, prog_bar=False)
        self.log("val/perc", perc, on_epoch=True, prog_bar=False)
        self.log("val/ssim", ssim, on_epoch=True, prog_bar=False)
        self.log("val/total", total, on_epoch=True, prog_bar=True)

        return total

    def test_step(self, batch, batch_idx):
        inp, tgt = batch
        pred = self(inp)
        loss = F.l1_loss(pred, tgt)
        self.log("test/loss", loss)
        return {"test_loss": loss}


print("Lightning Module defined!")

# Not to train cell starts

In [None]:
class PeriodicWeightsCallback(Callback):
    """Save model.state_dict() (weights-only .pth) every 'period' epochs and when best model updates."""
    def __init__(self, save_dir=MODELS_DIR, period=SAVE_PTH_PERIOD):
        super().__init__()
        self.save_dir = Path(save_dir)
        self.period = period
        self.best_val = float("inf")
        self.save_dir.mkdir(parents=True, exist_ok=True)

    def on_validation_end(self, trainer, pl_module):
        epoch = trainer.current_epoch + 1
        metrics = trainer.callback_metrics
        
        # Save every period
        if epoch % self.period == 0:
            dest = self.save_dir / f"weights_epoch_{epoch:03d}.pth"
            torch.save(pl_module.model.state_dict(), str(dest))
            print(f"\nSaved weights-only snapshot: {dest}")
        
        # Save best by val loss
        val_loss = metrics.get("val/total")
        if val_loss is not None:
            try:
                val_loss_val = float(val_loss)
                if val_loss_val < self.best_val:
                    self.best_val = val_loss_val
                    dest = self.save_dir / f"best_weights_epoch_{epoch:03d}.pth"
                    torch.save(pl_module.model.state_dict(), str(dest))
                    print(f"\nSaved NEW best weights-only snapshot: {dest}")
            except Exception:
                pass


class PrintEveryNEpochsCallback(Callback):
    def __init__(self, n=5):
        super().__init__()
        self.n = n

    def on_validation_epoch_end(self, trainer, pl_module):
        epoch = trainer.current_epoch + 1
        if epoch % self.n == 0 or epoch == 1:
            metrics = trainer.callback_metrics
            tl = metrics.get("train/total")
            l1 = metrics.get("train/l1")
            perc = metrics.get("train/perc")
            ssimv = metrics.get("train/ssim")
            vl = metrics.get("val/total")
            vl1 = metrics.get("val/l1")
            vperc = metrics.get("val/perc")
            vssim = metrics.get("val/ssim")
            
            print(f"\n{'='*60}")
            print(f"Epoch {epoch:03d} SUMMARY:")
            if tl is not None:
                print(f"  Train -> Total: {float(tl):.4f}  L1: {float(l1):.4f}  Perc: {float(perc):.4f}  SSIM: {float(ssimv):.4f}")
            if vl is not None:
                print(f"  Val   -> Total: {float(vl):.4f}  L1: {float(vl1):.4f}  Perc: {float(vperc):.4f}  SSIM: {float(vssim):.4f}")
            print(f"{'='*60}\n")


print("Callbacks defined!")


In [None]:
hparams = {"lr": LR, "weight_decay": WEIGHT_DECAY}
model = RestorationLitModel(hparams)

checkpoint_callback = ModelCheckpoint(
    dirpath=str(CHECKPOINTS_DIR),
    filename="best-{epoch:02d}-{val/total:.4f}",
    monitor="val/total",  # ✅ FIXED: Match the logged metric name
    mode="min",
    save_top_k=1,
    save_last=True
)

early_stop = EarlyStopping(
    monitor="val/total",  # ✅ FIXED: Match the logged metric name
    patience=10,
    mode="min"
)

csv_logger = CSVLogger("logs", name="uw_rest_cbam")

periodic_cb = PeriodicWeightsCallback(
    save_dir=MODELS_DIR,
    period=SAVE_PTH_PERIOD
)

printer_cb = PrintEveryNEpochsCallback(n=5)

trainer = pl.Trainer(
    max_epochs=MAX_EPOCHS,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    devices=1,
    precision="16-mixed" if torch.cuda.is_available() else 32,  # ✅ Updated precision syntax
    callbacks=[checkpoint_callback, early_stop, periodic_cb, printer_cb],
    logger=csv_logger,
    gradient_clip_val=1.0,
    log_every_n_steps=10
)

early_stop = EarlyStopping(
    monitor="val/total",  # ✅ FIXED: Match the logged metric name
    patience=10,  # Stops if no improvement for 10 epochs
    mode="min",  # We want to minimize the loss
    verbose=True,  # Print messages when early stopping is triggered
    min_delta=0.0001  # Minimum change to qualify as an improvement
)


# ✅ Use the DataModule (dm)
last_ckpt = CHECKPOINTS_DIR / "last.ckpt"
if last_ckpt.exists():
    print("Resuming from checkpoint:", last_ckpt)
    trainer.fit(model, datamodule=dm, ckpt_path=str(last_ckpt))
else:
    print("Training from scratch.")
    trainer.fit(model, datamodule=dm)

print("\nTraining completed!")

In [None]:
best_ckpt = checkpoint_callback.best_model_path
if best_ckpt:
    shutil.copy(best_ckpt, MODELS_DIR / "best_model_full.ckpt")
    print("Copied best full ckpt to models dir:", MODELS_DIR / "best_model_full.ckpt")

# Zip weights-only snapshots directory
zipf = WORK_DIR / "weights_snapshots.zip"
with zipfile.ZipFile(zipf, 'w', zipfile.ZIP_DEFLATED) as z:
    for p in MODELS_DIR.glob("*.pth"):
        z.write(p, arcname=p.name)
print("Zipped weights snapshots to:", zipf)

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import os

# Path to CSV (update as needed)
metrics_path = "/kaggle/working/logs/uw_rest_cbam/version_7/metrics.csv"
if not os.path.exists(metrics_path):
    print("Metrics CSV not found. Training log may be incomplete.")
else:
    df = pd.read_csv(metrics_path)
    loss_cols = [col for col in df.columns if 'train_loss' in col or 'val_loss' in col]
    if not loss_cols:
        print("No train/val loss columns found yet. Model may still be training, or logger needs updating.")
    else:
        plt.figure(figsize=(8, 5))
        for col in loss_cols:
            plt.plot(df[col].dropna().values, label=col)
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.title("Training/Validation Loss Curves")
        plt.legend()
        plt.grid(True)
        plt.show()

# Not to run cell ends

# Checkpoint cell

In [None]:
import torch
from pathlib import Path

# Define path to checkpoint in Kaggle input
checkpoint_path = Path("/kaggle/input/test-1/best_weights_epoch_041.pth")

if not checkpoint_path.is_file():
    raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")

print(f"Loading checkpoint from: {checkpoint_path}")

# Instantiate your model exactly as in your notebook
infer_model = ResNet34_UNet_CBAM(pretrained=False)

# Load weights (state dict) from checkpoint
state = torch.load(str(checkpoint_path), map_location="cpu")
infer_model.load_state_dict(state)
infer_model.eval()

# Move model to device (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
infer_model.to(device)

print("Model loaded and ready for inference on device:", device)


In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image
import shutil
import os

# Assuming these variables/constants are imported or declared before:
# dm - your lightning DataModule
# ResNet34_UNet_CBAM - your model class
# OUTPUT_DIR, MODELS_DIR, CHECKPOINTS_DIR, BATCH_SIZE, ssim_fn (torchmetrics)

# Path to weights .pth snapshot file you want to load for inference
weights_path = "/kaggle/input/test-1/best_weights_epoch_041.pth"
weights_path = Path(weights_path)

if not weights_path.is_file():
    raise FileNotFoundError(f"No .pth weights file found at: {weights_path}")

print(f"Loading weights-only for inference: {weights_path}")

# Instantiate model and load weights
infer_model = ResNet34_UNet_CBAM(pretrained=False)
state = torch.load(weights_path, map_location="cpu")
infer_model.load_state_dict(state)
infer_model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
infer_model.to(device)
print(f"Model loaded on {device}")

# Run inference on predict dataloader
infer_loader = dm.predict_dataloader()
results = []

print("Running inference...")
with torch.no_grad():
    for batch_idx, batch in enumerate(infer_loader):
        inp, tgt = batch
        inp = inp.to(device)
        pred = infer_model(inp)

        results.append({
            'input': inp.cpu(),
            'pred': pred.cpu(),
            'target': tgt.cpu(),
            'batch_idx': batch_idx
        })

print(f"Ran inference on {len(results)} batches ({len(results)*BATCH_SIZE} images).")

# Utility to convert tensor to numpy image for visualization
def tensor_to_np(t):
    if t.dim() == 4:
        t = t[0]
    return torch.clamp(t, 0, 1).permute(1, 2, 0).numpy()

# Visualize first 30 batches
n_show = min(30, len(results))
rows = []
for i in range(n_show):
    batch_result = results[i]
    inp = batch_result['input']
    pred = batch_result['pred']
    tgt = batch_result['target']

    inp_img = tensor_to_np(inp)
    pred_img = tensor_to_np(pred)
    tgt_img = tensor_to_np(tgt)

    row = np.concatenate([inp_img, pred_img, tgt_img], axis=1)
    rows.append(row)

if rows:
    grid = np.concatenate(rows, axis=0)
    out_file = OUTPUT_DIR / "inference_sample_grid.png"
    plt.figure(figsize=(15, 5 * n_show))
    plt.imshow(grid)
    plt.axis("off")
    plt.title("Input | Restored | Ground Truth", fontsize=16, pad=20)
    plt.tight_layout()
    plt.savefig(out_file, bbox_inches="tight", dpi=150)
    plt.show()
    print(f"✅ Saved inference sample grid to: {out_file}")

# Save individual predictions for first 20 batches
print("\nSaving individual predictions...")
pred_dir = OUTPUT_DIR / "predictions"
pred_dir.mkdir(exist_ok=True)

saved_count = 0
for batch_idx, batch_result in enumerate(results[:20]):
    pred_batch = batch_result['pred']
    for img_idx in range(pred_batch.shape[0]):
        pred_img = tensor_to_np(pred_batch[img_idx:img_idx+1])
        pred_pil = Image.fromarray((pred_img * 255).astype(np.uint8))
        save_path = pred_dir / f"pred_batch{batch_idx:03d}_img{img_idx:02d}.png"
        pred_pil.save(save_path)
        saved_count += 1

print(f"✅ Saved {saved_count} individual predictions to: {pred_dir}")

# Calculate metrics PSNR and SSIM over test set
print("\nCalculating metrics on predictions...")
psnr_values = []
ssim_values = []

for batch_result in results:
    pred_batch = batch_result['pred']
    tgt_batch = batch_result['target']

    for i in range(pred_batch.shape[0]):
        pred_img = pred_batch[i:i+1].to(device)
        tgt_img = tgt_batch[i:i+1].to(device)

        mse = F.mse_loss(pred_img, tgt_img)
        psnr = 10 * torch.log10(1.0 / mse)
        psnr_values.append(psnr.item())

        ssim_val = ssim_fn(pred_img, tgt_img, data_range=1.0)
        ssim_values.append(ssim_val.item())

avg_psnr = np.mean(psnr_values)
avg_ssim = np.mean(ssim_values)

print(f"\n{'='*60}")
print(f"INFERENCE METRICS:")
print(f"  Average PSNR: {avg_psnr:.2f} dB")
print(f"  Average SSIM: {avg_ssim:.4f}")
print(f"  Total images evaluated: {len(psnr_values)}")
print(f"{'='*60}\n")

# Save metrics to text file
metrics_file = OUTPUT_DIR / "inference_metrics.txt"
with open(metrics_file, 'w') as f:
    f.write("Inference Metrics\n")
    f.write("="*60 + "\n")
    f.write(f"Model: {weights_path}\n")
    f.write(f"Number of images: {len(psnr_values)}\n")
    f.write(f"Average PSNR: {avg_psnr:.2f} dB\n")
    f.write(f"Average SSIM: {avg_ssim:.4f}\n")
    f.write(f"PSNR std: {np.std(psnr_values):.2f}\n")
    f.write(f"SSIM std: {np.std(ssim_values):.4f}\n")

print(f"✅ Saved metrics to: {metrics_file}")

print("\nALL DONE! ✅")
print(f"Check outputs in: {OUTPUT_DIR}")


In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image
from skimage import exposure
from torchmetrics.functional import structural_similarity_index_measure as ssim_fn

# User params
BATCH_SIZE = 8
PLOT_LIMIT = 32

def adaptive_red_boost(img):
    r_mean = img[..., 0].mean()
    g_mean = img[..., 1].mean()
    b_mean = img[..., 2].mean()
    if r_mean < 1:
        factor = 2.0
    else:
        factor = (g_mean + b_mean) / (2 * r_mean)
    factor = np.clip(factor, 1.2, 3.0)
    out = img.copy()
    out[..., 0] = np.clip(out[..., 0] * factor, 0, 255)
    return out.astype(np.uint8)

def gamma_correct(img, gamma=1.4):
    img_corr = exposure.adjust_gamma(img / 255., gamma)
    return (img_corr * 255).astype(np.uint8)

def clahe_enhancement(img):
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    img_clahe = np.zeros_like(img)
    for i in range(3):
        img_clahe[..., i] = clahe.apply(img[..., i])
    return img_clahe

def simple_dehaze(img, omega=0.95, win_size=15):
    norm_img = img.astype(np.float32) / 255.0
    dark_channel = cv2.erode(np.min(norm_img, axis=2), np.ones((win_size, win_size)))
    A = np.max(norm_img, axis=(0,1))
    t = 1 - omega * dark_channel[..., np.newaxis]
    t = np.clip(t, 0.1, 1)
    J = (norm_img - A) / t + A
    J = np.clip(J * 255, 0, 255).astype(np.uint8)
    return J

def tensor_to_img(t):
    t = torch.clamp(t, 0, 1)
    img = t.detach().cpu().permute(1, 2, 0).numpy()
    img = (img * 255).astype(np.uint8)
    return img

from pathlib import Path

checkpoint_path = Path("/kaggle/input/test-1/best_weights_epoch_041.pth")
if not checkpoint_path.is_file():
    raise RuntimeError(f"Checkpoint file not found at {checkpoint_path}")

print(f"Loading weights-only for inference: {checkpoint_path}")

infer_model = ResNet34_UNet_CBAM(pretrained=False)
state = torch.load(checkpoint_path, map_location="cpu")
infer_model.load_state_dict(state)
infer_model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
infer_model.to(device)
print(f"Model loaded on {device}")

full_loader = dm.predict_dataloader()

results = []
print("Running inference with metrics computation...")
psnr_values_batch = []
ssim_values_batch = []

with torch.no_grad():
    for batch_idx, batch in enumerate(full_loader):
        inp, tgt = batch
        inp = inp.to(device)
        pred = infer_model(inp)
        pred_cpu = pred.cpu()

        # Compute metrics batch-wise
        for i in range(inp.shape[0]):
            pred_img = pred[i:i+1]
            tgt_img = tgt[i:i+1].to(device)
            mse = F.mse_loss(pred_img, tgt_img)
            psnr = 10 * torch.log10(1.0 / mse)
            ssim_val = ssim_fn(pred_img, tgt_img, data_range=1.0)
            psnr_values_batch.append(psnr.item())
            ssim_values_batch.append(ssim_val.item())

        print(f"Batch {batch_idx}: PSNR = {np.mean(psnr_values_batch):.2f} dB, SSIM = {np.mean(ssim_values_batch):.4f}")

        # Process and store all images for visualization
        for i in range(inp.shape[0]):
            orig_img = tensor_to_img(inp[i])
            gt_img = tensor_to_img(tgt[i])
            pred_img_cpu = pred_cpu[i]

            pred_img = tensor_to_img(pred_img_cpu)
            clahe_img = clahe_enhancement(pred_img)
            dehazed_img = simple_dehaze(clahe_img)
            gamma_corrected = gamma_correct(dehazed_img, gamma=1.4)
            red_boosted = adaptive_red_boost(gamma_corrected)
            

            results.append({
                "original": orig_img,
                "ground_truth": gt_img,
                "restored": pred_img,
                "red_boosted": red_boosted,
                "gamma_corrected": gamma_corrected,
                "clahe_corrected": clahe_img,
                "dehazed": dehazed_img,
            })

        if len(results) >= PLOT_LIMIT:
            break

print(f"Collected {len(results)} images for visualization.")

# Visualization grid: 7 images per row
plt.figure(figsize=(20, PLOT_LIMIT * 3))
titles = ['Camera Input', 'Ground Truth', 'Restored',
          'Red Boosted', 'Gamma Corrected', 'CLAHE Contrast', 'Dehazed']

for idx, res in enumerate(results[:PLOT_LIMIT]):
    imgs = [res[t] for t in ['original', 'ground_truth', 'restored', 'red_boosted', 'gamma_corrected', 'clahe_corrected', 'dehazed']]

    for col_idx, (img, title) in enumerate(zip(imgs, titles)):
        plt.subplot(PLOT_LIMIT, 7, idx * 7 + col_idx + 1)
        plt.imshow(img)
        plt.axis('off')
        if idx == 0:
            plt.title(title, fontsize=10)

plt.tight_layout()
plt.show()


# Improved Inference Model

In [None]:
import torch
import torch.nn.functional as F
from torchmetrics.functional import structural_similarity_index_measure as ssim_fn
from torchvision.models import vgg16
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image
from skimage import exposure
from pathlib import Path

# Parameters
PLOT_LIMIT = 5  # Show 5 samples to keep plots clear

# Define classical color correction functions
def adaptive_red_boost(img):
    r_mean = img[..., 0].mean()
    g_mean = img[..., 1].mean()
    b_mean = img[..., 2].mean()
    factor = 2.0 if r_mean < 1 else (g_mean + b_mean) / (2 * r_mean)
    factor = np.clip(factor, 1.1, 1.2)
    out = img.copy()
    out[..., 0] = np.clip(out[..., 0] * factor, 0, 255)
    return out.astype(np.uint8)

def gamma_correct(img, gamma=1.4):
    img_corr = exposure.adjust_gamma(img / 255., gamma)
    return (img_corr * 255).astype(np.uint8)

def clahe_enhancement(img):
    clahe = cv2.createCLAHE(clipLimit=1.5, tileGridSize=(8,8))
    img_clahe = np.empty_like(img)
    for i in range(3):
        img_clahe[..., i] = clahe.apply(img[..., i])
    return img_clahe

def simple_dehaze(img, omega=0.95, win_size=15):
    norm_img = img.astype(np.float32)/255.0
    dark_channel = cv2.erode(np.min(norm_img, axis=2), np.ones((win_size, win_size)))
    A = np.max(norm_img, axis=(0,1))
    t = 1 - omega * dark_channel[..., np.newaxis]
    t = np.clip(t, 0.1, 1)
    J = (norm_img - A) / t + A
    J = np.clip(J*255, 0, 255).astype(np.uint8)
    return J

def gray_world_white_balance(img):
    img_float = img.astype(np.float32)
    avgR = np.mean(img_float[..., 0])
    avgG = np.mean(img_float[..., 1])
    avgB = np.mean(img_float[..., 2])
    avg = (avgR + avgG + avgB) / 3
    scaleR, scaleG, scaleB = avg/avgR, avg/avgG, avg/avgB
    img_balanced = img_float.copy()
    img_balanced[..., 0] *= scaleR
    img_balanced[..., 1] *= scaleG
    img_balanced[..., 2] *= scaleB
    return np.clip(img_balanced, 0, 255).astype(np.uint8)

def tensor_to_img(tensor):
    tensor = torch.clamp(tensor, 0, 1)
    img = tensor.detach().cpu().permute(1,2,0).numpy()
    return (img * 255).astype(np.uint8)

# Perceptual loss using VGG16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg = vgg16(pretrained=True).features.eval().to(device)
transform_vgg = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224,224)),
    transforms.ToTensor()
])
def perceptual_loss(img1, img2):
    im1 = transform_vgg(img1).unsqueeze(0).to(device)
    im2 = transform_vgg(img2).unsqueeze(0).to(device)
    f1, f2 = vgg(im1), vgg(im2)
    return F.l1_loss(f1, f2).item()

# Load model checkpoint
checkpoint_path = Path("/kaggle/input/test-1/best_weights_epoch_041.pth")
if not checkpoint_path.is_file():
    raise RuntimeError(f"Checkpoint file not found at {checkpoint_path}")

infer_model = ResNet34_UNet_CBAM(pretrained=False)
state = torch.load(checkpoint_path, map_location="cpu")
infer_model.load_state_dict(state)
infer_model.eval()
infer_model.to(device)

# Use your datamodule to prepare dataloader
full_loader = dm.predict_dataloader()

results = []

metrics_before = {'psnr': [], 'ssim': [], 'perceptual': []}
metrics_after = {'psnr': [], 'ssim': [], 'perceptual': []}

import math

# Run inference and corrections
with torch.no_grad():
    for batch_idx, batch in enumerate(full_loader):
        inp, tgt = batch
        inp = inp.to(device)
        pred = infer_model(inp)
        pred_cpu = pred.cpu()

        for i in range(inp.size(0)):
            orig_img = tensor_to_img(inp[i])
            gt_img = tensor_to_img(tgt[i])
            pred_img = tensor_to_img(pred_cpu[i])

            # Metrics before correction
            p_tensor, gt_tensor = pred[i:i+1].to(device), tgt[i:i+1].to(device)
            mse = F.mse_loss(p_tensor, gt_tensor)
            metrics_before['psnr'].append(10 * torch.log10(1.0 / mse).item())
            metrics_before['ssim'].append(ssim_fn(p_tensor, gt_tensor).item())
            metrics_before['perceptual'].append(perceptual_loss(pred_img, gt_img))

            # Classical corrections
            
            
            clahe_img = clahe_enhancement(pred_img)
            #clahe_img = clahe_enhancement(clahe_img)
            dehazed = simple_dehaze(clahe_img)
            gamma_corr_img = gamma_correct(clahe_img, gamma=1.4)
            wb_img = gray_world_white_balance(clahe_img)
            red_boost_img = adaptive_red_boost(clahe_img)
            
            # Metrics after correction
            gamma_tensor = torch.tensor(clahe_img / 255.).unsqueeze(0).permute(0,3,1,2).to(device)
            mse_after = F.mse_loss(gamma_tensor, gt_tensor)
            metrics_after['psnr'].append(10 * torch.log10(1.0 / mse_after).item())
            metrics_after['ssim'].append(ssim_fn(gamma_tensor, gt_tensor).item())
            metrics_after['perceptual'].append(perceptual_loss(gamma_corr_img, gt_img))

            # Append for visualization
            results.append({
                'original': orig_img,
                'ground_truth': gt_img,
                'restored': pred_img,
                'dehazed': dehazed,
                'clahe_corrected': clahe_img,
                'white_balanced': wb_img,
                'red_boosted': red_boost_img,
                'gamma_corrected': gamma_corr_img
            })

            if len(results) >= PLOT_LIMIT:
                break
        if len(results) >= PLOT_LIMIT:
            break

print(f"Avg PSNR before corrections: {np.mean(metrics_before['psnr']):.3f} dB")
print(f"Avg SSIM before corrections: {np.mean(metrics_before['ssim']):.4f}")
print(f"Avg Perceptual loss before corrections: {np.mean(metrics_before['perceptual']):.4f}")

print(f"Avg PSNR after corrections: {np.mean(metrics_after['psnr']):.3f} dB")
print(f"Avg SSIM after corrections: {np.mean(metrics_after['ssim']):.4f}")
print(f"Avg Perceptual loss after corrections: {np.mean(metrics_after['perceptual']):.4f}")

In [None]:
import torch
import torch.nn.functional as F
from torchmetrics.functional import structural_similarity_index_measure as ssim_fn
from torchvision.models import vgg16
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image
from skimage import exposure
from pathlib import Path

# Parameters
PLOT_LIMIT = 2000  # Show 5 samples to keep plots clear

# Define classical color correction functions
def adaptive_red_boost(img):
    r_mean = img[..., 0].mean()
    g_mean = img[..., 1].mean()
    b_mean = img[..., 2].mean()
    factor = 2.0 if r_mean < 1 else (g_mean + b_mean) / (2 * r_mean)
    factor = np.clip(factor, 1.1, 1.2)
    out = img.copy()
    out[..., 0] = np.clip(out[..., 0] * factor, 0, 255)
    return out.astype(np.uint8)

def gamma_correct(img, gamma=1.4):
    img_corr = exposure.adjust_gamma(img / 255., gamma)
    return (img_corr * 255).astype(np.uint8)

def clahe_enhancement(img):
    clahe = cv2.createCLAHE(clipLimit=1.5, tileGridSize=(8,8))
    img_clahe = np.empty_like(img)
    for i in range(3):
        img_clahe[..., i] = clahe.apply(img[..., i])
    return img_clahe

def simple_dehaze(img, omega=0.95, win_size=15):
    norm_img = img.astype(np.float32)/255.0
    dark_channel = cv2.erode(np.min(norm_img, axis=2), np.ones((win_size, win_size)))
    A = np.max(norm_img, axis=(0,1))
    t = 1 - omega * dark_channel[..., np.newaxis]
    t = np.clip(t, 0.1, 1)
    J = (norm_img - A) / t + A
    J = np.clip(J*255, 0, 255).astype(np.uint8)
    return J

def gray_world_white_balance(img):
    img_float = img.astype(np.float32)
    avgR = np.mean(img_float[..., 0])
    avgG = np.mean(img_float[..., 1])
    avgB = np.mean(img_float[..., 2])
    avg = (avgR + avgG + avgB) / 3
    scaleR, scaleG, scaleB = avg/avgR, avg/avgG, avg/avgB
    img_balanced = img_float.copy()
    img_balanced[..., 0] *= scaleR
    img_balanced[..., 1] *= scaleG
    img_balanced[..., 2] *= scaleB
    return np.clip(img_balanced, 0, 255).astype(np.uint8)

def tensor_to_img(tensor):
    tensor = torch.clamp(tensor, 0, 1)
    img = tensor.detach().cpu().permute(1,2,0).numpy()
    return (img * 255).astype(np.uint8)

# Perceptual loss using VGG16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg = vgg16(pretrained=True).features.eval().to(device)
transform_vgg = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224,224)),
    transforms.ToTensor()
])
def perceptual_loss(img1, img2):
    im1 = transform_vgg(img1).unsqueeze(0).to(device)
    im2 = transform_vgg(img2).unsqueeze(0).to(device)
    f1, f2 = vgg(im1), vgg(im2)
    return F.l1_loss(f1, f2).item()

# Load model checkpoint
checkpoint_path = Path("/kaggle/input/test-1/best_weights_epoch_041.pth")
if not checkpoint_path.is_file():
    raise RuntimeError(f"Checkpoint file not found at {checkpoint_path}")

infer_model = ResNet34_UNet_CBAM(pretrained=False)
state = torch.load(checkpoint_path, map_location="cpu")
infer_model.load_state_dict(state)
infer_model.eval()
infer_model.to(device)

# Use your datamodule to prepare dataloader
full_loader = dm.predict_dataloader()

results = []

metrics_before = {'psnr': [], 'ssim': [], 'perceptual': []}
metrics_after = {'psnr': [], 'ssim': [], 'perceptual': []}

import math

# Run inference and corrections
with torch.no_grad():
    for batch_idx, batch in enumerate(full_loader):
        inp, tgt = batch
        inp = inp.to(device)
        pred = infer_model(inp)
        pred_cpu = pred.cpu()

        for i in range(inp.size(0)):
            orig_img = tensor_to_img(inp[i])
            gt_img = tensor_to_img(tgt[i])
            pred_img = tensor_to_img(pred_cpu[i])

            # Metrics before correction
            p_tensor, gt_tensor = pred[i:i+1].to(device), tgt[i:i+1].to(device)
            mse = F.mse_loss(p_tensor, gt_tensor)
            metrics_before['psnr'].append(10 * torch.log10(1.0 / mse).item())
            metrics_before['ssim'].append(ssim_fn(p_tensor, gt_tensor).item())
            metrics_before['perceptual'].append(perceptual_loss(pred_img, gt_img))

            # Classical corrections
            
            
            clahe_img = clahe_enhancement(pred_img)
            #clahe_img = clahe_enhancement(clahe_img)
            #dehazed = simple_dehaze(clahe_img)
            #gamma_corr_img = gamma_correct(clahe_img, gamma=1.4)
            #wb_img = gray_world_white_balance(clahe_img)
            #red_boost_img = adaptive_red_boost(clahe_img)
            
            # Metrics after correction
            gamma_tensor = torch.tensor(clahe_img / 255.).unsqueeze(0).permute(0,3,1,2).to(device)
            mse_after = F.mse_loss(gamma_tensor, gt_tensor)
            metrics_after['psnr'].append(10 * torch.log10(1.0 / mse_after).item())
            metrics_after['ssim'].append(ssim_fn(gamma_tensor, gt_tensor).item())
            metrics_after['perceptual'].append(perceptual_loss(gamma_corr_img, gt_img))

            # Append for visualization
            results.append({
                'original': orig_img,
                'ground_truth': gt_img,
                'restored': pred_img,
                'dehazed': dehazed,
                'clahe_corrected': clahe_img,
                'white_balanced': wb_img,
                'red_boosted': red_boost_img,
                'gamma_corrected': gamma_corr_img
            })

            if len(results) >= PLOT_LIMIT:
                break
        if len(results) >= PLOT_LIMIT:
            break

print(f"Avg PSNR before corrections: {np.mean(metrics_before['psnr']):.3f} dB")
print(f"Avg SSIM before corrections: {np.mean(metrics_before['ssim']):.4f}")
print(f"Avg Perceptual loss before corrections: {np.mean(metrics_before['perceptual']):.4f}")

print(f"Avg PSNR after corrections: {np.mean(metrics_after['psnr']):.3f} dB")
print(f"Avg SSIM after corrections: {np.mean(metrics_after['ssim']):.4f}")
print(f"Avg Perceptual loss after corrections: {np.mean(metrics_after['perceptual']):.4f}")

In [None]:
import torch
import torch.nn.functional as F
from torchmetrics.functional import structural_similarity_index_measure as ssim_fn
from torchvision.models import vgg16
from torchvision import transforms
import numpy as np
import cv2
import matplotlib.pyplot as plt
from pathlib import Path

# Postprocessing functions (same as before)
def adaptive_red_boost(img):
    r_mean = img[..., 0].mean()
    g_mean = img[..., 1].mean()
    b_mean = img[..., 2].mean()
    factor = 2.0 if r_mean < 1 else (g_mean + b_mean) / (2 * r_mean)
    factor = np.clip(factor, 1.1, 1.2)
    out = img.copy()
    out[..., 0] = np.clip(out[..., 0] * factor, 0, 255)
    return out.astype(np.uint8)

def gamma_correct(img, gamma=1.4):
    import skimage.exposure as exposure
    img_corr = exposure.adjust_gamma(img / 255., gamma)
    return (img_corr * 255).astype(np.uint8)

def clahe_enhancement(img):
    clahe = cv2.createCLAHE(clipLimit=1.5, tileGridSize=(8,8))
    img_clahe = np.empty_like(img)
    for i in range(3):
        img_clahe[..., i] = clahe.apply(img[..., i])
    return img_clahe

def simple_dehaze(img, omega=0.95, win_size=15):
    norm_img = img.astype(np.float32)/255.0
    dark_channel = cv2.erode(np.min(norm_img, axis=2), np.ones((win_size, win_size)))
    A = np.max(norm_img, axis=(0,1))
    t = 1 - omega * dark_channel[..., np.newaxis]
    t = np.clip(t, 0.1, 1)
    J = (norm_img - A) / t + A
    J = np.clip(J*255, 0, 255).astype(np.uint8)
    return J

def tensor_to_img(tensor):
    tensor = torch.clamp(tensor, 0, 1)
    img = tensor.detach().cpu().permute(1,2,0).numpy()
    return (img * 255).astype(np.uint8)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg = vgg16(pretrained=True).features.eval().to(device)
transform_vgg = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224,224)),
    transforms.ToTensor()
])

def perceptual_loss(img1, img2):
    im1 = transform_vgg(img1).unsqueeze(0).to(device)
    im2 = transform_vgg(img2).unsqueeze(0).to(device)
    f1, f2 = vgg(im1), vgg(im2)
    return F.l1_loss(f1, f2).item()

# Load model checkpoint and setup model
checkpoint_path = Path("/kaggle/input/test-1/best_weights_epoch_041.pth")
if not checkpoint_path.is_file():
    raise RuntimeError(f"Checkpoint file not found at {checkpoint_path}")

infer_model = ResNet34_UNet_CBAM(pretrained=False)
state = torch.load(checkpoint_path, map_location="cpu")
infer_model.load_state_dict(state)
infer_model.eval()
infer_model.to(device)

# Loader from your datamodule
full_loader = dm.predict_dataloader()

# Processing sequences for rows (matching 3 columns total):
row1_imgs = ['Camera Input', 'Ground Truth', 'Inference Output']  # original images
row2_seqs = [
    [clahe_enhancement],
    [clahe_enhancement, adaptive_red_boost],
    [adaptive_red_boost, clahe_enhancement]
]
row3_seqs = [
    [simple_dehaze, clahe_enhancement, gamma_correct, adaptive_red_boost],
    [gamma_correct, simple_dehaze, clahe_enhancement, adaptive_red_boost],
    [adaptive_red_boost, clahe_enhancement, simple_dehaze, gamma_correct]
]

save_dir = Path("/kaggle/working/result_images/iter_3")
save_dir.mkdir(exist_ok=True)

def apply_sequence(img, seq):
    for func in seq:
        img = func(img)
    return img

num_samples = 12
indices_to_sample = np.random.choice(len(dm.infer_ds), num_samples, replace=False)

with torch.no_grad():
    for idx in indices_to_sample:
        inp, tgt = dm.infer_ds[idx]
        inp_tensor = inp.unsqueeze(0).to(device)
        pred_tensor = infer_model(inp_tensor).cpu()[0]

        inp_img = tensor_to_img(inp)
        gt_img = tensor_to_img(tgt)
        pred_img = tensor_to_img(pred_tensor)

        fig, axs = plt.subplots(3, 3, figsize=(15, 10))
        fig.suptitle(f"Image index: {idx}", fontsize=16)

        # Row 1
        row1_images = [inp_img, gt_img, pred_img]
        row2_images = [apply_sequence(pred_img, seq) for seq in row2_seqs]
        # Row 3
        row3_images = [apply_sequence(pred_img, seq) for seq in row3_seqs]
        row1_titles = ['Camera Input', 'Ground Truth', 'Inference Output']
        row2_titles = [
            'CLAHE only', 
            'CLAHE -> Red Boost', 
            'Red Boost -> CLAHE'
        ]
        row3_titles = [
            'Dehaze -> CLAHE ->\nGamma -> Red Boost',
            'Gamma -> Dehaze ->\nCLAHE -> Red Boost',
            'Red Boost -> CLAHE ->\nDehaze -> Gamma'
        ]
        
        # During plotting, use these titles instead of previous ones:
        for ax, img, title in zip(axs[0], row1_images, row1_titles):
            ax.imshow(img)
            ax.axis('off')
            ax.set_title(title)
        
        for ax, img, title in zip(axs[1], row2_images, row2_titles):
            ax.imshow(img)
            ax.axis('off')
            ax.set_title(title)
        
        for ax, img, title in zip(axs[2], row3_images, row3_titles):
            ax.imshow(img)
            ax.axis('off')
            ax.set_title(title)

        plt.tight_layout(rect=[0, 0, 1, 0.95])
        save_path = save_dir / f"image_{idx}_3x3_processed.png"
        plt.savefig(save_path)
        plt.close(fig)
        print(f"Saved 3x3 plot for image {idx} at {save_path}")


In [None]:
import matplotlib.pyplot as plt

PLOT_LIMIT = min(10, len(results))  # max 10 samples or less if fewer available
titles = ['Camera Input', 'Ground Truth', 'Restored', 'Dehazed', 'CLAHE Contrast', 'White Balanced', 'Red Boosted', 'Gamma Corrected']
keys = ["original", "ground_truth", "restored", "dehazed",
        "clahe_corrected", "white_balanced", "red_boosted", "gamma_corrected"]

fig, axs = plt.subplots(PLOT_LIMIT, len(titles), figsize=(3*len(titles), 3*PLOT_LIMIT))

for sample_idx in range(PLOT_LIMIT):
    for col_idx, (title, key) in enumerate(zip(titles, keys)):
        ax = axs[sample_idx, col_idx] if PLOT_LIMIT > 1 else axs[col_idx]
        ax.imshow(results[sample_idx][key])
        ax.axis("off")
        if sample_idx == 0:
            ax.set_title(title, fontsize=12)

plt.tight_layout()
plt.show()


In [None]:
import torch
import torch.nn.functional as F
from torchmetrics.functional import structural_similarity_index_measure as ssim_fn
from torchvision.models import vgg16
from torchvision import transforms
import numpy as np
import cv2
from pathlib import Path
from itertools import combinations, permutations
from torch.utils.data import Subset

# Classical post-processing functions
def adaptive_red_boost(img):
    r_mean = img[..., 0].mean()
    g_mean = img[..., 1].mean()
    b_mean = img[..., 2].mean()
    factor = 2.0 if r_mean < 1 else (g_mean + b_mean) / (2 * r_mean)
    factor = np.clip(factor, 1.1, 1.2)
    out = img.copy()
    out[..., 0] = np.clip(out[..., 0] * factor, 0, 255)
    return out.astype(np.uint8)

def gamma_correct(img, gamma=1.4):
    import skimage.exposure as exposure
    img_corr = exposure.adjust_gamma(img / 255., gamma)
    return (img_corr * 255).astype(np.uint8)

def clahe_enhancement(img):
    clahe = cv2.createCLAHE(clipLimit=1.5, tileGridSize=(8,8))
    img_clahe = np.empty_like(img)
    for i in range(3):
        img_clahe[..., i] = clahe.apply(img[..., i])
    return img_clahe

def simple_dehaze(img, omega=0.95, win_size=15):
    norm_img = img.astype(np.float32)/255.0
    dark_channel = cv2.erode(np.min(norm_img, axis=2), np.ones((win_size, win_size)))
    A = np.max(norm_img, axis=(0,1))
    t = 1 - omega * dark_channel[..., np.newaxis]
    t = np.clip(t, 0.1, 1)
    J = (norm_img - A) / t + A
    J = np.clip(J*255, 0, 255).astype(np.uint8)
    return J

def tensor_to_img(tensor):
    tensor = torch.clamp(tensor, 0, 1)
    img = tensor.detach().cpu().permute(1,2,0).numpy()
    return (img * 255).astype(np.uint8)

# Perceptual loss model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg = vgg16(pretrained=True).features.eval().to(device)
transform_vgg = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224,224)),
    transforms.ToTensor()
])
def perceptual_loss(img1, img2):
    im1 = transform_vgg(img1).unsqueeze(0).to(device)
    im2 = transform_vgg(img2).unsqueeze(0).to(device)
    f1, f2 = vgg(im1), vgg(im2)
    return F.l1_loss(f1, f2).item()

# Load model checkpoint
checkpoint_path = Path("/kaggle/input/test-1/best_weights_epoch_041.pth")
if not checkpoint_path.is_file():
    raise RuntimeError(f"Checkpoint file not found at {checkpoint_path}")

infer_model = ResNet34_UNet_CBAM(pretrained=False)
state = torch.load(checkpoint_path, map_location="cpu")
infer_model.load_state_dict(state)
infer_model.eval()
infer_model.to(device)
print(f"Model loaded on {device}")

# Use datamodule's dataset for prediction
full_dataset = dm.infer_ds


# Sample a random subset for faster processing
num_samples = 200
np.random.seed(42)
sample_indices = np.random.choice(len(full_dataset), num_samples, replace=False)
subset_dataset = Subset(full_dataset, sample_indices)

# New DataLoader for subset
subset_loader = torch.utils.data.DataLoader(subset_dataset,
                                            batch_size=dm.predict_dataloader().batch_size,
                                            shuffle=False,
                                            num_workers=dm.predict_dataloader().num_workers,
                                            pin_memory=True)

# Map names to functions
pp_funcs = {
    'clahe': clahe_enhancement,
    'red_boost': adaptive_red_boost,
    'gamma': gamma_correct,
    'dehaze': simple_dehaze
}

proc_names = list(pp_funcs.keys())
proc_orders = [()]
for r in range(1, len(proc_names)+1):
    for comb in combinations(proc_names, r):
        for perm in permutations(comb):
            proc_orders.append(perm)
print(f"Total permutations: {len(proc_orders)}")

metrics_summary = {seq: {'psnr': [], 'ssim': [], 'perceptual': []} for seq in proc_orders}

def apply_processing(img, seq):
    out = img
    for p in seq:
        out = pp_funcs[p](out)
    return out

with torch.no_grad():
    for perm_idx, seq in enumerate(proc_orders):
        print(f"\nProcessing permutation {perm_idx+1}/{len(proc_orders)}: {' -> '.join(seq) if seq else 'no_processing'}")

        for batch_idx, batch in enumerate(subset_loader):
            inp, tgt = batch
            inp = inp.to(device)
            preds = infer_model(inp).cpu()

            for i in range(inp.shape[0]):
                pred_img = tensor_to_img(preds[i])
                gt_img = tensor_to_img(tgt[i])

                processed_img = apply_processing(pred_img, seq)

                proc_tensor = torch.tensor(processed_img/255.).permute(2,0,1).unsqueeze(0).to(device)
                gt_tensor = torch.tensor(gt_img/255.).permute(2,0,1).unsqueeze(0).to(device)

                mse = F.mse_loss(proc_tensor, gt_tensor)
                psnr = 10 * torch.log10(1.0 / mse).item()
                ssim_val = ssim_fn(proc_tensor, gt_tensor).item()
                perc_loss = perceptual_loss(processed_img, gt_img)

                metrics_summary[seq]['psnr'].append(psnr)
                metrics_summary[seq]['ssim'].append(ssim_val)
                metrics_summary[seq]['perceptual'].append(perc_loss)

out_file = "/kaggle/working/post_processing_permutations_metrics.txt"
with open(out_file, 'w') as f:
    for seq, metrics in metrics_summary.items():
        psnr_avg = np.mean(metrics['psnr']) if metrics['psnr'] else float('nan')
        ssim_avg = np.mean(metrics['ssim']) if metrics['ssim'] else float('nan')
        perc_avg = np.mean(metrics['perceptual']) if metrics['perceptual'] else float('nan')

        line = f"Sequence: {' -> '.join(seq) if seq else 'no_processing'}\n"
        line += f"  Avg PSNR  : {psnr_avg:.4f} dB\n"
        line += f"  Avg SSIM  : {ssim_avg:.4f}\n"
        line += f"  Avg Perceptual Loss : {perc_avg:.4f}\n"
        line += "-"*50 + "\n"
        f.write(line)

print(f"All done! Metrics saved to {out_file}")
