In [None]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.
import kagglehub
anettvarghese_highrev_path = kagglehub.dataset_download('anettvarghese/highrev')
lei0331_highrev_testset_path = kagglehub.dataset_download('lei0331/highrev-testset')

print('Data source import complete.')


In [None]:
# Install required packages
!pip install torch torchvision scikit-image
!pip install --upgrade matplotlib

# Set environment variable to reduce memory fragmentation
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Import libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from skimage.metrics import peak_signal_noise_ratio as compute_psnr
from skimage.metrics import structural_similarity as compute_ssim
import cv2
import matplotlib.pyplot as plt
import gc
import time
from torch.optim.lr_scheduler import CosineAnnealingLR

torch.cuda.empty_cache()
gc.collect()

def display_gpu_memory():
    if torch.cuda.is_available():
        print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GiB")
        print(f"GPU memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GiB")

# Function to convert raw events to voxel grid
def create_voxel_grid(x, y, timestamps, polarity, num_bins=5, height=128, width=128):
    grid = np.zeros((num_bins, height, width), dtype=np.float32)
    if timestamps.size > 0:
        t_min, t_max = timestamps.min(), timestamps.max()
        t_normalized = (timestamps - t_min) / (t_max - t_min) * (num_bins - 1) if t_max > t_min else np.zeros_like(timestamps)

        for i in range(len(x)):
            xi, yi = int(x[i]), int(y[i])
            ti = int(t_normalized[i])
            if 0 <= xi < width and 0 <= yi < height and 0 <= ti < num_bins:
                grid[ti, yi, xi] += polarity[i]
    return grid

# Define ConvBlock with Dropout
class ConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_rate=0.1):
        super(ConvLayer, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout2d(dropout_rate)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.dropout(x)
        return x

class AttentionModule(nn.Module):
    def __init__(self, channels, num_heads=4):
        super(AttentionModule, self).__init__()
        self.num_heads = num_heads
        self.head_dim = channels // num_heads
        assert self.head_dim * num_heads == channels, "channels must be divisible by num_heads"

        self.query_conv = nn.Conv2d(channels, channels, kernel_size=1)
        self.key_conv = nn.Conv2d(channels, channels, kernel_size=1)
        self.value_conv = nn.Conv2d(channels, channels, kernel_size=1)
        self.softmax = nn.Softmax(dim=-1)
        self.out_conv = nn.Conv2d(channels, channels, kernel_size=1)

    def forward(self, img_features, event_features):
        batch, channels, height, width = img_features.size()

        query = self.query_conv(img_features).view(batch, self.num_heads, self.head_dim, height * width)
        key = self.key_conv(event_features).view(batch, self.num_heads, self.head_dim, height * width)
        value = self.value_conv(event_features).view(batch, self.num_heads, self.head_dim, height * width)

        query = query.permute(0, 1, 3, 2)  # (batch, heads, H*W, head_dim)
        key = key  # (batch, heads, head_dim, H*W)
        value = value  # (batch, heads, head_dim, H*W)

        attention = self.softmax(torch.matmul(query, key) / (self.head_dim ** 0.5))  # (batch, heads, H*W, H*W)
        out = torch.matmul(attention, value.permute(0, 1, 3, 2))  # (batch, heads, H*W, head_dim)
        out = out.permute(0, 1, 3, 2).contiguous().view(batch, channels, height, width)

        out = self.out_conv(out)
        return out + img_features

class DeblurNet(nn.Module):
    def __init__(self, in_channels=3, event_channels=5, out_channels=3):
        super(DeblurNet, self).__init__()
        # Encoder
        self.img_conv1 = ConvLayer(in_channels, 128)
        self.evt_conv1 = ConvLayer(event_channels, 128)
        self.pool = nn.MaxPool2d(2, 2)

        self.img_conv2 = ConvLayer(128, 256)
        self.evt_conv2 = ConvLayer(128, 256)

        # Attention
        self.attention1 = AttentionModule(128, num_heads=4)
        self.attention2 = AttentionModule(256, num_heads=4)

        # Residual Blocks
        self.res_blocks1 = nn.Sequential(
            ConvLayer(128, 128),
            ConvLayer(128, 128),
            ConvLayer(128, 128),
        )
        self.res_blocks2 = nn.Sequential(
            ConvLayer(256, 256),
            ConvLayer(256, 256),
            ConvLayer(256, 256),
        )

        # Decoder
        self.upconv = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.final_conv = nn.Conv2d(128, out_channels, kernel_size=3, padding=1)

    def forward(self, blur_img, event_data):
        # Encoder
        img1 = self.img_conv1(blur_img)
        evt1 = self.evt_conv1(event_data)
        fused1 = self.attention1(img1, evt1)
        fused1 = self.res_blocks1(fused1)

        img2 = self.pool(fused1)
        evt2 = self.pool(evt1)
        img2 = self.img_conv2(img2)
        evt2 = self.evt_conv2(evt2)
        fused2 = self.attention2(img2, evt2)
        fused2 = self.res_blocks2(fused2)

        # Decoder
        up = self.upconv(fused2)
        up = up + fused1  # Skip connection
        deblurred = self.final_conv(up)
        return deblurred

# Custom Dataset for HighREV (Training/Validation)
class HighREVDataset(Dataset):
    def __init__(self, data_dir, phase='train', target_size=(128, 128), num_bins=5):
        self.data_dir = os.path.join(data_dir, phase)
        self.blur_dir = os.path.join(self.data_dir, 'blur')
        self.sharp_dir = os.path.join(self.data_dir, 'sharp')
        self.event_dir = os.path.join(self.data_dir, 'event')
        self.target_size = target_size
        self.num_bins = num_bins

        self.blur_files = sorted(os.listdir(self.blur_dir))
        self.event_files = sorted(os.listdir(self.event_dir))

        self.file_map = {}
        for blur_file in self.blur_files:
            blur_base = os.path.splitext(blur_file)[0]
            for event_file in self.event_files:
                if blur_base in event_file:
                    self.file_map[blur_file] = event_file
                    break

    def __len__(self):
        return len(self.blur_files)

    def __getitem__(self, idx):
        blur_file = self.blur_files[idx]
        blur_path = os.path.join(self.blur_dir, blur_file)
        blur_img = cv2.imread(blur_path)
        if blur_img is None:
            raise FileNotFoundError(f"Could not load blur image: {blur_path}")
        original_height, original_width = blur_img.shape[:2]
        blur_img = cv2.cvtColor(blur_img, cv2.COLOR_BGR2RGB)
        blur_img = cv2.resize(blur_img, (self.target_size[1], self.target_size[0]))
        blur_img = blur_img.transpose(2, 0, 1) / 255.0
        blur_img = torch.FloatTensor(blur_img)

        sharp_path = os.path.join(self.sharp_dir, blur_file)
        sharp_img = cv2.imread(sharp_path)
        if sharp_img is None:
            raise FileNotFoundError(f"Could not load sharp image: {sharp_path}")
        sharp_img = cv2.cvtColor(sharp_img, cv2.COLOR_BGR2RGB)
        sharp_img = cv2.resize(sharp_img, (self.target_size[1], self.target_size[0]))
        sharp_img = sharp_img.transpose(2, 0, 1) / 255.0
        sharp_img = torch.FloatTensor(sharp_img)

        event_file = self.file_map[blur_file]
        event_path = os.path.join(self.event_dir, event_file)
        npz_data = np.load(event_path)
        x = npz_data['x']
        y = npz_data['y']
        timestamps = npz_data['timestamp']
        polarity = npz_data['polarity']

        original_height, original_width = cv2.imread(blur_path).shape[:2]
        x = (x / original_width) * self.target_size[1]
        y = (y / original_height) * self.target_size[0]

        event_data = create_voxel_grid(
            x, y, timestamps, polarity, self.num_bins, self.target_size[0], self.target_size[1]
        )
        if idx == 0:
            print(f"Event data shape: {event_data.shape}")
            print(f"Blurry image shape: {blur_img.shape}")
        event_data = torch.FloatTensor(event_data)

        return blur_img, event_data, sharp_img, blur_file, (int(original_width), int(original_height))

# Custom Dataset for Test (No Sharp Images)
class HighREVTestDataset(Dataset):
    def __init__(self, data_dir, target_size=(128, 128), num_bins=5):
        self.data_dir = data_dir
        self.blur_dir = os.path.join(self.data_dir, 'blur')
        self.event_dir = os.path.join(self.data_dir, 'event')
        self.target_size = target_size
        self.num_bins = num_bins

        self.blur_files = sorted(os.listdir(self.blur_dir))
        self.event_files = sorted(os.listdir(self.event_dir))

        self.file_map = {}
        for blur_file in self.blur_files:
            blur_base = os.path.splitext(blur_file)[0]
            for event_file in self.event_files:
                if blur_base in event_file:
                    self.file_map[blur_file] = event_file
                    break

    def __len__(self):
        return len(self.blur_files)

    def __getitem__(self, idx):
        blur_file = self.blur_files[idx]
        blur_path = os.path.join(self.blur_dir, blur_file)
        blur_img = cv2.imread(blur_path)
        if blur_img is None:
            raise FileNotFoundError(f"Could not load blur image: {blur_path}")
        original_height, original_width = blur_img.shape[:2]
        blur_img = cv2.cvtColor(blur_img, cv2.COLOR_BGR2RGB)
        blur_img = cv2.resize(blur_img, (self.target_size[1], self.target_size[0]))
        blur_img = blur_img.transpose(2, 0, 1) / 255.0
        blur_img = torch.FloatTensor(blur_img)

        event_file = self.file_map[blur_file]
        event_path = os.path.join(self.event_dir, event_file)
        npz_data = np.load(event_path)
        x = npz_data['x']
        y = npz_data['y']
        timestamps = npz_data['timestamp']
        polarity = npz_data['polarity']

        original_height, original_width = cv2.imread(blur_path).shape[:2]
        x = (x / original_width) * self.target_size[1]
        y = (y / original_height) * self.target_size[0]

        event_data = create_voxel_grid(
            x, y, timestamps, polarity, self.num_bins, self.target_size[0], self.target_size[1]
        )
        if idx == 0:
            print(f"Test event data shape: {event_data.shape}")
            print(f"Test blurry image shape: {blur_img.shape}")
        event_data = torch.FloatTensor(event_data)

        original_size = (int(original_width), int(original_height))
        print(f"Dataset item {idx}, blur_file: {blur_file}, original_size: {original_size}")
        return blur_img, event_data, blur_file, original_size

# Custom collate function to preserve original_size as a list of tuples
def custom_collate_fn(batch):
    blurry, events, blur_filename, original_size = zip(*batch)
    return (
        torch.stack(blurry),
        torch.stack(events),
        list(blur_filename),
        list(original_size)  # Preserve original_size as a list of tuples
    )

# Enable mixed precision training
from torch.cuda.amp import GradScaler, autocast

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DeblurNet().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-5)
scaler = GradScaler()
scheduler = CosineAnnealingLR(optimizer, T_max=12)

train_dataset = HighREVDataset('/kaggle/input/highrev/HighREV', phase='train', target_size=(128, 128))
val_dataset = HighREVDataset('/kaggle/input/highrev/HighREV', phase='val', target_size=(128, 128))
test_dataset = HighREVTestDataset('/kaggle/input/highrev-testset/HighREV_test/HighREV_test', target_size=(128, 128))
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=2, pin_memory=True, collate_fn=custom_collate_fn)

num_epochs = 6
accumulation_steps = 2
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    optimizer.zero_grad()
    for i, (blurry, events, sharp, _, _) in enumerate(train_loader):
        blurry, events, sharp = blurry.to(device), events.to(device), sharp.to(device)

        with torch.amp.autocast('cuda'):
            outputs = model(blurry, events)
            loss = criterion(outputs, sharp) / accumulation_steps

        scaler.scale(loss).backward()
        if (i + 1) % accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)
            torch.cuda.empty_cache()
            gc.collect()

        running_loss += loss.item() * accumulation_steps
        if i % 100 == 99:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {running_loss/100:.4f}')
            running_loss = 0.0
        torch.cuda.empty_cache()
        gc.collect()

    scheduler.step()
    print(f"Epoch {epoch+1}, Learning Rate: {scheduler.get_last_lr()[0]:.6f}")
    display_gpu_memory()

# Save checkpoint
checkpoint = {
    'epoch': num_epochs,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    'scaler_state_dict': scaler.state_dict(),
}
checkpoint_path = '/kaggle/working/checkpoint_epoch_6.pth'
torch.save(checkpoint, checkpoint_path)
print(f"Checkpoint saved at: {checkpoint_path}")

# Evaluation on validation set for metrics
model.eval()
psnr_scores = []
ssim_scores = []
val_time = 0.0
print(f"Evaluating {len(val_loader)} validation images for metrics...")

with torch.no_grad():
    for i, (blurry, events, sharp, blur_filename, _) in enumerate(val_loader):
        start_time = time.time()
        blurry, events, sharp = blurry.to(device), events.to(device), sharp.to(device)
        with torch.amp.autocast('cuda'):
            deblurred = model(blurry, events)

        deblurred_np = deblurred.cpu().numpy()
        sharp_np = sharp.cpu().numpy()
        blurry_np = blurry.cpu().numpy()

        for j in range(deblurred_np.shape[0]):
            deblurred_img = deblurred_np[j].transpose(1, 2, 0).astype(np.float32)
            sharp_img = sharp_np[j].transpose(1, 2, 0).astype(np.float32)

            deblurred_img = np.clip(deblurred_img, 0, 1)
            sharp_img = np.clip(sharp_img, 0, 1)

            if deblurred_img.shape[0] < 7 or deblurred_img.shape[1] < 7:
                print(f"Warning: Validation image {blur_filename[j]} too small for SSIM")
                psnr_score = compute_psnr(sharp_img, deblurred_img, data_range=1.0)
                ssim_score = 0
            else:
                psnr_score = compute_psnr(sharp_img, deblurred_img, data_range=1.0)
                ssim_score = compute_ssim(sharp_img, deblurred_img, data_range=1.0, multichannel=True, channel_axis=2)
            psnr_scores.append(psnr_score)
            ssim_scores.append(ssim_score)

        end_time = time.time()
        val_time += end_time - start_time

        del deblurred, blurry, events, sharp, deblurred_np, sharp_np, blurry_np
        torch.cuda.empty_cache()
        gc.collect()

# Process test set and save outputs
total_time = 0.0
output_dir = '/kaggle/working/test_output_epoch_6'
os.makedirs(output_dir, exist_ok=True)
print(f"Processing {len(test_loader)} test images...")

with torch.no_grad():
    for i, (blurry, events, blur_filename, original_size) in enumerate(test_loader):
        start_time = time.time()
        blurry, events = blurry.to(device), events.to(device)
        with torch.amp.autocast('cuda'):
            deblurred = model(blurry, events)

        deblurred_np = deblurred.cpu().numpy()
        blurry_np = blurry.cpu().numpy()

        for j in range(deblurred_np.shape[0]):
            deblurred_img = deblurred_np[j].transpose(1, 2, 0).astype(np.float32)
            print(f"Processing test image {blur_filename[j]}")
            print(f"deblurred_img shape: {deblurred_img.shape}, dtype: {deblurred_img.dtype}, min: {deblurred_img.min()}, max: {deblurred_img.max()}")

            if deblurred_img.size == 0 or deblurred_img.shape[0] == 0 or deblurred_img.shape[1] == 0:
                print(f"Warning: Invalid deblurred_img for {blur_filename[j]}, skipping")
                continue

            deblurred_img = np.clip(deblurred_img, 0, 1)

            # Debug and validate original_size
            print(f"Batch {i}, j={j}, original_size[{j}]: {original_size[j]}")
            if not isinstance(original_size[j], (tuple, list)) or len(original_size[j]) != 2:
                raise ValueError(f"Invalid original_size for {blur_filename[j]}: {original_size[j]}. Expected a 2-element tuple (width, height).")

            original_width = int(original_size[j][0])
            original_height = int(original_size[j][1])

            deblurred_resized = cv2.resize(deblurred_img, (original_width, original_height), interpolation=cv2.INTER_LINEAR)
            print(f"deblurred_resized shape: {deblurred_resized.shape}")

            blurred = cv2.GaussianBlur(deblurred_resized, (5, 5), 0)
            deblurred_resized = cv2.addWeighted(deblurred_resized, 1.5, blurred, -0.5, 0)

            deblurred_img = (deblurred_resized * 255).astype(np.uint8)
            print(f"deblurred_img shape before save: {deblurred_img.shape}")
            deblurred_img = cv2.cvtColor(deblurred_img, cv2.COLOR_RGB2BGR)
            output_path = os.path.join(output_dir, blur_filename[j])
            print(f"Saving deblurred test image to: {output_path}")
            success = cv2.imwrite(output_path, deblurred_img, [cv2.IMWRITE_PNG_COMPRESSION, 0])
            if not success:
                print(f"Failed to save {output_path}")
            else:
                print(f"Successfully saved {output_path}")

        end_time = time.time()
        processing_time = end_time - start_time
        total_time += processing_time
        print(f"Processing time for test batch {i+1}: {processing_time:.2f} seconds")

        if i == 0:
            try:
                plt.figure(figsize=(10, 5))
                plt.subplot(1, 2, 1)
                plt.title("Blurry Test Image")
                plt.imshow(blurry_np[0].transpose(1, 2, 0))
                plt.axis('off')
                plt.subplot(1, 2, 2)
                plt.title("Deblurred Test Image")
                plt.imshow(deblurred_resized)
                plt.axis('off')
                plt.show()
            except Exception as e:
                print(f"Error during visualization: {e}")

        del deblurred, blurry, events, deblurred_np, blurry_np, deblurred_resized, deblurred_img
        torch.cuda.empty_cache()
        gc.collect()
        display_gpu_memory()

# Summary
avg_psnr = np.mean(psnr_scores) if psnr_scores else 0.0
avg_ssim = np.mean(ssim_scores) if ssim_scores else 0.0
avg_time_per_image = total_time / (len(test_loader.dataset)) if len(test_loader.dataset) > 0 else 0.0
avg_val_time_per_image = val_time / len(val_loader.dataset) if len(val_loader.dataset) > 0 else 0.0
print(f"Average PSNR (Validation) after 6 epochs: {avg_psnr:.2f}")
print(f"Average SSIM (Validation) after 6 epochs: {avg_ssim:.4f}")
print(f"Average runtime per test image: {avg_time_per_image:.2f} seconds")

device_type = 0 if torch.cuda.is_available() and device.type == 'cuda' else 1
extra_data = 0

# Create readme.txt
readme_content = f"""runtime per image [s] : {avg_time_per_image:.2f}
CPU[1] / GPU[0] : {device_type}
Extra Data [1] / No Extra Data [0] : {extra_data}
Average PSNR (Validation) : {avg_psnr:.2f}
Average SSIM (Validation) : {avg_ssim:.4f}
Other description : Intermediate results after 6 epochs. Enhanced EFNet model with U-Net-like architecture and multi-head CrossModalAttention for event-based deblurring. Trained on HighREV training dataset with CosineAnnealingLR scheduling, lower learning rate (0.0005), and weight decay (1e-5). Batch size 1 with accumulation_steps=2 to simulate batch_size=2. Reduced num_bins to 5 for event data. Post-processing includes unsharp masking. Checkpoint saved at /kaggle/working/checkpoint_epoch_6.pth. Deblurred test images saved in /kaggle/working/test_output_epoch_6 as full-quality RGB PNGs at original resolution (1632x1224) with no compression.
"""
print(f"Deblurred test images saved in: {output_dir}")
print("Check the Kaggle output section under '/kaggle/working/test_output_epoch_6' for images and readme.txt")