# **Local Excitation Network for Restoring a JPEG-Comp**

### **1) Environment check / input paths**

In [1]:
import os
import random
import math
from PIL import Image
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as TF

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


Device: cpu


### **2) Install / import libraries**

In [2]:
# Correct DIV2K HR directory (your screenshot confirmed this path)
HR_DIR = "/kaggle/input/div2k-dataset/DIV2K_train_HR/DIV2K_train_HR"

# Show some images to confirm
print("Total HR images:", len(os.listdir(HR_DIR)))
print("Example files:", os.listdir(HR_DIR)[:10])


Total HR images: 800
Example files: ['0566.png', '0115.png', '0050.png', '0501.png', '0263.png', '0133.png', '0563.png', '0693.png', '0006.png', '0516.png']


### **3) Dataset class (on-the-fly JPEG compress & patch extraction)**

In [3]:
class DIV2KJPEGDataset(Dataset):
    def __init__(self, hr_dir, quality=10, patch_size=80, stride=79, augment=True, max_patches_per_image=40):
        self.hr_paths = sorted([
            os.path.join(hr_dir, f) for f in os.listdir(hr_dir)
            if f.lower().endswith(('.png','.jpg','.jpeg'))
        ])
        
        self.quality = quality
        self.patch_size = patch_size
        self.stride = stride
        self.augment = augment
        self.max_patches_per_image = max_patches_per_image
        
        self.index = []
        for i, p in enumerate(self.hr_paths):
            img = Image.open(p)
            w, h = img.size
            coords = []
            for x in range(0, max(1, w - patch_size), stride):
                for y in range(0, max(1, h - patch_size), stride):
                    coords.append((x, y))
            if len(coords) > max_patches_per_image:
                coords = random.sample(coords, max_patches_per_image)
            for (x, y) in coords:
                self.index.append((i, x, y))
        
        print(f"Dataset created: images={len(self.hr_paths)} patches={len(self.index)}")

    def __len__(self):
        return len(self.index)

    def __getitem__(self, idx):
        img_idx, x, y = self.index[idx]
        path = self.hr_paths[img_idx]
        img = Image.open(path).convert('RGB')

        patch = img.crop((x, y, x+self.patch_size, y+self.patch_size))

        # JPEG compression
        from io import BytesIO
        buf = BytesIO()
        patch.save(buf, format='JPEG', quality=self.quality)
        buf.seek(0)
        jpeg_patch = Image.open(buf).convert('RGB')

        t_gt = TF.to_tensor(patch)
        t_jpeg = TF.to_tensor(jpeg_patch)

        if self.augment:
            if random.random()<0.5:
                t_gt = TF.hflip(t_gt); t_jpeg = TF.hflip(t_jpeg)
            if random.random()<0.5:
                t_gt = TF.vflip(t_gt); t_jpeg = TF.vflip(t_jpeg)

        return t_jpeg, t_gt


### **4) LEJR Compact model (PyTorch)**

In [4]:
def depthwise_conv(in_channels):
    return nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, groups=in_channels)

class LocalExcitationBlock(nn.Module):
    def __init__(self, ch=64):
        super().__init__()
        self.conv = nn.Conv2d(ch, ch, 3, padding=1)
        self.prelu = nn.PReLU()
        self.dw = depthwise_conv(ch)
        self.pw = nn.Conv2d(ch, ch, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        res = self.prelu(self.conv(x))
        w = self.sigmoid(self.pw(self.dw(res)))
        return res * w

class LEJR_Compact(nn.Module):
    def __init__(self, blocks=10, ch=64):
        super().__init__()
        self.entry = nn.Conv2d(3, ch, 3, padding=1)
        self.blocks = nn.ModuleList([LocalExcitationBlock(ch) for _ in range(blocks)])
        self.exit = nn.Conv2d(ch, 3, 3, padding=1)

    def forward(self, x):
        out = self.entry(x)
        for b in self.blocks:
            out = b(out)
        return self.exit(out) + x


### **5) Training loop (L1 then fine-tune L2). Small run configuration**

In [5]:
train_dataset = DIV2KJPEGDataset(
    HR_DIR,
    quality=10,
    patch_size=80,
    stride=79,
    augment=True,
    max_patches_per_image=40
)

train_loader = DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)


Dataset created: images=800 patches=32000


### **6) Test on LIVE1 and compute PSNR / save outputs**

In [6]:
model = LEJR_Compact().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.L1Loss()

epochs = 5  # increase to 40–100 for full training
for epoch in range(epochs):
    model.train()
    for jpeg, gt in tqdm(train_loader):
        jpeg = jpeg.to(device)
        gt = gt.to(device)

        optimizer.zero_grad()
        out = model(jpeg)
        loss = criterion(out, gt)
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}/{epochs} | Loss: {loss.item():.6f}")


100%|██████████| 2000/2000 [1:35:39<00:00,  2.87s/it]


Epoch 1/5 | Loss: 0.035630


100%|██████████| 2000/2000 [1:34:00<00:00,  2.82s/it]


Epoch 2/5 | Loss: 0.035547


100%|██████████| 2000/2000 [1:35:00<00:00,  2.85s/it]


Epoch 3/5 | Loss: 0.030322


100%|██████████| 2000/2000 [1:34:40<00:00,  2.84s/it]


Epoch 4/5 | Loss: 0.025135


100%|██████████| 2000/2000 [1:34:01<00:00,  2.82s/it]

Epoch 5/5 | Loss: 0.030391





**7) Tips & Next steps**

If dataset paths differ, update detection in cell 1 and variables HR_DIR / TEST_DIR.

For full reproduction (paper):

Train longer (paper: many iterations; here we used a small number for demo).

Use full DIV2K (800 imgs), larger max_patches_per_image, more epochs.

Use the baseline LEJR (with down/up scaling & recursion depth) instead of compact for SOTA.

To use GAN (LEJR_GAN): add a discriminator model, VGG perceptual loss, alternate training. Ask me and I'll add the GAN cells.

To download outputs from Kaggle: go to Output panel or click the files in the left sidebar (/kaggle/working/restored).

In [7]:
torch.save(model.state_dict(), "/kaggle/working/lejr_compact.pth")
print("Model saved!")


Model saved!
