# Automatic Lens Correction

**Approach:** Parametric distortion prediction (k1, k2, k3, cx, cy) using EfficientNet-B3 + differentiable undistortion + test-time optimization.

**Setup:** Add the competition data via **Add Data → Competition → automatic-lens-correction**. Enable **GPU** in accelerator settings.

## Cell 1 — Install Dependencies & Explore Data

In [None]:
!pip install -q timm kornia pytorch-msssim albumentations

import os, glob, json, csv, zipfile
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor, as_completed

import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from scipy.optimize import minimize
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from tqdm.auto import tqdm
from skimage.metrics import structural_similarity as ssim
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
import matplotlib.pyplot as plt

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Explore data directory
INPUT_DIR = Path('/kaggle/input/automatic-lens-correction')
WORK_DIR = Path('/kaggle/working')

print(f"\nData directory contents:")
for p in sorted(INPUT_DIR.rglob('*')):
    if p.is_dir():
        n_files = len(list(p.iterdir()))
        print(f"  [DIR]  {p.relative_to(INPUT_DIR)}/ ({n_files} items)")

# Count images
img_exts = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}
all_images = [p for p in INPUT_DIR.rglob('*') if p.suffix.lower() in img_exts]
print(f"\nTotal images found: {len(all_images)}")

# Show a few sample filenames
print("\nSample filenames:")
for img in all_images[:5]:
    print(f"  {img.relative_to(INPUT_DIR)}")

# Check resolution of first image
if all_images:
    sample = cv2.imread(str(all_images[0]))
    if sample is not None:
        print(f"\nSample resolution: {sample.shape[1]}x{sample.shape[0]}")

## Cell 2 — Auto-Detect Directory Structure

Run this cell, then **check the output** and adjust `DIST_DIR`, `CORR_DIR`, `TEST_DIR` if needed.

In [None]:
def find_dirs(data_dir):
    """Auto-detect distorted, corrected, and test directories."""
    data_dir = Path(data_dir)
    
    patterns = [
        ('train/distorted', 'train/corrected'),
        ('train/input', 'train/target'),
        ('train_distorted', 'train_corrected'),
        ('distorted', 'corrected'),
        ('input', 'target'),
        ('train_input', 'train_target'),
        ('train/input', 'train/ground_truth'),
    ]
    
    dist_dir = corr_dir = test_dir = None
    
    for dd, cd in patterns:
        d = data_dir / dd
        c = data_dir / cd
        if d.exists() and c.exists():
            dist_dir, corr_dir = d, c
            break
    
    # Find test dir
    test_patterns = ['test', 'test_images', 'test_input', 'test/input', 'test/distorted']
    for tp in test_patterns:
        t = data_dir / tp
        if t.exists():
            test_dir = t
            break
    
    # Fallback: auto-detect from subdirectories
    if dist_dir is None:
        subdirs = sorted([d for d in data_dir.iterdir() if d.is_dir()])
        for i, d1 in enumerate(subdirs):
            for d2 in subdirs[i+1:]:
                files1 = {f.stem for f in d1.glob('*') if f.suffix.lower() in img_exts}
                files2 = {f.stem for f in d2.glob('*') if f.suffix.lower() in img_exts}
                if len(files1 & files2) > 10:
                    dist_dir, corr_dir = d1, d2
                    break
            if dist_dir:
                break
    
    if test_dir is None:
        # Any remaining directory that doesn't overlap with train
        subdirs = sorted([d for d in data_dir.iterdir() if d.is_dir()])
        for d in subdirs:
            if d != dist_dir and d != corr_dir and d.parent != dist_dir and d.parent != corr_dir:
                imgs = [f for f in d.rglob('*') if f.suffix.lower() in img_exts]
                if imgs:
                    test_dir = d
                    break
    
    return dist_dir, corr_dir, test_dir

DIST_DIR, CORR_DIR, TEST_DIR = find_dirs(INPUT_DIR)

print(f"Distorted dir: {DIST_DIR}")
print(f"Corrected dir: {CORR_DIR}")
print(f"Test dir:      {TEST_DIR}")

if DIST_DIR:
    dist_imgs = [f for f in DIST_DIR.rglob('*') if f.suffix.lower() in img_exts]
    print(f"\nDistorted images: {len(dist_imgs)}")
if CORR_DIR:
    corr_imgs = [f for f in CORR_DIR.rglob('*') if f.suffix.lower() in img_exts]
    print(f"Corrected images: {len(corr_imgs)}")
if TEST_DIR:
    test_imgs = [f for f in TEST_DIR.rglob('*') if f.suffix.lower() in img_exts]
    print(f"Test images:      {len(test_imgs)}")

# =====================================================
# >>> MANUAL OVERRIDE: Uncomment and edit if auto-detect fails <<<
# DIST_DIR = INPUT_DIR / 'train' / 'distorted'
# CORR_DIR = INPUT_DIR / 'train' / 'corrected'
# TEST_DIR = INPUT_DIR / 'test'
# =====================================================

## Cell 3 — Visualize Training Pairs

In [None]:
def show_pairs(dist_dir, corr_dir, n=4):
    dist_files = sorted([f for f in Path(dist_dir).iterdir() if f.suffix.lower() in img_exts])[:n]
    fig, axes = plt.subplots(n, 2, figsize=(12, 4*n))
    if n == 1:
        axes = [axes]
    for i, df in enumerate(dist_files):
        for ext in img_exts:
            cf = Path(corr_dir) / (df.stem + ext)
            if cf.exists():
                break
        d_img = cv2.cvtColor(cv2.imread(str(df)), cv2.COLOR_BGR2RGB)
        c_img = cv2.cvtColor(cv2.imread(str(cf)), cv2.COLOR_BGR2RGB)
        axes[i][0].imshow(d_img)
        axes[i][0].set_title(f'Distorted: {df.name}')
        axes[i][0].axis('off')
        axes[i][1].imshow(c_img)
        axes[i][1].set_title(f'Corrected: {cf.name}')
        axes[i][1].axis('off')
    plt.tight_layout()
    plt.show()

if DIST_DIR and CORR_DIR:
    show_pairs(DIST_DIR, CORR_DIR, n=4)

## Cell 4 — Extract Distortion Parameters from Training Pairs

In [None]:
PARAMS_CSV = WORK_DIR / 'params.csv'

def undistort_image(img, k1, k2, k3, cx, cy):
    h, w = img.shape[:2]
    fx = fy = max(h, w)
    camera_matrix = np.array([[fx, 0, cx * w], [0, fy, cy * h], [0, 0, 1]], dtype=np.float64)
    dist_coeffs = np.array([k1, k2, 0, 0, k3], dtype=np.float64)
    new_camera_matrix, roi = cv2.getOptimalNewCameraMatrix(camera_matrix, dist_coeffs, (w, h), alpha=0)
    undistorted = cv2.undistort(img, camera_matrix, dist_coeffs, None, new_camera_matrix)
    x, y, rw, rh = roi
    if rw > 0 and rh > 0:
        undistorted = undistorted[y:y+rh, x:x+rw]
        undistorted = cv2.resize(undistorted, (w, h), interpolation=cv2.INTER_LINEAR)
    return undistorted

def objective(params, distorted, corrected):
    k1, k2, k3, cx, cy = params
    try:
        undist = undistort_image(distorted, k1, k2, k3, cx, cy)
        return np.mean((undist.astype(np.float32) - corrected.astype(np.float32)) ** 2)
    except Exception:
        return 1e10

def extract_single(args):
    dist_path, corr_path, size = args
    dist_img = cv2.imread(str(dist_path))
    corr_img = cv2.imread(str(corr_path))
    if dist_img is None or corr_img is None:
        return None
    dist_small = cv2.resize(dist_img, (size, size), interpolation=cv2.INTER_AREA)
    corr_small = cv2.resize(corr_img, (size, size), interpolation=cv2.INTER_AREA)
    x0 = np.array([0.0, 0.0, 0.0, 0.5, 0.5])
    bounds = [(-1.0, 1.0), (-1.0, 1.0), (-1.0, 1.0), (0.3, 0.7), (0.3, 0.7)]
    result = minimize(objective, x0, args=(dist_small, corr_small),
                      method='L-BFGS-B', bounds=bounds, options={'maxiter': 200, 'ftol': 1e-8})
    return result.x

# Match training pairs
dist_files = sorted([f for f in DIST_DIR.iterdir() if f.suffix.lower() in img_exts])
pairs = []
for df in dist_files:
    for ext in img_exts:
        cf = CORR_DIR / (df.stem + ext)
        if cf.exists():
            pairs.append((df, cf))
            break

print(f"Matched {len(pairs)} training pairs")

# Extract params in parallel
OPT_SIZE = 256
N_WORKERS = 4

results = []
work_items = [(d, c, OPT_SIZE) for d, c in pairs]

with ProcessPoolExecutor(max_workers=N_WORKERS) as executor:
    futures = {executor.submit(extract_single, w): w[0].stem for w in work_items}
    for f in tqdm(as_completed(futures), total=len(futures), desc="Extracting params"):
        image_id = futures[f]
        try:
            params = f.result()
            if params is not None:
                results.append((image_id, *params))
        except Exception as e:
            print(f"Error {image_id}: {e}")

# Save CSV
with open(PARAMS_CSV, 'w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(['image_id', 'k1', 'k2', 'k3', 'cx', 'cy'])
    for row in results:
        writer.writerow(row)

print(f"\nSaved {len(results)} parameter sets to {PARAMS_CSV}")

# Quick validation
print("\nValidation (PSNR on full-res):")
param_dict = {r[0]: r[1:] for r in results}
psnrs = []
for dp, cp in pairs[:5]:
    if dp.stem not in param_dict:
        continue
    k1, k2, k3, cx, cy = param_dict[dp.stem]
    d = cv2.imread(str(dp))
    c = cv2.imread(str(cp))
    u = undistort_image(d, k1, k2, k3, cx, cy)
    mse = np.mean((u.astype(np.float32) - c.astype(np.float32)) ** 2)
    psnr = 10 * np.log10(255**2 / max(mse, 1e-10))
    psnrs.append(psnr)
    print(f"  {dp.stem}: PSNR = {psnr:.2f} dB")
if psnrs:
    print(f"  Average: {np.mean(psnrs):.2f} dB (target: >30 dB)")

## Cell 5 — Dataset & Model Definition

In [None]:
class DistortionDataset(Dataset):
    def __init__(self, image_dir, params_csv, image_size=224, augment=True, corrected_dir=None):
        self.image_dir = Path(image_dir)
        self.corrected_dir = Path(corrected_dir) if corrected_dir else None
        self.image_size = image_size
        self.augment = augment

        self.samples = []
        with open(params_csv, 'r') as f:
            reader = csv.DictReader(f)
            for row in reader:
                params = np.array([float(row['k1']), float(row['k2']), float(row['k3']),
                                   float(row['cx']), float(row['cy'])], dtype=np.float32)
                self.samples.append((row['image_id'], params))

        self._image_paths = {}
        for f in self.image_dir.iterdir():
            if f.suffix.lower() in img_exts:
                self._image_paths[f.stem] = f
        self.samples = [(i, p) for i, p in self.samples if i in self._image_paths]
        self._build_transforms()

    def _build_transforms(self):
        sz = self.image_size
        if self.augment:
            self.transform = A.Compose([
                A.Resize(sz, sz), A.ColorJitter(0.2, 0.2, 0.1, 0.05, p=0.5),
                A.GaussNoise(std_range=(0.01, 0.03), p=0.3),
                A.ImageCompression(quality_range=(70, 95), p=0.3),
                A.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]), ToTensorV2()])
        else:
            self.transform = A.Compose([
                A.Resize(sz, sz),
                A.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]), ToTensorV2()])
        self.target_transform = A.Compose([
            A.Resize(sz, sz),
            A.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]), ToTensorV2()])

    def update_image_size(self, new_size):
        self.image_size = new_size
        self._build_transforms()

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image_id, params = self.samples[idx]
        img = cv2.cvtColor(cv2.imread(str(self._image_paths[image_id])), cv2.COLOR_BGR2RGB)
        result = {'image': self.transform(image=img)['image'],
                  'params': torch.from_numpy(params), 'image_id': image_id}
        if self.corrected_dir is not None:
            for ext in img_exts:
                cp = self.corrected_dir / (image_id + ext)
                if cp.exists():
                    ci = cv2.cvtColor(cv2.imread(str(cp)), cv2.COLOR_BGR2RGB)
                    result['corrected'] = self.target_transform(image=ci)['image']
                    break
        return result


class TestDataset(Dataset):
    def __init__(self, image_dir, image_size=384):
        self.image_dir = Path(image_dir)
        self.image_files = sorted([f for f in self.image_dir.iterdir() if f.suffix.lower() in img_exts])
        self.transform = A.Compose([
            A.Resize(image_size, image_size),
            A.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]), ToTensorV2()])

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        p = self.image_files[idx]
        img = cv2.cvtColor(cv2.imread(str(p)), cv2.COLOR_BGR2RGB)
        oh, ow = img.shape[:2]
        return {'image': self.transform(image=img)['image'], 'image_id': p.stem,
                'image_path': str(p), 'orig_h': oh, 'orig_w': ow}


class DistortionNet(nn.Module):
    def __init__(self, backbone='efficientnet_b3', pretrained=True):
        super().__init__()
        self.backbone = timm.create_model(backbone, pretrained=pretrained, num_classes=0)
        feat_dim = self.backbone.num_features
        self.head = nn.Sequential(
            nn.Linear(feat_dim, 256), nn.ReLU(inplace=True),
            nn.Dropout(0.2), nn.Linear(256, 5))
        nn.init.zeros_(self.head[-1].weight)
        nn.init.zeros_(self.head[-1].bias)
        with torch.no_grad():
            self.head[-1].bias[3] = 0.5
            self.head[-1].bias[4] = 0.5

    def forward(self, x):
        features = self.backbone(x)
        params = self.head(features)
        k = torch.tanh(params[:, :3])
        center = torch.sigmoid(params[:, 3:])
        return torch.cat([k, center], dim=1)


def differentiable_undistort(image, params, output_size=None):
    B, C, H, W = image.shape
    out_H, out_W = output_size if output_size else (H, W)
    k1, k2, k3 = params[:, 0:1], params[:, 1:2], params[:, 2:3]
    cx, cy = params[:, 3:4], params[:, 4:5]

    gy, gx = torch.meshgrid(
        torch.linspace(-1, 1, out_H, device=image.device),
        torch.linspace(-1, 1, out_W, device=image.device), indexing='ij')
    gx = gx.unsqueeze(0).expand(B, -1, -1)
    gy = gy.unsqueeze(0).expand(B, -1, -1)

    cx_n = (cx * 2 - 1).unsqueeze(-1)
    cy_n = (cy * 2 - 1).unsqueeze(-1)
    dx = gx - cx_n
    dy = gy - cy_n
    r2 = dx**2 + dy**2

    k1, k2, k3 = k1.unsqueeze(-1), k2.unsqueeze(-1), k3.unsqueeze(-1)
    radial = 1 + k1 * r2 + k2 * r2**2 + k3 * r2**3
    grid = torch.stack([dx * radial + cx_n, dy * radial + cy_n], dim=-1)
    return F.grid_sample(image, grid, mode='bilinear', padding_mode='zeros', align_corners=True)


class DistortionLoss(nn.Module):
    def __init__(self, param_weight=1.0, pixel_weight=0.5):
        super().__init__()
        self.pw, self.xw = param_weight, pixel_weight

    def forward(self, pred, gt, dist_img=None, corr_img=None):
        lp = F.mse_loss(pred, gt)
        lx = torch.tensor(0.0, device=pred.device)
        if dist_img is not None and corr_img is not None:
            pred_corr = differentiable_undistort(dist_img, pred)
            lx = F.l1_loss(pred_corr, corr_img)
        return self.pw * lp + self.xw * lx, lp, lx


print(f"Dataset & model classes defined.")
model = DistortionNet(pretrained=True).to(DEVICE)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

## Cell 6 — Train

In [None]:
# ── Hyperparameters ──
EPOCHS = 30
BATCH_SIZE = 16
LR = 3e-3
WEIGHT_DECAY = 0.01
VAL_SPLIT = 0.2
NUM_WORKERS = 2
CKPT_DIR = WORK_DIR / 'checkpoints'
CKPT_DIR.mkdir(exist_ok=True)

# Progressive resizing schedule
SIZE_SCHEDULE = {0: 224, EPOCHS // 3: 384, 2 * EPOCHS // 3: 512}

# ── Dataset & Split ──
full_ds = DistortionDataset(DIST_DIR, PARAMS_CSV, image_size=224, augment=True, corrected_dir=CORR_DIR)
n_val = int(len(full_ds) * VAL_SPLIT)
n_train = len(full_ds) - n_val
train_ds, val_ds = random_split(full_ds, [n_train, n_val], generator=torch.Generator().manual_seed(42))

val_ds_noaug = DistortionDataset(DIST_DIR, PARAMS_CSV, image_size=224, augment=False, corrected_dir=CORR_DIR)
val_ds_noaug.samples = [val_ds_noaug.samples[i] for i in val_ds.indices]

print(f"Train: {n_train}, Val: {n_val}")

# ── Model, Loss, Optimizer ──
model = DistortionNet(pretrained=True).to(DEVICE)
criterion = DistortionLoss(param_weight=1.0, pixel_weight=0.5)
optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
steps_per_epoch = n_train // BATCH_SIZE + 1
scheduler = OneCycleLR(optimizer, max_lr=LR, total_steps=steps_per_epoch * EPOCHS,
                       pct_start=0.3, div_factor=25, final_div_factor=1000)

# ── Training Loop ──
best_val_loss = float('inf')
history = []

for epoch in range(1, EPOCHS + 1):
    # Progressive resize
    if epoch - 1 in SIZE_SCHEDULE:
        sz = SIZE_SCHEDULE[epoch - 1]
        print(f"\n>>> Resize to {sz}x{sz}")
        full_ds.update_image_size(sz)
        val_ds_noaug.update_image_size(sz)

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                              num_workers=NUM_WORKERS, pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_ds_noaug, batch_size=BATCH_SIZE, shuffle=False,
                            num_workers=NUM_WORKERS, pin_memory=True)

    # Train
    model.train()
    t_loss, t_pl, t_xl, nb = 0, 0, 0, 0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}", leave=False)
    for batch in pbar:
        imgs = batch['image'].to(DEVICE)
        params = batch['params'].to(DEVICE)
        corr = batch.get('corrected')
        corr = corr.to(DEVICE) if corr is not None else None

        pred = model(imgs)
        loss, lp, lx = criterion(pred, params, imgs, corr)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

        t_loss += loss.item(); t_pl += lp.item(); t_xl += lx.item(); nb += 1
        pbar.set_postfix(loss=f"{t_loss/nb:.4f}", param=f"{t_pl/nb:.4f}", pixel=f"{t_xl/nb:.4f}")

    # Validate
    model.eval()
    v_loss, v_pl, errs = 0, 0, []
    with torch.no_grad():
        for batch in val_loader:
            imgs = batch['image'].to(DEVICE)
            params = batch['params'].to(DEVICE)
            corr = batch.get('corrected')
            corr = corr.to(DEVICE) if corr is not None else None
            pred = model(imgs)
            loss, lp, lx = criterion(pred, params, imgs, corr)
            v_loss += loss.item(); v_pl += lp.item()
            errs.append((pred - params).abs().cpu().numpy())

    n_vb = max(len(val_loader), 1)
    v_loss /= n_vb
    me = np.concatenate(errs).mean(axis=0) if errs else np.zeros(5)
    print(f"Epoch {epoch}: train={t_loss/nb:.4f} val={v_loss:.4f} | "
          f"k1e={me[0]:.4f} k2e={me[1]:.4f} k3e={me[2]:.4f} cxe={me[3]:.4f} cye={me[4]:.4f}")

    history.append({'epoch': epoch, 'train_loss': t_loss/nb, 'val_loss': v_loss})

    if v_loss < best_val_loss:
        best_val_loss = v_loss
        torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(),
                    'val_loss': best_val_loss}, CKPT_DIR / 'best_model.pth')
        print(f"  >> Saved best model (val_loss={best_val_loss:.4f})")

print(f"\nDone. Best val_loss: {best_val_loss:.4f}")

## Cell 7 — Training Curves

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 4))
ax.plot([h['epoch'] for h in history], [h['train_loss'] for h in history], label='Train')
ax.plot([h['epoch'] for h in history], [h['val_loss'] for h in history], label='Val')
ax.set_xlabel('Epoch'); ax.set_ylabel('Loss'); ax.legend(); ax.set_title('Training Curves')
plt.tight_layout(); plt.show()

## Cell 8 — Metrics (Local Scoring)

In [None]:
def edge_similarity(img1, img2, scales=(1.0, 0.5, 0.25)):
    g1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY) if len(img1.shape) == 3 else img1
    g2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY) if len(img2.shape) == 3 else img2
    f1s = []
    for s in scales:
        if s != 1.0:
            h, w = int(g1.shape[0]*s), int(g1.shape[1]*s)
            a, b = cv2.resize(g1, (w,h)), cv2.resize(g2, (w,h))
        else:
            a, b = g1, g2
        m1, m2 = np.median(a), np.median(b)
        e1 = cv2.Canny(a, int(max(0,0.67*m1)), int(min(255,1.33*m1)))
        e2 = cv2.Canny(b, int(max(0,0.67*m2)), int(min(255,1.33*m2)))
        k = np.ones((3,3), np.uint8)
        e1d, e2d = cv2.dilate(e1, k, iterations=1), cv2.dilate(e2, k, iterations=1)
        if e1.sum() == 0 and e2.sum() == 0: f1s.append(1.0); continue
        if e1.sum() == 0 or e2.sum() == 0: f1s.append(0.0); continue
        p = (e1 & e2d).sum() / max(e1.sum(), 1)
        r = (e2 & e1d).sum() / max(e2.sum(), 1)
        f1s.append(2*p*r/(p+r) if p+r > 0 else 0.0)
    return np.mean(f1s)

def line_straightness(img, ref=None):
    g = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if len(img.shape) == 3 else img
    e = cv2.Canny(g, 50, 150, apertureSize=3)
    lines = cv2.HoughLinesP(e, 1, np.pi/180, 50, minLineLength=30, maxLineGap=10)
    if lines is None: return 0.5
    angles = np.array([np.arctan2(l[0][3]-l[0][1], l[0][2]-l[0][0])*180/np.pi for l in lines])
    if ref is not None:
        gr = cv2.cvtColor(ref, cv2.COLOR_BGR2GRAY) if len(ref.shape) == 3 else ref
        er = cv2.Canny(gr, 50, 150, apertureSize=3)
        lr = cv2.HoughLinesP(er, 1, np.pi/180, 50, minLineLength=30, maxLineGap=10)
        if lr is None: return 0.5
        ar = np.array([np.arctan2(l[0][3]-l[0][1], l[0][2]-l[0][0])*180/np.pi for l in lr])
        bins = np.linspace(-90, 90, 37)
        h1, _ = np.histogram(angles, bins=bins, density=True)
        h2, _ = np.histogram(ar, bins=bins, density=True)
        h1, h2 = h1/(h1.sum()+1e-10), h2/(h2.sum()+1e-10)
        return float(np.sum(np.sqrt(h1*h2)))
    return float(np.mean(np.minimum(np.abs(angles), np.abs(np.abs(angles)-90)) < 5))

def gradient_orientation_similarity(img1, img2, n_bins=36):
    g1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY).astype(np.float32) if len(img1.shape)==3 else img1.astype(np.float32)
    g2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY).astype(np.float32) if len(img2.shape)==3 else img2.astype(np.float32)
    gx1, gy1 = cv2.Sobel(g1,cv2.CV_32F,1,0,ksize=3), cv2.Sobel(g1,cv2.CV_32F,0,1,ksize=3)
    gx2, gy2 = cv2.Sobel(g2,cv2.CV_32F,1,0,ksize=3), cv2.Sobel(g2,cv2.CV_32F,0,1,ksize=3)
    bins = np.linspace(-np.pi, np.pi, n_bins+1)
    h1, _ = np.histogram(np.arctan2(gy1,gx1), bins=bins, weights=np.sqrt(gx1**2+gy1**2))
    h2, _ = np.histogram(np.arctan2(gy2,gx2), bins=bins, weights=np.sqrt(gx2**2+gy2**2))
    h1, h2 = h1/(h1.sum()+1e-10), h2/(h2.sum()+1e-10)
    return float(np.sum(np.sqrt(h1*h2)))

def pixel_accuracy(img1, img2):
    if img1.shape != img2.shape: img2 = cv2.resize(img2, (img1.shape[1], img1.shape[0]))
    return 1.0 - np.mean(np.abs(img1.astype(np.float32) - img2.astype(np.float32))) / 255.0

def compute_score(corrected, gt):
    if corrected.shape != gt.shape: corrected = cv2.resize(corrected, (gt.shape[1], gt.shape[0]))
    es = edge_similarity(corrected, gt)
    ls = line_straightness(corrected, gt)
    go = gradient_orientation_similarity(corrected, gt)
    ss = ssim(corrected, gt, channel_axis=2, data_range=255) if len(corrected.shape)==3 else ssim(corrected, gt, data_range=255)
    pa = pixel_accuracy(corrected, gt)
    overall = 0.40*es + 0.22*ls + 0.18*go + 0.15*ss + 0.05*pa
    return overall, dict(edge=es, line=ls, grad=go, ssim=ss, pixel=pa, overall=overall)

print("Metrics defined.")

## Cell 9 — Validate on Training Pairs (Local Metric Check)

In [None]:
# Score a few validation images using our model + local metrics
ckpt = torch.load(CKPT_DIR / 'best_model.pth', map_location=DEVICE, weights_only=False)
model.load_state_dict(ckpt['model_state_dict'])
model.eval()
print(f"Loaded best model from epoch {ckpt['epoch']}")

N_VIS = min(8, len(val_ds_noaug))
val_loader_vis = DataLoader(val_ds_noaug, batch_size=1, shuffle=False)

scores = []
fig, axes = plt.subplots(N_VIS, 3, figsize=(15, 4 * N_VIS))
if N_VIS == 1: axes = [axes]

for i, batch in enumerate(val_loader_vis):
    if i >= N_VIS:
        break
    image_id = batch['image_id'][0]

    # Predict params
    with torch.no_grad():
        pred_params = model(batch['image'].to(DEVICE)).cpu().numpy()[0]
    k1, k2, k3, cx, cy = pred_params

    # Load full-res images
    dist_path = next(DIST_DIR.glob(f"{image_id}.*"))
    corr_path = next(CORR_DIR.glob(f"{image_id}.*"))
    dist_img = cv2.imread(str(dist_path))
    corr_img = cv2.imread(str(corr_path))

    # Undistort
    pred_corr = undistort_image(dist_img, k1, k2, k3, cx, cy)

    # Score
    overall, m = compute_score(pred_corr, corr_img)
    scores.append(overall)

    # Visualize
    axes[i][0].imshow(cv2.cvtColor(dist_img, cv2.COLOR_BGR2RGB))
    axes[i][0].set_title(f'Distorted'); axes[i][0].axis('off')
    axes[i][1].imshow(cv2.cvtColor(pred_corr, cv2.COLOR_BGR2RGB))
    axes[i][1].set_title(f'Our correction (score={overall:.3f})'); axes[i][1].axis('off')
    axes[i][2].imshow(cv2.cvtColor(corr_img, cv2.COLOR_BGR2RGB))
    axes[i][2].set_title(f'Ground truth'); axes[i][2].axis('off')

plt.tight_layout(); plt.show()
print(f"\nAvg local score on {len(scores)} val images: {np.mean(scores):.4f}")

## Cell 10 — Predict Test Images + Test-Time Optimization

In [None]:
OUTPUT_DIR = WORK_DIR / 'output'
OUTPUT_DIR.mkdir(exist_ok=True)

# Load best model
ckpt = torch.load(CKPT_DIR / 'best_model.pth', map_location=DEVICE, weights_only=False)
model.load_state_dict(ckpt['model_state_dict'])
model.eval()

test_ds = TestDataset(TEST_DIR, image_size=384)
test_loader = DataLoader(test_ds, batch_size=16, shuffle=False, num_workers=2, pin_memory=True)
print(f"Test images: {len(test_ds)}")

# ── Stage 1: CNN predictions ──
predictions = {}
with torch.no_grad():
    for batch in tqdm(test_loader, desc="CNN prediction"):
        imgs = batch['image'].to(DEVICE)
        params = model(imgs)
        for i, img_id in enumerate(batch['image_id']):
            predictions[img_id] = {
                'params': params[i].cpu().numpy(),
                'image_path': batch['image_path'][i],
                'orig_h': batch['orig_h'][i].item(),
                'orig_w': batch['orig_w'][i].item(),
            }

# ── Stage 2: Test-Time Optimization ──
TTO_STEPS = 50
TTO_LR = 0.001

def tto_loss(undistorted):
    """Self-supervised loss: edge sharpness + border coverage."""
    img = undistorted.squeeze(0)
    gray = (0.299*img[0] + 0.587*img[1] + 0.114*img[2]).unsqueeze(0).unsqueeze(0)
    sx = torch.tensor([[-1,0,1],[-2,0,2],[-1,0,1]], dtype=torch.float32, device=img.device).view(1,1,3,3)
    sy = torch.tensor([[-1,-2,-1],[0,0,0],[1,2,1]], dtype=torch.float32, device=img.device).view(1,1,3,3)
    gx, gy = F.conv2d(gray, sx, padding=1), F.conv2d(gray, sy, padding=1)
    edge_loss = -torch.sqrt(gx**2 + gy**2 + 1e-6).mean()
    C, H, W = img.shape
    bs = max(2, min(H,W)//20)
    borders = torch.cat([img[:,:bs,:].reshape(-1), img[:,-bs:,:].reshape(-1),
                         img[:,:,:bs].reshape(-1), img[:,:,-bs:].reshape(-1)])
    border_loss = (1.0 - borders.abs().clamp(0,1)).mean()
    return 0.5*edge_loss + 0.5*border_loss

def diff_undistort_single(image, params):
    B, C, H, W = image.shape
    k1, k2, k3, cx, cy = params[0,0], params[0,1], params[0,2], params[0,3], params[0,4]
    gy, gx = torch.meshgrid(torch.linspace(-1,1,H,device=image.device),
                            torch.linspace(-1,1,W,device=image.device), indexing='ij')
    dx, dy = gx - (cx*2-1), gy - (cy*2-1)
    r2 = dx**2 + dy**2
    radial = 1 + k1*r2 + k2*r2**2 + k3*r2**3
    grid = torch.stack([dx*radial + (cx*2-1), dy*radial + (cy*2-1)], dim=-1).unsqueeze(0)
    return F.grid_sample(image, grid, mode='bilinear', padding_mode='zeros', align_corners=True)

print(f"\nRunning TTO ({TTO_STEPS} steps per image)...")
for img_id in tqdm(predictions, desc="TTO"):
    pred = predictions[img_id]
    img = cv2.cvtColor(cv2.imread(pred['image_path']), cv2.COLOR_BGR2RGB)
    img_small = cv2.resize(img, (256, 256))
    img_t = torch.from_numpy(img_small).float().permute(2,0,1).unsqueeze(0).to(DEVICE) / 255.0

    p = torch.tensor(pred['params'], dtype=torch.float32, device=DEVICE).unsqueeze(0)
    p = p.clone().detach().requires_grad_(True)
    init_p = torch.tensor(pred['params'], dtype=torch.float32, device=DEVICE)
    opt = torch.optim.Adam([p], lr=TTO_LR)

    best_loss, best_p = float('inf'), pred['params'].copy()
    for _ in range(TTO_STEPS):
        opt.zero_grad()
        with torch.no_grad():
            p.data[:,:3].clamp_(-1,1); p.data[:,3:].clamp_(0.1,0.9)
        undist = diff_undistort_single(img_t, p)
        loss = tto_loss(undist) + 0.1 * F.mse_loss(p.squeeze(), init_p)
        loss.backward()
        opt.step()
        if loss.item() < best_loss:
            best_loss = loss.item()
            best_p = p.detach().squeeze().cpu().numpy().copy()

    pred['params_tto'] = best_p

print("TTO complete.")

## Cell 11 — Apply Corrections & Create Submission ZIP

In [None]:
for img_id, pred in tqdm(predictions.items(), desc="Saving corrected images"):
    img = cv2.imread(pred['image_path'])
    if img is None:
        continue
    params = pred.get('params_tto', pred['params'])
    k1, k2, k3, cx, cy = params
    corrected = undistort_image(img, k1, k2, k3, cx, cy)
    cv2.imwrite(str(OUTPUT_DIR / f"{img_id}.jpg"), corrected, [cv2.IMWRITE_JPEG_QUALITY, 95])

# Create ZIP
ZIP_PATH = WORK_DIR / 'submission.zip'
with zipfile.ZipFile(ZIP_PATH, 'w', zipfile.ZIP_DEFLATED) as zf:
    for img_file in sorted(OUTPUT_DIR.glob('*.jpg')):
        zf.write(img_file, img_file.name)

n_output = len(list(OUTPUT_DIR.glob('*.jpg')))
zip_mb = ZIP_PATH.stat().st_size / 1024 / 1024
print(f"\nSaved {n_output} corrected images")
print(f"Submission ZIP: {ZIP_PATH} ({zip_mb:.1f} MB)")
print(f"\nNext steps:")
print(f"  1. Download submission.zip from the Output tab (right sidebar)")
print(f"  2. Upload to https://bounty.autohdr.com")
print(f"  3. Download the scoring CSV")
print(f"  4. Submit CSV to Kaggle")

## Cell 12 — (Optional) Visualize Test Corrections

In [None]:
test_files = sorted(OUTPUT_DIR.glob('*.jpg'))[:6]
fig, axes = plt.subplots(len(test_files), 2, figsize=(12, 4 * len(test_files)))
if len(test_files) == 1: axes = [axes]

for i, cf in enumerate(test_files):
    img_id = cf.stem
    orig = cv2.cvtColor(cv2.imread(predictions[img_id]['image_path']), cv2.COLOR_BGR2RGB)
    corr = cv2.cvtColor(cv2.imread(str(cf)), cv2.COLOR_BGR2RGB)
    axes[i][0].imshow(orig); axes[i][0].set_title(f'Distorted: {img_id}'); axes[i][0].axis('off')
    axes[i][1].imshow(corr); axes[i][1].set_title(f'Corrected'); axes[i][1].axis('off')

plt.tight_layout(); plt.show()