# Automatic Lens Correction

**Approach:** Parametric distortion prediction (k1, k2, k3, cx, cy) using EfficientNet-B3 + differentiable undistortion + test-time optimization.

**Setup:** Add the competition data via **Add Data -> Competition -> automatic-lens-correction**. Enable **GPU** in accelerator settings. No internet required.

## Cell 1 — Imports & Explore Data

In [None]:
import os, glob, json, csv, zipfile, random
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor, as_completed

import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as T
from scipy.optimize import minimize
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from tqdm.auto import tqdm
from skimage.metrics import structural_similarity as ssim
from PIL import Image
import matplotlib.pyplot as plt

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

INPUT_DIR = Path('/kaggle/input/automatic-lens-correction')
WORK_DIR = Path('/kaggle/working')

print(f"\nData directory contents:")
for p in sorted(INPUT_DIR.rglob('*')):
    if p.is_dir():
        n_files = len(list(p.iterdir()))
        print(f"  [DIR]  {p.relative_to(INPUT_DIR)}/ ({n_files} items)")

img_exts = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}
all_images = [p for p in INPUT_DIR.rglob('*') if p.suffix.lower() in img_exts]
print(f"\nTotal images found: {len(all_images)}")

print("\nSample filenames:")
for img in all_images[:5]:
    print(f"  {img.relative_to(INPUT_DIR)}")

if all_images:
    sample = cv2.imread(str(all_images[0]))
    if sample is not None:
        print(f"\nSample resolution: {sample.shape[1]}x{sample.shape[0]}")

## Cell 2 — Auto-Detect Directory Structure

Check the output. If auto-detect fails, uncomment the manual override at the bottom.

In [None]:
def find_dirs(data_dir):
    data_dir = Path(data_dir)
    patterns = [
        ('train/distorted', 'train/corrected'),
        ('train/input', 'train/target'),
        ('train_distorted', 'train_corrected'),
        ('distorted', 'corrected'),
        ('input', 'target'),
        ('train_input', 'train_target'),
        ('train/input', 'train/ground_truth'),
    ]
    dist_dir = corr_dir = test_dir = None
    for dd, cd in patterns:
        d, c = data_dir / dd, data_dir / cd
        if d.exists() and c.exists():
            dist_dir, corr_dir = d, c
            break
    for tp in ['test', 'test_images', 'test_input', 'test/input', 'test/distorted']:
        t = data_dir / tp
        if t.exists():
            test_dir = t
            break
    if dist_dir is None:
        subdirs = sorted([d for d in data_dir.iterdir() if d.is_dir()])
        for i, d1 in enumerate(subdirs):
            for d2 in subdirs[i+1:]:
                f1 = {f.stem for f in d1.glob('*') if f.suffix.lower() in img_exts}
                f2 = {f.stem for f in d2.glob('*') if f.suffix.lower() in img_exts}
                if len(f1 & f2) > 10:
                    dist_dir, corr_dir = d1, d2
                    break
            if dist_dir: break
    if test_dir is None:
        for d in sorted(data_dir.iterdir()):
            if d.is_dir() and d != dist_dir and d != corr_dir:
                if any(f.suffix.lower() in img_exts for f in d.rglob('*')):
                    test_dir = d
                    break
    return dist_dir, corr_dir, test_dir

DIST_DIR, CORR_DIR, TEST_DIR = find_dirs(INPUT_DIR)
print(f"Distorted dir: {DIST_DIR}")
print(f"Corrected dir: {CORR_DIR}")
print(f"Test dir:      {TEST_DIR}")
if DIST_DIR:
    print(f"\nDistorted images: {len([f for f in DIST_DIR.rglob('*') if f.suffix.lower() in img_exts])}")
if CORR_DIR:
    print(f"Corrected images: {len([f for f in CORR_DIR.rglob('*') if f.suffix.lower() in img_exts])}")
if TEST_DIR:
    print(f"Test images:      {len([f for f in TEST_DIR.rglob('*') if f.suffix.lower() in img_exts])}")

# >>> MANUAL OVERRIDE: Uncomment and edit if auto-detect fails <<<
# DIST_DIR = INPUT_DIR / 'train' / 'distorted'
# CORR_DIR = INPUT_DIR / 'train' / 'corrected'
# TEST_DIR = INPUT_DIR / 'test'

## Cell 3 — Visualize Training Pairs

In [None]:
def find_match(stem, directory):
    for ext in img_exts:
        p = directory / (stem + ext)
        if p.exists(): return p
    return None

def show_pairs(n=4):
    dist_files = sorted([f for f in DIST_DIR.iterdir() if f.suffix.lower() in img_exts])[:n]
    fig, axes = plt.subplots(n, 2, figsize=(12, 4*n))
    if n == 1: axes = [axes]
    for i, df in enumerate(dist_files):
        cf = find_match(df.stem, CORR_DIR)
        d_img = cv2.cvtColor(cv2.imread(str(df)), cv2.COLOR_BGR2RGB)
        c_img = cv2.cvtColor(cv2.imread(str(cf)), cv2.COLOR_BGR2RGB)
        axes[i][0].imshow(d_img); axes[i][0].set_title(f'Distorted: {df.name}'); axes[i][0].axis('off')
        axes[i][1].imshow(c_img); axes[i][1].set_title(f'Corrected: {cf.name}'); axes[i][1].axis('off')
    plt.tight_layout(); plt.show()

if DIST_DIR and CORR_DIR:
    show_pairs(4)

## Cell 4 — Extract Distortion Parameters from Training Pairs

In [None]:
PARAMS_CSV = WORK_DIR / 'params.csv'

def undistort_image(img, k1, k2, k3, cx, cy):
    h, w = img.shape[:2]
    fx = fy = max(h, w)
    cam = np.array([[fx, 0, cx*w], [0, fy, cy*h], [0, 0, 1]], dtype=np.float64)
    dist = np.array([k1, k2, 0, 0, k3], dtype=np.float64)
    new_cam, roi = cv2.getOptimalNewCameraMatrix(cam, dist, (w, h), alpha=0)
    out = cv2.undistort(img, cam, dist, None, new_cam)
    x, y, rw, rh = roi
    if rw > 0 and rh > 0:
        out = cv2.resize(out[y:y+rh, x:x+rw], (w, h), interpolation=cv2.INTER_LINEAR)
    return out

def objective(params, distorted, corrected):
    try:
        u = undistort_image(distorted, *params)
        return np.mean((u.astype(np.float32) - corrected.astype(np.float32))**2)
    except Exception:
        return 1e10

def extract_single(args):
    dp, cp, sz = args
    d = cv2.imread(str(dp))
    c = cv2.imread(str(cp))
    if d is None or c is None: return None
    d = cv2.resize(d, (sz, sz), interpolation=cv2.INTER_AREA)
    c = cv2.resize(c, (sz, sz), interpolation=cv2.INTER_AREA)
    res = minimize(objective, [0,0,0,0.5,0.5], args=(d,c), method='L-BFGS-B',
                   bounds=[(-1,1),(-1,1),(-1,1),(0.3,0.7),(0.3,0.7)],
                   options={'maxiter': 200, 'ftol': 1e-8})
    return res.x

# Match pairs
dist_files = sorted([f for f in DIST_DIR.iterdir() if f.suffix.lower() in img_exts])
pairs = []
for df in dist_files:
    cf = find_match(df.stem, CORR_DIR)
    if cf: pairs.append((df, cf))
print(f"Matched {len(pairs)} training pairs")

# Extract in parallel
results = []
with ProcessPoolExecutor(max_workers=4) as executor:
    futures = {executor.submit(extract_single, (d,c,256)): d.stem for d,c in pairs}
    for f in tqdm(as_completed(futures), total=len(futures), desc="Extracting params"):
        img_id = futures[f]
        try:
            p = f.result()
            if p is not None: results.append((img_id, *p))
        except Exception as e:
            print(f"Error {img_id}: {e}")

with open(PARAMS_CSV, 'w', newline='') as f:
    w = csv.writer(f)
    w.writerow(['image_id','k1','k2','k3','cx','cy'])
    for row in results: w.writerow(row)
print(f"\nSaved {len(results)} params to {PARAMS_CSV}")

# Validate
print("\nPSNR validation:")
pdict = {r[0]: r[1:] for r in results}
psnrs = []
for dp, cp in pairs[:5]:
    if dp.stem not in pdict: continue
    d, c = cv2.imread(str(dp)), cv2.imread(str(cp))
    u = undistort_image(d, *pdict[dp.stem])
    mse = np.mean((u.astype(np.float32) - c.astype(np.float32))**2)
    psnr = 10*np.log10(255**2/max(mse,1e-10))
    psnrs.append(psnr)
    print(f"  {dp.stem}: {psnr:.2f} dB")
if psnrs: print(f"  Average: {np.mean(psnrs):.2f} dB")

## Cell 5 — Dataset & Model Definition

Uses `torchvision.models.efficientnet_b3` (pre-cached on Kaggle, no download needed).

In [None]:
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

class DistortionDataset(Dataset):
    def __init__(self, image_dir, params_csv, image_size=224, augment=True, corrected_dir=None):
        self.image_dir = Path(image_dir)
        self.corrected_dir = Path(corrected_dir) if corrected_dir else None
        self.image_size = image_size
        self.augment = augment
        self.samples = []
        with open(params_csv, 'r') as f:
            for row in csv.DictReader(f):
                p = np.array([float(row[k]) for k in ['k1','k2','k3','cx','cy']], dtype=np.float32)
                self.samples.append((row['image_id'], p))
        self._paths = {}
        for f in self.image_dir.iterdir():
            if f.suffix.lower() in img_exts: self._paths[f.stem] = f
        self.samples = [(i,p) for i,p in self.samples if i in self._paths]
        self._build_transforms()

    def _build_transforms(self):
        sz = self.image_size
        if self.augment:
            self.transform = T.Compose([
                T.Resize((sz, sz)),
                T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.05),
                T.ToTensor(),
                T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)])
        else:
            self.transform = T.Compose([
                T.Resize((sz, sz)), T.ToTensor(),
                T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)])
        self.target_transform = T.Compose([
            T.Resize((sz, sz)), T.ToTensor(),
            T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)])

    def update_image_size(self, new_size):
        self.image_size = new_size
        self._build_transforms()

    def __len__(self): return len(self.samples)

    def __getitem__(self, idx):
        image_id, params = self.samples[idx]
        img = Image.open(self._paths[image_id]).convert('RGB')
        result = {'image': self.transform(img), 'params': torch.from_numpy(params), 'image_id': image_id}
        if self.corrected_dir is not None:
            cp = find_match(image_id, self.corrected_dir)
            if cp:
                result['corrected'] = self.target_transform(Image.open(cp).convert('RGB'))
        return result


class TestDataset(Dataset):
    def __init__(self, image_dir, image_size=384):
        self.image_dir = Path(image_dir)
        self.files = sorted([f for f in self.image_dir.iterdir() if f.suffix.lower() in img_exts])
        self.transform = T.Compose([
            T.Resize((image_size, image_size)), T.ToTensor(),
            T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)])

    def __len__(self): return len(self.files)

    def __getitem__(self, idx):
        p = self.files[idx]
        img = Image.open(p).convert('RGB')
        w, h = img.size
        return {'image': self.transform(img), 'image_id': p.stem,
                'image_path': str(p), 'orig_h': h, 'orig_w': w}


class DistortionNet(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        weights = models.EfficientNet_B3_Weights.DEFAULT if pretrained else None
        backbone = models.efficientnet_b3(weights=weights)
        feat_dim = backbone.classifier[1].in_features  # 1536
        backbone.classifier = nn.Identity()
        self.backbone = backbone
        self.head = nn.Sequential(
            nn.Linear(feat_dim, 256), nn.ReLU(inplace=True),
            nn.Dropout(0.2), nn.Linear(256, 5))
        nn.init.zeros_(self.head[-1].weight)
        nn.init.zeros_(self.head[-1].bias)
        with torch.no_grad():
            self.head[-1].bias[3] = 0.5
            self.head[-1].bias[4] = 0.5

    def forward(self, x):
        feat = self.backbone(x)
        p = self.head(feat)
        return torch.cat([torch.tanh(p[:,:3]), torch.sigmoid(p[:,3:])], dim=1)


def differentiable_undistort(image, params):
    B, C, H, W = image.shape
    k1, k2, k3 = params[:,0:1], params[:,1:2], params[:,2:3]
    cx, cy = params[:,3:4], params[:,4:5]
    gy, gx = torch.meshgrid(
        torch.linspace(-1,1,H,device=image.device),
        torch.linspace(-1,1,W,device=image.device), indexing='ij')
    gx = gx.unsqueeze(0).expand(B,-1,-1)
    gy = gy.unsqueeze(0).expand(B,-1,-1)
    cx_n = (cx*2-1).unsqueeze(-1)
    cy_n = (cy*2-1).unsqueeze(-1)
    dx, dy = gx - cx_n, gy - cy_n
    r2 = dx**2 + dy**2
    k1, k2, k3 = k1.unsqueeze(-1), k2.unsqueeze(-1), k3.unsqueeze(-1)
    radial = 1 + k1*r2 + k2*r2**2 + k3*r2**3
    grid = torch.stack([dx*radial + cx_n, dy*radial + cy_n], dim=-1)
    return F.grid_sample(image, grid, mode='bilinear', padding_mode='zeros', align_corners=True)


class DistortionLoss(nn.Module):
    def __init__(self, param_weight=1.0, pixel_weight=0.5):
        super().__init__()
        self.pw, self.xw = param_weight, pixel_weight

    def forward(self, pred, gt, dist_img=None, corr_img=None):
        lp = F.mse_loss(pred, gt)
        lx = torch.tensor(0.0, device=pred.device)
        if dist_img is not None and corr_img is not None:
            lx = F.l1_loss(differentiable_undistort(dist_img, pred), corr_img)
        return self.pw*lp + self.xw*lx, lp, lx


print("Classes defined.")
model = DistortionNet(pretrained=True).to(DEVICE)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

## Cell 6 — Train

In [None]:
EPOCHS = 30
BATCH_SIZE = 16
LR = 3e-3
NUM_WORKERS = 2
CKPT_DIR = WORK_DIR / 'checkpoints'
CKPT_DIR.mkdir(exist_ok=True)
SIZE_SCHEDULE = {0: 224, EPOCHS//3: 384, 2*EPOCHS//3: 512}

full_ds = DistortionDataset(DIST_DIR, PARAMS_CSV, image_size=224, augment=True, corrected_dir=CORR_DIR)
n_val = int(len(full_ds) * 0.2)
n_train = len(full_ds) - n_val
train_ds, val_ds = random_split(full_ds, [n_train, n_val], generator=torch.Generator().manual_seed(42))
val_ds_noaug = DistortionDataset(DIST_DIR, PARAMS_CSV, image_size=224, augment=False, corrected_dir=CORR_DIR)
val_ds_noaug.samples = [val_ds_noaug.samples[i] for i in val_ds.indices]
print(f"Train: {n_train}, Val: {n_val}")

model = DistortionNet(pretrained=True).to(DEVICE)
criterion = DistortionLoss(param_weight=1.0, pixel_weight=0.5)
optimizer = AdamW(model.parameters(), lr=LR, weight_decay=0.01)
scheduler = OneCycleLR(optimizer, max_lr=LR, total_steps=(n_train//BATCH_SIZE+1)*EPOCHS,
                       pct_start=0.3, div_factor=25, final_div_factor=1000)

best_val_loss = float('inf')
history = []

for epoch in range(1, EPOCHS+1):
    if epoch-1 in SIZE_SCHEDULE:
        sz = SIZE_SCHEDULE[epoch-1]
        print(f"\n>>> Resize to {sz}x{sz}")
        full_ds.update_image_size(sz)
        val_ds_noaug.update_image_size(sz)

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                              num_workers=NUM_WORKERS, pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_ds_noaug, batch_size=BATCH_SIZE, shuffle=False,
                            num_workers=NUM_WORKERS, pin_memory=True)

    model.train()
    t_loss, t_pl, t_xl, nb = 0, 0, 0, 0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}", leave=False)
    for batch in pbar:
        imgs = batch['image'].to(DEVICE)
        params = batch['params'].to(DEVICE)
        corr = batch.get('corrected')
        corr = corr.to(DEVICE) if corr is not None else None
        pred = model(imgs)
        loss, lp, lx = criterion(pred, params, imgs, corr)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        t_loss += loss.item(); t_pl += lp.item(); t_xl += lx.item(); nb += 1
        pbar.set_postfix(loss=f"{t_loss/nb:.4f}", param=f"{t_pl/nb:.4f}", pixel=f"{t_xl/nb:.4f}")

    model.eval()
    v_loss, errs = 0, []
    with torch.no_grad():
        for batch in val_loader:
            imgs = batch['image'].to(DEVICE)
            params = batch['params'].to(DEVICE)
            corr = batch.get('corrected')
            corr = corr.to(DEVICE) if corr is not None else None
            pred = model(imgs)
            loss, lp, lx = criterion(pred, params, imgs, corr)
            v_loss += loss.item()
            errs.append((pred - params).abs().cpu().numpy())
    n_vb = max(len(val_loader), 1)
    v_loss /= n_vb
    me = np.concatenate(errs).mean(axis=0) if errs else np.zeros(5)
    print(f"Epoch {epoch}: train={t_loss/nb:.4f} val={v_loss:.4f} | "
          f"k1e={me[0]:.4f} k2e={me[1]:.4f} k3e={me[2]:.4f} cxe={me[3]:.4f} cye={me[4]:.4f}")
    history.append({'epoch': epoch, 'train_loss': t_loss/nb, 'val_loss': v_loss})

    if v_loss < best_val_loss:
        best_val_loss = v_loss
        torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(),
                    'val_loss': best_val_loss}, CKPT_DIR / 'best_model.pth')
        print(f"  >> Saved best (val_loss={best_val_loss:.4f})")

print(f"\nDone. Best val_loss: {best_val_loss:.4f}")

## Cell 7 — Training Curves

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 4))
ax.plot([h['epoch'] for h in history], [h['train_loss'] for h in history], label='Train')
ax.plot([h['epoch'] for h in history], [h['val_loss'] for h in history], label='Val')
ax.set_xlabel('Epoch'); ax.set_ylabel('Loss'); ax.legend(); ax.set_title('Training Curves')
plt.tight_layout(); plt.show()

## Cell 8 — Local Scoring Metrics

In [None]:
def edge_similarity(img1, img2, scales=(1.0, 0.5, 0.25)):
    g1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY) if len(img1.shape)==3 else img1
    g2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY) if len(img2.shape)==3 else img2
    f1s = []
    for s in scales:
        a = cv2.resize(g1, (int(g1.shape[1]*s), int(g1.shape[0]*s))) if s != 1.0 else g1
        b = cv2.resize(g2, (int(g2.shape[1]*s), int(g2.shape[0]*s))) if s != 1.0 else g2
        e1 = cv2.Canny(a, int(max(0,0.67*np.median(a))), int(min(255,1.33*np.median(a))))
        e2 = cv2.Canny(b, int(max(0,0.67*np.median(b))), int(min(255,1.33*np.median(b))))
        k = np.ones((3,3), np.uint8)
        e1d, e2d = cv2.dilate(e1,k,iterations=1), cv2.dilate(e2,k,iterations=1)
        if e1.sum()==0 and e2.sum()==0: f1s.append(1.0); continue
        if e1.sum()==0 or e2.sum()==0: f1s.append(0.0); continue
        p = (e1 & e2d).sum()/max(e1.sum(),1)
        r = (e2 & e1d).sum()/max(e2.sum(),1)
        f1s.append(2*p*r/(p+r) if p+r>0 else 0.0)
    return np.mean(f1s)

def line_straightness(img, ref=None):
    g = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if len(img.shape)==3 else img
    lines = cv2.HoughLinesP(cv2.Canny(g,50,150,apertureSize=3), 1, np.pi/180, 50, minLineLength=30, maxLineGap=10)
    if lines is None: return 0.5
    angles = np.array([np.arctan2(l[0][3]-l[0][1], l[0][2]-l[0][0])*180/np.pi for l in lines])
    if ref is not None:
        gr = cv2.cvtColor(ref, cv2.COLOR_BGR2GRAY) if len(ref.shape)==3 else ref
        lr = cv2.HoughLinesP(cv2.Canny(gr,50,150,apertureSize=3), 1, np.pi/180, 50, minLineLength=30, maxLineGap=10)
        if lr is None: return 0.5
        ar = np.array([np.arctan2(l[0][3]-l[0][1], l[0][2]-l[0][0])*180/np.pi for l in lr])
        bins = np.linspace(-90,90,37)
        h1,_ = np.histogram(angles,bins=bins,density=True)
        h2,_ = np.histogram(ar,bins=bins,density=True)
        h1, h2 = h1/(h1.sum()+1e-10), h2/(h2.sum()+1e-10)
        return float(np.sum(np.sqrt(h1*h2)))
    return float(np.mean(np.minimum(np.abs(angles), np.abs(np.abs(angles)-90))<5))

def gradient_orientation_sim(img1, img2, n_bins=36):
    g1 = cv2.cvtColor(img1,cv2.COLOR_BGR2GRAY).astype(np.float32) if len(img1.shape)==3 else img1.astype(np.float32)
    g2 = cv2.cvtColor(img2,cv2.COLOR_BGR2GRAY).astype(np.float32) if len(img2.shape)==3 else img2.astype(np.float32)
    gx1, gy1 = cv2.Sobel(g1,cv2.CV_32F,1,0,ksize=3), cv2.Sobel(g1,cv2.CV_32F,0,1,ksize=3)
    gx2, gy2 = cv2.Sobel(g2,cv2.CV_32F,1,0,ksize=3), cv2.Sobel(g2,cv2.CV_32F,0,1,ksize=3)
    bins = np.linspace(-np.pi, np.pi, n_bins+1)
    h1,_ = np.histogram(np.arctan2(gy1,gx1), bins=bins, weights=np.sqrt(gx1**2+gy1**2))
    h2,_ = np.histogram(np.arctan2(gy2,gx2), bins=bins, weights=np.sqrt(gx2**2+gy2**2))
    h1, h2 = h1/(h1.sum()+1e-10), h2/(h2.sum()+1e-10)
    return float(np.sum(np.sqrt(h1*h2)))

def pixel_accuracy(img1, img2):
    if img1.shape != img2.shape: img2 = cv2.resize(img2, (img1.shape[1], img1.shape[0]))
    return 1.0 - np.mean(np.abs(img1.astype(np.float32) - img2.astype(np.float32)))/255.0

def compute_score(corrected, gt):
    if corrected.shape != gt.shape: corrected = cv2.resize(corrected, (gt.shape[1], gt.shape[0]))
    es = edge_similarity(corrected, gt)
    ls = line_straightness(corrected, gt)
    go = gradient_orientation_sim(corrected, gt)
    ss = ssim(corrected, gt, channel_axis=2, data_range=255) if len(corrected.shape)==3 else ssim(corrected, gt, data_range=255)
    pa = pixel_accuracy(corrected, gt)
    overall = 0.40*es + 0.22*ls + 0.18*go + 0.15*ss + 0.05*pa
    return overall, dict(edge=es, line=ls, grad=go, ssim=ss, pixel=pa, overall=overall)

print("Metrics defined.")

## Cell 9 — Validate on Training Pairs

In [None]:
ckpt = torch.load(CKPT_DIR / 'best_model.pth', map_location=DEVICE, weights_only=False)
model.load_state_dict(ckpt['model_state_dict'])
model.eval()
print(f"Loaded best model from epoch {ckpt['epoch']}")

N_VIS = min(8, len(val_ds_noaug))
val_vis_loader = DataLoader(val_ds_noaug, batch_size=1, shuffle=False)

scores = []
fig, axes = plt.subplots(N_VIS, 3, figsize=(15, 4*N_VIS))
if N_VIS == 1: axes = [axes]

for i, batch in enumerate(val_vis_loader):
    if i >= N_VIS: break
    image_id = batch['image_id'][0]
    with torch.no_grad():
        pp = model(batch['image'].to(DEVICE)).cpu().numpy()[0]
    dist_img = cv2.imread(str(next(DIST_DIR.glob(f"{image_id}.*"))))
    corr_img = cv2.imread(str(next(CORR_DIR.glob(f"{image_id}.*"))))
    pred_corr = undistort_image(dist_img, *pp)
    overall, m = compute_score(pred_corr, corr_img)
    scores.append(overall)
    axes[i][0].imshow(cv2.cvtColor(dist_img, cv2.COLOR_BGR2RGB))
    axes[i][0].set_title('Distorted'); axes[i][0].axis('off')
    axes[i][1].imshow(cv2.cvtColor(pred_corr, cv2.COLOR_BGR2RGB))
    axes[i][1].set_title(f'Ours (score={overall:.3f})'); axes[i][1].axis('off')
    axes[i][2].imshow(cv2.cvtColor(corr_img, cv2.COLOR_BGR2RGB))
    axes[i][2].set_title('Ground truth'); axes[i][2].axis('off')

plt.tight_layout(); plt.show()
print(f"\nAvg local score: {np.mean(scores):.4f}")

## Cell 10 — Predict Test + Test-Time Optimization

In [None]:
OUTPUT_DIR = WORK_DIR / 'output'
OUTPUT_DIR.mkdir(exist_ok=True)

ckpt = torch.load(CKPT_DIR / 'best_model.pth', map_location=DEVICE, weights_only=False)
model.load_state_dict(ckpt['model_state_dict'])
model.eval()

test_ds = TestDataset(TEST_DIR, image_size=384)
test_loader = DataLoader(test_ds, batch_size=16, shuffle=False, num_workers=2, pin_memory=True)
print(f"Test images: {len(test_ds)}")

# Stage 1: CNN predictions
predictions = {}
with torch.no_grad():
    for batch in tqdm(test_loader, desc="CNN prediction"):
        preds = model(batch['image'].to(DEVICE))
        for i, img_id in enumerate(batch['image_id']):
            predictions[img_id] = {
                'params': preds[i].cpu().numpy(),
                'image_path': batch['image_path'][i],
                'orig_h': batch['orig_h'][i].item(),
                'orig_w': batch['orig_w'][i].item()}

# Stage 2: Test-Time Optimization
TTO_STEPS = 50

def tto_loss(undistorted):
    img = undistorted.squeeze(0)
    gray = (0.299*img[0]+0.587*img[1]+0.114*img[2]).unsqueeze(0).unsqueeze(0)
    sx = torch.tensor([[-1,0,1],[-2,0,2],[-1,0,1]],dtype=torch.float32,device=img.device).view(1,1,3,3)
    sy = torch.tensor([[-1,-2,-1],[0,0,0],[1,2,1]],dtype=torch.float32,device=img.device).view(1,1,3,3)
    gx, gy = F.conv2d(gray,sx,padding=1), F.conv2d(gray,sy,padding=1)
    edge_loss = -torch.sqrt(gx**2+gy**2+1e-6).mean()
    C,H,W = img.shape
    bs = max(2, min(H,W)//20)
    borders = torch.cat([img[:,:bs,:].reshape(-1), img[:,-bs:,:].reshape(-1),
                         img[:,:,:bs].reshape(-1), img[:,:,-bs:].reshape(-1)])
    return 0.5*edge_loss + 0.5*(1.0-borders.abs().clamp(0,1)).mean()

def diff_undistort(image, params):
    B,C,H,W = image.shape
    k1,k2,k3,cx,cy = params[0,0],params[0,1],params[0,2],params[0,3],params[0,4]
    gy,gx = torch.meshgrid(torch.linspace(-1,1,H,device=image.device),
                           torch.linspace(-1,1,W,device=image.device),indexing='ij')
    dx, dy = gx-(cx*2-1), gy-(cy*2-1)
    r2 = dx**2+dy**2
    rad = 1+k1*r2+k2*r2**2+k3*r2**3
    grid = torch.stack([dx*rad+(cx*2-1), dy*rad+(cy*2-1)], dim=-1).unsqueeze(0)
    return F.grid_sample(image, grid, mode='bilinear', padding_mode='zeros', align_corners=True)

print(f"\nRunning TTO ({TTO_STEPS} steps/image)...")
for img_id in tqdm(predictions, desc="TTO"):
    pred = predictions[img_id]
    img = cv2.cvtColor(cv2.imread(pred['image_path']), cv2.COLOR_BGR2RGB)
    img_t = torch.from_numpy(cv2.resize(img,(256,256))).float().permute(2,0,1).unsqueeze(0).to(DEVICE)/255.0
    p = torch.tensor(pred['params'],dtype=torch.float32,device=DEVICE).unsqueeze(0).clone().detach().requires_grad_(True)
    init_p = torch.tensor(pred['params'],dtype=torch.float32,device=DEVICE)
    opt = torch.optim.Adam([p], lr=0.001)
    best_loss, best_p = float('inf'), pred['params'].copy()
    for _ in range(TTO_STEPS):
        opt.zero_grad()
        with torch.no_grad(): p.data[:,:3].clamp_(-1,1); p.data[:,3:].clamp_(0.1,0.9)
        loss = tto_loss(diff_undistort(img_t, p)) + 0.1*F.mse_loss(p.squeeze(), init_p)
        loss.backward(); opt.step()
        if loss.item() < best_loss:
            best_loss = loss.item()
            best_p = p.detach().squeeze().cpu().numpy().copy()
    pred['params_tto'] = best_p

print("TTO complete.")

## Cell 11 — Apply Corrections & Create Submission ZIP

In [None]:
for img_id, pred in tqdm(predictions.items(), desc="Saving corrected"):
    img = cv2.imread(pred['image_path'])
    if img is None: continue
    corrected = undistort_image(img, *pred.get('params_tto', pred['params']))
    cv2.imwrite(str(OUTPUT_DIR / f"{img_id}.jpg"), corrected, [cv2.IMWRITE_JPEG_QUALITY, 95])

ZIP_PATH = WORK_DIR / 'submission.zip'
with zipfile.ZipFile(ZIP_PATH, 'w', zipfile.ZIP_DEFLATED) as zf:
    for f in sorted(OUTPUT_DIR.glob('*.jpg')): zf.write(f, f.name)

n_out = len(list(OUTPUT_DIR.glob('*.jpg')))
print(f"\nSaved {n_out} corrected images")
print(f"ZIP: {ZIP_PATH} ({ZIP_PATH.stat().st_size/1024/1024:.1f} MB)")
print(f"\nNext steps:")
print(f"  1. Download submission.zip from the Output tab")
print(f"  2. Upload to https://bounty.autohdr.com")
print(f"  3. Download the scoring CSV")
print(f"  4. Submit CSV to Kaggle")

## Cell 12 — (Optional) Visualize Test Corrections

In [None]:
test_files = sorted(OUTPUT_DIR.glob('*.jpg'))[:6]
fig, axes = plt.subplots(len(test_files), 2, figsize=(12, 4*len(test_files)))
if len(test_files) == 1: axes = [axes]
for i, cf in enumerate(test_files):
    img_id = cf.stem
    orig = cv2.cvtColor(cv2.imread(predictions[img_id]['image_path']), cv2.COLOR_BGR2RGB)
    corr = cv2.cvtColor(cv2.imread(str(cf)), cv2.COLOR_BGR2RGB)
    axes[i][0].imshow(orig); axes[i][0].set_title(f'Distorted: {img_id}'); axes[i][0].axis('off')
    axes[i][1].imshow(corr); axes[i][1].set_title('Corrected'); axes[i][1].axis('off')
plt.tight_layout(); plt.show()