# **Low-Light**

In [None]:
!pip install pandas openpyxl scikit-image pytorch-ssim

Collecting pytorch-ssim
  Downloading pytorch_ssim-0.1.tar.gz (1.4 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: pytorch-ssim
  Building wheel for pytorch-ssim (setup.py) ... [?25l[?25hdone
  Created wheel for pytorch-ssim: filename=pytorch_ssim-0.1-py3-none-any.whl size=2006 sha256=d82a67a2007bcf93f9c99038b4aa3b71bebf72cd362c5be81a3af504523dc007
  Stored in directory: /root/.cache/pip/wheels/58/68/a2/68a41e8268a076c128bbc3988d243187fa4681828e648bf1ca
Successfully built pytorch-ssim
Installing collected packages: pytorch-ssim
Successfully installed pytorch-ssim-0.1


In [None]:
import pytorch_ssim
import torch.nn as nn
class CGSformerLoss(nn.Module):
    def __init__(self, alpha=0.7, beta=0.3):
        super(CGSformerLoss, self).__init__()
        self.mse = nn.MSELoss()
        self.ssim = pytorch_ssim.SSIM(window_size=11)
        self.alpha = alpha
        self.beta = beta

    def forward(self, output, target):
        mse_loss = self.mse(output, target)
        ssim_loss = 1 - self.ssim(output, target)
        return self.alpha * mse_loss + self.beta * ssim_loss


In [None]:
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image
import os
import random

class LOLDataset(Dataset):
    def __init__(self, root_dir, transform=None,patch_size=128):
        self.low_light_dir = os.path.join(root_dir, 'low')
        self.high_light_dir = os.path.join(root_dir, 'high')
        self.low_light_images = sorted(os.listdir(self.low_light_dir))
        self.high_light_images = sorted(os.listdir(self.high_light_dir))
        self.transform = transform
        self.patch_size = patch_size

    def __len__(self):
        return len(self.low_light_images)

    def __getitem__(self, idx):
        low_image_path = os.path.join(self.low_light_dir, self.low_light_images[idx])
        high_image_path = os.path.join(self.high_light_dir, self.high_light_images[idx])

        low_img = Image.open(low_image_path).convert('RGB')
        high_img = Image.open(high_image_path).convert('RGB')

        # Random crop
        i, j, h, w = transforms.RandomCrop.get_params(low_img, output_size=(self.patch_size, self.patch_size))
        low_img = transforms.functional.crop(low_img, i, j, h, w)
        high_img = transforms.functional.crop(high_img, i, j, h, w)

        # Random flip
        if random.random() > 0.5:
            low_img = transforms.functional.hflip(low_img)
            high_img = transforms.functional.hflip(high_img)
        if random.random() > 0.5:
            low_img = transforms.functional.vflip(low_img)
            high_img = transforms.functional.vflip(high_img)

        if self.transform:
            low_img = self.transform(low_img)
            high_img = self.transform(high_img)

        return low_img, high_img

# Example transforms
train_transforms = transforms.Compose([
    transforms.ToTensor()
])

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CFS(nn.Module): #Cross Feature Scrambling
    def __init__(self, channels, threshold=0.5):
        super(CFS, self).__init__()
        self.threshold = threshold
        self.sigmoid = nn.Sigmoid()
        self.gn = nn.GroupNorm(1, channels)

    def forward(self, x):
        x_ln = self.gn(x)
        var = torch.var(x_ln, dim=[2,3], keepdim=True)
        importance = var / (torch.sum(var, dim=1, keepdim=True) + 1e-6)
        importance = self.sigmoid(importance)
        mask_info = (importance > self.threshold).float()
        mask_noninfo = (importance <= self.threshold).float()

        x_info = mask_info * x_ln
        x_noninfo = mask_noninfo * x_ln

        pooled = F.adaptive_avg_pool2d(x_info + x_noninfo, (1, 1))
        beta = self.sigmoid(pooled)

        out = beta * x_info + (1 - beta) * x_noninfo
        return out

class ASA(nn.Module): #Adaptive Shift Attention
    def __init__(self, channels, topk_ratio=0.5):
        super(ASA, self).__init__()
        self.topk_ratio = topk_ratio
        self.query_conv = nn.Conv2d(channels, channels, 1)
        self.key_conv = nn.Conv2d(channels, channels, 1)
        self.value_conv = nn.Conv2d(channels, channels, 1)
        self.scale = channels ** -0.5

    def forward(self, x):
        q = self.query_conv(x).flatten(2).transpose(1, 2)
        k = self.key_conv(x).flatten(2).transpose(1, 2)
        v = self.value_conv(x).flatten(2).transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        topk = int(attn.size(-1) * self.topk_ratio)
        topk_values, _ = torch.topk(attn, k=topk, dim=-1)
        threshold = topk_values[:, :, -1].unsqueeze(-1)
        mask = attn >= threshold
        attn = attn.masked_fill(~mask, float('-inf'))
        attn = F.softmax(attn, dim=-1)

        out = attn @ v
        out = out.transpose(1, 2).reshape(x.size())
        return out

class BGFF(nn.Module): #BIlateral Grid Feature Fusion
    def __init__(self, channels):
        super(BGFF, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, 1)
        self.conv_dw3x3 = nn.Conv2d(channels, channels, 3, padding=1, groups=channels)
        self.conv_dw7x7 = nn.Conv2d(channels, channels, 7, padding=3, groups=channels)
        self.conv2 = nn.Conv2d(channels, channels, 1)
        self.swish = lambda x: x * torch.sigmoid(x)

    def forward(self, x):
        out = self.conv1(x)
        path1 = self.swish(self.conv_dw3x3(out))
        path2 = self.swish(self.conv_dw7x7(out))
        out = path1 * path2
        out = self.conv2(out)
        return out + x

class CGSformerBlock(nn.Module):
    def __init__(self, channels):
        super(CGSformerBlock, self).__init__()
        self.cfs = CFS(channels)
        self.asa = ASA(channels)
        self.bgff = BGFF(channels)
        self.norm1 = nn.LayerNorm([channels, 128, 128])
        self.norm2 = nn.LayerNorm([channels, 128, 128])

    def forward(self, x):
        x_cfs = self.cfs(x)
        x = self.asa(self.norm1(x_cfs)) + x
        x = self.bgff(self.norm2(x)) + x
        return x

class SparseTransformer(nn.Module):
    def __init__(self, channels=64):
        super(SparseTransformer, self).__init__()
        self.encoder = nn.Conv2d(3, channels, 3, padding=1)

        self.block1 = CGSformerBlock(channels)
        self.block2 = CGSformerBlock(channels)
        self.block3 = CGSformerBlock(channels)
        self.block4 = CGSformerBlock(channels)
        self.block5 = CGSformerBlock(channels)
        self.block6 = CGSformerBlock(channels)
        self.block7 = CGSformerBlock(channels)

        self.decoder = nn.Conv2d(channels, 3, 3, padding=1)

    def forward(self, x):
        x = self.encoder(x)

        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        x = self.block6(x)
        x = self.block7(x)

        x = self.decoder(x)
        return x


In [None]:
import torch
import torch.nn.functional as F

def ssim(img1, img2, window_size=11):
    channel = img1.shape[1]
    window = torch.ones((channel, 1, window_size, window_size)).to(img1.device) / (window_size ** 2)

    mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
    mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)

    sigma1_sq = F.conv2d(img1 ** 2, window, padding=window_size // 2, groups=channel) - mu1 ** 2
    sigma2_sq = F.conv2d(img2 ** 2, window, padding=window_size // 2, groups=channel) - mu2 ** 2
    sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1 * mu2

    C1, C2 = 0.01**2, 0.03**2  # Stability constants
    ssim_map = ((2 * mu1 * mu2 + C1) * (2 * sigma12 + C2)) / ((mu1 ** 2 + mu2 ** 2 + C1) * (sigma1_sq + sigma2_sq + C2))

    return ssim_map.mean()



# **Testing**

In [None]:
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import numpy as np
from PIL import Image
import os

def pad_to_multiple(img, patch_size=128):
    """ Pad the image to multiple of patch size (no less than original size) """
    _, h, w = img.shape
    pad_h = (patch_size - h % patch_size) % patch_size
    pad_w = (patch_size - w % patch_size) % patch_size
    img = F.pad(img, (0, pad_w, 0, pad_h), mode='reflect')
    return img

def split_patches(img, patch_size=128):
    """ Split the image into non-overlapping patches """
    patches = []
    coords = []
    c, h, w = img.shape
    for i in range(0, h, patch_size):
        for j in range(0, w, patch_size):
            patch = img[:, i:i+patch_size, j:j+patch_size]
            patches.append(patch)
            coords.append((i, j))
    return patches, coords

def merge_patches(patches, coords, image_shape, patch_size=128):
    """ Merge patches back into full image """
    c, h, w = image_shape
    merged = torch.zeros((c, h, w)).to(patches[0].device)
    counter = torch.zeros((c, h, w)).to(patches[0].device)

    for patch, (i, j) in zip(patches, coords):
        merged[:, i:i+patch.shape[1], j:j+patch.shape[2]] += patch
        counter[:, i:i+patch.shape[1], j:j+patch.shape[2]] += 1

    counter[counter == 0] = 1
    merged = merged / counter
    return merged

def enhance_image(model, img_path, save_path, device, patch_size=128):
    """ Full enhancement pipeline """
    model.eval()

    # Load image
    img = Image.open(img_path).convert('RGB')
    img_tensor = TF.to_tensor(img).to(device)

    c, h, w = img_tensor.shape

    if h < patch_size or w < patch_size:
        # If image is smaller in any dimension, pad to at least 128
        img_tensor = pad_to_multiple(img_tensor, patch_size)
        with torch.no_grad():
            output = model(img_tensor.unsqueeze(0)).squeeze(0)
        output = output[:, :h, :w]  # Crop back to original size
    else:
        # Normal size or large image
        padded_img = pad_to_multiple(img_tensor, patch_size)
        c_pad, h_pad, w_pad = padded_img.shape

        patches, coords = split_patches(padded_img, patch_size)

        enhanced_patches = []
        with torch.no_grad():
            for patch in patches:
                out_patch = model(patch.unsqueeze(0)).squeeze(0)
                enhanced_patches.append(out_patch)

        merged = merge_patches(enhanced_patches, coords, (c_pad, h_pad, w_pad), patch_size)
        output = merged[:, :h, :w]  # Remove padding to original size

    output_img = TF.to_pil_image(torch.clamp(output, 0, 1).cpu())
    output_img.save(save_path)
    print(f"Saved enhanced image at {save_path}")


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


test_images_dir = "./drive/MyDrive/Major/ImageRestoration/low_input"
save_dir = "./drive/MyDrive/Major/ImageRestoration/m/Sparse_transform"
model = SparseTransformer().to(device)
model.load_state_dict(torch.load(save_dir+"/model_epoch_100.pth",map_location=device))
out_dir="./drive/MyDrive/Major/ImageRestoration/output_low"

for img_name in os.listdir(test_images_dir):
    img_path = os.path.join(test_images_dir, img_name)
    save_path = os.path.join(out_dir, img_name)
    enhance_image(model, img_path, save_path, device)

Saved enhanced image at ./drive/MyDrive/Major/ImageRestoration/output_low/669.png
Saved enhanced image at ./drive/MyDrive/Major/ImageRestoration/output_low/1.png
Saved enhanced image at ./drive/MyDrive/Major/ImageRestoration/output_low/778.png
Saved enhanced image at ./drive/MyDrive/Major/ImageRestoration/output_low/179.png


# **Training**

In [None]:
import torch
import torch.optim as optim
import pandas as pd
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import pytorch_ssim
from collections import OrderedDict
from tqdm import tqdm

# Load LOL Dataset
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])

num_epochs = 150
start_epoch = 100
batch_size = 1
patch_size = 128
initial_lr = 3e-4

save_dir = "./drive/MyDrive/Major/ImageRestoration/m/Sparse_transform"
log_dir = "./drive/MyDrive/Major/ImageRestoration/training_log"

train_dataset = LOLDataset(root_dir="./drive/MyDrive/Major/References/lol_dataset/our485", transform=transforms.ToTensor(),patch_size=patch_size)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

# Initialize model, loss function, optimizer
model = SparseTransformer().cuda()
if start_epoch!=0:
  checkpoint_path = f"./drive/MyDrive/Major/ImageRestoration/m/Sparse_transform/model_epoch_{start_epoch}.pth"
  checkpoint = torch.load(checkpoint_path)
  try:
      model.load_state_dict(checkpoint)
  except:
      state_dict = checkpoint
      new_state_dict = OrderedDict()
      for k, v in state_dict.items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v
      model.load_state_dict(new_state_dict)
criterion_mse = torch.nn.MSELoss()
optimizer = optim.AdamW(model.parameters(), lr=3e-4)

# Training loop
log_results = []

def compute_psnr(target, output):
    mse = F.mse_loss(target, output)
    psnr = 20 * torch.log10(1.0 / torch.sqrt(mse))
    return psnr.item()

for epoch in range(start_epoch,num_epochs):
    model.train()
    epoch_loss = 0
    epoch_psnr = 0
    epoch_ssim = 0

    for i_img,o_img  in train_loader:
        i_img = i_img.cuda()
        o_img = o_img.cuda()
        optimizer.zero_grad()
        enhanced_img = model(i_img)

        loss_mse = criterion_mse(enhanced_img, o_img)
        loss_ssim = ssim(enhanced_img, o_img)
        loss = 0.7 * loss_mse + 0.3 * (1 - loss_ssim)

        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        epoch_psnr += compute_psnr(o_img, enhanced_img)
        epoch_ssim += loss_ssim.item()

    avg_loss = epoch_loss / len(train_loader)
    avg_psnr = epoch_psnr / len(train_loader)
    avg_ssim = epoch_ssim / len(train_loader)

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}, PSNR: {avg_psnr:.4f}, SSIM: {avg_ssim:.4f}")

    # Save model every epoch
    model_path = os.path.join(save_dir, f"model_epoch_{epoch+1}.pth")
    torch.save(model.state_dict(),model_path)

    # Logging to Excel
    log_results.append({"Epoch": epoch+1, "Loss": avg_loss, "PSNR": avg_psnr, "SSIM": avg_ssim})

    df = pd.DataFrame(log_results)
    df.to_excel(os.path.join(log_dir,"training_log_100.xlsx"), index=False)


Epoch 101/150, Loss: 0.0657, PSNR: 19.6254, SSIM: 0.8280
Epoch 102/150, Loss: 0.0670, PSNR: 19.3574, SSIM: 0.8244
Epoch 103/150, Loss: 0.0651, PSNR: 19.5050, SSIM: 0.8291
Epoch 104/150, Loss: 0.0640, PSNR: 19.6134, SSIM: 0.8310
Epoch 105/150, Loss: 0.0661, PSNR: 19.6908, SSIM: 0.8240
Epoch 106/150, Loss: 0.0685, PSNR: 19.3740, SSIM: 0.8196
Epoch 107/150, Loss: 0.0705, PSNR: 19.1405, SSIM: 0.8130
Epoch 108/150, Loss: 0.0641, PSNR: 19.8280, SSIM: 0.8281
Epoch 109/150, Loss: 0.0659, PSNR: 19.5201, SSIM: 0.8263
Epoch 110/150, Loss: 0.0630, PSNR: 19.6569, SSIM: 0.8350
Epoch 111/150, Loss: 0.0660, PSNR: 19.6793, SSIM: 0.8240
Epoch 112/150, Loss: 0.0696, PSNR: 19.2604, SSIM: 0.8166
Epoch 113/150, Loss: 0.0659, PSNR: 19.5152, SSIM: 0.8265
Epoch 114/150, Loss: 0.0639, PSNR: 19.6099, SSIM: 0.8290
Epoch 115/150, Loss: 0.0634, PSNR: 19.7042, SSIM: 0.8319
Epoch 116/150, Loss: 0.0677, PSNR: 19.3560, SSIM: 0.8215
Epoch 117/150, Loss: 0.0646, PSNR: 19.6451, SSIM: 0.8292
Epoch 118/150, Loss: 0.0615, PS