# **Deblurring**

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from pdb import set_trace as stx


def conv(in_channels, out_channels, kernel_size, bias=False, stride=1):
    return nn.Conv2d(
        in_channels, out_channels, kernel_size,
        padding=(kernel_size // 2), bias=bias, stride=stride)


def conv_down(in_chn, out_chn, bias=False):
    layer = nn.Conv2d(in_chn, out_chn, kernel_size=4, stride=2, padding=1, bias=bias)
    return layer


def default_conv(in_channels, out_channels, kernel_size,stride=1, bias=True):
    return nn.Conv2d(
        in_channels, out_channels, kernel_size,
        padding=(kernel_size//2),stride=stride, bias=bias)


class ResBlock(nn.Module):
    def __init__(
        self, conv, n_feats, kernel_size,
        bias=True, bn=False, act=nn.PReLU(), res_scale=1):

        super(ResBlock, self).__init__()
        m = []
        for i in range(2):
            if i == 0:
                m.append(conv(n_feats, 64, kernel_size, bias=bias))
            else:
                m.append(conv(64, n_feats, kernel_size, bias=bias))
            if bn:
                m.append(nn.BatchNorm2d(n_feats))
            if i == 0:
                m.append(act)

        self.body = nn.Sequential(*m)
        self.res_scale = res_scale

    def forward(self, x):
        res = self.body(x).mul(self.res_scale)
        res += x

        return res


class CAB(nn.Module):
    def __init__(self, n_feat, kernel_size, reduction, bias, act):
        super(CAB, self).__init__()
        modules_body = []
        modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
        modules_body.append(act)
        modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))

        self.CA = CALayer(n_feat, reduction, bias=bias)
        self.body = nn.Sequential(*modules_body)

    def forward(self, x):
        res = self.body(x)
        res = self.CA(res)
        res += x
        return res


class CALayer(nn.Module):
    def __init__(self, channel, reduction=16, bias=False):
        super(CALayer, self).__init__()
        # global average pooling: feature --> point
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        # feature channel downscale and upscale --> channel weight
        self.conv_du = nn.Sequential(
            nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias),
            nn.Sigmoid()
        )

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.conv_du(y)
        return x * y


class SAM(nn.Module):
    def __init__(self, n_feat, kernel_size, bias):
        super(SAM, self).__init__()
        self.conv1 = conv(n_feat, n_feat, kernel_size, bias=bias)
        self.conv2 = conv(n_feat, 3, kernel_size, bias=bias)

    def forward(self, x, x_img):
        x1 = self.conv1(x)
        img = self.conv2(x) + x_img
        x1 = x1 + x
        return x1, img


class mergeblock(nn.Module):
    def __init__(self, n_feat, kernel_size, bias, subspace_dim=16):
        super(mergeblock, self).__init__()
        self.conv_block = conv(n_feat * 2, n_feat, kernel_size, bias=bias)
        self.num_subspace = subspace_dim
        self.subnet = conv(n_feat * 2, self.num_subspace, kernel_size, bias=bias)

    def forward(self, x, bridge):
        out = torch.cat([x, bridge], 1)
        b_, c_, h_, w_ = bridge.shape
        sub = self.subnet(out)
        V_t = sub.view(b_, self.num_subspace, h_*w_)
        V_t = V_t / (1e-6 + torch.abs(V_t).sum(axis=2, keepdims=True))
        V = V_t.permute(0, 2, 1)
        mat = torch.matmul(V_t, V)
        mat_inv = torch.inverse(mat)
        project_mat = torch.matmul(mat_inv, V_t)
        bridge_ = bridge.view(b_, c_, h_*w_)
        project_feature = torch.matmul(project_mat, bridge_.permute(0, 2, 1))
        bridge = torch.matmul(V, project_feature).permute(0, 2, 1).view(b_, c_, h_, w_)
        out = torch.cat([x, bridge], 1)
        out = self.conv_block(out)
        return out+x

class Encoder(nn.Module):
    def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff,depth=5):
        super(Encoder, self).__init__()
        self.body=nn.ModuleList()#[]
        self.depth=depth
        for i in range(depth-1):
            self.body.append(UNetConvBlock(in_size=n_feat+scale_unetfeats*i, out_size=n_feat+scale_unetfeats*(i+1), downsample=True, relu_slope=0.2, use_csff=csff, use_HIN=True))
        self.body.append(UNetConvBlock(in_size=n_feat+scale_unetfeats*(depth-1), out_size=n_feat+scale_unetfeats*(depth-1), downsample=False, relu_slope=0.2, use_csff=csff, use_HIN=True))

    def forward(self, x, encoder_outs=None, decoder_outs=None):
        res=[]
        if encoder_outs is not None and decoder_outs is not None:
            for i,down in enumerate(self.body):
                if (i+1) < self.depth:
                    x, x_up = down(x,encoder_outs[i],decoder_outs[-i-1])
                    res.append(x_up)
                else:
                    x = down(x)
        else:
            for i,down in enumerate(self.body):
                if (i+1) < self.depth:
                    x, x_up = down(x)
                    res.append(x_up)
                else:
                    x = down(x)
        return res,x


class UNetConvBlock(nn.Module):
    def __init__(self, in_size, out_size, downsample, relu_slope, use_csff=False, use_HIN=False):
        super(UNetConvBlock, self).__init__()
        self.downsample = downsample
        self.identity = nn.Conv2d(in_size, out_size, 1, 1, 0)
        self.use_csff = use_csff

        self.conv_1 = nn.Conv2d(in_size, out_size, kernel_size=3, padding=1, bias=True)
        self.relu_1 = nn.LeakyReLU(relu_slope, inplace=False)
        self.conv_2 = nn.Conv2d(out_size, out_size, kernel_size=3, padding=1, bias=True)
        self.relu_2 = nn.LeakyReLU(relu_slope, inplace=False)

        if downsample and use_csff:
            self.csff_enc = nn.Conv2d(out_size, out_size, 3, 1, 1)
            self.csff_dec = nn.Conv2d(in_size, out_size, 3, 1, 1)
            self.phi = nn.Conv2d(out_size, out_size, 3, 1, 1)
            self.gamma = nn.Conv2d(out_size, out_size, 3, 1, 1)

        if use_HIN:
            self.norm = nn.InstanceNorm2d(out_size//2, affine=True)
        self.use_HIN = use_HIN

        if downsample:
            self.downsample = conv_down(out_size, out_size, bias=False)

    def forward(self, x, enc=None, dec=None):
        out = self.conv_1(x)

        if self.use_HIN:
            out_1, out_2 = torch.chunk(out, 2, dim=1)
            out = torch.cat([self.norm(out_1), out_2], dim=1)
        out = self.relu_1(out)
        out = self.relu_2(self.conv_2(out))

        out += self.identity(x)
        if enc is not None and dec is not None:
            assert self.use_csff
            skip_ = F.leaky_relu(self.csff_enc(enc) + self.csff_dec(dec), 0.1, inplace=True)
            out = out*F.sigmoid(self.phi(skip_)) + self.gamma(skip_) + out
        if self.downsample:
            out_down = self.downsample(out)
            return out_down, out
        else:
            return out


class UNetUpBlock(nn.Module):
    def __init__(self, in_size, out_size, relu_slope):
        super(UNetUpBlock, self).__init__()
        self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2, bias=True)
        self.conv_block = UNetConvBlock(out_size*2, out_size, False, relu_slope)

    def forward(self, x, bridge):
        up = self.up(x)
        out = torch.cat([up, bridge], 1)
        out = self.conv_block(out)
        return out


class Decoder(nn.Module):
    def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats,depth=5):
        super(Decoder, self).__init__()

        self.body=nn.ModuleList()
        self.skip_conv=nn.ModuleList()#[]
        for i in range(depth-1):
            self.body.append(UNetUpBlock(in_size=n_feat+scale_unetfeats*(depth-i-1), out_size=n_feat+scale_unetfeats*(depth-i-2), relu_slope=0.2))
            self.skip_conv.append(nn.Conv2d(n_feat+scale_unetfeats*(depth-i-1), n_feat+scale_unetfeats*(depth-i-2), 3, 1, 1))

    def forward(self, x, bridges):
        res=[]
        for i,up in enumerate(self.body):
            x=up(x,self.skip_conv[i](bridges[-i-1]))
            res.append(x)

        return res


class DownSample(nn.Module):
    def __init__(self, in_channels, s_factor):
        super(DownSample, self).__init__()
        self.down = nn.Sequential(nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False),
                                  nn.Conv2d(in_channels, in_channels + s_factor, 1, stride=1, padding=0, bias=False))

    def forward(self, x):
        x = self.down(x)
        return x


class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UpSample, self).__init__()
        self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
                                nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0, bias=False))

    def forward(self, x):
        x = self.up(x)
        return x


class Basic_block(nn.Module):
    def __init__(self, in_c=3, out_c=3, n_feat=80, scale_unetfeats=48, scale_orsnetfeats=32, num_cab=8, kernel_size=3, reduction=4, bias=False):
        super(Basic_block, self).__init__()
        act = nn.PReLU()
        self.phi_1 = ResBlock(default_conv,3,3)
        self.phit_1 = ResBlock(default_conv,3,3)
        self.shallow_feat2 = nn.Sequential(conv(3, n_feat, kernel_size, bias=bias),
                                           CAB(n_feat, kernel_size, reduction, bias=bias, act=act))
        self.stage2_encoder = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats,depth=4, csff=True)
        self.stage2_decoder = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats,depth=4)
        self.sam23 = SAM(n_feat, kernel_size=1, bias=bias)
        self.r1 = nn.Parameter(torch.Tensor([0.5]))
        self.concat12 = conv(n_feat * 2, n_feat, kernel_size, bias=bias)

        self.merge12=mergeblock(n_feat,3,True)

    def forward(self, img,stage1_img,feat1,res1,x2_samfeats):
        phixsy_2 = self.phi_1(stage1_img) - img
        x2_img = stage1_img - self.r1*self.phit_1(phixsy_2)
        x2 = self.shallow_feat2(x2_img)
        x2_cat = self.merge12(x2, x2_samfeats)
        feat2,feat_fin2 = self.stage2_encoder(x2_cat, feat1, res1)
        res2 = self.stage2_decoder(feat_fin2,feat2)
        x3_samfeats, stage2_img = self.sam23(res2[-1], x2_img)
        return x3_samfeats, stage2_img, feat2, res2

class Generator(nn.Module):
    def __init__(self, in_c=3, out_c=3, n_feat=80, scale_unetfeats=48, scale_orsnetfeats=32, num_cab=8, kernel_size=3,
                 reduction=4, bias=False, depth=5):
        super(Generator, self).__init__()

        act = nn.PReLU()
        self.depth=depth
        self.basic=Basic_block(in_c, out_c, n_feat, scale_unetfeats, scale_orsnetfeats, num_cab, kernel_size, reduction, bias)
        self.shallow_feat1 = nn.Sequential(conv(3, n_feat, kernel_size, bias=bias),
                                           CAB(n_feat, kernel_size, reduction, bias=bias, act=act))
        self.shallow_feat7 = nn.Sequential(conv(3, n_feat, kernel_size, bias=bias),
                                           CAB(n_feat, kernel_size, reduction, bias=bias, act=act))

        self.stage1_encoder = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats,depth=4, csff=False)
        self.stage1_decoder = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats,depth=4)

        self.sam12 = SAM(n_feat, kernel_size=1, bias=bias)

        self.phi_0 = ResBlock(default_conv,3,3)
        self.phit_0 = ResBlock(default_conv,3,3)
        self.phi_6 = ResBlock(default_conv,3,3)
        self.phit_6 = ResBlock(default_conv,3,3)
        self.r0 = nn.Parameter(torch.Tensor([0.5]))
        self.r6 = nn.Parameter(torch.Tensor([0.5]))

        self.concat67 = conv(n_feat * 2, n_feat + scale_orsnetfeats, kernel_size, bias=bias)
        self.tail = conv(n_feat + scale_orsnetfeats, 3, kernel_size, bias=bias)

    def forward(self, img):
        res=[]
        phixsy_1 = self.phi_0(img) - img
        x1_img = img - self.r0*self.phit_0(phixsy_1)
        x1 = self.shallow_feat1(x1_img)
        feat1,feat_fin1 = self.stage1_encoder(x1)
        res1 = self.stage1_decoder(feat_fin1,feat1)
        x2_samfeats, stage1_img = self.sam12(res1[-1], x1_img)
        res.append(stage1_img)

        for _ in range(self.depth):
            x2_samfeats, stage1_img, feat1, res1 = self.basic(img,stage1_img,feat1,res1,x2_samfeats)
            res.append(stage1_img)
        phixsy_7 = self.phi_6(stage1_img) - img
        x7_img = stage1_img - self.r6*self.phit_6(phixsy_7)
        x7 = self.shallow_feat7(x7_img)
        x7_cat = self.concat67(torch.cat([x7, x2_samfeats], 1))
        stage7_img = self.tail(x7_cat)+ img
        res.append(stage7_img)

        return res[::-1]

import torch
import torch.nn as nn
import torch.nn.functional as F

class Discriminator(nn.Module):
    def __init__(self, in_c=3, n_feat=64, num_layers=5):
        super(Discriminator, self).__init__()

        layers = [
            nn.Conv2d(in_c, n_feat, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        ]

        for i in range(1, num_layers):
            in_channels = n_feat * min(2**(i-1), 8)
            out_channels = n_feat * min(2**i, 8)
            layers += [
                nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2 if i < num_layers - 1 else 1, padding=1, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.LeakyReLU(0.2, inplace=True)
            ]

        layers.append(nn.Conv2d(out_channels, 1, kernel_size=4, stride=1, padding=1))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

In [None]:
!pip install opencv-python

In [None]:
import torch
import numpy as np
import cv2

def torchPSNR(tar_img, prd_img):
    imdff = torch.clamp(prd_img,0,1) - torch.clamp(tar_img,0,1)
    rmse = (imdff**2).mean().sqrt()
    ps = 20*torch.log10(1/rmse)
    return ps

def numpyPSNR(tar_img, prd_img):
    imdff = np.float32(prd_img) - np.float32(tar_img)
    rmse = np.sqrt(np.mean(imdff**2))
    ps = 20*np.log10(255/rmse)
    return ps

In [None]:
import os
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
from torchvision.transforms import functional as TF
class GoProDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_pairs = []

        for subdir in os.listdir(root_dir):
            blur_dir = os.path.join(root_dir, subdir, 'blur_gamma')
            if os.path.isdir(blur_dir):
                common_images = os.listdir(blur_dir)
                for img_name in common_images:
                    self.image_pairs.append(os.path.join(blur_dir, img_name))

        # Define a transformation to ensure all images are (256, 256, 3)
        self.default_transform = transforms.Compose([  # Resize to 256x256
            transforms.ToTensor(),         # Convert to tensor (C, H, W)
            transforms.Lambda(lambda x: x[:3, :, :])  # Ensure 3 channels
        ])

    def __len__(self):
        return len(self.image_pairs)

    def __getitem__(self, idx):
        img_path = self.image_pairs[idx]
        blur_image = Image.open(img_path).convert('RGB')

        # Apply transformations
        blur_image = TF.to_tensor(blur_image)

        return blur_image, img_path

# **Testing code and score**

In [None]:

import numpy as np
import os
import argparse
from tqdm import tqdm
from collections import OrderedDict
import torch.nn as nn
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F

from skimage import img_as_ubyte
from pdb import set_trace as stx
class Config:
  input_dir='/content/drive/MyDrive/Major/ImageRestoration/deblur_input'
  result_dir='./drive/MyDrive/Major/ImageRestoration/output_deblur'
  weights='./drive/MyDrive/Major/ImageRestoration/m/Deblurring/Generator.pth'
  dataset='GoPro',

args = Config()

model = Generator()
checkpoint = torch.load(args.weights)
try:
    model.load_state_dict(checkpoint["state_dict"])
except:
    state_dict = checkpoint["state_dict"]
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:] # remove `module.`
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict)
print("===>Testing using weights: ",args.weights)
model.cuda()
model = nn.DataParallel(model)
model.eval()
dataset = args.dataset
test_dataset=GoProDataset(root_dir=args.input_dir)
result_dir=args.result_dir
test_loader  = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=4, drop_last=False, pin_memory=True)

with torch.no_grad():
    for i, data_test in enumerate(tqdm(test_loader), 0):
        torch.cuda.ipc_collect()
        torch.cuda.empty_cache()

        input_    = data_test[0].cuda()
        filenames = data_test[1]

        # Padding in case images are not multiples of 8
        restored = model(input_)
        restored = torch.clamp(restored[0],0,1)

        # Unpad images to original dimensions
        print(filenames)
        restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy()
        restored_im = img_as_ubyte(restored[0])
        cv2.imwrite(result_dir+"/"+str(filenames[0].split("/")[-1]),cv2.cvtColor(restored_im, cv2.COLOR_RGB2BGR))
        print(result_dir+"/"+str(filenames[0].split("/")[-1]))

In [None]:
#score evaluation
import os
import cv2
import numpy as np
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
datasets=["GOPRO"]

for dataset in datasets:
    file_path = os.path.join("/content/drive/MyDrive/ImageRestoration/GOPRO_Large/test/GOPR0868_11_00/sharp")
    gt_path = os.path.join("/content/drive/MyDrive/ImageRestoration/output/out")

    image_files = [f for f in os.listdir(file_path) if f.endswith(('.jpg', '.png'))]
    gt_files = [f for f in os.listdir(gt_path) if f.endswith(('.jpg', '.png'))]
    image_pairs= []
    common_images = set(os.listdir(file_path)).intersection(os.listdir(gt_path))
    for img_name in common_images:
      image_pairs.append((os.path.join(file_path, img_name), os.path.join(gt_path, img_name)))
    total_psnr = 0
    total_ssim = 0
    img_num = len(image_files)

    if img_num > 0:
        for image_name, gt_name in image_pairs:
            input_image = cv2.imread(os.path.join(file_path, image_name))
            gt_image = cv2.imread(os.path.join(gt_path, gt_name))

            input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2GRAY)
            gt_image = cv2.cvtColor(gt_image, cv2.COLOR_BGR2GRAY)

            ssim_val = ssim(input_image, gt_image, data_range=gt_image.max() - gt_image.min())
            psnr_val = psnr(input_image, gt_image, data_range=gt_image.max() - gt_image.min())

            total_ssim += ssim_val
            total_psnr += psnr_val

    qm_psnr = total_psnr / img_num if img_num > 0 else 0
    qm_ssim = total_ssim / img_num if img_num > 0 else 0

    print(f'For {dataset} dataset PSNR: {qm_psnr:.4f} SSIM: {qm_ssim:.4f}')

# **Training COde**

In [None]:
class GoProDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_pairs = []

        for subdir in os.listdir(root_dir):
            blur_dir = os.path.join(root_dir, subdir, 'blur')
            sharp_dir = os.path.join(root_dir, subdir, 'sharp')

            if os.path.isdir(blur_dir) and os.path.isdir(sharp_dir):
                common_images = set(os.listdir(blur_dir)).intersection(os.listdir(sharp_dir))
                for img_name in common_images:
                    self.image_pairs.append((os.path.join(blur_dir, img_name), os.path.join(sharp_dir, img_name)))

    def __len__(self):
        return len(self.image_pairs)

    def __getitem__(self, idx):
        blur_image = Image.open(self.image_pairs[idx][0]).convert('RGB')
        sharp_image = Image.open(self.image_pairs[idx][1]).convert('RGB')
        blur_image = TF.to_tensor(blur_image)
        sharp_image = TF.to_tensor(sharp_image)
        if self.transform:
            blur_image = self.transform(blur_image)
            sharp_image = self.transform(sharp_image)
        return blur_image, sharp_image

In [None]:
import os

def get_training_data(rgb_dir, img_options):
    assert os.path.exists(rgb_dir)
    return DataLoaderTrain(rgb_dir, img_options)

def get_training_data_all(rgb_dir, img_options):
    for dir in rgb_dir:
        assert os.path.exists(dir)
    return DataLoaderTrain_all(rgb_dir, img_options)

def get_validation_data(rgb_dir, img_options):
    assert os.path.exists(rgb_dir)
    return DataLoaderVal(rgb_dir, img_options)

def get_test_data(rgb_dir, img_options):
    assert os.path.exists(rgb_dir)
    return DataLoaderTest(rgb_dir, img_options)

In [None]:
import os
import numpy as np
from torch.utils.data import Dataset
import torch
from PIL import Image
import torchvision.transforms.functional as TF
from pdb import set_trace as stx
import random

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif'])

class DataLoaderTrain(Dataset):
    def __init__(self, rgb_dir, img_options=None):
        super(DataLoaderTrain, self).__init__()

        inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input')))
        tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target')))

        self.inp_filenames = [os.path.join(rgb_dir, 'input', x)  for x in inp_files if is_image_file(x)]
        self.tar_filenames = [os.path.join(rgb_dir, 'target', x) for x in tar_files if is_image_file(x)]

        self.img_options = img_options
        self.sizex       = len(self.tar_filenames)  # get the size of target

        self.ps = self.img_options['patch_size']

    def __len__(self):
        return self.sizex

    def __getitem__(self, index):
        index_ = index % self.sizex
        ps = self.ps

        inp_path = self.inp_filenames[index_]
        tar_path = self.tar_filenames[index_]

        inp_img = Image.open(inp_path)
        tar_img = Image.open(tar_path)

        w,h = tar_img.size
        padw = ps-w if w<ps else 0
        padh = ps-h if h<ps else 0

        # Reflect Pad in case image is smaller than patch_size
        if padw!=0 or padh!=0:
            inp_img = TF.pad(inp_img, (0,0,padw,padh), padding_mode='reflect')
            tar_img = TF.pad(tar_img, (0,0,padw,padh), padding_mode='reflect')

        aug    = random.randint(0, 2)
        if aug == 1:
            inp_img = TF.adjust_gamma(inp_img, 1)
            tar_img = TF.adjust_gamma(tar_img, 1)

        aug    = random.randint(0, 2)
        if aug == 1:
            sat_factor = 1 + (0.2 - 0.4*np.random.rand())
            inp_img = TF.adjust_saturation(inp_img, sat_factor)
            tar_img = TF.adjust_saturation(tar_img, sat_factor)

        inp_img = TF.to_tensor(inp_img)
        tar_img = TF.to_tensor(tar_img)

        hh, ww = tar_img.shape[1], tar_img.shape[2]

        rr     = random.randint(0, hh-ps)
        cc     = random.randint(0, ww-ps)
        aug    = random.randint(0, 8)

        # Crop patch
        inp_img = inp_img[:, rr:rr+ps, cc:cc+ps]
        tar_img = tar_img[:, rr:rr+ps, cc:cc+ps]

        # Data Augmentations
        if aug==1:
            inp_img = inp_img.flip(1)
            tar_img = tar_img.flip(1)
        elif aug==2:
            inp_img = inp_img.flip(2)
            tar_img = tar_img.flip(2)
        elif aug==3:
            inp_img = torch.rot90(inp_img,dims=(1,2))
            tar_img = torch.rot90(tar_img,dims=(1,2))
        elif aug==4:
            inp_img = torch.rot90(inp_img,dims=(1,2), k=2)
            tar_img = torch.rot90(tar_img,dims=(1,2), k=2)
        elif aug==5:
            inp_img = torch.rot90(inp_img,dims=(1,2), k=3)
            tar_img = torch.rot90(tar_img,dims=(1,2), k=3)
        elif aug==6:
            inp_img = torch.rot90(inp_img.flip(1),dims=(1,2))
            tar_img = torch.rot90(tar_img.flip(1),dims=(1,2))
        elif aug==7:
            inp_img = torch.rot90(inp_img.flip(2),dims=(1,2))
            tar_img = torch.rot90(tar_img.flip(2),dims=(1,2))

        filename = os.path.splitext(os.path.split(tar_path)[-1])[0]

        return tar_img, inp_img, filename

class DataLoaderTrain_all(Dataset):
    def __init__(self, rgb_dir, img_options=None):
        super(DataLoaderTrain_all, self).__init__()

        inp_files_noise = sorted(os.listdir(os.path.join(rgb_dir[0], 'input')))
        tar_files_noise = sorted(os.listdir(os.path.join(rgb_dir[0], 'target')))
        inp_files_blur = sorted(os.listdir(os.path.join(rgb_dir[1], 'input')))
        tar_files_blur = sorted(os.listdir(os.path.join(rgb_dir[1], 'target')))
        inp_files_rain = sorted(os.listdir(os.path.join(rgb_dir[2], 'input')))
        tar_files_rain = sorted(os.listdir(os.path.join(rgb_dir[2], 'target')))

        self.inp_filenames_noise = [os.path.join(rgb_dir[0], 'input', x)  for x in inp_files_noise if is_image_file(x)]
        self.tar_filenames_noise = [os.path.join(rgb_dir[0], 'target', x) for x in tar_files_noise if is_image_file(x)]
        self.inp_filenames_blur = [os.path.join(rgb_dir[1], 'input', x)  for x in inp_files_blur if is_image_file(x)]
        self.tar_filenames_blur = [os.path.join(rgb_dir[1], 'target', x) for x in tar_files_blur if is_image_file(x)]
        self.inp_filenames_rain = [os.path.join(rgb_dir[2], 'input', x)  for x in inp_files_rain if is_image_file(x)]
        self.tar_filenames_rain = [os.path.join(rgb_dir[2], 'target', x) for x in tar_files_rain if is_image_file(x)]

        self.img_options = img_options
        self.sizex_noise       = len(self.tar_filenames_noise)  # get the size of target
        self.sizex_blur       = len(self.tar_filenames_blur)
        self.sizex_rain       = len(self.tar_filenames_rain)

        self.ps_noise = self.img_options['patch_size_noise']
        self.ps_blur = self.img_options['patch_size_blur']
        self.ps_rain = self.img_options['patch_size_rain']

    def __len__(self):
        return self.sizex_blur

    def __getitem__(self, index):
        id_ = np.random.randint(0,3)
        if id_==0:
            index_ = index % self.sizex_noise
            ps = self.ps_noise
            inp_path = self.inp_filenames_noise[index_]
            tar_path = self.tar_filenames_noise[index_]
        elif id_==1:
            index_ = index % self.sizex_blur
            ps = self.ps_blur
            inp_path = self.inp_filenames_blur[index_]
            tar_path = self.tar_filenames_blur[index_]
        else:
            index_ = index % self.sizex_rain
            ps = self.ps_rain
            inp_path = self.inp_filenames_rain[index_]
            tar_path = self.tar_filenames_rain[index_]

        inp_img = Image.open(inp_path)
        tar_img = Image.open(tar_path)

        w,h = tar_img.size
        padw = ps-w if w<ps else 0
        padh = ps-h if h<ps else 0

        # Reflect Pad in case image is smaller than patch_size
        if padw!=0 or padh!=0:
            inp_img = TF.pad(inp_img, (0,0,padw,padh), padding_mode='reflect')
            tar_img = TF.pad(tar_img, (0,0,padw,padh), padding_mode='reflect')

        aug    = random.randint(0, 2)
        if aug == 1:
            inp_img = TF.adjust_gamma(inp_img, 1)
            tar_img = TF.adjust_gamma(tar_img, 1)

        aug    = random.randint(0, 2)
        if aug == 1:
            sat_factor = 1 + (0.2 - 0.4*np.random.rand())
            inp_img = TF.adjust_saturation(inp_img, sat_factor)
            tar_img = TF.adjust_saturation(tar_img, sat_factor)

        inp_img = TF.to_tensor(inp_img)
        tar_img = TF.to_tensor(tar_img)

        hh, ww = tar_img.shape[1], tar_img.shape[2]

        rr     = random.randint(0, hh-ps)
        cc     = random.randint(0, ww-ps)
        aug    = random.randint(0, 8)

        # Crop patch
        inp_img = inp_img[:, rr:rr+ps, cc:cc+ps]
        tar_img = tar_img[:, rr:rr+ps, cc:cc+ps]

        # Data Augmentations
        if aug==1:
            inp_img = inp_img.flip(1)
            tar_img = tar_img.flip(1)
        elif aug==2:
            inp_img = inp_img.flip(2)
            tar_img = tar_img.flip(2)
        elif aug==3:
            inp_img = torch.rot90(inp_img,dims=(1,2))
            tar_img = torch.rot90(tar_img,dims=(1,2))
        elif aug==4:
            inp_img = torch.rot90(inp_img,dims=(1,2), k=2)
            tar_img = torch.rot90(tar_img,dims=(1,2), k=2)
        elif aug==5:
            inp_img = torch.rot90(inp_img,dims=(1,2), k=3)
            tar_img = torch.rot90(tar_img,dims=(1,2), k=3)
        elif aug==6:
            inp_img = torch.rot90(inp_img.flip(1),dims=(1,2))
            tar_img = torch.rot90(tar_img.flip(1),dims=(1,2))
        elif aug==7:
            inp_img = torch.rot90(inp_img.flip(2),dims=(1,2))
            tar_img = torch.rot90(tar_img.flip(2),dims=(1,2))

        filename = os.path.splitext(os.path.split(tar_path)[-1])[0]

        return tar_img, inp_img, filename

class DataLoaderVal(Dataset):
    def __init__(self, rgb_dir, img_options=None, rgb_dir2=None):
        super(DataLoaderVal, self).__init__()

        inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input')))
        tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target')))

        self.inp_filenames = [os.path.join(rgb_dir, 'input', x)  for x in inp_files if is_image_file(x)]
        self.tar_filenames = [os.path.join(rgb_dir, 'target', x) for x in tar_files if is_image_file(x)]

        self.img_options = img_options
        self.sizex       = len(self.tar_filenames)  # get the size of target

        self.ps = self.img_options['patch_size']

    def __len__(self):
        return self.sizex

    def __getitem__(self, index):
        index_ = index % self.sizex
        ps = self.ps

        inp_path = self.inp_filenames[index_]
        tar_path = self.tar_filenames[index_]

        inp_img = Image.open(inp_path)
        tar_img = Image.open(tar_path)

        # Validate on center crop
        if self.ps is not None:
            inp_img = TF.center_crop(inp_img, (ps,ps))
            tar_img = TF.center_crop(tar_img, (ps,ps))

        inp_img = TF.to_tensor(inp_img)
        tar_img = TF.to_tensor(tar_img)

        filename = os.path.splitext(os.path.split(tar_path)[-1])[0]

        return tar_img, inp_img, filename

class DataLoaderTest(Dataset):
    def __init__(self, inp_dir, img_options):
        super(DataLoaderTest, self).__init__()

        inp_files = sorted(os.listdir(inp_dir))
        self.inp_filenames = [os.path.join(inp_dir, x) for x in inp_files if is_image_file(x)]

        self.inp_size = len(self.inp_filenames)
        self.img_options = img_options

    def __len__(self):
        return self.inp_size

    def __getitem__(self, index):

        path_inp = self.inp_filenames[index]
        filename = os.path.splitext(os.path.split(path_inp)[-1])[0]
        inp = Image.open(path_inp)

        inp = TF.to_tensor(inp)
        return inp, filename

In [None]:
import torch

class MixUp_AUG:
    def __init__(self):
        self.dist = torch.distributions.beta.Beta(torch.tensor([0.6]), torch.tensor([0.6]))

    def aug(self, rgb_gt, rgb_noisy):
        bs = rgb_gt.size(0)
        indices = torch.randperm(bs)
        rgb_gt2 = rgb_gt[indices]
        rgb_noisy2 = rgb_noisy[indices]

        lam = self.dist.rsample((bs,1)).view(-1,1,1,1).cuda()

        rgb_gt    = lam * rgb_gt + (1-lam) * rgb_gt2
        rgb_noisy = lam * rgb_noisy + (1-lam) * rgb_noisy2

        return rgb_gt, rgb_noisy

In [None]:
!pip install natsort

In [None]:
import os
from natsort import natsorted
from glob import glob

def mkdirs(paths):
    if isinstance(paths, list) and not isinstance(paths, str):
        for path in paths:
            mkdir(path)
    else:
        mkdir(paths)

def mkdir(path):
    if not os.path.exists(path):
        os.makedirs(path)

def get_last_path(path, session):
	x = natsorted(glob(os.path.join(path,'*%s'%session)))[-1]
	return x

In [None]:
!pip install opencv-python

In [None]:
import torch
import numpy as np
import cv2

def torchPSNR(tar_img, prd_img):
    imdff = torch.clamp(prd_img,0,1) - torch.clamp(tar_img,0,1)
    rmse = (imdff**2).mean().sqrt()
    ps = 20*torch.log10(1/rmse)
    return ps

def save_img(filepath, img):
    cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR))

def numpyPSNR(tar_img, prd_img):
    imdff = np.float32(prd_img) - np.float32(tar_img)
    rmse = np.sqrt(np.mean(imdff**2))
    ps = 20*np.log10(255/rmse)
    return ps

In [None]:
import torch
import os
from collections import OrderedDict

def freeze(model):
    for p in model.parameters():
        p.requires_grad=False

def unfreeze(model):
    for p in model.parameters():
        p.requires_grad=True

def is_frozen(model):
    x = [p.requires_grad for p in model.parameters()]
    return not all(x)

def save_checkpoint(model_dir, state, session):
    epoch = state['epoch']
    model_out_path = os.path.join(model_dir,"model_epoch_{}_{}.pth".format(epoch,session))
    torch.save(state, model_out_path)

def load_checkpoint(model, weights):
    checkpoint = torch.load(weights)
    try:
        model.load_state_dict(checkpoint["state_dict"])
    except:
        state_dict = checkpoint["state_dict"]
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:] # remove `module.`
            new_state_dict[name] = v
        model.load_state_dict(new_state_dict)


def load_checkpoint_multigpu(model, weights):
    checkpoint = torch.load(weights)
    state_dict = checkpoint["state_dict"]
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:] # remove `module.`
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict)

def load_start_epoch(weights):
    checkpoint = torch.load(weights)
    epoch = checkpoint["epoch"]
    return epoch

def load_optim(optimizer, weights):
    checkpoint = torch.load(weights)
    optimizer.load_state_dict(checkpoint['optimizer'])
    # for p in optimizer.param_groups: lr = p['lr']
    # return lr

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CharbonnierLoss(nn.Module):
    """Charbonnier Loss (L1)"""

    def __init__(self, eps=1e-3):
        super(CharbonnierLoss, self).__init__()
        self.eps = eps

    def forward(self, x, y):
        diff = x - y
        # loss = torch.sum(torch.sqrt(diff * diff + self.eps))
        loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps)))
        return loss

class EdgeLoss(nn.Module):
    def __init__(self):
        super(EdgeLoss, self).__init__()
        k = torch.Tensor([[.05, .25, .4, .25, .05]])
        self.kernel = torch.matmul(k.t(),k).unsqueeze(0).repeat(3,1,1,1)
        if torch.cuda.is_available():
            self.kernel = self.kernel.cuda()
        self.loss = CharbonnierLoss()

    def conv_gauss(self, img):
        n_channels, _, kw, kh = self.kernel.shape
        img = F.pad(img, (kw//2, kh//2, kw//2, kh//2), mode='replicate')
        return F.conv2d(img, self.kernel, groups=n_channels)

    def laplacian_kernel(self, current):
        filtered    = self.conv_gauss(current)    # filter
        down        = filtered[:,:,::2,::2]               # downsample
        new_filter  = torch.zeros_like(filtered)
        new_filter[:,:,::2,::2] = down*4                  # upsample
        filtered    = self.conv_gauss(new_filter) # filter
        diff = current - filtered
        return diff

    def forward(self, x, y):
        loss = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(y))
        return loss

In [None]:
!pip install warmup-scheduler

In [None]:
import os
import torch

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

import random
import time
import numpy as np
from tqdm import tqdm

from warmup_scheduler import GradualWarmupScheduler
######### Set Seeds ###########
random.seed(4321)
np.random.seed(1234)
torch.manual_seed(1234)
torch.cuda.manual_seed_all(1234)

start_epoch = 1
class Config:
    GPU = [0, 1, 2, 3]
    VERBOSE = True

    class MODEL:
        MODE = 'Deblurring'
        SESSION = 'DGUNet'

    class OPTIM:
        BATCH_SIZE = 1
        NUM_EPOCHS = 50
        # NEPOCH_DECAY = [10]
        LR_INITIAL = 1e-4
        LR_MIN = 1e-6
        # BETA1 = 0.9

    class TRAINING:
        VAL_AFTER_EVERY = 10
        RESUME = False
        TRAIN_PS = 256
        VAL_PS = 256
opt=Config()

result_dir = "./drive/MyDrive/ImageRestoration/results"
trained_dir  = "./drive/MyDrive/ImageRestoration/m/Generator.pth"
model_dir  = "./drive/MyDrive/ImageRestoration/m"


train_dir = "/content/drive/MyDrive/Major/References/lol_dataset/our485"
val_dir   = "./value"

######### Model ###########
model_restoration = Generator()
model_restoration.cuda()

checkpoint = torch.load(trained_dir)
try:
    model_restoration.load_state_dict(checkpoint["state_dict"])
except:
    state_dict = checkpoint["state_dict"]
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v
    model_restoration.load_state_dict(new_state_dict)
######### Optimizer ###########

new_lr=opt.OPTIM.LR_INITIAL



optimizer = optim.Adam(model_restoration.parameters(), lr=new_lr, betas=(0.9, 0.999),eps=1e-8)


######### Scheduler ###########
warmup_epochs = 3
scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, warmup_epochs, eta_min=opt.OPTIM.LR_MIN)
scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine)
scheduler.step()


for i in range(1, start_epoch):
    scheduler.step()
    new_lr = scheduler.get_lr()[0]
    print('------------------------------------------------------------------------------')
    print("==> Resuming Training with learning rate:", new_lr)
    print('------------------------------------------------------------------------------')

#Loss
criterion_char = CharbonnierLoss()
criterion_edge = EdgeLoss()

# DataLoaders
train_dataset = LOLDataset(train_dir)
train_loader = DataLoader(dataset=train_dataset, batch_size=opt.OPTIM.BATCH_SIZE, shuffle=True, num_workers=16, drop_last=False, pin_memory=True)



print('===> Start Epoch {} End Epoch {}'.format(start_epoch,opt.OPTIM.NUM_EPOCHS + 1))
print('===> Loading datasets')

best_psnr = 0
best_epoch = 0

for epoch in range(start_epoch, opt.OPTIM.NUM_EPOCHS + 1):
    epoch_start_time = time.time()
    epoch_loss = 0
    train_id = 1

    model_restoration.train()
    for i, data in enumerate(tqdm(train_loader), 0):

        # zero_grad
        for param in model_restoration.parameters():
            param.grad = None

        target = data[1].cuda()
        input_ = data[0].cuda()

        restored = model_restoration(input_)
        # Compute loss at each stage
        # Compute loss at each stage (detach tensors before converting to NumPy)
        loss_char = sum([criterion_char(restored[j], target) for j in range(len(restored))])

        # Compute loss at each stage using PyTorch's sum function (instead of np.sum)
        loss_char = sum([criterion_char(restored[j], target) for j in range(len(restored))])
        loss_edge = sum([criterion_edge(restored[j], target) for j in range(len(restored))])

# Combine losses
        loss = loss_char + (0.05 * loss_edge)

# Backpropagation and optimization
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()  # Convert loss tensor to scalar for accumulation


        torch.save({'epoch': epoch,
                    'state_dict': model_restoration.state_dict(),
                    'optimizer' : optimizer.state_dict()
                    }, os.path.join(model_dir,f"model_epoch_{epoch}.pth"))

    scheduler.step()

    print("------------------------------------------------------------------")
    print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch, time.time()-epoch_start_time, epoch_loss, scheduler.get_lr()[0]))
    print("------------------------------------------------------------------")

    torch.save({'epoch': epoch,
                'state_dict': model_restoration.state_dict(),
                'optimizer' : optimizer.state_dict()
                }, os.path.join(model_dir,"model_latest.pth"))

In [None]:
import numpy as np
import os
import argparse
import cv2
from tqdm import tqdm
from collections import OrderedDict
import torch.nn as nn
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from skimage import img_as_ubyte

class Config:
    input_dir = '/content/drive/MyDrive/Major/ImageRestoration/deblur_input'
    result_dir = './drive/MyDrive/Major/ImageRestoration/output_deblur'
    weights = './drive/MyDrive/Major/ImageRestoration/m/Deblurring/Generator.pth'
    dataset = 'GoPro'
    patch_size = 256
    stride = 256

args = Config()

# Load Model
model = Generator()
checkpoint = torch.load(args.weights)
state_dict = checkpoint["state_dict"]
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:]  # remove `module.`
    new_state_dict[name] = v
model.load_state_dict(new_state_dict)

print("===> Testing using weights:", args.weights)
model.cuda()
model = nn.DataParallel(model)
model.eval()

# Dataset Loader
test_dataset = GoProDataset(root_dir=args.input_dir)
test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=4, drop_last=False, pin_memory=True)

# Padding Function (Handles Small Images)
def pad_image(img, patch_size=256):
    _, _, h, w = img.shape
    pad_h = (patch_size - h % patch_size) % patch_size
    pad_w = (patch_size - w % patch_size) % patch_size
    padded_img = F.pad(img, (0, pad_w, 0, pad_h), mode='reflect')  # Reflective padding to minimize artifacts
    return padded_img, pad_h, pad_w

# Patch Extraction Function
def extract_patches(img, patch_size=256, stride=256):
    _, _, h, w = img.shape
    unfold = F.unfold(img, kernel_size=patch_size, stride=stride)
    patches = unfold.view(img.size(0), img.size(1), patch_size, patch_size, -1)
    return patches, h, w

# Patch Merging Function
def merge_patches(patches, orig_h, orig_w, patch_size=256, stride=256):
    patches = torch.cat(patches, dim=0)
    fold = F.fold(patches, output_size=(orig_h, orig_w), kernel_size=patch_size, stride=stride)
    return fold

# Processing Images
with torch.no_grad():
    for i, data_test in enumerate(tqdm(test_loader), 0):
        torch.cuda.ipc_collect()
        torch.cuda.empty_cache()

        input_ = data_test[0].cuda()
        filenames = data_test[1]

        # Apply Padding if Needed
        input_, pad_h, pad_w = pad_image(input_, patch_size=args.patch_size)

        # Extract patches
        patches, orig_h, orig_w = extract_patches(input_, patch_size=args.patch_size, stride=args.stride)
        restored_patches = []

        for j in range(patches.shape[-1]):
            patch = patches[..., j].cuda()
            restored_patch = model(patch.unsqueeze(0))
            restored_patch = torch.clamp(restored_patch, 0, 1)
            restored_patches.append(restored_patch)

        # Merge patches back
        restored = merge_patches(restored_patches, orig_h, orig_w, patch_size=args.patch_size, stride=args.stride)

        # Remove Padding Before Saving
        restored = restored[:, :, :orig_h - pad_h, :orig_w - pad_w]

        # Save the final image
        restored_im = img_as_ubyte(restored.cpu().squeeze().numpy())
        cv2.imwrite(os.path.join(args.result_dir, filenames[0].split("/")[-1]), cv2.cvtColor(restored_im, cv2.COLOR_RGB2BGR))
        print("Saved:", os.path.join(args.result_dir, filenames[0].split("/")[-1]))

===> Testing using weights: ./drive/MyDrive/Major/ImageRestoration/m/Deblurring/Generator.pth


  0%|          | 0/2 [00:00<?, ?it/s]


RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [1, 1, 3, 256, 256]

In [None]:

import numpy as np
import os
import argparse
import cv2
from tqdm import tqdm
from collections import OrderedDict
import torch.nn as nn
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from skimage import img_as_ubyte

class Config:
    input_dir = '/content/drive/MyDrive/ImageRestoration/deblur_input'
    result_dir = './drive/MyDrive/ImageRestoration/output_deblur'
    weights = './drive/MyDrive/ImageRestoration/m/Deblurring/Generator.pth'
    dataset = 'GoPro'
    patch_size = 256
    stride = 256

args = Config()

# Load Model
model = Generator()
checkpoint = torch.load(args.weights)
state_dict = checkpoint["state_dict"]
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:]  # remove `module.`
    new_state_dict[name] = v
model.load_state_dict(new_state_dict)

print("===> Testing using weights:", args.weights)
model.cuda()
model = nn.DataParallel(model)
model.eval()

# Dataset Loader
test_dataset = GoProDataset(root_dir=args.input_dir)
test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=4, drop_last=False, pin_memory=True)

# Patch Extraction Function
def extract_patches(img, patch_size=256, stride=256):
    _, _, h, w = img.shape
    unfold = F.unfold(img, kernel_size=patch_size, stride=stride)
    patches = unfold.view(img.size(0), img.size(1), patch_size, patch_size, -1)
    return patches, h, w

# Patch Merging Function
def merge_patches(patches, orig_h, orig_w, patch_size=256, stride=256):
    patches = torch.cat(patches, dim=0)
    fold = F.fold(patches, output_size=(orig_h, orig_w), kernel_size=patch_size, stride=stride)
    return fold

# Processing Images
with torch.no_grad():
    for i, data_test in enumerate(tqdm(test_loader), 0):
        torch.cuda.ipc_collect()
        torch.cuda.empty_cache()

        input_ = data_test[0].cuda()
        filenames = data_test[1]

        # Extract patches
        patches, orig_h, orig_w = extract_patches(input_, patch_size=args.patch_size, stride=args.stride)
        restored_patches = []

        for j in range(patches.shape[-1]):
            patch = patches[..., j].cuda()
            restored_patch = model(patch.unsqueeze(0))
            restored_patch = torch.clamp(restored_patch, 0, 1)
            restored_patches.append(restored_patch)

        # Merge patches back
        restored = merge_patches(restored_patches, orig_h, orig_w, patch_size=args.patch_size, stride=args.stride)

        # Save the final image
        restored_im = img_as_ubyte(restored.cpu().squeeze().numpy())
        cv2.imwrite(os.path.join(args.result_dir, filenames[0].split("/")[-1]), cv2.cvtColor(restored_im, cv2.COLOR_RGB2BGR))
        print("Saved:", os.path.join(args.result_dir, filenames[0].split("/")[-1]))