In [15]:
import torch
from torch.utils.data import Dataset, DataLoader
import cv2
from glob import glob
import numpy as np
from arch import HAT
from tqdm import tqdm

img_list = glob("HAT_official/datasets/data/test/lr/*.png")
img_list

['../HAT_official/datasets/data/test/lr/20007.png',
 '../HAT_official/datasets/data/test/lr/20010.png',
 '../HAT_official/datasets/data/test/lr/20004.png',
 '../HAT_official/datasets/data/test/lr/20014.png',
 '../HAT_official/datasets/data/test/lr/20002.png',
 '../HAT_official/datasets/data/test/lr/20012.png',
 '../HAT_official/datasets/data/test/lr/20001.png',
 '../HAT_official/datasets/data/test/lr/20013.png',
 '../HAT_official/datasets/data/test/lr/20008.png',
 '../HAT_official/datasets/data/test/lr/20005.png',
 '../HAT_official/datasets/data/test/lr/20009.png',
 '../HAT_official/datasets/data/test/lr/20011.png',
 '../HAT_official/datasets/data/test/lr/20015.png',
 '../HAT_official/datasets/data/test/lr/20000.png',
 '../HAT_official/datasets/data/test/lr/20017.png',
 '../HAT_official/datasets/data/test/lr/20006.png',
 '../HAT_official/datasets/data/test/lr/20016.png',
 '../HAT_official/datasets/data/test/lr/20003.png']

# 이미지 한장 통째로 Inference

In [76]:
input_imgs

['../HAT_official/datasets/data/test/lr/20007.png',
 '../HAT_official/datasets/data/test/lr/20010.png',
 '../HAT_official/datasets/data/test/lr/20004.png',
 '../HAT_official/datasets/data/test/lr/20014.png',
 '../HAT_official/datasets/data/test/lr/20002.png',
 '../HAT_official/datasets/data/test/lr/20012.png',
 '../HAT_official/datasets/data/test/lr/20001.png',
 '../HAT_official/datasets/data/test/lr/20013.png',
 '../HAT_official/datasets/data/test/lr/20008.png',
 '../HAT_official/datasets/data/test/lr/20005.png',
 '../HAT_official/datasets/data/test/lr/20009.png',
 '../HAT_official/datasets/data/test/lr/20011.png',
 '../HAT_official/datasets/data/test/lr/20015.png',
 '../HAT_official/datasets/data/test/lr/20000.png',
 '../HAT_official/datasets/data/test/lr/20017.png',
 '../HAT_official/datasets/data/test/lr/20006.png',
 '../HAT_official/datasets/data/test/lr/20016.png',
 '../HAT_official/datasets/data/test/lr/20003.png']

In [78]:


def load_img_to_tensor(img_path):
    img = cv2.imread(img_path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
    img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float() # BGR -> RGB
    return img

def test(model_path, input_img_paths, device='cuda'):
    
    metric = PSNR()
    model = HAT(upscale=4,
            in_chans=3,
            img_size=64,
            window_size=16,
            compress_ratio=3,
            squeeze_factor=30,
            conv_scale=0.01,
            overlap_ratio=0.5,
            img_range=1.,
            depths=(6,6,6,6,6,6,6,6,6,6,6,6),
            embed_dim=180,
            num_heads=(6,6,6,6,6,6,6,6,6,6,6,6),
            mlp_ratio=2,
            upsampler='pixelshuffle',
            resi_connection='1conv')
    
    state_dict = torch.load(model_path)
    model.load_state_dict(state_dict['params_ema'], strict=True)
    
    model.to(device)
    model.eval()
    scores = []
    save_path = os.path.join('HAT_official/datasets/data/test/inference', model_path.split('/')[-1][:-4])
    if not os.path.exists(save_path): os.makedirs(save_path)
    
    for img_path in tqdm(input_imgs, total=len(input_imgs), ncols=100):
        input_img = load_img_to_tensor(img_path)
        input_img = input_img.unsqueeze(0).to(device)
        
        # hr_img = img_path.split('/')[-1]
        # hr_img = cv2.imread(os.path.join('../HAT_official/datasets/data', hr_img), cv2.IMREAD_COLOR).astype(np.float32)
        # hr_img = torch.from_numpy(np.transpose(hr_img[:, :, [2, 1, 0]], (2, 0, 1))).float().to(device)
        
        with torch.no_grad():
            output = model(input_img)
            # scores.append(metric(hr_img, (output.squeeze(0)*255.0).round().type(torch.uint8)))
            
            # save
            output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
            if output.ndim == 3:
                output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
            output = (output * 255.0).round().astype(np.uint8)
            cv2.imwrite(os.path.join(save_path, img_path.split('/')[-1]), output)
    # return scores
            
input_imgs = glob("HAT_official/datasets/data/test/lr/*.png")
model_paths = glob("HAT_official/experiments/train_HAT-L_SRx4_finetune_from_ImageNet_pretrain/models/*.pth")
score_dict = {}
for model_path in model_paths[1:]:
    test(model_path, input_imgs, device='cuda')

    
score_dict

100%|███████████████████████████████████████████████████████████████| 18/18 [01:54<00:00,  6.38s/it]
100%|███████████████████████████████████████████████████████████████| 18/18 [01:55<00:00,  6.42s/it]
100%|███████████████████████████████████████████████████████████████| 18/18 [01:55<00:00,  6.42s/it]
100%|███████████████████████████████████████████████████████████████| 18/18 [01:55<00:00,  6.42s/it]
100%|███████████████████████████████████████████████████████████████| 18/18 [01:55<00:00,  6.42s/it]
100%|███████████████████████████████████████████████████████████████| 18/18 [01:55<00:00,  6.42s/it]
100%|███████████████████████████████████████████████████████████████| 18/18 [01:55<00:00,  6.41s/it]
100%|███████████████████████████████████████████████████████████████| 18/18 [01:55<00:00,  6.42s/it]
100%|███████████████████████████████████████████████████████████████| 18/18 [01:55<00:00,  6.42s/it]
100%|███████████████████████████████████████████████████████████████| 18/18 [01:55<00:00,  

{}

# sub image로 나눠서 inference 120x120

In [90]:


def load_img_to_tensor(img_path):
    img = cv2.imread(img_path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
    img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float() # BGR -> RGB
    return img

def test(model_path, input_img_paths, device='cuda'):
    
    metric = PSNR()
    model = HAT(upscale=4,
            in_chans=3,
            img_size=64,
            window_size=16,
            compress_ratio=3,
            squeeze_factor=30,
            conv_scale=0.01,
            overlap_ratio=0.5,
            img_range=1.,
            depths=(6,6,6,6,6,6,6,6,6,6,6,6),
            embed_dim=180,
            num_heads=(6,6,6,6,6,6,6,6,6,6,6,6),
            mlp_ratio=2,
            upsampler='pixelshuffle',
            resi_connection='1conv')
    
    state_dict = torch.load(model_path)
    model.load_state_dict(state_dict['params_ema'], strict=True)
    
    model.to(device)
    model.eval()
    scores = []
    save_path = os.path.join('HAT_official/datasets/data/test/inference/128', model_path.split('/')[-1][:-4])
    if not os.path.exists(save_path): os.makedirs(save_path)
    
    for img_path in tqdm(input_imgs, total=len(input_imgs), ncols=100):
        full_img = load_img_to_tensor(img_path)
        _, h, w = full_img.shape
        input_img = torch.zeros((h//128*w//128, 3, 128, 128), dtype=torch.float32)
        n = 0
        for i in range(h//128):
            for j in range(w//128):
                input_img[n] += full_img[:,128*i:128*(i+1), 128*j:128*(j+1)]
                n+=1
                
        input_img = input_img.to(device)
        
        with torch.no_grad():
            sub_output = model(input_img)
            # scores.append(metric(hr_img, (output.squeeze(0)*255.0).round().type(torch.uint8)))
            
            # save
            sub_output = sub_output.cpu()
            output = torch.zeros((3,2048,2048), dtype=torch.float32)
            for i in range(sub_output.shape[0]):
                h = i // 4
                w = i % 4
                output[:, 512*h:512*(h+1), 512*w:512*(w+1)] += sub_output[i]
            
            output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
            if output.ndim == 3:
                output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
            output = (output * 255.0).round().astype(np.uint8)
            cv2.imwrite(os.path.join(save_path, img_path.split('/')[-1]), output)
    # return scores
            
input_imgs = glob("HAT_official/datasets/data/test/lr/*.png")
model_paths = glob("HAT_official/experiments/train_HAT-L_SRx4_finetune_from_ImageNet_pretrain/models/*.pth")
for model_path in model_paths[1:]:
    print('start', model_path)
    test(model_path, input_imgs, device='cuda')

start ../HAT_official/experiments/train_HAT-L_SRx4_finetune_from_ImageNet_pretrain/models/net_g_60000.pth


100%|███████████████████████████████████████████████████████████████| 18/18 [01:26<00:00,  4.81s/it]


start ../HAT_official/experiments/train_HAT-L_SRx4_finetune_from_ImageNet_pretrain/models/net_g_40000.pth


100%|███████████████████████████████████████████████████████████████| 18/18 [01:26<00:00,  4.81s/it]


start ../HAT_official/experiments/train_HAT-L_SRx4_finetune_from_ImageNet_pretrain/models/net_g_75000.pth


100%|███████████████████████████████████████████████████████████████| 18/18 [01:26<00:00,  4.81s/it]


start ../HAT_official/experiments/train_HAT-L_SRx4_finetune_from_ImageNet_pretrain/models/net_g_50000.pth


100%|███████████████████████████████████████████████████████████████| 18/18 [01:26<00:00,  4.81s/it]


start ../HAT_official/experiments/train_HAT-L_SRx4_finetune_from_ImageNet_pretrain/models/net_g_45000.pth


100%|███████████████████████████████████████████████████████████████| 18/18 [01:26<00:00,  4.81s/it]


start ../HAT_official/experiments/train_HAT-L_SRx4_finetune_from_ImageNet_pretrain/models/net_g_10000.pth


100%|███████████████████████████████████████████████████████████████| 18/18 [01:26<00:00,  4.81s/it]


start ../HAT_official/experiments/train_HAT-L_SRx4_finetune_from_ImageNet_pretrain/models/net_g_5000.pth


100%|███████████████████████████████████████████████████████████████| 18/18 [01:26<00:00,  4.81s/it]


start ../HAT_official/experiments/train_HAT-L_SRx4_finetune_from_ImageNet_pretrain/models/net_g_65000.pth


100%|███████████████████████████████████████████████████████████████| 18/18 [01:26<00:00,  4.81s/it]


start ../HAT_official/experiments/train_HAT-L_SRx4_finetune_from_ImageNet_pretrain/models/net_g_70000.pth


100%|███████████████████████████████████████████████████████████████| 18/18 [01:26<00:00,  4.81s/it]


start ../HAT_official/experiments/train_HAT-L_SRx4_finetune_from_ImageNet_pretrain/models/net_g_30000.pth


100%|███████████████████████████████████████████████████████████████| 18/18 [01:26<00:00,  4.81s/it]


start ../HAT_official/experiments/train_HAT-L_SRx4_finetune_from_ImageNet_pretrain/models/net_g_85000.pth


100%|███████████████████████████████████████████████████████████████| 18/18 [01:26<00:00,  4.81s/it]


start ../HAT_official/experiments/train_HAT-L_SRx4_finetune_from_ImageNet_pretrain/models/net_g_25000.pth


100%|███████████████████████████████████████████████████████████████| 18/18 [01:26<00:00,  4.81s/it]


start ../HAT_official/experiments/train_HAT-L_SRx4_finetune_from_ImageNet_pretrain/models/net_g_20000.pth


100%|███████████████████████████████████████████████████████████████| 18/18 [01:26<00:00,  4.81s/it]


start ../HAT_official/experiments/train_HAT-L_SRx4_finetune_from_ImageNet_pretrain/models/net_g_55000.pth


100%|███████████████████████████████████████████████████████████████| 18/18 [01:26<00:00,  4.81s/it]


start ../HAT_official/experiments/train_HAT-L_SRx4_finetune_from_ImageNet_pretrain/models/net_g_35000.pth


100%|███████████████████████████████████████████████████████████████| 18/18 [01:26<00:00,  4.81s/it]


start ../HAT_official/experiments/train_HAT-L_SRx4_finetune_from_ImageNet_pretrain/models/net_g_80000.pth


100%|███████████████████████████████████████████████████████████████| 18/18 [01:26<00:00,  4.81s/it]


start ../HAT_official/experiments/train_HAT-L_SRx4_finetune_from_ImageNet_pretrain/models/net_g_90000.pth


100%|███████████████████████████████████████████████████████████████| 18/18 [01:26<00:00,  4.81s/it]


In [53]:
inf_img = cv2.imread(os.path.join(output_path, inf_img_name), cv2.IMREAD_COLOR)
inf_img
inf_img_name

[ WARN:0@16478.447] global /io/opencv/modules/imgcodecs/src/loadsave.cpp (239) findDecoder imread_('../HAT_official/datasets/data/test/inference/../HAT_official/datasets/data/test/inference/test_net_g_60000.pth.png'): can't open/read file: check file path/integrity


'../HAT_official/datasets/data/test/inference/test_net_g_60000.pth.png'

In [54]:
import os
import sys
import math
import torch
import numpy as np
import cv2

class PSNR:
    """Peak Signal to Noise Ratio
    img1 and img2 have range [0, 255]"""

    def __init__(self):
        self.name = "PSNR"

    @staticmethod
    def __call__(img1, img2):
        mse = torch.mean((img1 - img2) ** 2)
        return 20 * torch.log10(255.0 / torch.sqrt(mse))
    
# test_img = cv2.imread(img_list[0], cv2.IMREAD_COLOR).astype(np.float32)
metric = PSNR()
test_img_name = img_list[0].split('/')[-1]
test_img = cv2.imread(os.path.join('HAT_official/datasets/data', test_img_name), cv2.IMREAD_COLOR).astype(np.float32)
test_img = torch.from_numpy(np.transpose(test_img[:, :, [2, 1, 0]], (2, 0, 1))).float()
for inf_img_name in glob(f"{output_path}/*.png"):
    inf_img = cv2.imread(os.path.join(inf_img_name), cv2.IMREAD_COLOR).astype(np.float32)
    inf_img = torch.from_numpy(np.transpose(inf_img[:, :, [2, 1, 0]], (2, 0, 1))).float() # BGR -> RGB

    print(inf_img_name.split('/')[-1], metric(test_img, inf_img))


test_net_g_60000.pth.png tensor(28.8036)
test_net_g_45000.pth.png tensor(29.1339)
test_net_g_55000.pth.png tensor(29.0801)
test_net_g_30000.pth.png tensor(29.2564)
test_net_g_40000.pth.png tensor(29.0355)
test_net_g_25000.pth.png tensor(29.1738)
test_net_g_85000.pth.png tensor(28.8511)
test_net_g_50000.pth.png tensor(29.0925)
test_net_g_70000.pth.png tensor(28.8634)
test_net_g_80000.pth.png tensor(29.0505)
test_net_g_15000.pth.png tensor(29.3265)
test_net_g_75000.pth.png tensor(28.9665)
test_net_g_20000.pth.png tensor(29.2145)
test_net_g_90000.pth.png tensor(28.8833)
test_net_g_5000.pth.png tensor(29.6976)
test_net_g_10000.pth.png tensor(29.5264)
test_net_g_35000.pth.png tensor(29.2765)
test_net_g_65000.pth.png tensor(28.7833)


In [21]:
path = img_list[0]
device = 'cuda'
img = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float() # BGR -> RGB
img = img.unsqueeze(0).to(device)
img.shape

torch.Size([1, 3, 512, 512])

In [22]:
model = model.to(device)
with torch.no_grad():
    output = model(img)

In [23]:
output.shape

torch.Size([1, 3, 2048, 2048])

In [26]:
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output.shape

(3, 2048, 2048)

In [27]:
if output.ndim == 3:
    output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
output.shape

(2048, 2048, 3)

In [30]:
import os

In [31]:
output_path = 'HAT_official/datasets/data/test/inference'
output = (output * 255.0).round().astype(np.uint8)
cv2.imwrite(os.path.join(output_path, f'test.png'), output)

True

In [None]:
model.eval()

for idx, path in enumerate(img_list):
    # read image
    imgname = path.split('/')[-1]
    print('Testing', idx, imgname)
    # read image
    img = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
    img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float() # BGR -> RGB
    img = img.unsqueeze(0).to(device)
    
    # inference
    with torch.no_grad():
        output = model(img)