In [1]:
import os
import cv2
import torch
import torch.nn as nn
import numpy as np
from collections import OrderedDict
from PIL import Image
from skimage import img_as_ubyte
from skimage.metrics import structural_similarity as ssim
from tqdm import tqdm
import torchvision.transforms.functional as TF

def load_checkpoint(model, weights):
    checkpoint = torch.load(weights, map_location='cpu')
    try:
        model.load_state_dict(checkpoint["state_dict"])
    except:
        state_dict = checkpoint["state_dict"]
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:] if k.startswith('module.') else k
            new_state_dict[name] = v
        model.load_state_dict(new_state_dict)

def load_start_epoch(weights):
    checkpoint = torch.load(weights)
    epoch = checkpoint["epoch"]
    return epoch

def mkdir(path):
    if not os.path.exists(path):
        os.makedirs(path)

In [2]:
def torchPSNR(tar_img, prd_img):
    imdff = torch.clamp(prd_img, 0, 1) - torch.clamp(tar_img, 0, 1)
    rmse = (imdff**2).mean().sqrt()
    ps = 20 * torch.log10(1 / rmse)
    return ps

def save_img(filepath, img):
    cv2.imwrite(filepath, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))

def numpyPSNR(tar_img, prd_img):
    imdff = np.float32(prd_img) - np.float32(tar_img)
    rmse = np.sqrt(np.mean(imdff**2))
    ps = 20 * np.log10(255 / rmse)
    return ps

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif'])

In [3]:
class DataLoaderTest(torch.utils.data.Dataset):
    def __init__(self, inp_dir, img_options):
        super(DataLoaderTest, self).__init__()
        inp_files = sorted(os.listdir(inp_dir))
        self.inp_filenames = [os.path.join(inp_dir, x) for x in inp_files if is_image_file(x)]
        self.inp_size = len(self.inp_filenames)
        self.img_options = img_options

    def __len__(self):
        return self.inp_size

    def __getitem__(self, index):
        path_inp = self.inp_filenames[index]
        filename = os.path.splitext(os.path.split(path_inp)[-1])[0]
        inp = Image.open(path_inp)
        inp = TF.to_tensor(inp)
        return inp, filename

In [4]:
def window_partitions(x, window_size):
    if isinstance(window_size, int):
        window_size = [window_size, window_size]
    B, C, H, W = x.shape
    x = x.view(B, C, H // window_size[0], window_size[0], W // window_size[1], window_size[1])
    windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, C, window_size[0], window_size[1])
    return windows

def window_reverses(windows, window_size, H, W):
    if isinstance(window_size, int):
        window_size = [window_size, window_size]
    C = windows.shape[1]
    x = windows.view(-1, H // window_size[0], W // window_size[1], C, window_size[0], window_size[1])
    x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(-1, C, H, W)
    return x

def window_partitionx(x, window_size):
    _, _, H, W = x.shape
    h, w = window_size * (H // window_size), window_size * (W // window_size)
    x_main = window_partitions(x[:, :, :h, :w], window_size)
    b_main = x_main.shape[0]
    if h == H and w == W:
        return x_main, [b_main]
    if h != H and w != W:
        x_r = window_partitions(x[:, :, :h, -window_size:], window_size)
        b_r = x_r.shape[0] + b_main
        x_d = window_partitions(x[:, :, -window_size:, :w], window_size)
        b_d = x_d.shape[0] + b_r
        x_dd = x[:, :, -window_size:, -window_size:]
        b_dd = x_dd.shape[0] + b_d
        return torch.cat([x_main, x_r, x_d, x_dd], dim=0), [b_main, b_r, b_d, b_dd]
    if h == H and w != W:
        x_r = window_partitions(x[:, :, :h, -window_size:], window_size)
        b_r = x_r.shape[0] + b_main
        return torch.cat([x_main, x_r], dim=0), [b_main, b_r]
    if h != H and w == W:
        x_d = window_partitions(x[:, :, -window_size:, :w], window_size)
        b_d = x_d.shape[0] + b_main
        return torch.cat([x_main, x_d], dim=0), [b_main, b_d]

def window_reversex(windows, window_size, H, W, batch_list):
    h, w = window_size * (H // window_size), window_size * (W // window_size)
    x_main = window_reverses(windows[:batch_list[0], ...], window_size, h, w)
    B, C, _, _ = x_main.shape
    if torch.is_complex(windows):
        res = torch.complex(torch.zeros([B, C, H, W]), torch.zeros([B, C, H, W]))
        res = res.to(windows.device)
    else:
        res = torch.zeros([B, C, H, W], device=windows.device)
    res[:, :, :h, :w] = x_main
    if h == H and w == W:
        return res
    if h != H and w != W and len(batch_list) == 4:
        x_dd = window_reverses(windows[batch_list[2]:, ...], window_size, window_size, window_size)
        res[:, :, h:, w:] = x_dd[:, :, h - H:, w - W:]
        x_r = window_reverses(windows[batch_list[0]:batch_list[1], ...], window_size, h, window_size)
        res[:, :, :h, w:] = x_r[:, :, :, w - W:]
        x_d = window_reverses(windows[batch_list[1]:batch_list[2], ...], window_size, window_size, w)
        res[:, :, h:, :w] = x_d[:, :, h - H:, :]
        return res
    if w != W and len(batch_list) == 2:
        x_r = window_reverses(windows[batch_list[0]:batch_list[1], ...], window_size, h, window_size)
        res[:, :, :h, w:] = x_r[:, :, :, w - W:]
    if h != H and len(batch_list) == 2:
        x_d = window_reverses(windows[batch_list[0]:batch_list[1], ...], window_size, window_size, w)
        res[:, :, h:, :w] = x_d[:, :, h - H:, :]
    return res

In [5]:
import torch.nn as nn
import torch
import numpy as np
from torch.nn import functional as F

hidden_list = [256, 256, 256]
L = 4

def make_coord(shape, ranges=None, flatten=True):
    coord_seqs = []
    for i, n in enumerate(shape):
        if ranges is None:
            v0, v1 = -1, 1
        else:
            v0, v1 = ranges[i]
        r = (v1 - v0) / (2 * n)
        seq = v0 + r + (2 * r) * torch.arange(n).float()
        coord_seqs.append(seq)
    ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
    if flatten:
        ret = ret.view(-1, ret.shape[-1])
    return ret

class MLP(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_list):
        super().__init__()
        layers = []
        lastv = in_dim
        for hidden in hidden_list:
            layers.append(nn.Linear(lastv, hidden))
            layers.append(nn.ReLU())
            lastv = hidden
        layers.append(nn.Linear(lastv, out_dim))
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        shape = x.shape[:-1]
        x = self.layers(x.view(-1, x.shape[-1]))
        return x.view(*shape, -1)

class INR(nn.Module):
    def __init__(self, dim, local_ensemble=True, feat_unfold=True, cell_decode=True):
        super().__init__()
        self.local_ensemble = local_ensemble
        self.feat_unfold = feat_unfold
        self.cell_decode = cell_decode
        imnet_in_dim = dim

        if self.feat_unfold:
            imnet_in_dim *= 9
        imnet_in_dim += 2 + 4 * L  
        if self.cell_decode:
            imnet_in_dim += 2

        self.imnet = MLP(imnet_in_dim, 3, hidden_list)

    def query_rgb(self, inp, coord, cell=None):
        feat = inp
        if self.feat_unfold:
            feat = F.unfold(feat, 3, padding=1).view(
                feat.shape[0], feat.shape[1] * 9, feat.shape[2], feat.shape[3])

        if self.local_ensemble:
            vx_lst = [-1, 1]
            vy_lst = [-1, 1]
            eps_shift = 1e-6
        else:
            vx_lst, vy_lst, eps_shift = [0], [0], 0

        rx = 2 / feat.shape[-2] / 2
        ry = 2 / feat.shape[-1] / 2
        
        device = inp.device
        feat_coord = make_coord(feat.shape[-2:], flatten=False).to(device) \
            .permute(2, 0, 1) \
            .unsqueeze(0).expand(feat.shape[0], 2, *feat.shape[-2:])

        preds = []
        areas = []
        for vx in vx_lst:
            for vy in vy_lst:
                coord_ = coord.clone()
                coord_[:, :, 0] += vx * rx + eps_shift
                coord_[:, :, 1] += vy * ry + eps_shift
                coord_.clamp_(-1 + 1e-6, 1 - 1e-6)

                bs, q, h, w = feat.shape
                q_feat = feat.view(bs, q, -1).permute(0, 2, 1)

                bs, q, h, w = feat_coord.shape
                q_coord = feat_coord.view(bs, q, -1).permute(0, 2, 1)

                points_enc = self.positional_encoding(q_coord, L=L)
                q_coord = torch.cat([q_coord, points_enc], dim=-1)  

                rel_coord = coord - q_coord
                rel_coord[:, :, 0] *= feat.shape[-2]
                rel_coord[:, :, 1] *= feat.shape[-1]
                inp = torch.cat([q_feat, rel_coord], dim=-1)

                if self.cell_decode:
                    rel_cell = cell.clone()
                    rel_cell[:, :, 0] *= feat.shape[-2]
                    rel_cell[:, :, 1] *= feat.shape[-1]
                    inp = torch.cat([inp, rel_cell], dim=-1)

                bs, q = coord.shape[:2]
                pred = self.imnet(inp.view(bs * q, -1)).view(bs, q, -1)
                preds.append(pred)

                area = torch.abs(rel_coord[:, :, 0] * rel_coord[:, :, 1])
                areas.append(area + 1e-9)

        tot_area = torch.stack(areas).sum(dim=0)
        if self.local_ensemble:
            t = areas[0];
            areas[0] = areas[3];
            areas[3] = t
            t = areas[1];
            areas[1] = areas[2];
            areas[2] = t
        ret = 0
        for pred, area in zip(preds, areas):
            ret = ret + pred * (area / tot_area).unsqueeze(-1)

        bs, q, h, w = feat.shape
        ret = ret.view(bs, h, w, -1).permute(0, 3, 1, 2)
        return ret

    def forward(self, inp):
        h, w = inp.shape[2], inp.shape[3]
        B = inp.shape[0]
        #coord = make_coord((h, w)).cuda() # cũ
        device = inp.device
        coord = make_coord((h, w)).to(device)
        cell = torch.ones_like(coord)
        cell[:, 0] *= 2 / h
        cell[:, 1] *= 2 / w
        cell = cell.unsqueeze(0).repeat(B, 1, 1)
        coord = coord.unsqueeze(0).repeat(B, 1, 1)
        points_enc = self.positional_encoding(coord, L=L)
        coord = torch.cat([coord, points_enc], dim=-1)  

        return self.query_rgb(inp, coord, cell)

    def positional_encoding(self, input, L): 
        shape = input.shape
        device = input.device
        freq = 2 ** torch.arange(L, dtype=torch.float32).to(device) * np.pi  
        spectrum = input[..., None] * freq  
        sin, cos = spectrum.sin(), spectrum.cos()  
        input_enc = torch.stack([sin, cos], dim=-2)  
        input_enc = input_enc.view(*shape[:-1], -1)  

        return input_enc

In [6]:
import torch
from torch import nn
from torch.nn import functional as F
import numbers
from einops import rearrange

def to_3d(x):
    return rearrange(x, 'b c h w -> b (h w) c')


def to_4d(x, h, w):
    return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)


class BasicConv(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, stride, bias=False, norm=False, relu=True, transpose=False,
                 channel_shuffle_g=0, norm_method=nn.BatchNorm2d, groups=1):
        super(BasicConv, self).__init__()
        self.channel_shuffle_g = channel_shuffle_g
        self.norm = norm
        if bias and norm:
            bias = False

        padding = kernel_size // 2
        layers = list()
        if transpose:
            padding = kernel_size // 2 - 1
            layers.append(
                nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias,
                                   groups=groups))
        else:
            layers.append(
                nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias,
                          groups=groups))
        if norm:
            layers.append(norm_method(out_channel))
        elif relu:
            layers.append(nn.ReLU(inplace=True))

        self.main = nn.Sequential(*layers)

    def forward(self, x):
        return self.main(x)


class BiasFree_LayerNorm(nn.Module):
    def __init__(self, normalized_shape):
        super(BiasFree_LayerNorm, self).__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        normalized_shape = torch.Size(normalized_shape)

        assert len(normalized_shape) == 1

        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.normalized_shape = normalized_shape

    def forward(self, x):
        sigma = x.var(-1, keepdim=True, unbiased=False)
        return x / torch.sqrt(sigma + 1e-5) * self.weight


class WithBias_LayerNorm(nn.Module):
    def __init__(self, normalized_shape):
        super(WithBias_LayerNorm, self).__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        normalized_shape = torch.Size(normalized_shape)

        assert len(normalized_shape) == 1

        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.normalized_shape = normalized_shape

    def forward(self, x):
        mu = x.mean(-1, keepdim=True)
        sigma = x.var(-1, keepdim=True, unbiased=False)
        return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias


class LayerNorm(nn.Module):
    def __init__(self, dim, LayerNorm_type):
        super(LayerNorm, self).__init__()
        if LayerNorm_type == 'BiasFree':
            self.body = BiasFree_LayerNorm(dim)
        else:
            self.body = WithBias_LayerNorm(dim)

    def forward(self, x):
        h, w = x.shape[-2:]
        return to_4d(self.body(to_3d(x)), h, w)


class FeedForward(nn.Module):
    def __init__(self, dim, ffn_expansion_factor, bias, BasicConv=BasicConv):
        super(FeedForward, self).__init__()

        hidden_features = int(dim * ffn_expansion_factor)

        self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)

        self.dwconv = BasicConv(hidden_features * 2, hidden_features * 2, kernel_size=3, stride=1, bias=bias,
                                relu=False, groups=hidden_features * 2)

        self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)

    def forward(self, x):
        x = self.project_in(x)
        x1, x2 = self.dwconv(x).chunk(2, dim=1)
        x = F.gelu(x1) * x2
        x = self.project_out(x)
        return x


class Attention(nn.Module):
    def __init__(self, dim, num_heads, bias, BasicConv=BasicConv):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

        self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
        self.qkv_dwconv = BasicConv(dim * 3, dim * 3, kernel_size=3, stride=1, bias=bias, relu=False, groups=dim * 3)
        self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)

    def forward(self, x):
        b, c, h, w = x.shape

        qkv = self.qkv_dwconv(self.qkv(x))
        q, k, v = qkv.chunk(3, dim=1)

        q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)

        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)

        attn = (q @ k.transpose(-2, -1)) * self.temperature
        attn = attn.softmax(dim=-1)

        out = (attn @ v)

        out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)

        out = self.project_out(out)
        return out


class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type, BasicConv=BasicConv):
        super(TransformerBlock, self).__init__()

        self.norm1 = LayerNorm(dim, LayerNorm_type)
        self.attn = Attention(dim, num_heads, bias, BasicConv=BasicConv)
        self.norm2 = LayerNorm(dim, LayerNorm_type)
        self.ffn = FeedForward(dim, ffn_expansion_factor, bias, BasicConv=BasicConv)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.ffn(self.norm2(x))

        return x


class OverlapPatchEmbed(nn.Module):
    def __init__(self, in_c=3, embed_dim=48, bias=False):
        super(OverlapPatchEmbed, self).__init__()

        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)

    def forward(self, x):
        x = self.proj(x)

        return x

class Downsample(nn.Module):
    def __init__(self, n_feat):
        super(Downsample, self).__init__()

        self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False),
                                  nn.PixelUnshuffle(2))

    def forward(self, x):
        return self.body(x)


class Upsample(nn.Module):
    def __init__(self, n_feat):
        super(Upsample, self).__init__()

        self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False),
                                  nn.PixelShuffle(2))

    def forward(self, x):
        return self.body(x)


class Fusion(nn.Module):
    def __init__(self, in_dim=32):
        super(Fusion, self).__init__()
        self.chanel_in = in_dim

        self.query_conv = nn.Conv2d(in_dim, in_dim, 3, 1, 1, bias=True)
        self.key_conv = nn.Conv2d(in_dim, in_dim, 3, 1, 1, bias=True)

        self.gamma1 = nn.Conv2d(in_dim * 2, 2, 3, 1, 1, bias=True)
        self.gamma2 = nn.Conv2d(in_dim * 2, 2, 3, 1, 1, bias=True)
        self.sig = nn.Sigmoid()

    def forward(self, x, y):
        x_q = self.query_conv(x)
        y_k = self.key_conv(y)
        energy = x_q * y_k
        attention = self.sig(energy)
        attention_x = x * attention
        attention_y = y * attention

        x_gamma = self.gamma1(torch.cat((x, attention_x), dim=1))
        x_out = x * x_gamma[:, [0], :, :] + attention_x * x_gamma[:, [1], :, :]

        y_gamma = self.gamma2(torch.cat((y, attention_y), dim=1))
        y_out = y * y_gamma[:, [0], :, :] + attention_y * y_gamma[:, [1], :, :]

        x_s = x_out + y_out

        return x_s

In [7]:
class MultiscaleNet(nn.Module):
    def __init__(self,
                 inp_channels=3,
                 out_channels=3,
                 dim=48,
                 num_blocks=[2, 3, 3],
                 heads=[1, 2, 4],
                 ffn_expansion_factor=2.66,
                 bias=False,
                 LayerNorm_type='WithBias',
                 ):
        super(MultiscaleNet, self).__init__()
        self.patch_embed_small = OverlapPatchEmbed(inp_channels, dim)

        self.encoder_level1_small = nn.Sequential(*[
            TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias,
                             LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])

        self.down1_2_small = Downsample(dim)
        self.encoder_level2_small = nn.Sequential(*[
            TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
                             bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])

        self.down2_3_small = Downsample(int(dim * 2 ** 1))
        self.latent_small = nn.Sequential(*[
            TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor,
                             bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])

        self.up3_2_small = Upsample(int(dim * 2 ** 2))
        self.reduce_chan_level2_small = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias)
        self.decoder_level2_small = nn.Sequential(*[
            TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
                             bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])

        self.up2_1_small = Upsample(int(dim * 2 ** 1))
        self.reduce_chan_level1_small = nn.Conv2d(int(dim * 2 ** 1), int(dim * 1 ** 1), kernel_size=1, bias=bias)
        self.decoder_level1_small = nn.Sequential(*[
            TransformerBlock(dim=int(dim * 1 ** 1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor,
                             bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])

        self.output_small = nn.Conv2d(int(dim * 1 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)

        #self.INR = INR(dim).cuda() # cũ
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.INR = INR(dim).to(device) # mới

        self.patch_embed_mid = OverlapPatchEmbed(inp_channels, dim)

        self.encoder_level1_mid1 = nn.Sequential(*[
            TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias,
                             LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])

        self.encoder_level1_mid2 = nn.Sequential(*[
            TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias,
                             LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])

        self.down1_2_mid = Downsample(dim)
        self.down1_2_mid2 = Downsample(dim)
        self.encoder_level2_mid1 = nn.Sequential(*[
            TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
                             bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
        self.encoder_level2_mid2 = nn.Sequential(*[
            TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
                             bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])

        self.down2_3_mid = Downsample(int(dim * 2 ** 1))
        self.down2_3_mid2 = Downsample(int(dim * 2 ** 1))
        self.latent_mid1 = nn.Sequential(*[
            TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor,
                             bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
        self.latent_mid2 = nn.Sequential(*[
            TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor,
                             bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])

        self.up3_2_mid = Upsample(int(dim * 2 ** 2))
        self.up3_2_mid2 = Upsample(int(dim * 2 ** 2))
        self.reduce_chan_level2_mid1 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias)
        self.reduce_chan_level2_mid2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias)
        self.decoder_level2_mid1 = nn.Sequential(*[
            TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
                             bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
        self.decoder_level2_mid2 = nn.Sequential(*[
            TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
                             bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])

        self.up2_1_mid = Upsample(int(dim * 2 ** 1))
        self.up2_1_mid2 = Upsample(int(dim * 2 ** 1))
        self.reduce_chan_level1_mid1 = nn.Conv2d(int(dim * 2 ** 1), int(dim * 1 ** 1), kernel_size=1, bias=bias)
        self.reduce_chan_level1_mid2 = nn.Conv2d(int(dim * 2 ** 1), int(dim * 1 ** 1), kernel_size=1, bias=bias)
        self.decoder_level1_mid1 = nn.Sequential(*[
            TransformerBlock(dim=int(dim * 1 ** 1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor,
                             bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
        self.decoder_level1_mid2 = nn.Sequential(*[
            TransformerBlock(dim=int(dim * 1 ** 1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor,
                             bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])

        self.output_mid = nn.Conv2d(int(dim * 1 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
        self.output_mid_context = nn.Conv2d(int(dim * 1 ** 1), dim, kernel_size=3, stride=1, padding=1, bias=bias)

        #self.INR2 = INR(dim).cuda() # cũ

        self.INR2 = INR(dim).to(device) # mới

        self.patch_embed_max = OverlapPatchEmbed(inp_channels, dim)

        self.encoder_level1_max1 = nn.Sequential(*[
            TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias,
                             LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
        self.encoder_level1_max2 = nn.Sequential(*[
            TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias,
                             LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
        self.encoder_level1_max3 = nn.Sequential(*[
            TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias,
                             LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])

        self.down1_2_max = Downsample(dim)
        self.down1_2_max2 = Downsample(dim)
        self.down1_2_max3 = Downsample(dim)
        self.encoder_level2_max1 = nn.Sequential(*[
            TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
                             bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
        self.encoder_level2_max2 = nn.Sequential(*[
            TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
                             bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
        self.encoder_level2_max3 = nn.Sequential(*[
            TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
                             bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])

        self.down2_3_max = Downsample(int(dim * 2 ** 1))
        self.down2_3_max2 = Downsample(int(dim * 2 ** 1))
        self.down2_3_max3 = Downsample(int(dim * 2 ** 1))
        self.latent_max1 = nn.Sequential(*[
            TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor,
                             bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
        self.latent_max2 = nn.Sequential(*[
            TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor,
                             bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])
        self.latent_max3 = nn.Sequential(*[
            TransformerBlock(dim=int(dim * 2 ** 2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor,
                             bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])])

        self.up3_2_max = Upsample(int(dim * 2 ** 2))
        self.up3_2_max2 = Upsample(int(dim * 2 ** 2))
        self.up3_2_max3 = Upsample(int(dim * 2 ** 2))
        self.reduce_chan_level2_max1 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias)
        self.reduce_chan_level2_max2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias)
        self.reduce_chan_level2_max3 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias)
        self.decoder_level2_max1 = nn.Sequential(*[
            TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
                             bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
        self.decoder_level2_max2 = nn.Sequential(*[
            TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
                             bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])
        self.decoder_level2_max3 = nn.Sequential(*[
            TransformerBlock(dim=int(dim * 2 ** 1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
                             bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])])

        self.up2_1_max = Upsample(int(dim * 2 ** 1))
        self.up2_1_max2 = Upsample(int(dim * 2 ** 1))
        self.up2_1_max3 = Upsample(int(dim * 2 ** 1))
        self.reduce_chan_level1_max1 = nn.Conv2d(int(dim * 2 ** 1), int(dim * 1 ** 1), kernel_size=1, bias=bias)
        self.reduce_chan_level1_max2 = nn.Conv2d(int(dim * 2 ** 1), int(dim * 1 ** 1), kernel_size=1, bias=bias)
        self.reduce_chan_level1_max3 = nn.Conv2d(int(dim * 2 ** 1), int(dim * 1 ** 1), kernel_size=1, bias=bias)
        self.decoder_level1_max1 = nn.Sequential(*[
            TransformerBlock(dim=int(dim * 1 ** 1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor,
                             bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
        self.decoder_level1_max2 = nn.Sequential(*[
            TransformerBlock(dim=int(dim * 1 ** 1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor,
                             bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])
        self.decoder_level1_max3 = nn.Sequential(*[
            TransformerBlock(dim=int(dim * 1 ** 1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor,
                             bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])])

        self.output_max = nn.Conv2d(int(dim * 1 ** 1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
        self.output_max_context1 = nn.Conv2d(int(dim * 1 ** 1), dim, kernel_size=3, stride=1, padding=1, bias=bias)
        self.output_max_context2 = nn.Conv2d(int(dim * 1 ** 1), dim, kernel_size=3, stride=1, padding=1, bias=bias)

        self.BF1 = Fusion(dim * 4)
        self.BF2 = Fusion(dim * 4)
        self.BF3 = Fusion(dim * 4)

        self.upsmall2mid1 = Upsample(int(dim * 4 ** 1))
        self.upsmall2mid2 = Upsample(int(dim * 2 ** 1))

        self.upmid2max1 = Upsample(int(dim * 4 ** 1))
        self.upmid2max2 = Upsample(int(dim * 2 ** 1))

    def forward(self, inp_img):
        outputs = list()

        inp_img_max = inp_img
        inp_img_mid = F.interpolate(inp_img, scale_factor=0.5)
        inp_img_small = F.interpolate(inp_img, scale_factor=0.25)

        inp_enc_level1_small = self.patch_embed_small(inp_img_small)
        out_enc_level1_small = self.encoder_level1_small(inp_enc_level1_small)

        inp_enc_level2_small = self.down1_2_small(out_enc_level1_small)
        out_enc_level2_small = self.encoder_level2_small(inp_enc_level2_small)

        inp_enc_level4_small = self.down2_3_small(out_enc_level2_small)
        latent_small = self.latent_small(inp_enc_level4_small)
        latent_small_mid = self.upsmall2mid1(latent_small)
        latent_small_mid = self.upsmall2mid2(latent_small_mid)

        outputs.append(inp_img_small)
        INR = self.INR(latent_small_mid)
        inp_img_small_ = INR + inp_img_small
        outputs.append(inp_img_small_)

        inp_img_small_ = F.interpolate(inp_img_small_, scale_factor=2)

        mid_img = inp_img_mid + inp_img_small_

        inp_enc_level1_mid = self.patch_embed_mid(mid_img)
        out_enc_level1_mid = self.encoder_level1_mid1(inp_enc_level1_mid)

        inp_enc_level2_mid = self.down1_2_mid(out_enc_level1_mid)
        out_enc_level2_mid = self.encoder_level2_mid1(inp_enc_level2_mid)

        inp_enc_level4_mid = self.down2_3_mid(out_enc_level2_mid)
        latent_mid = self.latent_mid1(inp_enc_level4_mid)
        latent_mid_INR_max = self.upmid2max1(latent_mid)
        latent_mid_INR_max = self.upmid2max2(latent_mid_INR_max)

        outputs.append(mid_img / 2)
        INR2 = self.INR2(latent_mid_INR_max)
        mid_img_ = INR2 + mid_img
        outputs.append(mid_img_)

        mid_img_ = F.interpolate(mid_img_, scale_factor=2)

        max_img = inp_img_max + mid_img_

        inp_enc_level1_max = self.patch_embed_max(max_img)
        out_enc_level1_max = self.encoder_level1_max1(inp_enc_level1_max)

        inp_enc_level2_max = self.down1_2_max(out_enc_level1_max)
        out_enc_level2_max = self.encoder_level2_max1(inp_enc_level2_max)

        inp_enc_level4_max = self.down2_3_max(out_enc_level2_max)
        latent_max = self.latent_max1(inp_enc_level4_max)
        BFF_max_1 = latent_max

        inp_dec_level2_max = self.up3_2_max(latent_max)
        inp_dec_level2_max = torch.cat([inp_dec_level2_max, out_enc_level2_max], 1)
        inp_dec_level2_max = self.reduce_chan_level2_max1(inp_dec_level2_max)
        out_dec_level2_max = self.decoder_level2_max1(inp_dec_level2_max)

        inp_dec_level1_max = self.up2_1_max(out_dec_level2_max)
        inp_dec_level1_max = torch.cat([inp_dec_level1_max, out_enc_level1_max], 1)
        inp_dec_level1_max = self.reduce_chan_level1_max1(inp_dec_level1_max)
        out_dec_level1_max = self.decoder_level1_max1(inp_dec_level1_max)

        out_dec_level1_max = self.output_max_context1(out_dec_level1_max)
        out_enc_level1_max = self.encoder_level1_max2(out_dec_level1_max)

        inp_enc_level2_max = self.down1_2_max2(out_enc_level1_max)
        out_enc_level2_max = self.encoder_level2_max2(inp_enc_level2_max)

        inp_enc_level4_max = self.down2_3_max2(out_enc_level2_max)
        latent_max = self.latent_max2(inp_enc_level4_max)
        BFF_max_2 = latent_max

        inp_dec_level2_max = self.up3_2_max2(latent_max)
        inp_dec_level2_max = torch.cat([inp_dec_level2_max, out_enc_level2_max], 1)
        inp_dec_level2_max = self.reduce_chan_level2_max2(inp_dec_level2_max)
        out_dec_level2_max = self.decoder_level2_max2(inp_dec_level2_max)

        inp_dec_level1_max = self.up2_1_max2(out_dec_level2_max)
        inp_dec_level1_max = torch.cat([inp_dec_level1_max, out_enc_level1_max], 1)
        inp_dec_level1_max = self.reduce_chan_level1_max2(inp_dec_level1_max)
        out_dec_level1_max = self.decoder_level1_max2(inp_dec_level1_max)

        out_dec_level1_max = self.output_max_context2(out_dec_level1_max)
        out_enc_level1_max = self.encoder_level1_max3(out_dec_level1_max)

        inp_enc_level2_max = self.down1_2_max3(out_enc_level1_max)
        out_enc_level2_max = self.encoder_level2_max3(inp_enc_level2_max)

        inp_enc_level4_max = self.down2_3_max3(out_enc_level2_max)
        latent_max = self.latent_max3(inp_enc_level4_max)
        BFF_max_3 = latent_max

        BFF1 = self.BF1(BFF_max_1, BFF_max_2)
        BFF2 = self.BF2(BFF_max_2, BFF_max_3)

        BFF1 = F.interpolate(BFF1, scale_factor=0.5)
        BFF2 = F.interpolate(BFF2, scale_factor=0.5)

        inp_dec_level2_max = self.up3_2_max3(latent_max)

        BFF3_1 = latent_mid
        latent_mid = latent_mid + BFF1

        inp_dec_level2_mid = self.up3_2_mid(latent_mid)
        inp_dec_level2_mid = torch.cat([inp_dec_level2_mid, out_enc_level2_mid], 1)
        inp_dec_level2_mid = self.reduce_chan_level2_mid1(inp_dec_level2_mid)
        out_dec_level2_mid = self.decoder_level2_mid1(inp_dec_level2_mid)

        inp_dec_level1_mid = self.up2_1_mid(out_dec_level2_mid)
        inp_dec_level1_mid = torch.cat([inp_dec_level1_mid, out_enc_level1_mid], 1)
        inp_dec_level1_mid = self.reduce_chan_level1_mid1(inp_dec_level1_mid)
        out_dec_level1_mid = self.decoder_level1_mid1(inp_dec_level1_mid)

        out_dec_level1_mid = self.output_mid_context(out_dec_level1_mid)
        out_enc_level1_mid = self.encoder_level1_mid2(out_dec_level1_mid)

        inp_enc_level2_mid = self.down1_2_mid2(out_enc_level1_mid)
        out_enc_level2_mid = self.encoder_level2_mid2(inp_enc_level2_mid)

        inp_enc_level4_mid = self.down2_3_mid2(out_enc_level2_mid)
        latent_mid = self.latent_mid2(inp_enc_level4_mid)
        BFF3_2 = latent_mid
        BFF3 = self.BF3(BFF3_1, BFF3_2)
        BFF3 = F.interpolate(BFF3, scale_factor=0.5)

        latent_mid = latent_mid + BFF2

        inp_dec_level2_mid = self.up3_2_mid2(latent_mid)

        latent_small = latent_small + BFF3

        inp_dec_level2_small = self.up3_2_small(latent_small)
        inp_dec_level2_small = torch.cat([inp_dec_level2_small, out_enc_level2_small], 1)
        inp_dec_level2_small = self.reduce_chan_level2_small(inp_dec_level2_small)
        out_dec_level2_small = self.decoder_level2_small(inp_dec_level2_small)

        inp_dec_level1_small = self.up2_1_small(out_dec_level2_small)
        inp_dec_level1_small = torch.cat([inp_dec_level1_small, out_enc_level1_small], 1)
        inp_dec_level1_small = self.reduce_chan_level1_small(inp_dec_level1_small)
        out_dec_level1_small = self.decoder_level1_small(inp_dec_level1_small)

        small_2_mid = out_dec_level1_small

        out_dec_level1_small = self.output_small(out_dec_level1_small) + inp_img_small

        outputs.append(out_dec_level1_small)
        small = F.interpolate(out_dec_level1_small, scale_factor=2)

        inp_dec_level2_mid = torch.cat([inp_dec_level2_mid, out_enc_level2_mid], 1)
        inp_dec_level2_mid = self.reduce_chan_level2_mid2(inp_dec_level2_mid)
        out_dec_level2_mid = self.decoder_level2_mid2(inp_dec_level2_mid)

        inp_dec_level1_mid = self.up2_1_mid2(out_dec_level2_mid)
        inp_dec_level1_mid = torch.cat([inp_dec_level1_mid, out_enc_level1_mid], 1)
        inp_dec_level1_mid = self.reduce_chan_level1_mid2(inp_dec_level1_mid)
        out_dec_level1_mid = self.decoder_level1_mid2(inp_dec_level1_mid)

        small_2_mid = F.interpolate(small_2_mid, scale_factor=2)
        out_dec_level1_mid = out_dec_level1_mid + small_2_mid

        mid_2_max = out_dec_level1_mid

        out_dec_level1_mid = self.output_mid(out_dec_level1_mid) + inp_img_mid

        outputs.append(out_dec_level1_mid)
        mid = F.interpolate(out_dec_level1_mid, scale_factor=2)

        inp_dec_level2_max = torch.cat([inp_dec_level2_max, out_enc_level2_max], 1)
        inp_dec_level2_max = self.reduce_chan_level2_max3(inp_dec_level2_max)
        out_dec_level2_max = self.decoder_level2_max3(inp_dec_level2_max)

        inp_dec_level1_max = self.up2_1_max3(out_dec_level2_max)
        inp_dec_level1_max = torch.cat([inp_dec_level1_max, out_enc_level1_max], 1)
        inp_dec_level1_max = self.reduce_chan_level1_max2(inp_dec_level1_max)
        mid_2_max = F.interpolate(mid_2_max, scale_factor=2)
        out_dec_level1_max = self.decoder_level1_max3(inp_dec_level1_max) + mid_2_max

        out_dec_level1_max = self.output_max(out_dec_level1_max) + inp_img_max

        outputs.append(out_dec_level1_max)

        return outputs[::-1]

TESTING

In [8]:
from torch.utils.data import DataLoader

datasets = ["Rain100H", "Rain100L", "Test100", "Test1200", "Test2800"]
weights = "/kaggle/input/checkpoints/Deraining/models/Multiscale/model_best.pth"
win_size = 256

model_restoration_test = MultiscaleNet()
load_checkpoint(model_restoration_test, weights)
model_restoration_test.cuda()
model_restoration_test = nn.DataParallel(model_restoration_test)
model_restoration_test.eval()

with torch.no_grad():
    for dataset in datasets:
        print(f"\n===> Testing {dataset}")

        input_dir = f"/kaggle/input/rain13kdataset/test/test/{dataset}/input/"
        output_dir = f"/kaggle/working/results/{dataset}/"
        mkdir(output_dir)

        test_dataset = DataLoaderTest(input_dir, img_options={})
        test_loader = DataLoader(
            dataset=test_dataset,
            batch_size=1,
            shuffle=False,
            num_workers=4,
            drop_last=False,
            pin_memory=True
        )

        for ii, data_test in enumerate(tqdm(test_loader), 0):
            torch.cuda.ipc_collect()
            torch.cuda.empty_cache()

            input_ = data_test[0].cuda()
            filenames = data_test[1]

            _, _, Hx, Wx = input_.shape
            pad_h = (win_size - Hx % win_size) % win_size
            pad_w = (win_size - Wx % win_size) % win_size

            input_pad = F.pad(input_, (0, pad_w, 0, pad_h), mode='reflect')
            input_re, batch_list = window_partitionx(input_pad, win_size)

            restored = model_restoration_test(input_re)
            restored = window_reversex(
                restored[0], win_size, Hx + pad_h, Wx + pad_w, batch_list
            )

            restored = restored[:, :, :Hx, :Wx]
            restored = torch.clamp(restored, 0, 1)
            restored = restored.permute(0, 2, 3, 1).cpu().numpy()

            for b in range(len(restored)):
                restored_img = img_as_ubyte(restored[b])
                save_img(os.path.join(output_dir, filenames[b] + ".png"), restored_img)

print("=== Testing completed ===")



===> Testing Rain100H


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
100%|██████████| 100/100 [01:30<00:00,  1.10it/s]



===> Testing Rain100L


100%|██████████| 100/100 [01:29<00:00,  1.11it/s]



===> Testing Test100


100%|██████████| 98/98 [01:28<00:00,  1.11it/s]



===> Testing Test1200


100%|██████████| 1200/1200 [18:07<00:00,  1.10it/s]



===> Testing Test2800


100%|██████████| 2800/2800 [47:11<00:00,  1.01s/it]

=== Testing completed ===





EVALUATION

In [9]:
def find_gt_file(gt_dir, pred_name):
    base = os.path.splitext(pred_name)[0]
    for ext in ["png", "jpg", "jpeg"]:
        candidate = os.path.join(gt_dir, base + "." + ext)
        if os.path.exists(candidate):
            return candidate
    return None

def append_results(save_file, dataset_name, psnr, ssim):
    header = "Dataset,PSNR,SSIM\n"
    if not os.path.exists(save_file):
        with open(save_file, "w") as f:
            f.write(header)
    with open(save_file, "a") as f:
        f.write(f"{dataset_name},{psnr:.4f},{ssim:.4f}\n")
        

In [10]:
from skimage.metrics import structural_similarity as ssim

datasets = ["Rain100H", "Rain100L", "Test100", "Test1200", "Test2800"]
save_file = "/kaggle/working/results.csv"

for dataset in datasets:
    print(f"\n===> Evaluating {dataset}")

    result_dir = f"/kaggle/working/results/{dataset}/"
    gt_dir = f"/kaggle/input/rain13kdataset/test/test/{dataset}/target/"

    result_files = sorted(os.listdir(result_dir))
    psnr_list, ssim_list = [], []

    for name in tqdm(result_files):
        pred_path = os.path.join(result_dir, name)
        gt_path = find_gt_file(gt_dir, name)

        if gt_path is None:
            print(f"GT not found for {name}")
            continue

        pred = Image.open(pred_path).convert("L")
        gt = Image.open(gt_path).convert("L")

        pred_np = np.array(pred, dtype=np.uint8)
        gt_np = np.array(gt, dtype=np.uint8)

        psnr_list.append(numpyPSNR(pred_np, gt_np))
        ssim_list.append(ssim(gt_np, pred_np, data_range=255))

    avg_psnr = np.mean(psnr_list)
    avg_ssim = np.mean(ssim_list)

    print(f"{dataset} | PSNR: {avg_psnr:.4f} | SSIM: {avg_ssim:.4f}")
    append_results(save_file, dataset, avg_psnr, avg_ssim)

print("\nSaved results to:", save_file)


===> Evaluating Rain100H


100%|██████████| 100/100 [00:03<00:00, 31.34it/s]


Rain100H | PSNR: 26.0931 | SSIM: 0.7885

===> Evaluating Rain100L


100%|██████████| 100/100 [00:03<00:00, 33.11it/s]


Rain100L | PSNR: 30.5683 | SSIM: 0.9176

===> Evaluating Test100


100%|██████████| 98/98 [00:03<00:00, 28.83it/s]


Test100 | PSNR: 24.5943 | SSIM: 0.8226

===> Evaluating Test1200


100%|██████████| 1200/1200 [01:19<00:00, 15.18it/s]


Test1200 | PSNR: 30.9124 | SSIM: 0.8868

===> Evaluating Test2800


100%|██████████| 2800/2800 [01:52<00:00, 24.94it/s]

Test2800 | PSNR: 29.9599 | SSIM: 0.8998

Saved results to: /kaggle/working/results.csv



