In [1]:
!pip install pandas openpyxl scikit-image pytorch-ssim

Collecting pytorch-ssim
  Downloading pytorch_ssim-0.1.tar.gz (1.4 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: pytorch-ssim
  Building wheel for pytorch-ssim (setup.py) ... [?25l[?25hdone
  Created wheel for pytorch-ssim: filename=pytorch_ssim-0.1-py3-none-any.whl size=2006 sha256=79ad9781ab7ff6ecffa5b8026904d3becdccaad60c4b9caa08881e8d8bbbf95e
  Stored in directory: /root/.cache/pip/wheels/58/68/a2/68a41e8268a076c128bbc3988d243187fa4681828e648bf1ca
Successfully built pytorch-ssim
Installing collected packages: pytorch-ssim
Successfully installed pytorch-ssim-0.1


In [2]:
import numpy as np
import pandas as pd

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

/kaggle/input/deblur/pytorch/default/1/Generator.pth
/kaggle/input/low-light/pytorch/default/1/model_epoch_100.pth


In [3]:

# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

/kaggle/input/deblur/pytorch/default/1/Generator.pth
/kaggle/input/low-light/pytorch/default/1/model_epoch_100.pth


In [4]:
import pytorch_ssim
import torch.nn as nn
class CGSformerLoss(nn.Module):
    def __init__(self, alpha=0.7, beta=0.3):
        super(CGSformerLoss, self).__init__()
        self.mse = nn.MSELoss()
        self.ssim = pytorch_ssim.SSIM(window_size=11)
        self.alpha = alpha
        self.beta = beta

    def forward(self, output, target):
        mse_loss = self.mse(output, target)
        ssim_loss = 1 - self.ssim(output, target)
        return self.alpha * mse_loss + self.beta * ssim_loss


In [5]:
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image
import os
import random

class LOLDataset(Dataset):
    def __init__(self, root_dir, transform=None,patch_size=128):
        self.low_light_dir = os.path.join(root_dir, 'low')
        self.high_light_dir = os.path.join(root_dir, 'high')
        self.low_light_images = sorted(os.listdir(self.low_light_dir))
        self.high_light_images = sorted(os.listdir(self.high_light_dir))
        self.transform = transform
        self.patch_size = patch_size

    def __len__(self):
        return len(self.low_light_images)

    def __getitem__(self, idx):
        low_image_path = os.path.join(self.low_light_dir, self.low_light_images[idx])
        high_image_path = os.path.join(self.high_light_dir, self.high_light_images[idx])

        low_img = Image.open(low_image_path).convert('RGB')
        high_img = Image.open(high_image_path).convert('RGB')

        # Random crop
        i, j, h, w = transforms.RandomCrop.get_params(low_img, output_size=(self.patch_size, self.patch_size))
        low_img = transforms.functional.crop(low_img, i, j, h, w)
        high_img = transforms.functional.crop(high_img, i, j, h, w)

        # Random flip
        if random.random() > 0.5:
            low_img = transforms.functional.hflip(low_img)
            high_img = transforms.functional.hflip(high_img)
        if random.random() > 0.5:
            low_img = transforms.functional.vflip(low_img)
            high_img = transforms.functional.vflip(high_img)

        if self.transform:
            low_img = self.transform(low_img)
            high_img = self.transform(high_img)

        return low_img, high_img

# Example transforms
train_transforms = transforms.Compose([
    transforms.ToTensor()
])

Sparse Transform

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CFS(nn.Module): #Cross Feature Scrambling
    def __init__(self, channels, threshold=0.5):
        super(CFS, self).__init__()
        self.threshold = threshold
        self.sigmoid = nn.Sigmoid()
        self.gn = nn.GroupNorm(1, channels)

    def forward(self, x):
        x_ln = self.gn(x)
        var = torch.var(x_ln, dim=[2,3], keepdim=True)
        importance = var / (torch.sum(var, dim=1, keepdim=True) + 1e-6)
        importance = self.sigmoid(importance)
        mask_info = (importance > self.threshold).float()
        mask_noninfo = (importance <= self.threshold).float()

        x_info = mask_info * x_ln
        x_noninfo = mask_noninfo * x_ln

        pooled = F.adaptive_avg_pool2d(x_info + x_noninfo, (1, 1))
        beta = self.sigmoid(pooled)

        out = beta * x_info + (1 - beta) * x_noninfo
        return out

class ASA(nn.Module): #Adaptive Shift Attention
    def __init__(self, channels, topk_ratio=0.5):
        super(ASA, self).__init__()
        self.topk_ratio = topk_ratio
        self.query_conv = nn.Conv2d(channels, channels, 1)
        self.key_conv = nn.Conv2d(channels, channels, 1)
        self.value_conv = nn.Conv2d(channels, channels, 1)
        self.scale = channels ** -0.5

    def forward(self, x):
        q = self.query_conv(x).flatten(2).transpose(1, 2)
        k = self.key_conv(x).flatten(2).transpose(1, 2)
        v = self.value_conv(x).flatten(2).transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        topk = int(attn.size(-1) * self.topk_ratio)
        topk_values, _ = torch.topk(attn, k=topk, dim=-1)
        threshold = topk_values[:, :, -1].unsqueeze(-1)
        mask = attn >= threshold
        attn = attn.masked_fill(~mask, float('-inf'))
        attn = F.softmax(attn, dim=-1)

        out = attn @ v
        out = out.transpose(1, 2).reshape(x.size())
        return out

class BGFF(nn.Module): #BIlateral Grid Feature Fusion
    def __init__(self, channels):
        super(BGFF, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, 1)
        self.conv_dw3x3 = nn.Conv2d(channels, channels, 3, padding=1, groups=channels)
        self.conv_dw7x7 = nn.Conv2d(channels, channels, 7, padding=3, groups=channels)
        self.conv2 = nn.Conv2d(channels, channels, 1)
        self.swish = lambda x: x * torch.sigmoid(x)

    def forward(self, x):
        out = self.conv1(x)
        path1 = self.swish(self.conv_dw3x3(out))
        path2 = self.swish(self.conv_dw7x7(out))
        out = path1 * path2
        out = self.conv2(out)
        return out + x

class CGSformerBlock(nn.Module):
    def __init__(self, channels):
        super(CGSformerBlock, self).__init__()
        self.cfs = CFS(channels)
        self.asa = ASA(channels)
        self.bgff = BGFF(channels)
        self.norm1 = nn.LayerNorm([channels, 128, 128])
        self.norm2 = nn.LayerNorm([channels, 128, 128])

    def forward(self, x):
        x_cfs = self.cfs(x)
        x = self.asa(self.norm1(x_cfs)) + x
        x = self.bgff(self.norm2(x)) + x
        return x

class SparseTransformer(nn.Module):
    def __init__(self, channels=64):
        super(SparseTransformer, self).__init__()
        self.encoder = nn.Conv2d(3, channels, 3, padding=1)

        self.block1 = CGSformerBlock(channels)
        self.block2 = CGSformerBlock(channels)
        self.block3 = CGSformerBlock(channels)
        self.block4 = CGSformerBlock(channels)
        self.block5 = CGSformerBlock(channels)
        self.block6 = CGSformerBlock(channels)
        self.block7 = CGSformerBlock(channels)

        self.decoder = nn.Conv2d(channels, 3, 3, padding=1)

    def forward(self, x):
        x = self.encoder(x)

        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        x = self.block6(x)
        x = self.block7(x)

        x = self.decoder(x)
        return x

In [7]:
import torch
import torch.nn.functional as F

def ssim(img1, img2, window_size=11):
    channel = img1.shape[1]
    window = torch.ones((channel, 1, window_size, window_size)).to(img1.device) / (window_size ** 2)

    mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
    mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)

    sigma1_sq = F.conv2d(img1 ** 2, window, padding=window_size // 2, groups=channel) - mu1 ** 2
    sigma2_sq = F.conv2d(img2 ** 2, window, padding=window_size // 2, groups=channel) - mu2 ** 2
    sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1 * mu2

    C1, C2 = 0.01**2, 0.03**2  # Stability constants
    ssim_map = ((2 * mu1 * mu2 + C1) * (2 * sigma12 + C2)) / ((mu1 ** 2 + mu2 ** 2 + C1) * (sigma1_sq + sigma2_sq + C2))

    return ssim_map.mean()

DGUNET

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from pdb import set_trace as stx

def conv(in_channels, out_channels, kernel_size, bias=False, stride=1):
    return nn.Conv2d(
        in_channels, out_channels, kernel_size,
        padding=(kernel_size // 2), bias=bias, stride=stride)


def conv_down(in_chn, out_chn, bias=False):
    layer = nn.Conv2d(in_chn, out_chn, kernel_size=4, stride=2, padding=1, bias=bias)
    return layer


def default_conv(in_channels, out_channels, kernel_size,stride=1, bias=True):
    return nn.Conv2d(
        in_channels, out_channels, kernel_size,
        padding=(kernel_size//2),stride=stride, bias=bias)


class ResBlock(nn.Module):
    def __init__(
        self, conv, n_feats, kernel_size,
        bias=True, bn=False, act=nn.PReLU(), res_scale=1):

        super(ResBlock, self).__init__()
        m = []
        for i in range(2):
            if i == 0:
                m.append(conv(n_feats, 64, kernel_size, bias=bias))
            else:
                m.append(conv(64, n_feats, kernel_size, bias=bias))
            if bn:
                m.append(nn.BatchNorm2d(n_feats))
            if i == 0:
                m.append(act)

        self.body = nn.Sequential(*m)
        self.res_scale = res_scale

    def forward(self, x):
        res = self.body(x).mul(self.res_scale)
        res += x

        return res


class CAB(nn.Module):
    def __init__(self, n_feat, kernel_size, reduction, bias, act):
        super(CAB, self).__init__()
        modules_body = []
        modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
        modules_body.append(act)
        modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))

        self.CA = CALayer(n_feat, reduction, bias=bias)
        self.body = nn.Sequential(*modules_body)

    def forward(self, x):
        res = self.body(x)
        res = self.CA(res)
        res += x
        return res


class CALayer(nn.Module):
    def __init__(self, channel, reduction=16, bias=False):
        super(CALayer, self).__init__()
        # global average pooling: feature --> point
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        # feature channel downscale and upscale --> channel weight
        self.conv_du = nn.Sequential(
            nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias),
            nn.Sigmoid()
        )

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.conv_du(y)
        return x * y


class SAM(nn.Module):
    def __init__(self, n_feat, kernel_size, bias):
        super(SAM, self).__init__()
        self.conv1 = conv(n_feat, n_feat, kernel_size, bias=bias)
        self.conv2 = conv(n_feat, 3, kernel_size, bias=bias)

    def forward(self, x, x_img):
        x1 = self.conv1(x)
        img = self.conv2(x) + x_img
        x1 = x1 + x
        return x1, img


class mergeblock(nn.Module):
    def __init__(self, n_feat, kernel_size, bias, subspace_dim=16):
        super(mergeblock, self).__init__()
        self.conv_block = conv(n_feat * 2, n_feat, kernel_size, bias=bias)
        self.num_subspace = subspace_dim
        self.subnet = conv(n_feat * 2, self.num_subspace, kernel_size, bias=bias)

    def forward(self, x, bridge):
        out = torch.cat([x, bridge], 1)
        b_, c_, h_, w_ = bridge.shape
        sub = self.subnet(out)
        V_t = sub.view(b_, self.num_subspace, h_*w_)
        V_t = V_t / (1e-6 + torch.abs(V_t).sum(axis=2, keepdims=True))
        V = V_t.permute(0, 2, 1)
        mat = torch.matmul(V_t, V)
        mat_inv = torch.inverse(mat)
        project_mat = torch.matmul(mat_inv, V_t)
        bridge_ = bridge.view(b_, c_, h_*w_)
        project_feature = torch.matmul(project_mat, bridge_.permute(0, 2, 1))
        bridge = torch.matmul(V, project_feature).permute(0, 2, 1).view(b_, c_, h_, w_)
        out = torch.cat([x, bridge], 1)
        out = self.conv_block(out)
        return out+x

class Encoder(nn.Module):
    def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff,depth=5):
        super(Encoder, self).__init__()
        self.body=nn.ModuleList()#[]
        self.depth=depth
        for i in range(depth-1):
            self.body.append(UNetConvBlock(in_size=n_feat+scale_unetfeats*i, out_size=n_feat+scale_unetfeats*(i+1), downsample=True, relu_slope=0.2, use_csff=csff, use_HIN=True))
        self.body.append(UNetConvBlock(in_size=n_feat+scale_unetfeats*(depth-1), out_size=n_feat+scale_unetfeats*(depth-1), downsample=False, relu_slope=0.2, use_csff=csff, use_HIN=True))

    def forward(self, x, encoder_outs=None, decoder_outs=None):
        res=[]
        if encoder_outs is not None and decoder_outs is not None:
            for i,down in enumerate(self.body):
                if (i+1) < self.depth:
                    x, x_up = down(x,encoder_outs[i],decoder_outs[-i-1])
                    res.append(x_up)
                else:
                    x = down(x)
        else:
            for i,down in enumerate(self.body):
                if (i+1) < self.depth:
                    x, x_up = down(x)
                    res.append(x_up)
                else:
                    x = down(x)
        return res,x


class UNetConvBlock(nn.Module):
    def __init__(self, in_size, out_size, downsample, relu_slope, use_csff=False, use_HIN=False):
        super(UNetConvBlock, self).__init__()
        self.downsample = downsample
        self.identity = nn.Conv2d(in_size, out_size, 1, 1, 0)
        self.use_csff = use_csff

        self.conv_1 = nn.Conv2d(in_size, out_size, kernel_size=3, padding=1, bias=True)
        self.relu_1 = nn.LeakyReLU(relu_slope, inplace=False)
        self.conv_2 = nn.Conv2d(out_size, out_size, kernel_size=3, padding=1, bias=True)
        self.relu_2 = nn.LeakyReLU(relu_slope, inplace=False)

        if downsample and use_csff:
            self.csff_enc = nn.Conv2d(out_size, out_size, 3, 1, 1)
            self.csff_dec = nn.Conv2d(in_size, out_size, 3, 1, 1)
            self.phi = nn.Conv2d(out_size, out_size, 3, 1, 1)
            self.gamma = nn.Conv2d(out_size, out_size, 3, 1, 1)

        if use_HIN:
            self.norm = nn.InstanceNorm2d(out_size//2, affine=True)
        self.use_HIN = use_HIN

        if downsample:
            self.downsample = conv_down(out_size, out_size, bias=False)

    def forward(self, x, enc=None, dec=None):
        out = self.conv_1(x)

        if self.use_HIN:
            out_1, out_2 = torch.chunk(out, 2, dim=1)
            out = torch.cat([self.norm(out_1), out_2], dim=1)
        out = self.relu_1(out)
        out = self.relu_2(self.conv_2(out))

        out += self.identity(x)
        if enc is not None and dec is not None:
            assert self.use_csff
            skip_ = F.leaky_relu(self.csff_enc(enc) + self.csff_dec(dec), 0.1, inplace=True)
            out = out*F.sigmoid(self.phi(skip_)) + self.gamma(skip_) + out
        if self.downsample:
            out_down = self.downsample(out)
            return out_down, out
        else:
            return out


class UNetUpBlock(nn.Module):
    def __init__(self, in_size, out_size, relu_slope):
        super(UNetUpBlock, self).__init__()
        self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2, bias=True)
        self.conv_block = UNetConvBlock(out_size*2, out_size, False, relu_slope)

    def forward(self, x, bridge):
        up = self.up(x)
        out = torch.cat([up, bridge], 1)
        out = self.conv_block(out)
        return out


class Decoder(nn.Module):
    def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats,depth=5):
        super(Decoder, self).__init__()

        self.body=nn.ModuleList()
        self.skip_conv=nn.ModuleList()#[]
        for i in range(depth-1):
            self.body.append(UNetUpBlock(in_size=n_feat+scale_unetfeats*(depth-i-1), out_size=n_feat+scale_unetfeats*(depth-i-2), relu_slope=0.2))
            self.skip_conv.append(nn.Conv2d(n_feat+scale_unetfeats*(depth-i-1), n_feat+scale_unetfeats*(depth-i-2), 3, 1, 1))

    def forward(self, x, bridges):
        res=[]
        for i,up in enumerate(self.body):
            x=up(x,self.skip_conv[i](bridges[-i-1]))
            res.append(x)

        return res


class DownSample(nn.Module):
    def __init__(self, in_channels, s_factor):
        super(DownSample, self).__init__()
        self.down = nn.Sequential(nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False),
                                  nn.Conv2d(in_channels, in_channels + s_factor, 1, stride=1, padding=0, bias=False))

    def forward(self, x):
        x = self.down(x)
        return x


class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UpSample, self).__init__()
        self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
                                nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0, bias=False))

    def forward(self, x):
        x = self.up(x)
        return x


class Basic_block(nn.Module):
    def __init__(self, in_c=3, out_c=3, n_feat=80, scale_unetfeats=48, scale_orsnetfeats=32, num_cab=8, kernel_size=3, reduction=4, bias=False):
        super(Basic_block, self).__init__()
        act = nn.PReLU()
        self.phi_1 = ResBlock(default_conv,3,3)
        self.phit_1 = ResBlock(default_conv,3,3)
        self.shallow_feat2 = nn.Sequential(conv(3, n_feat, kernel_size, bias=bias),
                                           CAB(n_feat, kernel_size, reduction, bias=bias, act=act))
        self.stage2_encoder = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats,depth=4, csff=True)
        self.stage2_decoder = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats,depth=4)
        self.sam23 = SAM(n_feat, kernel_size=1, bias=bias)
        self.r1 = nn.Parameter(torch.Tensor([0.5]))
        self.concat12 = conv(n_feat * 2, n_feat, kernel_size, bias=bias)

        self.merge12=mergeblock(n_feat,3,True)

    def forward(self, img,stage1_img,feat1,res1,x2_samfeats):
        phixsy_2 = self.phi_1(stage1_img) - img
        x2_img = stage1_img - self.r1*self.phit_1(phixsy_2)
        x2 = self.shallow_feat2(x2_img)
        x2_cat = self.merge12(x2, x2_samfeats)
        feat2,feat_fin2 = self.stage2_encoder(x2_cat, feat1, res1)
        res2 = self.stage2_decoder(feat_fin2,feat2)
        x3_samfeats, stage2_img = self.sam23(res2[-1], x2_img)
        return x3_samfeats, stage2_img, feat2, res2

class Generator(nn.Module):
    def __init__(self, in_c=3, out_c=3, n_feat=80, scale_unetfeats=48, scale_orsnetfeats=32, num_cab=8, kernel_size=3,
                 reduction=4, bias=False, depth=5):
        super(Generator, self).__init__()

        act = nn.PReLU()
        self.depth=depth
        self.basic=Basic_block(in_c, out_c, n_feat, scale_unetfeats, scale_orsnetfeats, num_cab, kernel_size, reduction, bias)
        self.shallow_feat1 = nn.Sequential(conv(3, n_feat, kernel_size, bias=bias),
                                           CAB(n_feat, kernel_size, reduction, bias=bias, act=act))
        self.shallow_feat7 = nn.Sequential(conv(3, n_feat, kernel_size, bias=bias),
                                           CAB(n_feat, kernel_size, reduction, bias=bias, act=act))

        self.stage1_encoder = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats,depth=4, csff=False)
        self.stage1_decoder = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats,depth=4)

        self.sam12 = SAM(n_feat, kernel_size=1, bias=bias)

        self.phi_0 = ResBlock(default_conv,3,3)
        self.phit_0 = ResBlock(default_conv,3,3)
        self.phi_6 = ResBlock(default_conv,3,3)
        self.phit_6 = ResBlock(default_conv,3,3)
        self.r0 = nn.Parameter(torch.Tensor([0.5]))
        self.r6 = nn.Parameter(torch.Tensor([0.5]))

        self.concat67 = conv(n_feat * 2, n_feat + scale_orsnetfeats, kernel_size, bias=bias)
        self.tail = conv(n_feat + scale_orsnetfeats, 3, kernel_size, bias=bias)

    def forward(self, img):
        res=[]
        phixsy_1 = self.phi_0(img) - img
        x1_img = img - self.r0*self.phit_0(phixsy_1)
        x1 = self.shallow_feat1(x1_img)
        feat1,feat_fin1 = self.stage1_encoder(x1)
        res1 = self.stage1_decoder(feat_fin1,feat1)
        x2_samfeats, stage1_img = self.sam12(res1[-1], x1_img)
        res.append(stage1_img)

        for _ in range(self.depth):
            x2_samfeats, stage1_img, feat1, res1 = self.basic(img,stage1_img,feat1,res1,x2_samfeats)
            res.append(stage1_img)
        phixsy_7 = self.phi_6(stage1_img) - img
        x7_img = stage1_img - self.r6*self.phit_6(phixsy_7)
        x7 = self.shallow_feat7(x7_img)
        x7_cat = self.concat67(torch.cat([x7, x2_samfeats], 1))
        stage7_img = self.tail(x7_cat)+ img
        res.append(stage7_img)

        return res[::-1]

In [9]:
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import numpy as np
from PIL import Image
import os

def pad_to_multiple(img, patch_size=128):
    """ Pad the image to multiple of patch size (no less than original size) """
    _, h, w = img.shape
    pad_h = (patch_size - h % patch_size) % patch_size
    pad_w = (patch_size - w % patch_size) % patch_size
    img = F.pad(img, (0, pad_w, 0, pad_h), mode='reflect')
    return img

def split_patche(img, patch_size=128):
    """ Split the image into non-overlapping patches """
    patches = []
    coords = []
    c, h, w = img.shape
    for i in range(0, h, patch_size):
        for j in range(0, w, patch_size):
            patch = img[:, i:i+patch_size, j:j+patch_size]
            patches.append(patch)
            coords.append((i, j))
    return patches, coords

def merge_patche(patches, coords, image_shape, patch_size=128):
    """ Merge patches back into full image """
    c, h, w = image_shape
    merged = torch.zeros((c, h, w)).to(patches[0].device)
    counter = torch.zeros((c, h, w)).to(patches[0].device)

    for patch, (i, j) in zip(patches, coords):
        merged[:, i:i+patch.shape[1], j:j+patch.shape[2]] += patch
        counter[:, i:i+patch.shape[1], j:j+patch.shape[2]] += 1

    counter[counter == 0] = 1
    merged = merged / counter
    return merged

def enhance_image(model, img_path, save_path, device, patch_size=128):
    """ Full enhancement pipeline """
    model.eval()

    # Load image
    img_tensor = TF.to_tensor(img_path).to(device)

    c, h, w = img_tensor.shape

    if h < patch_size or w < patch_size:
        # If image is smaller in any dimension, pad to at least 128
        img_tensor = pad_to_multiple(img_tensor, patch_size)
        with torch.no_grad():
            output = model(img_tensor.unsqueeze(0)).squeeze(0)
        output = output[:, :h, :w]  # Crop back to original size
    else:
        # Normal size or large image
        padded_img = pad_to_multiple(img_tensor, patch_size)
        c_pad, h_pad, w_pad = padded_img.shape

        patches, coords = split_patche(padded_img, patch_size)

        enhanced_patches = []
        with torch.no_grad():
            for patch in patches:
                out_patch = model(patch.unsqueeze(0)).squeeze(0)
                enhanced_patches.append(out_patch)

        merged = merge_patche(enhanced_patches, coords, (c_pad, h_pad, w_pad), patch_size)
        output = merged[:, :h, :w]  # Remove padding to original size

    output_img = TF.to_pil_image(torch.clamp(output, 0, 1).cpu())
    output_img.save(save_path)
    print(f"Saved enhanced image at {save_path}")

In [10]:
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image, ImageDraw
import numpy as np
import os
from collections import OrderedDict

# Load and initialize your trained model


# Set patch size
PATCH_SIZE = 128

# Image transforms
transform = transforms.Compose([
    transforms.ToTensor()
])
to_pil = transforms.ToPILImage()

# Function to pad image
def pad_image(image_tensor, patch_size):
    _, h, w = image_tensor.shape
    pad_h = (patch_size - h % patch_size) % patch_size
    pad_w = (patch_size - w % patch_size) % patch_size
    padded = F.pad(image_tensor, (0, pad_w, 0, pad_h), mode='reflect')
    return padded, h, w

# Function to draw patch grid
def draw_patch_grid(image, patch_size):
    draw = ImageDraw.Draw(image)
    w, h = image.size
    for x in range(0, w, patch_size):
        draw.line([(x, 0), (x, h)], fill='red', width=1)
    for y in range(0, h, patch_size):
        draw.line([(0, y), (w, y)], fill='red', width=1)
    return image

# Function to split into patches
def split_into_patches(image_tensor, patch_size):
    _, h, w = image_tensor.shape
    patches = []
    for i in range(0, h, patch_size):
        for j in range(0, w, patch_size):
            patch = image_tensor[:, i:i+patch_size, j:j+patch_size]
            patches.append((patch, i, j))
    return patches, h, w

# Function to merge patches
def merge_patches(patches, full_h, full_w):
    output = torch.zeros(3, full_h, full_w)
    for patch_tensor, i, j in patches:
        output[:, i:i+PATCH_SIZE, j:j+PATCH_SIZE] = patch_tensor
    return output

# Main inference function
def process_image(image,device):
    image = image.convert('RGB')
    image_tensor = transform(image).to(device)
    padded_image, orig_h, orig_w = pad_image(image_tensor, PATCH_SIZE)
    patches, padded_h, padded_w = split_into_patches(padded_image, PATCH_SIZE)
    processed_patches=[]
    for patch, i, j in patches:
        with torch.no_grad():
            input_patch = patch.unsqueeze(0)  # Add batch dimension
            output_patch = model(input_patch)
            if isinstance(output_patch, (list, tuple)):
                output_patch = output_patch[0]
            processed_patches.append((output_patch.squeeze(0), i, j))


    merged = merge_patches(processed_patches, padded_h, padded_w)
    final = merged[:, :orig_h, :orig_w]  # Remove padding
    output_img = to_pil(final.clamp(0, 1))

    # Draw grid
    output_img = draw_patch_grid(output_img, PATCH_SIZE)
    return output_img

# Example usage:

In [11]:
!pip install gradio opencv-python numpy

Collecting gradio
  Downloading gradio-5.29.0-py3-none-any.whl.metadata (16 kB)
Collecting fastapi<1.0,>=0.115.2 (from gradio)
  Downloading fastapi-0.115.12-py3-none-any.whl.metadata (27 kB)
Collecting ffmpy (from gradio)
  Downloading ffmpy-0.5.0-py3-none-any.whl.metadata (3.0 kB)
Collecting gradio-client==1.10.0 (from gradio)
  Downloading gradio_client-1.10.0-py3-none-any.whl.metadata (7.1 kB)
Collecting groovy~=0.1 (from gradio)
  Downloading groovy-0.1.2-py3-none-any.whl.metadata (6.1 kB)
Collecting python-multipart>=0.0.18 (from gradio)
  Downloading python_multipart-0.0.20-py3-none-any.whl.metadata (1.8 kB)
Collecting ruff>=0.9.3 (from gradio)
  Downloading ruff-0.11.9-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (25 kB)
Collecting safehttpx<0.2.0,>=0.1.6 (from gradio)
  Downloading safehttpx-0.1.6-py3-none-any.whl.metadata (4.2 kB)
Collecting semantic-version~=2.0 (from gradio)
  Downloading semantic_version-2.10.0-py2.py3-none-any.whl.meta

In [12]:
import gradio as gr
import cv2
import numpy as np

# Deblurring function using a simple sharpening kernel
def deblur_image(image):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model= Generator()
    model_path = '/kaggle/input/deblur/pytorch/default/1/Generator.pth'
    checkpoint = torch.load(model_path, map_location='cpu')  # Ensure loaded to CPU
    try:
        model.load_state_dict(checkpoint["state_dict"])
    except:
        state_dict = checkpoint["state_dict"]
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:]  # remove `module.`
            new_state_dict[name] = v
        model.load_state_dict(new_state_dict)
    model.eval()
    output = process_image(image,device)
    output.save('/kaggle/working/output_deblur.jpg')

# Low light enhancement using histogram equalization
def enhance_low_light(image):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = SparseTransformer().to(device)
    model_path = '/kaggle/input/low-light/pytorch/default/1/model_epoch_100.pth'
    model.load_state_dict(torch.load(model_path,map_location=device))  
    enhance_image(model,image,"/kaggle/working/output_low.png",device)

def run_image(image, enhancement_type):
    if enhancement_type == "Deblurring":
        deblur_image(image)
        return Image.open('/kaggle/working/output_deblur.jpg')
    elif enhancement_type == "Low Light Enhancement":
        enhance_low_light(image)
        return Image.open('/kaggle/working/output_low.png')
    else:
        return image 
# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("## 🖼️ Image Enhancement Tool")
    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(label="Upload Image", type="pil")
            enhancement_option = gr.Radio(["Deblurring", "Low Light Enhancement"], label="Select Enhancement Type")
            submit_btn = gr.Button("Enhance Image")
        
        with gr.Column():
            submit_btn.click(fn=run_image, inputs=[input_image, enhancement_option], outputs = gr.Image(label="Enhanced Image"))

    

# Launch the app
demo.launch()

* Running on local URL:  http://127.0.0.1:7860
It looks like you are running Gradio on a hosted a Jupyter notebook. For the Gradio app to work, sharing must be enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

* Running on public URL: https://7518388dc2262d6fab.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


