# Fourier Phase Retrieval - Kaggle Master Notebook

This notebook consolidates the full pipeline:
- Dataset construction (MNIST -> diffraction intensity)
- U-Net model (DoubleConv blocks)
- Physics forward model (Bartlett window + FFT pipeline)
- Pre-training
- Test-time fine-tuning
- Evaluation and metrics
- Export results/models as a zip file

Use the `MODE` config cell to switch between:
- `TEST`   (500 train / 50 val, 1 epoch)
- `MEDIUM` (25% train / 25% val, 10 epochs)
- `FULL`   (100% train / 100% val, 30 epochs)

In [None]:
%matplotlib inline

import os
import random
import shutil
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from IPython.display import FileLink, display
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import datasets, transforms
from tqdm.auto import tqdm

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

BASE_DIR = Path('.')
DATA_DIR = BASE_DIR / 'data'
RESULTS_DIR = BASE_DIR / 'results'
MODELS_DIR = BASE_DIR / 'models'

DATA_DIR.mkdir(parents=True, exist_ok=True)
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
MODELS_DIR.mkdir(parents=True, exist_ok=True)

: 

In [None]:
# =============================
# Configuration: TEST / MEDIUM / FULL
# =============================
MODE = 'TEST'  # Change to 'MEDIUM' or 'FULL'

CONFIGS = {
    'TEST': {
        'train_size': 500,
        'val_size': 50,
        'epochs': 1,
        'batch_size': 16,
        'finetune_iterations': 20,
        'eval_samples': 2,
        'lr_pretrain': 1e-3,
        'lr_finetune': 1e-4,
        'early_stopping_patience': 3,
        'scheduler_patience': 1,
    },
    'MEDIUM': {
        'train_fraction': 0.25,
        'val_fraction': 0.25,
        'epochs': 10,
        'batch_size': 16,
        'finetune_iterations': 50,
        'eval_fraction': 0.25,
        'max_eval_samples': 250,
        'lr_pretrain': 1e-3,
        'lr_finetune': 1e-4,
        'early_stopping_patience': 4,
        'scheduler_patience': 2,
    },
    'FULL': {
        'train_fraction': 1.0,
        'val_fraction': 1.0,
        'epochs': 30,
        'batch_size': 16,
        'finetune_iterations': 100,
        'eval_fraction': 1.0,
        'max_eval_samples': 500,
        'lr_pretrain': 1e-3,
        'lr_finetune': 1e-4,
        'early_stopping_patience': 6,
        'scheduler_patience': 3,
    },
}

cfg = CONFIGS[MODE]
print('MODE:', MODE)
print('Config:', cfg)

In [None]:
# =============================
# Dataset: MNIST -> Diffraction Intensity (Windowed + Log-scaled)
# =============================
class FPRDataset(Dataset):
    def __init__(self, mnist_root='./data', train=True, size=128):
        self.mnist = datasets.MNIST(
            root=mnist_root,
            train=train,
            download=True,
            transform=transforms.Compose([
                transforms.Resize(size),
                transforms.ToTensor(),
            ]),
        )
        w1d = torch.bartlett_window(size)
        self.window = w1d.unsqueeze(1) * w1d.unsqueeze(0)

    def __len__(self):
        return len(self.mnist)

    def compute_diffraction(self, image):
        windowed_image = image[0] * self.window
        fourier = torch.fft.fft2(windowed_image)
        fourier_shifted = torch.fft.fftshift(fourier, dim=(-2, -1))
        intensity = torch.abs(fourier_shifted) ** 2
        log_intensity = torch.log1p(intensity)
        normalized = (log_intensity - log_intensity.min()) / (log_intensity.max() - log_intensity.min() + 1e-12)
        return normalized.unsqueeze(0)

    def __getitem__(self, idx):
        image, _ = self.mnist[idx]
        intensity = self.compute_diffraction(image)
        return intensity, image


def build_subsets(dataset_train, dataset_val, config):
    n_train = len(dataset_train)
    n_val = len(dataset_val)

    if 'train_size' in config:
        train_size = min(config['train_size'], n_train)
    else:
        train_size = int(n_train * config['train_fraction'])

    if 'val_size' in config:
        val_size = min(config['val_size'], n_val)
    else:
        val_size = int(n_val * config['val_fraction'])

    train_subset = Subset(dataset_train, list(range(train_size)))
    val_subset = Subset(dataset_val, list(range(val_size)))
    return train_subset, val_subset


full_train = FPRDataset(mnist_root=str(DATA_DIR), train=True, size=128)
full_val = FPRDataset(mnist_root=str(DATA_DIR), train=False, size=128)
train_subset, val_subset = build_subsets(full_train, full_val, cfg)

train_loader = DataLoader(
    train_subset,
    batch_size=cfg['batch_size'],
    shuffle=True,
    num_workers=0,
)
val_loader = DataLoader(
    val_subset,
    batch_size=cfg['batch_size'],
    shuffle=False,
    num_workers=0,
)

print(f'Train samples: {len(train_subset):,}')
print(f'Val samples: {len(val_subset):,}')

sample_inp, sample_tgt = train_subset[0]
print('Input shape:', sample_inp.shape, '| Target shape:', sample_tgt.shape)
print(f'Input abs min/max: {sample_inp.abs().min().item():.6f} / {sample_inp.abs().max().item():.6f}')
print(f'Target abs min/max: {sample_tgt.abs().min().item():.6f} / {sample_tgt.abs().max().item():.6f}')

fig, ax = plt.subplots(1, 2, figsize=(8, 4))
ax[0].imshow(sample_inp[0].cpu().numpy(), cmap='viridis')
ax[0].set_title('Sample Input (Diffraction)')
ax[0].axis('off')
ax[1].imshow(sample_tgt[0].cpu().numpy(), cmap='gray')
ax[1].set_title('Sample Target (Image)')
ax[1].axis('off')
plt.tight_layout()
plt.show()

In [None]:
# =============================
# U-Net Architecture
# =============================
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.block(x)


class UNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.enc1 = DoubleConv(1, 32)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = DoubleConv(32, 64)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = DoubleConv(64, 128)
        self.pool3 = nn.MaxPool2d(2)
        self.enc4 = DoubleConv(128, 256)
        self.pool4 = nn.MaxPool2d(2)

        self.bottleneck = DoubleConv(256, 512)

        self.up4 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec4 = DoubleConv(512, 256)
        self.up3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec3 = DoubleConv(256, 128)
        self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec2 = DoubleConv(128, 64)
        self.up1 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.dec1 = DoubleConv(64, 32)

        self.out_conv = nn.Conv2d(32, 1, kernel_size=3, padding=1)
        self.out_act = nn.Sigmoid()

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        e4 = self.enc4(self.pool3(e3))

        b = self.bottleneck(self.pool4(e4))

        d4 = self.up4(b)
        d4 = self.dec4(torch.cat([d4, e4], dim=1))
        d3 = self.up3(d4)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))
        d2 = self.up2(d3)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))
        d1 = self.up1(d2)
        d1 = self.dec1(torch.cat([d1, e1], dim=1))

        out = self.out_act(self.out_conv(d1))
        return out


model = UNet().to(device)
x = torch.randn(2, 1, 128, 128, device=device)
y = model(x)
print('Model OK. Output shape:', y.shape)
print('Parameters:', f"{sum(p.numel() for p in model.parameters()):,}")

In [None]:
# =============================
# Physics Forward Model + Fine-tuning
# =============================
def bartlett_window(size, device=device):
    w = torch.bartlett_window(size, device=device)
    return w.unsqueeze(1) * w.unsqueeze(0)


def physics_forward(image, window):
    # Supports [B,1,H,W] and [1,H,W]
    if image.dim() == 4:
        img = image[:, 0]
        windowed = img * window.unsqueeze(0)
        fourier = torch.fft.fft2(windowed)
        shifted = torch.fft.fftshift(fourier)
        intensity = torch.abs(shifted) ** 2
        log_intensity = torch.log1p(intensity)
        min_v = log_intensity.amin(dim=(-2, -1), keepdim=True)
        max_v = log_intensity.amax(dim=(-2, -1), keepdim=True)
        norm = (log_intensity - min_v) / (max_v - min_v + 1e-12)
        return norm.unsqueeze(1)
    if image.dim() == 3:
        img = image[0]
        windowed = img * window
        fourier = torch.fft.fft2(windowed)
        shifted = torch.fft.fftshift(fourier)
        intensity = torch.abs(shifted) ** 2
        log_intensity = torch.log1p(intensity)
        norm = (log_intensity - log_intensity.min()) / (log_intensity.max() - log_intensity.min() + 1e-12)
        return norm.unsqueeze(0)
    raise ValueError('Unsupported image shape for physics_forward')


def finetune_test_sample(model, measured_intensity, window, iterations=100, lr=1e-4, lambda_tv=1e-5):
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    losses = []

    for _ in range(iterations):
        optimizer.zero_grad()
        pred_image = model(measured_intensity)
        pred_intensity = physics_forward(pred_image, window)

        tv_h = torch.abs(pred_image[:, :, 1:, :] - pred_image[:, :, :-1, :]).sum()
        tv_w = torch.abs(pred_image[:, :, :, 1:] - pred_image[:, :, :, :-1]).sum()
        tv_loss = tv_h + tv_w

        data_loss = criterion(pred_intensity, measured_intensity)
        total_loss = data_loss + lambda_tv * tv_loss

        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        losses.append(total_loss.item())

    model.eval()
    with torch.no_grad():
        refined = model(measured_intensity).detach()
    return refined, losses

In [None]:
# =============================
# Pre-training (Train + Validation)
# =============================
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=cfg['lr_pretrain'], weight_decay=1e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=cfg['scheduler_patience'],
)

train_losses = []
val_losses = []

# Save policy:
# Every 5 epochs -> keep best model so far + 3 latest periodic checkpoints
checkpoint_every = 5
recent_checkpoints = []
best_val = float('inf')
best_path = MODELS_DIR / 'best_model_so_far.pth'

patience = cfg['early_stopping_patience']
early_stop_counter = 0
best_for_patience = float('inf')

for epoch in range(1, cfg['epochs'] + 1):
    model.train()
    train_sum = 0.0

    train_bar = tqdm(train_loader, desc=f'Epoch {epoch:02d} Train', leave=False)
    for inp, tgt in train_bar:
        inp = inp.to(device)
        tgt = tgt.to(device)

        pred = model(inp)
        loss = criterion(pred, tgt)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        train_sum += loss.item()
        train_bar.set_postfix({'loss': f'{loss.item():.5f}'})

    avg_train = train_sum / len(train_loader)

    model.eval()
    val_sum = 0.0
    val_bar = tqdm(val_loader, desc=f'Epoch {epoch:02d} Val', leave=False)
    with torch.no_grad():
        for inp, tgt in val_bar:
            inp = inp.to(device)
            tgt = tgt.to(device)
            pred = model(inp)
            batch_val = criterion(pred, tgt).item()
            val_sum += batch_val
            val_bar.set_postfix({'val_loss': f'{batch_val:.5f}'})

    avg_val = val_sum / len(val_loader)

    train_losses.append(avg_train)
    val_losses.append(avg_val)

    scheduler.step(avg_val)
    current_lr = optimizer.param_groups[0]['lr']

    # Track and save best model so far (always up-to-date)
    if avg_val < best_val:
        best_val = avg_val
        torch.save(
            {
                'epoch': epoch,
                'val_loss': avg_val,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_losses': train_losses,
                'val_losses': val_losses,
                'mode': MODE,
            },
            best_path,
        )

    # Every 5 epochs: save periodic checkpoint and keep only 3 latest
    if epoch % checkpoint_every == 0:
        ckpt_path = MODELS_DIR / f'last_run_epoch_{epoch:02d}.pth'
        torch.save(
            {
                'epoch': epoch,
                'val_loss': avg_val,
                'best_val_so_far': best_val,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_losses': train_losses,
                'val_losses': val_losses,
                'mode': MODE,
            },
            ckpt_path,
        )

        recent_checkpoints.append(ckpt_path)
        if len(recent_checkpoints) > 3:
            oldest = recent_checkpoints.pop(0)
            if oldest.exists():
                oldest.unlink()

    # Early stopping
    if avg_val < best_for_patience:
        best_for_patience = avg_val
        early_stop_counter = 0
    else:
        early_stop_counter += 1

    print(
        f'Epoch {epoch:02d}/{cfg["epochs"]} | '
        f'Train: {avg_train:.6f} | Val: {avg_val:.6f} | '
        f'Best Val: {best_val:.6f} | LR: {current_lr:.2e} | '
        f'Patience: {early_stop_counter}/{patience}'
    )

    if early_stop_counter >= patience:
        print(f'Early stopping triggered at epoch {epoch}.')
        break

plt.figure(figsize=(8, 4))
plt.plot(train_losses, marker='o', label='Train Loss')
plt.plot(val_losses, marker='s', label='Val Loss')
plt.title(f'Pre-training Curves ({MODE})')
plt.xlabel('Epoch')
plt.ylabel('MSE')
plt.grid(alpha=0.3)
plt.legend()
plt.tight_layout()
plt.savefig(RESULTS_DIR / f'pretrain_curves_{MODE.lower()}.png', dpi=150)
plt.show()

print('Best model so far:', best_path)
print('Kept latest periodic checkpoints (max 3):')
for p in recent_checkpoints:
    print(' -', p)

In [None]:
# =============================
# Evaluation: Pre-trained vs Fine-tuned
# =============================
model.load_state_dict(torch.load(best_path, map_location=device))
model.eval()

window = bartlett_window(128, device=device)

if MODE == 'TEST':
    eval_indices = list(range(min(cfg['eval_samples'], len(val_subset))))
else:
    fraction = cfg.get('eval_fraction', 0.25)
    max_eval = cfg.get('max_eval_samples', len(val_subset))
    n_eval = min(int(len(val_subset) * fraction), max_eval)
    n_eval = max(n_eval, 1)
    eval_indices = list(range(n_eval))

rows = []
viz_cap = min(8, len(eval_indices))
fig, axes = plt.subplots(viz_cap, 4, figsize=(14, 3 * viz_cap))
if viz_cap == 1:
    axes = np.expand_dims(axes, axis=0)

ft_loss_traces = []

for r, idx in enumerate(tqdm(eval_indices, desc='Fine-tune eval')):
    measured_intensity, target = val_subset[idx]
    measured_intensity = measured_intensity.unsqueeze(0).to(device)
    target = target.unsqueeze(0).to(device)

    with torch.no_grad():
        pretrained_pred = model(measured_intensity)

    mse_pre = nn.functional.mse_loss(pretrained_pred, target).item()

    finetuned_pred, ft_losses = finetune_test_sample(
        model=model,
        measured_intensity=measured_intensity,
        window=window,
        iterations=cfg['finetune_iterations'],
        lr=cfg['lr_finetune'],
    )

    mse_ft = nn.functional.mse_loss(finetuned_pred, target).item()
    improvement = (mse_pre - mse_ft) / (mse_pre + 1e-12) * 100.0

    ft_loss_traces.append({'sample_id': idx, 'losses': ft_losses})

    rows.append({
        'sample_id': idx,
        'mse_pretrained': mse_pre,
        'mse_finetuned': mse_ft,
        'improvement_percent': improvement,
    })

    if r < viz_cap:
        axes[r, 0].imshow(measured_intensity[0, 0].detach().cpu().numpy(), cmap='viridis')
        axes[r, 0].set_title(f'Sample {idx} Input')
        axes[r, 0].axis('off')

        axes[r, 1].imshow(target[0, 0].detach().cpu().numpy(), cmap='gray')
        axes[r, 1].set_title('Ground Truth')
        axes[r, 1].axis('off')

        axes[r, 2].imshow(pretrained_pred[0, 0].detach().cpu().numpy(), cmap='gray')
        axes[r, 2].set_title(f'Pre MSE={mse_pre:.4f}')
        axes[r, 2].axis('off')

        axes[r, 3].imshow(finetuned_pred[0, 0].detach().cpu().numpy(), cmap='gray')
        axes[r, 3].set_title(f'FT MSE={mse_ft:.4f}')
        axes[r, 3].axis('off')

metrics_df = pd.DataFrame(rows)
display(metrics_df.head(20))

avg_pre = metrics_df['mse_pretrained'].mean()
avg_ft = metrics_df['mse_finetuned'].mean()
avg_imp = metrics_df['improvement_percent'].mean()

print(f'Average MSE (Pre-trained): {avg_pre:.6f}')
print(f'Average MSE (Fine-tuned):  {avg_ft:.6f}')
print(f'Average Improvement (%):   {avg_imp:.2f}')

plt.tight_layout()
plt.savefig(RESULTS_DIR / f'eval_grid_{MODE.lower()}.png', dpi=150)
plt.show()

if len(ft_loss_traces) > 0:
    plt.figure(figsize=(7, 4))
    for i, trace in enumerate(ft_loss_traces[:3]):
        plt.plot(trace['losses'], label=f"Sample {trace['sample_id']}")
    plt.title('Fine-tuning Loss Traces')
    plt.xlabel('Iteration')
    plt.ylabel('Physics Loss (MSE)')
    plt.grid(alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.savefig(RESULTS_DIR / f'finetune_losses_{MODE.lower()}.png', dpi=150)
    plt.show()

metrics_df.to_csv(RESULTS_DIR / f'metrics_{MODE.lower()}.csv', index=False)

In [None]:
# =============================
# Export results and models as ZIP
# =============================
bundle_root = BASE_DIR / 'FPR_Results_Bundle'
if bundle_root.exists():
    shutil.rmtree(bundle_root)
bundle_root.mkdir(parents=True, exist_ok=True)

bundle_results = bundle_root / 'results'
bundle_models = bundle_root / 'models'

if RESULTS_DIR.exists():
    shutil.copytree(RESULTS_DIR, bundle_results, dirs_exist_ok=True)
if MODELS_DIR.exists():
    shutil.copytree(MODELS_DIR, bundle_models, dirs_exist_ok=True)

zip_base = BASE_DIR / 'FPR_Results'
zip_file = shutil.make_archive(str(zip_base), 'zip', root_dir=bundle_root)
print('Created zip:', zip_file)

display(FileLink('FPR_Results.zip'))