In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import cv2
import os
from tqdm import tqdm
import matplotlib.pyplot as plt

In [2]:
# ==============================================
# 1. Poisson-aware Data Loading & Preprocessing
# ==============================================

class SPADDataset(Dataset):
    def __init__(self, img_dir, depth_dir, frame_count=60, transform=None):
        self.img_dir = img_dir
        self.depth_dir = depth_dir
        self.transform = transform
        self.frame_count = frame_count
        self.img_files = sorted([f for f in os.listdir(img_dir) if f.endswith('.png')])
        
    def __len__(self):
        return len(self.img_files)
    
    def poisson_probability(self, binary_stack):
        """Convert binary frames to Poisson probability maps"""
        p_observed = binary_stack.mean(axis=0, keepdims=True)  # Temporal mean
        lambda_est = -torch.log(1 - p_observed + 1e-6)  # MLE for Poisson rate
        return lambda_est
    
    def __getitem__(self, idx):
        # Load binary frames (simulated SPAD data)
        img_path = os.path.join(self.img_dir, self.img_files[idx])
        binary_frames = torch.from_numpy(cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)).float() / 255.0
        binary_frames = binary_frames.unsqueeze(0)  # (1, H, W)
        
        # Convert to Poisson probability map
        input_img = self.poisson_probability(binary_frames)  # (1, H, W)
        
        # Load depth map
        depth_path = os.path.join(self.depth_dir, self.img_files[idx])
        depth_map = torch.from_numpy(cv2.imread(depth_path, cv2.IMREAD_GRAYSCALE)).float() / 255.0
        depth_map = depth_map.unsqueeze(0)  # (1, H, W)
        
        if self.transform:
            input_img = self.transform(input_img)
            depth_map = self.transform(depth_map)
            
        return input_img, depth_map


In [3]:
# ==============================================
# 2. Attention-Enhanced U-Net Architecture
# ==============================================

class CBAM(nn.Module):
    """Convolutional Block Attention Module"""
    def __init__(self, channels, reduction=8):
        super().__init__()
        self.channel_attention = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels // reduction, 1),
            nn.ReLU(),
            nn.Conv2d(channels // reduction, channels, 1),
            nn.Sigmoid()
        )
        self.spatial_attention = nn.Sequential(
            nn.Conv2d(channels, 1, 7, padding=3),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        channel_att = self.channel_attention(x)
        x_channel = x * channel_att
        
        spatial_att = self.spatial_attention(x_channel)
        x_out = x_channel * spatial_att
        return x_out

class UNetBlock(nn.Module):
    def __init__(self, in_ch, out_ch, use_attn=False):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU()
        )
        self.attn = CBAM(out_ch) if use_attn else nn.Identity()
        
    def forward(self, x):
        x = self.conv(x)
        x = self.attn(x)
        return x

class SPAD_UNet(nn.Module):
    def __init__(self, in_ch=1, out_ch=1, base_ch=32):
        super().__init__()
        # Encoder (Downsampling)
        self.enc1 = UNetBlock(in_ch, base_ch, use_attn=True)
        self.enc2 = UNetBlock(base_ch, base_ch*2)
        self.enc3 = UNetBlock(base_ch*2, base_ch*4)
        self.enc4 = UNetBlock(base_ch*4, base_ch*8)
        self.pool = nn.MaxPool2d(2)
        
        # Bottleneck
        self.bottleneck = UNetBlock(base_ch*8, base_ch*16)
        
        # Decoder (Upsampling)
        self.upconv4 = nn.ConvTranspose2d(base_ch*16, base_ch*8, 2, stride=2)
        self.dec4 = UNetBlock(base_ch*16, base_ch*8)
        
        self.upconv3 = nn.ConvTranspose2d(base_ch*8, base_ch*4, 2, stride=2)
        self.dec3 = UNetBlock(base_ch*8, base_ch*4)
        
        self.upconv2 = nn.ConvTranspose2d(base_ch*4, base_ch*2, 2, stride=2)
        self.dec2 = UNetBlock(base_ch*4, base_ch*2)
        
        self.upconv1 = nn.ConvTranspose2d(base_ch*2, base_ch, 2, stride=2)
        self.dec1 = UNetBlock(base_ch*2, base_ch)
        
        # Output
        self.outconv = nn.Conv2d(base_ch, out_ch, 1)
        
    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))
        
        # Bottleneck
        bn = self.bottleneck(self.pool(e4))
        
        # Decoder with skip connections
        d4 = self.upconv4(bn)
        d4 = torch.cat((e4, d4), dim=1)
        d4 = self.dec4(d4)
        
        d3 = self.upconv3(d4)
        d3 = torch.cat((e3, d3), dim=1)
        d3 = self.dec3(d3)
        
        d2 = self.upconv2(d3)
        d2 = torch.cat((e2, d2), dim=1)
        d2 = self.dec2(d2)
        
        d1 = self.upconv1(d2)
        d1 = torch.cat((e1, d1), dim=1)
        d1 = self.dec1(d1)
        
        # Output (linear activation for depth)
        return self.outconv(d1)

In [4]:
# ==============================================
# 3. Hybrid Loss Function
# ==============================================

class DepthLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.ssim = SSIM(window_size=11)
        
    def forward(self, pred, target):
        # RMSE
        rmse = torch.sqrt(F.mse_loss(pred, target))
        
        # MAE
        mae = F.l1_loss(pred, target)
        
        # SSIM (structural similarity)
        ssim_loss = 1 - self.ssim(pred, target)
        
        return rmse + 0.5*mae + 0.1*ssim_loss

In [5]:
# ==============================================
# 4. Training Pipeline
# ==============================================

def train_model():
    # Config
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    batch_size = 8
    epochs = 50
    lr = 1e-4
    
    # Data
    train_dataset = SPADDataset(img_dir='train/spad', depth_dir='train/depth')
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    # Model
    model = SPAD_UNet().to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)
    criterion = DepthLoss()
    
    # Training loop
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0
        
        for inputs, targets in tqdm(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
        
        avg_loss = epoch_loss / len(train_loader)
        scheduler.step(avg_loss)
        
        print(f'Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}')
        
        # Validation and model saving logic here
        # ...

    torch.save(model.state_dict(), 'spad_depth_model.pth')

In [None]:
# ==============================================
# 5. Inference & Submission
# ==============================================

def predict_depth(model, img_path):
    """Generate depth map for a single SPAD image"""
    model.eval()
    with torch.no_grad():
        # Load and preprocess image
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        img_tensor = torch.from_numpy(img).float().unsqueeze(0).unsqueeze(0) / 255.0
        
        # Convert to Poisson probability
        lambda_est = -torch.log(1 - img_tensor.mean(dim=1, keepdim=True) + 1e-6)
        
        # Predict
        depth = model(lambda_est.to(device))
        depth = depth.squeeze().cpu().numpy()
        
    return depth

def create_submission(model, test_dir, output_csv):
    """Generate CSV for Kaggle submission"""
    model.eval()
    test_files = sorted([f for f in os.listdir(test_dir) if f.endswith('.png')])
    
    with open(output_csv, 'w') as f:
        f.write('id,depth\n')
        
        for img_file in tqdm(test_files):
            img_path = os.path.join(test_dir, img_file)
            depth_map = predict_depth(model, img_path)
            
            # Flatten and convert to CSV format
            depth_flat = depth_map.flatten()
            for i, val in enumerate(depth_flat):
                f.write(f'{img_file[:-4]}_{i},{val:.6f}\n')

# Initialize and train
if __name__ == '__main__':
    train_model()