# Automatic Lens Correction

**Approach:** Parametric distortion prediction (k1, k2, k3, cx, cy) using EfficientNet-B3 + differentiable undistortion + test-time optimization.

**Works on:** Kaggle (offline), Google Colab, Paperspace, Lambda, any GPU box. Run Cell 0 first — it auto-detects your environment and sets everything up.

## Cell 0 — Environment Setup (run this first)

In [None]:
import os, subprocess, sysfrom pathlib import Path# ── Detect environment ──ENV = "unknown"if os.path.exists("/kaggle/input"):    ENV = "kaggle"elif "COLAB_RELEASE_TAG" in os.environ or os.path.exists("/content"):    ENV = "colab"elif os.path.exists("/notebooks"):  # Paperspace Gradient    ENV = "paperspace"else:    ENV = "generic"print(f"Detected environment: {ENV}")# ── Install missing deps (Kaggle has everything, others may not) ──if ENV != "kaggle":    print("Installing dependencies...")    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q",        "torch", "torchvision", "opencv-python-headless", "scipy",        "scikit-image", "tqdm", "pandas", "Pillow", "matplotlib", "kaggle"])# ── Set up paths ──if ENV == "kaggle":    INPUT_DIR = Path("/kaggle/input/automatic-lens-correction")    WORK_DIR = Path("/kaggle/working")elif ENV == "colab":    WORK_DIR = Path("/content/autohdr")    INPUT_DIR = WORK_DIR / "data"elif ENV == "paperspace":    WORK_DIR = Path("/notebooks/autohdr")    INPUT_DIR = WORK_DIR / "data"else:    WORK_DIR = Path(".")    INPUT_DIR = WORK_DIR / "data"WORK_DIR.mkdir(parents=True, exist_ok=True)INPUT_DIR.mkdir(parents=True, exist_ok=True)# ── Download data if not on Kaggle ──if ENV != "kaggle":    marker = INPUT_DIR / ".downloaded"    if not marker.exists():        print("\nDownloading competition data...")        print("Make sure ~/.kaggle/kaggle.json exists with your API credentials.")        if ENV == "colab":            # Colab: prompt for upload if missing            kaggle_json = Path.home() / ".kaggle" / "kaggle.json"            if not kaggle_json.exists():                print("Upload your kaggle.json:")                from google.colab import files                uploaded = files.upload()                kaggle_json.parent.mkdir(exist_ok=True)                for fname, content in uploaded.items():                    kaggle_json.write_bytes(content)                kaggle_json.chmod(0o600)        subprocess.check_call(["kaggle", "competitions", "download",                               "-c", "automatic-lens-correction",                               "-p", str(INPUT_DIR)])        # Unzip all zip files        import zipfile        for zf in INPUT_DIR.glob("*.zip"):            print(f"Extracting {zf.name}...")            with zipfile.ZipFile(zf, 'r') as z:                z.extractall(INPUT_DIR)            zf.unlink()        marker.touch()        print("Data ready.")    else:        print("Data already downloaded.")print(f"\nINPUT_DIR: {INPUT_DIR}")print(f"WORK_DIR:  {WORK_DIR}")

## Cell 1 — Imports & Dependency Fix

In [None]:
# Fix numpy/skimage compatibility issueimport subprocessimport systry:    from skimage.metrics import structural_similarity as ssimexcept ValueError as e:    print(f"Skimage import error: {e}")    print("Installing compatible versions...")    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "--upgrade", "numpy<2.0", "scikit-image"])    from skimage.metrics import structural_similarity as ssimimport csvimport zipfilefrom concurrent.futures import ProcessPoolExecutor, as_completedimport cv2import matplotlib.pyplot as pltimport numpy as npimport torchimport torch.nn as nnimport torch.nn.functional as Ffrom PIL import Imagefrom scipy.optimize import minimizefrom torch.optim import AdamWfrom torch.optim.lr_scheduler import OneCycleLRfrom torch.utils.data import DataLoader, Dataset, random_splitfrom torchvision import models, transforms as Tfrom tqdm import tqdmDEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')print(f"Device: {DEVICE}")

## Cell 2 — Discover Data Layout

The data uses `_original.jpg` / `_generated.jpg` pairs in the same folder.

In [None]:
img_exts = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}

def discover_data(data_dir):
    """Find train pairs and test images based on actual data structure."""
    data_dir = Path(data_dir)
    train_pairs = []  # list of (original_path, generated_path, image_id)
    test_files = []   # list of (path, image_id)
    train_dir = None
    test_dir = None
    
    # Training directory: lens-correction-train-cleaned
    train_dir_candidate = data_dir / 'lens-correction-train-cleaned'
    if train_dir_candidate.exists():
        train_dir = train_dir_candidate
        print(f"Found train dir: {train_dir}")
        
        # Find all images in train dir
        all_images = sorted([f for f in train_dir.iterdir() if f.suffix.lower() in img_exts])
        
        # Match _original with _generated
        originals = [f for f in all_images if f.stem.endswith('_original')]
        generated = [f for f in all_images if f.stem.endswith('_generated')]
        
        # Build pairs: extract base name (without _original or _generated)
        for orig in originals:
            # orig.stem is like "uuid_g0_original", strip "_original"
            base = orig.stem[:-9]  # Remove "_original" (9 chars)
            gen_name = base + '_generated.jpg'
            gen_path = orig.parent / gen_name
            
            if gen_path.exists():
                train_pairs.append((orig, gen_path, base))
        
        print(f"  Found {len(train_pairs)} training pairs")
    
    # Test directory: test-originals
    test_dir_candidate = data_dir / 'test-originals'
    if test_dir_candidate.exists():
        test_dir = test_dir_candidate
        print(f"Found test dir: {test_dir}")
        
        # Find all test images (no _original suffix)
        test_images = sorted([f for f in test_dir.iterdir() if f.suffix.lower() in img_exts])
        
        for img in test_images:
            # Extract image_id from filename (e.g., "uuid_g0.jpg" -> "uuid_g0")
            img_id = img.stem
            test_files.append((img, img_id))
        
        print(f"  Found {len(test_files)} test images")
    
    if not train_pairs and not test_files:
        print("Warning: No data found. Data structure:")
        for p in sorted(data_dir.rglob('*')):
            if p.is_file() and p.suffix.lower() in img_exts:
                print(f"  {p.relative_to(data_dir)}")
                if len(list(data_dir.rglob('*'))) > 20:
                    print("  ... (more files)")
                    break
    
    return train_pairs, train_dir, test_files, test_dir

# Run discovery
TRAIN_PAIRS, TRAIN_DIR, TEST_FILES, TEST_DIR = discover_data(INPUT_DIR)

# Create lookup dictionaries
PAIR_LOOKUP = {img_id: (orig, gen) for orig, gen, img_id in TRAIN_PAIRS}
TEST_LOOKUP = {img_id: path for path, img_id in TEST_FILES}

print(f"\nSummary:")
print(f"  Train pairs: {len(TRAIN_PAIRS)} | Test images: {len(TEST_FILES)}")
print(f"  PAIR_LOOKUP entries: {len(PAIR_LOOKUP)}")
if TRAIN_PAIRS:
    print(f"  Example train: {TRAIN_PAIRS[0][2]}")
if TEST_FILES:
    print(f"  Example test: {TEST_FILES[0][1]}")


## Cell 3 — Visualize Training Pairs

In [None]:
def show_pairs(n=4):    fig, axes = plt.subplots(n, 2, figsize=(12, 4*n))    if n == 1: axes = [axes]    for i in range(min(n, len(TRAIN_PAIRS))):        orig_path, gen_path, img_id = TRAIN_PAIRS[i]        d_img = cv2.cvtColor(cv2.imread(str(orig_path)), cv2.COLOR_BGR2RGB)        c_img = cv2.cvtColor(cv2.imread(str(gen_path)), cv2.COLOR_BGR2RGB)        axes[i][0].imshow(d_img); axes[i][0].set_title(f'Original (distorted): {img_id}'); axes[i][0].axis('off')        axes[i][1].imshow(c_img); axes[i][1].set_title(f'Generated (corrected)'); axes[i][1].axis('off')    plt.tight_layout(); plt.show()if TRAIN_PAIRS:    show_pairs(4)

## Cell 4 — Extract Distortion Parameters from Training Pairs

In [None]:
# Set to True to force recompute all params from scratch
FORCE_RECOMPUTE = False

PARAMS_CSV = WORK_DIR / 'params.csv'

def undistort_image(img, k1, k2, k3, cx, cy):
    h, w = img.shape[:2]
    fx = fy = max(h, w)
    cam = np.array([[fx, 0, cx*w], [0, fy, cy*h], [0, 0, 1]], dtype=np.float64)
    dist = np.array([k1, k2, 0, 0, k3], dtype=np.float64)
    new_cam, roi = cv2.getOptimalNewCameraMatrix(cam, dist, (w, h), alpha=0)
    out = cv2.undistort(img, cam, dist, None, new_cam)
    x, y, rw, rh = roi
    if rw > 0 and rh > 0:
        out = cv2.resize(out[y:y+rh, x:x+rw], (w, h), interpolation=cv2.INTER_LINEAR)
    return out

def objective(params, distorted, corrected):
    try:
        u = undistort_image(distorted, *params)
        return np.mean((u.astype(np.float32) - corrected.astype(np.float32))**2)
    except Exception:
        return 1e10

def extract_single(args):
    dp, cp, sz = args
    d = cv2.imread(str(dp))
    c = cv2.imread(str(cp))
    if d is None or c is None: return None
    d = cv2.resize(d, (sz, sz), interpolation=cv2.INTER_AREA)
    c = cv2.resize(c, (sz, sz), interpolation=cv2.INTER_AREA)
    res = minimize(objective, [0,0,0,0.5,0.5], args=(d,c), method='L-BFGS-B',
                   bounds=[(-1,1),(-1,1),(-1,1),(0.3,0.7),(0.3,0.7)],
                   options={'maxiter': 200, 'ftol': 1e-8})
    return res.x

# Load existing cache if not forcing recompute
cached_params = {}
if not FORCE_RECOMPUTE and PARAMS_CSV.exists():
    try:
        with open(PARAMS_CSV, 'r') as f:
            reader = csv.DictReader(f)
            for row in reader:
                cached_params[row['image_id']] = [
                    float(row['k1']), float(row['k2']), float(row['k3']),
                    float(row['cx']), float(row['cy'])
                ]
        print(f"Loaded {len(cached_params)} cached params from {PARAMS_CSV}")
    except Exception as e:
        print(f"Error loading cache: {e}, will recompute all")
        cached_params = {}

# Find which pairs need computation
train_ids = {img_id for _, _, img_id in TRAIN_PAIRS}
ids_to_compute = [img_id for img_id in train_ids if img_id not in cached_params]

if not ids_to_compute:
    print(f"All {len(train_ids)} pairs already cached. Set FORCE_RECOMPUTE = True to recompute.")
    results = [(img_id, *cached_params[img_id]) for img_id in train_ids]
else:
    print(f"Extracting params for {len(ids_to_compute)} new pairs (cached: {len(cached_params)})...")
    
    results = []
    with ProcessPoolExecutor(max_workers=4) as executor:
        futures = {}
        
        for orig_path, gen_path, img_id in TRAIN_PAIRS:
            if img_id in cached_params:
                results.append((img_id, *cached_params[img_id]))
            else:
                f = executor.submit(extract_single, (orig_path, gen_path, 256))
                futures[f] = img_id
        
        for f in tqdm(as_completed(futures), total=len(futures), desc="Extracting params"):
            img_id = futures[f]
            try:
                p = f.result()
                if p is not None:
                    results.append((img_id, *p))
            except Exception as e:
                print(f"Error {img_id}: {e}")
    
    print(f"Total params: {len(results)} (newly computed: {len(futures)})")

# Save all results back to CSV (ensures cache is complete)
with open(PARAMS_CSV, 'w', newline='') as f:
    w = csv.writer(f)
    w.writerow(['image_id', 'k1', 'k2', 'k3', 'cx', 'cy'])
    w.writerows(results)

print(f"Saved {len(results)} params to {PARAMS_CSV}")

# Quick sanity check: undistort one sample and show PSNR
print("\nSanity check (sample undistortion):")
psnrs = []
for orig_path, gen_path, img_id in TRAIN_PAIRS[:5]:
    d = cv2.imread(str(orig_path))
    c = cv2.imread(str(gen_path))
    if d is None or c is None: continue
    # Find params for this image
    params = None
    for r in results:
        if r[0] == img_id:
            params = r[1:]
            break
    if params is None: continue
    u = undistort_image(d, *params)
    u = cv2.resize(u, (c.shape[1], c.shape[0]))
    mse = np.mean((u.astype(np.float32) - c.astype(np.float32))**2)
    psnr = 10*np.log10(255**2/max(mse,1e-10))
    psnrs.append(psnr)
    print(f"  {img_id}: {psnr:.2f} dB")
if psnrs: print(f"  Average: {np.mean(psnrs):.2f} dB")


## Cell 5 — Dataset & Model Definition

Uses `torchvision.models.swin_t` (Swin Transformer) - better for spatial/geometric understanding than EfficientNet.

In [None]:
IMAGENET_MEAN = [0.485, 0.456, 0.406]IMAGENET_STD = [0.229, 0.224, 0.225]class DistortionDataset(Dataset):    """Loads (distorted_image, params) pairs. Uses PAIR_LOOKUP for corrected images."""    def __init__(self, pair_lookup, params_csv, image_size=224, augment=True):        self.pair_lookup = pair_lookup  # {image_id: (original_path, generated_path)}        self.image_size = image_size        self.augment = augment        self.samples = []        with open(params_csv, 'r') as f:            for row in csv.DictReader(f):                img_id = row['image_id']                if img_id in self.pair_lookup:                    p = np.array([float(row[k]) for k in ['k1','k2','k3','cx','cy']], dtype=np.float32)                    self.samples.append((img_id, p))        self._build_transforms()    def _build_transforms(self):        sz = self.image_size        if self.augment:            self.transform = T.Compose([                T.Resize((sz, sz)),                T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.05),                T.ToTensor(),                T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)])        else:            self.transform = T.Compose([                T.Resize((sz, sz)), T.ToTensor(),                T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)])        self.target_transform = T.Compose([            T.Resize((sz, sz)), T.ToTensor(),            T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)])    def update_image_size(self, new_size):        self.image_size = new_size        self._build_transforms()    def __len__(self): return len(self.samples)    def __getitem__(self, idx):        image_id, params = self.samples[idx]        orig_path, gen_path = self.pair_lookup[image_id]        img = Image.open(orig_path).convert('RGB')        result = {'image': self.transform(img), 'params': torch.from_numpy(params), 'image_id': image_id}        if gen_path.exists():            result['corrected'] = self.target_transform(Image.open(gen_path).convert('RGB'))        return resultclass TestDataset(Dataset):    """Loads test images. Uses TEST_FILES list of (path, image_id)."""    def __init__(self, test_files, image_size=384):        self.test_files = test_files  # list of (path, image_id)        self.transform = T.Compose([            T.Resize((image_size, image_size)), T.ToTensor(),            T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)])    def __len__(self): return len(self.test_files)    def __getitem__(self, idx):        p, img_id = self.test_files[idx]        img = Image.open(p).convert('RGB')        w, h = img.size        return {'image': self.transform(img), 'image_id': img_id,                'image_path': str(p), 'orig_h': h, 'orig_w': w}class DistortionNet(nn.Module):    def __init__(self, pretrained=True):        super().__init__()        # Swin Transformer - better for spatial/geometric tasks than EfficientNet        # Uses hierarchical attention to capture multi-scale edges and lines        weights = models.Swin_T_Weights.DEFAULT if pretrained else None        backbone = models.swin_t(weights=weights)        feat_dim = backbone.head.in_features  # 768 for swin_t        backbone.head = nn.Identity()  # Remove classification head        self.backbone = backbone        self.head = nn.Sequential(            nn.Linear(feat_dim, 256), nn.ReLU(inplace=True),            nn.Dropout(0.2), nn.Linear(256, 5))        nn.init.zeros_(self.head[-1].weight)        nn.init.zeros_(self.head[-1].bias)        with torch.no_grad():            self.head[-1].bias[3] = 0.5            self.head[-1].bias[4] = 0.5    def forward(self, x):        feat = self.backbone(x)        p = self.head(feat)        return torch.cat([torch.tanh(p[:,:3]), torch.sigmoid(p[:,3:])], dim=1)def differentiable_undistort(image, params):    B, C, H, W = image.shape    k1, k2, k3 = params[:,0:1], params[:,1:2], params[:,2:3]    cx, cy = params[:,3:4], params[:,4:5]    gy, gx = torch.meshgrid(        torch.linspace(-1,1,H,device=image.device),        torch.linspace(-1,1,W,device=image.device), indexing='ij')    gx = gx.unsqueeze(0).expand(B,-1,-1)    gy = gy.unsqueeze(0).expand(B,-1,-1)    cx_n = (cx*2-1).unsqueeze(-1)    cy_n = (cy*2-1).unsqueeze(-1)    dx, dy = gx - cx_n, gy - cy_n    r2 = dx**2 + dy**2    k1, k2, k3 = k1.unsqueeze(-1), k2.unsqueeze(-1), k3.unsqueeze(-1)    radial = 1 + k1*r2 + k2*r2**2 + k3*r2**3    grid = torch.stack([dx*radial + cx_n, dy*radial + cy_n], dim=-1)    return F.grid_sample(image, grid, mode='bilinear', padding_mode='zeros', align_corners=True)class DistortionLoss(nn.Module):    def __init__(self, param_weight=1.0, pixel_weight=0.5):        super().__init__()        self.pw, self.xw = param_weight, pixel_weight    def forward(self, pred, gt, dist_img=None, corr_img=None):        lp = F.mse_loss(pred, gt)        lx = torch.tensor(0.0, device=pred.device)        if dist_img is not None and corr_img is not None:            lx = F.l1_loss(differentiable_undistort(dist_img, pred), corr_img)        return self.pw*lp + self.xw*lx, lp, lxprint("Classes defined.")model = DistortionNet(pretrained=True).to(DEVICE)print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

## Cell 6 — Train

In [None]:
EPOCHS = 30BATCH_SIZE = 16LR = 3e-3NUM_WORKERS = 2CKPT_DIR = WORK_DIR / 'checkpoints'CKPT_DIR.mkdir(exist_ok=True)SIZE_SCHEDULE = {0: 224, EPOCHS//3: 384, 2*EPOCHS//3: 512}full_ds = DistortionDataset(PAIR_LOOKUP, PARAMS_CSV, image_size=224, augment=True)n_val = int(len(full_ds) * 0.2)n_train = len(full_ds) - n_valtrain_ds, val_ds = random_split(full_ds, [n_train, n_val], generator=torch.Generator().manual_seed(42))val_ds_noaug = DistortionDataset(PAIR_LOOKUP, PARAMS_CSV, image_size=224, augment=False)val_ds_noaug.samples = [val_ds_noaug.samples[i] for i in val_ds.indices]print(f"Train: {n_train}, Val: {n_val}")model = DistortionNet(pretrained=True).to(DEVICE)criterion = DistortionLoss(param_weight=1.0, pixel_weight=0.5)optimizer = AdamW(model.parameters(), lr=LR, weight_decay=0.01)scheduler = OneCycleLR(optimizer, max_lr=LR, total_steps=(n_train//BATCH_SIZE+1)*EPOCHS,                       pct_start=0.3, div_factor=25, final_div_factor=1000)best_val_loss = float('inf')history = []for epoch in range(1, EPOCHS+1):    if epoch-1 in SIZE_SCHEDULE:        sz = SIZE_SCHEDULE[epoch-1]        print(f"\n>>> Resize to {sz}x{sz}")        full_ds.update_image_size(sz)        val_ds_noaug.update_image_size(sz)    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,                              num_workers=NUM_WORKERS, pin_memory=True, drop_last=True)    val_loader = DataLoader(val_ds_noaug, batch_size=BATCH_SIZE, shuffle=False,                            num_workers=NUM_WORKERS, pin_memory=True)    model.train()    t_loss, t_pl, t_xl, nb = 0, 0, 0, 0    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}", leave=False)    for batch in pbar:        imgs = batch['image'].to(DEVICE)        params = batch['params'].to(DEVICE)        corr = batch.get('corrected')        corr = corr.to(DEVICE) if corr is not None else None        pred = model(imgs)        loss, lp, lx = criterion(pred, params, imgs, corr)        optimizer.zero_grad()        loss.backward()        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)        optimizer.step()        scheduler.step()        t_loss += loss.item(); t_pl += lp.item(); t_xl += lx.item(); nb += 1        pbar.set_postfix(loss=f"{t_loss/nb:.4f}", param=f"{t_pl/nb:.4f}", pixel=f"{t_xl/nb:.4f}")    model.eval()    v_loss, errs = 0, []    with torch.no_grad():        for batch in val_loader:            imgs = batch['image'].to(DEVICE)            params = batch['params'].to(DEVICE)            corr = batch.get('corrected')            corr = corr.to(DEVICE) if corr is not None else None            pred = model(imgs)            loss, lp, lx = criterion(pred, params, imgs, corr)            v_loss += loss.item()            errs.append((pred - params).abs().cpu().numpy())    n_vb = max(len(val_loader), 1)    v_loss /= n_vb    me = np.concatenate(errs).mean(axis=0) if errs else np.zeros(5)    print(f"Epoch {epoch}: train={t_loss/nb:.4f} val={v_loss:.4f} | "          f"k1e={me[0]:.4f} k2e={me[1]:.4f} k3e={me[2]:.4f} cxe={me[3]:.4f} cye={me[4]:.4f}")    history.append({'epoch': epoch, 'train_loss': t_loss/nb, 'val_loss': v_loss})    if v_loss < best_val_loss:        best_val_loss = v_loss        torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(),                    'val_loss': best_val_loss}, CKPT_DIR / 'best_model.pth')        print(f"  >> Saved best (val_loss={best_val_loss:.4f})")print(f"\nDone. Best val_loss: {best_val_loss:.4f}")

## Cell 7 — Training Curves

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 4))

ax.plot([h['epoch'] for h in history], [h['train_loss'] for h in history], label='Train')

ax.plot([h['epoch'] for h in history], [h['val_loss'] for h in history], label='Val')

ax.set_xlabel('Epoch'); ax.set_ylabel('Loss'); ax.legend(); ax.set_title('Training Curves')

plt.tight_layout(); plt.show()

## Cell 8 — Local Scoring Metrics

In [None]:
def edge_similarity(img1, img2, scales=(1.0, 0.5, 0.25)):

    g1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY) if len(img1.shape)==3 else img1

    g2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY) if len(img2.shape)==3 else img2

    f1s = []

    for s in scales:

        a = cv2.resize(g1, (int(g1.shape[1]*s), int(g1.shape[0]*s))) if s != 1.0 else g1

        b = cv2.resize(g2, (int(g2.shape[1]*s), int(g2.shape[0]*s))) if s != 1.0 else g2

        e1 = cv2.Canny(a, int(max(0,0.67*np.median(a))), int(min(255,1.33*np.median(a))))

        e2 = cv2.Canny(b, int(max(0,0.67*np.median(b))), int(min(255,1.33*np.median(b))))

        k = np.ones((3,3), np.uint8)

        e1d, e2d = cv2.dilate(e1,k,iterations=1), cv2.dilate(e2,k,iterations=1)

        if e1.sum()==0 and e2.sum()==0: f1s.append(1.0); continue

        if e1.sum()==0 or e2.sum()==0: f1s.append(0.0); continue

        p = (e1 & e2d).sum()/max(e1.sum(),1)

        r = (e2 & e1d).sum()/max(e2.sum(),1)

        f1s.append(2*p*r/(p+r) if p+r>0 else 0.0)

    return np.mean(f1s)



def line_straightness(img, ref=None):

    g = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if len(img.shape)==3 else img

    lines = cv2.HoughLinesP(cv2.Canny(g,50,150,apertureSize=3), 1, np.pi/180, 50, minLineLength=30, maxLineGap=10)

    if lines is None: return 0.5

    angles = np.array([np.arctan2(l[0][3]-l[0][1], l[0][2]-l[0][0])*180/np.pi for l in lines])

    if ref is not None:

        gr = cv2.cvtColor(ref, cv2.COLOR_BGR2GRAY) if len(ref.shape)==3 else ref

        lr = cv2.HoughLinesP(cv2.Canny(gr,50,150,apertureSize=3), 1, np.pi/180, 50, minLineLength=30, maxLineGap=10)

        if lr is None: return 0.5

        ar = np.array([np.arctan2(l[0][3]-l[0][1], l[0][2]-l[0][0])*180/np.pi for l in lr])

        bins = np.linspace(-90,90,37)

        h1,_ = np.histogram(angles,bins=bins,density=True)

        h2,_ = np.histogram(ar,bins=bins,density=True)

        h1, h2 = h1/(h1.sum()+1e-10), h2/(h2.sum()+1e-10)

        return float(np.sum(np.sqrt(h1*h2)))

    return float(np.mean(np.minimum(np.abs(angles), np.abs(np.abs(angles)-90))<5))



def gradient_orientation_sim(img1, img2, n_bins=36):

    g1 = cv2.cvtColor(img1,cv2.COLOR_BGR2GRAY).astype(np.float32) if len(img1.shape)==3 else img1.astype(np.float32)

    g2 = cv2.cvtColor(img2,cv2.COLOR_BGR2GRAY).astype(np.float32) if len(img2.shape)==3 else img2.astype(np.float32)

    gx1, gy1 = cv2.Sobel(g1,cv2.CV_32F,1,0,ksize=3), cv2.Sobel(g1,cv2.CV_32F,0,1,ksize=3)

    gx2, gy2 = cv2.Sobel(g2,cv2.CV_32F,1,0,ksize=3), cv2.Sobel(g2,cv2.CV_32F,0,1,ksize=3)

    bins = np.linspace(-np.pi, np.pi, n_bins+1)

    h1,_ = np.histogram(np.arctan2(gy1,gx1), bins=bins, weights=np.sqrt(gx1**2+gy1**2))

    h2,_ = np.histogram(np.arctan2(gy2,gx2), bins=bins, weights=np.sqrt(gx2**2+gy2**2))

    h1, h2 = h1/(h1.sum()+1e-10), h2/(h2.sum()+1e-10)

    return float(np.sum(np.sqrt(h1*h2)))



def pixel_accuracy(img1, img2):

    if img1.shape != img2.shape: img2 = cv2.resize(img2, (img1.shape[1], img1.shape[0]))

    return 1.0 - np.mean(np.abs(img1.astype(np.float32) - img2.astype(np.float32)))/255.0



def compute_score(corrected, gt):

    if corrected.shape != gt.shape: corrected = cv2.resize(corrected, (gt.shape[1], gt.shape[0]))

    es = edge_similarity(corrected, gt)

    ls = line_straightness(corrected, gt)

    go = gradient_orientation_sim(corrected, gt)

    ss = ssim(corrected, gt, channel_axis=2, data_range=255) if len(corrected.shape)==3 else ssim(corrected, gt, data_range=255)

    pa = pixel_accuracy(corrected, gt)

    overall = 0.40*es + 0.22*ls + 0.18*go + 0.15*ss + 0.05*pa

    return overall, dict(edge=es, line=ls, grad=go, ssim=ss, pixel=pa, overall=overall)



print("Metrics defined.")

## Cell 9 — Validate on Training Pairs

In [None]:
ckpt = torch.load(CKPT_DIR / 'best_model.pth', map_location=DEVICE, weights_only=False)model.load_state_dict(ckpt['model_state_dict'])model.eval()print(f"Loaded best model from epoch {ckpt['epoch']}")N_VIS = min(8, len(val_ds_noaug))val_vis_loader = DataLoader(val_ds_noaug, batch_size=1, shuffle=False)scores = []fig, axes = plt.subplots(N_VIS, 3, figsize=(15, 4*N_VIS))if N_VIS == 1: axes = [axes]for i, batch in enumerate(val_vis_loader):    if i >= N_VIS: break    image_id = batch['image_id'][0]    with torch.no_grad():        pp = model(batch['image'].to(DEVICE)).cpu().numpy()[0]    orig_path, gen_path = PAIR_LOOKUP[image_id]    dist_img = cv2.imread(str(orig_path))    corr_img = cv2.imread(str(gen_path))    pred_corr = undistort_image(dist_img, *pp)    overall, m = compute_score(pred_corr, corr_img)    scores.append(overall)    axes[i][0].imshow(cv2.cvtColor(dist_img, cv2.COLOR_BGR2RGB))    axes[i][0].set_title('Distorted'); axes[i][0].axis('off')    axes[i][1].imshow(cv2.cvtColor(pred_corr, cv2.COLOR_BGR2RGB))    axes[i][1].set_title(f'Ours (score={overall:.3f})'); axes[i][1].axis('off')    axes[i][2].imshow(cv2.cvtColor(corr_img, cv2.COLOR_BGR2RGB))    axes[i][2].set_title('Ground truth'); axes[i][2].axis('off')plt.tight_layout(); plt.show()print(f"\nAvg local score: {np.mean(scores):.4f}")

## Cell 10 — Predict Test + Test-Time Optimization

In [None]:
OUTPUT_DIR = WORK_DIR / 'output'OUTPUT_DIR.mkdir(exist_ok=True)ckpt = torch.load(CKPT_DIR / 'best_model.pth', map_location=DEVICE, weights_only=False)model.load_state_dict(ckpt['model_state_dict'])model.eval()test_ds = TestDataset(TEST_FILES, image_size=384)test_loader = DataLoader(test_ds, batch_size=16, shuffle=False, num_workers=2, pin_memory=True)print(f"Test images: {len(test_ds)}")# Stage 1: CNN predictionspredictions = {}with torch.no_grad():    for batch in tqdm(test_loader, desc="CNN prediction"):        preds = model(batch['image'].to(DEVICE))        for i, img_id in enumerate(batch['image_id']):            predictions[img_id] = {                'params': preds[i].cpu().numpy(),                'image_path': batch['image_path'][i],                'orig_h': batch['orig_h'][i].item(),                'orig_w': batch['orig_w'][i].item()}# Stage 2: Test-Time OptimizationTTO_STEPS = 50def tto_loss(undistorted):    img = undistorted.squeeze(0)    gray = (0.299*img[0]+0.587*img[1]+0.114*img[2]).unsqueeze(0).unsqueeze(0)    sx = torch.tensor([[-1,0,1],[-2,0,2],[-1,0,1]],dtype=torch.float32,device=img.device).view(1,1,3,3)    sy = torch.tensor([[-1,-2,-1],[0,0,0],[1,2,1]],dtype=torch.float32,device=img.device).view(1,1,3,3)    gx, gy = F.conv2d(gray,sx,padding=1), F.conv2d(gray,sy,padding=1)    edge_loss = -torch.sqrt(gx**2+gy**2+1e-6).mean()    C,H,W = img.shape    bs = max(2, min(H,W)//20)    borders = torch.cat([img[:,:bs,:].reshape(-1), img[:,-bs:,:].reshape(-1),                         img[:,:,:bs].reshape(-1), img[:,:,-bs:].reshape(-1)])    return 0.5*edge_loss + 0.5*(1.0-borders.abs().clamp(0,1)).mean()def diff_undistort(image, params):    B,C,H,W = image.shape    k1,k2,k3,cx,cy = params[0,0],params[0,1],params[0,2],params[0,3],params[0,4]    gy,gx = torch.meshgrid(torch.linspace(-1,1,H,device=image.device),                           torch.linspace(-1,1,W,device=image.device),indexing='ij')    dx, dy = gx-(cx*2-1), gy-(cy*2-1)    r2 = dx**2+dy**2    rad = 1+k1*r2+k2*r2**2+k3*r2**3    grid = torch.stack([dx*rad+(cx*2-1), dy*rad+(cy*2-1)], dim=-1).unsqueeze(0)    return F.grid_sample(image, grid, mode='bilinear', padding_mode='zeros', align_corners=True)print(f"\nRunning TTO ({TTO_STEPS} steps/image)...")for img_id in tqdm(predictions, desc="TTO"):    pred = predictions[img_id]    img = cv2.cvtColor(cv2.imread(pred['image_path']), cv2.COLOR_BGR2RGB)    img_t = torch.from_numpy(cv2.resize(img,(256,256))).float().permute(2,0,1).unsqueeze(0).to(DEVICE)/255.0    p = torch.tensor(pred['params'],dtype=torch.float32,device=DEVICE).unsqueeze(0).clone().detach().requires_grad_(True)    init_p = torch.tensor(pred['params'],dtype=torch.float32,device=DEVICE)    opt = torch.optim.Adam([p], lr=0.001)    best_loss, best_p = float('inf'), pred['params'].copy()    for _ in range(TTO_STEPS):        opt.zero_grad()        with torch.no_grad(): p.data[:,:3].clamp_(-1,1); p.data[:,3:].clamp_(0.1,0.9)        loss = tto_loss(diff_undistort(img_t, p)) + 0.1*F.mse_loss(p.squeeze(), init_p)        loss.backward(); opt.step()        if loss.item() < best_loss:            best_loss = loss.item()            best_p = p.detach().squeeze().cpu().numpy().copy()    pred['params_tto'] = best_pprint("TTO complete.")

## Cell 11 — Apply Corrections & Create Submission ZIP

In [None]:
for img_id, pred in tqdm(predictions.items(), desc="Saving corrected"):    img = cv2.imread(pred['image_path'])    if img is None: continue    corrected = undistort_image(img, *pred.get('params_tto', pred['params']))    cv2.imwrite(str(OUTPUT_DIR / f"{img_id}.jpg"), corrected, [cv2.IMWRITE_JPEG_QUALITY, 95])ZIP_PATH = WORK_DIR / 'submission.zip'with zipfile.ZipFile(ZIP_PATH, 'w', zipfile.ZIP_DEFLATED) as zf:    for f in sorted(OUTPUT_DIR.glob('*.jpg')): zf.write(f, f.name)n_out = len(list(OUTPUT_DIR.glob('*.jpg')))print(f"\nSaved {n_out} corrected images")print(f"ZIP: {ZIP_PATH} ({ZIP_PATH.stat().st_size/1024/1024:.1f} MB)")# Auto-download on Colabif ENV == "colab":    from google.colab import files    files.download(str(ZIP_PATH))    print("Download started automatically.")elif ENV == "kaggle":    print("Download submission.zip from the Output tab (right sidebar).")else:    print(f"submission.zip is at: {ZIP_PATH}")print(f"\nNext steps:")print(f"  1. Upload submission.zip to https://bounty.autohdr.com")print(f"  2. Download the scoring CSV")print(f"  3. Submit CSV to Kaggle")

## Cell 12 — (Optional) Visualize Test Corrections

In [None]:
test_files = sorted(OUTPUT_DIR.glob('*.jpg'))[:6]

fig, axes = plt.subplots(len(test_files), 2, figsize=(12, 4*len(test_files)))

if len(test_files) == 1: axes = [axes]

for i, cf in enumerate(test_files):

    img_id = cf.stem

    orig = cv2.cvtColor(cv2.imread(predictions[img_id]['image_path']), cv2.COLOR_BGR2RGB)

    corr = cv2.cvtColor(cv2.imread(str(cf)), cv2.COLOR_BGR2RGB)

    axes[i][0].imshow(orig); axes[i][0].set_title(f'Distorted: {img_id}'); axes[i][0].axis('off')

    axes[i][1].imshow(corr); axes[i][1].set_title('Corrected'); axes[i][1].axis('off')

plt.tight_layout(); plt.show()