In [None]:
%pip install coremltools
%pip install timm
%pip install opencv-python
%pip install thop

In [None]:
# classic
!python main_test_swinir.py --task classical_sr --scale 2 --training_patch_size 48 --model_path model_zoo/swinir/001_classicalSR_DIV2K_s48w8_SwinIR-M_x2.pth --folder_lq testsets/Set5/LR_bicubic/X2 --folder_gt testsets/Set5/HR

In [6]:
# lightweight
!python main_test_swinir.py --task lightweight_sr --scale 2 --model_path model_zoo/swinir/002_lightweightSR_DIV2K_s64w8_SwinIR-S_x2.pth --folder_lq testsets/Set5/LR_bicubic/X2 --folder_gt testsets/Set5/HR

Fail to import BlobReader from libmilstoragepython. No module named 'coremltools.libmilstoragepython'
Failed to load _MLModelProxy: No module named 'coremltools.libcoremlpython'
Fail to import BlobWriter from libmilstoragepython. No module named 'coremltools.libmilstoragepython'
downloading model model_zoo/swinir/002_lightweightSR_DIV2K_s64w8_SwinIR-S_x2.pth
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Testing 0 baby                 - PSNR: 37.40 dB; SSIM: 0.9603; PSNRB: 0.00 dB;PSNR_Y: 38.94 dB; SSIM_Y: 0.9681; PSNRB_Y: 0.00 dB.
Testing 1 bird                 - PSNR: 40.59 dB; SSIM: 0.9841; PSNRB: 0.00 dB;PSNR_Y: 43.28 dB; SSIM_Y: 0.9906; PSNRB_Y: 0.00 dB.
Testing 2 butterfly            - PSNR: 34.13 dB; SSIM: 0.9718; PSNRB: 0.00 dB;PSNR_Y: 35.75 dB; SSIM_Y: 0.9795; PSNRB_Y: 0.00 dB.
Testing 3 head                 - PSNR: 32.32 dB; SSIM: 0.8345; PSNRB: 0.00 dB;PSNR_Y: 36.07 dB; SSIM_Y: 0.8922; PSNRB_Y: 0.00 dB.
Testing 4 woman                - PSNR: 35.28 dB;

In [40]:
config = {
    'classical_sr' : {
        'args': {
            'upscale' : 2, # flexible
            'in_chans' : 3,
            'img_size' : 48,
            'window_size' : 8,
            'img_range' : 1.,
            'depths' :[6, 6, 6, 6, 6, 6],
            'embed_dim' : 180,
            'num_heads' : [6, 6, 6, 6, 6, 6],
            'mlp_ratio' : 2,
            'upsampler' : 'pixelshuffle',
            'resi_connection' : '1conv',
        },
        'path' : 'model_zoo/swinir/001_classicalSR_DIV2K_s48w8_SwinIR-M_x2.pth', #match with upscale factor
    },
    'lightweight_sr' : {
        'args': {
            'upscale' : 2, # flexible
            'in_chans' : 3,
            'img_size' : 64,
            'window_size' : 8,
            'img_range' : 1.,
            'depths' :[6, 6, 6, 6],
            'embed_dim' : 60,
            'num_heads' : [6, 6, 6, 6],
            'mlp_ratio' : 2,
            'upsampler' : 'pixelshuffledirect',
            'resi_connection' : '1conv',
        },
        'path' : 'model_zoo/swinir/002_lightweightSR_DIV2K_s64w8_SwinIR-S_x2.pth', #match with upscale factor
    }
}

In [30]:
import cv2
import os
import numpy as np
import torch
device = 'cpu'

def get_image_pair(scale=2, folder_lq="testsets/Set5/LR_bicubic", folder_gt="testsets/Set5/HR", imgname="baby", imgext=".png"):
    window_size=8
    # 001 classical image sr/ 002 lightweight image sr (load lq-gt image pairs)

    gt_img_path = f'{folder_gt}/{imgname}{imgext}'
    lq_img_path = f'{folder_lq}/X{scale}/{imgname}x{scale}{imgext}'

    print(f'gt: {gt_img_path} lq: {lq_img_path}')
    img_gt = cv2.imread(gt_img_path, cv2.IMREAD_COLOR).astype(np.float32) / 255.

    img_lq = cv2.imread(lq_img_path, cv2.IMREAD_COLOR).astype(
        np.float32) / 255.

    img_lq = np.transpose(img_lq if img_lq.shape[2] == 1 else img_lq[:, :, [2, 1, 0]], (2, 0, 1))  # HCW-BGR to CHW-RGB
    img_lq = torch.from_numpy(img_lq).float().unsqueeze(0).to(device)  # CHW-RGB to NCHW-RGB

    _, _, h_old, w_old = img_lq.size()
    h_pad = (h_old // window_size + 1) * window_size - h_old
    w_pad = (w_old // window_size + 1) * window_size - w_old
    img_lq = torch.cat([img_lq, torch.flip(img_lq, [2])], 2)[:, :, :h_old + h_pad, :]
    img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad]
    
    return imgname, img_lq, img_gt

In [41]:
import torch
from models.network_swinir import SwinIR as net

TASK_NAME = 'lightweight_sr'


model_info = config[TASK_NAME]
model = net(**model_info['args'])
param_key_g = 'params'

pretrained_model = torch.load(model_info['path'])
model.load_state_dict(pretrained_model[param_key_g] if param_key_g in pretrained_model.keys() else pretrained_model, strict=True)



<All keys matched successfully>

In [42]:
print(f"{model.flops() / 1e9:.2f} GFLOPS")

4.19 GFLOPS


In [43]:
from thop import profile, clever_format
macs, params = clever_format(profile(model, inputs=(torch.Tensor(1,3,48,48), )))

print(f"MACs: {macs} Params: {params}")

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_softmax() for <class 'torch.nn.modules.activation.Softmax'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pixelshuffle.PixelShuffle'>.
MACs: 2.08G Params: 877.75K


Visualizing

In [None]:
import numpy as np
import imageio.v2 as imageio

im_arr = imageio.imread("/content/drive/MyDrive/SwinIR/testsets/Set5/LR_bicubic/X2/birdx2.png")
im_arr

In [None]:
%pwd

In [None]:
!git init