In [120]:
import torch
import torch.cuda as cuda
from torch.utils.data import Dataset
from torchvision import transforms
from torch.utils.data import DataLoader

import numpy as np
from PIL import Image
from pathlib import Path
from tqdm import tqdm

from model import DeepRecursiveTransformer

In [121]:
class RainData(Dataset):
    def __init__(self, dataset_dir):
        super().__init__()
        self.img_transforms = self.build_transform()
        self.datasets = []
        p = Path(dataset_dir)
        for ext in ['png', 'jpg']:
            self.datasets.extend(p.glob(f'*.{ext}'))
            # self.datasets.extend(map(lambda path: path.as_posix(), p.glob(f'*.{ext}')))

    def build_transform(self):
        t = []
        t.append(transforms.ToTensor()) #convert (B, H, W, C) from [0,255] to (B, C, H, W) [0. ,1.]
        return transforms.Compose(t)

    def __len__(self):
        return len(self.datasets)

    def __getitem__(self, idx):
        path = self.datasets[idx]
        img = Image.open(path)
        img = self.img_transforms(img)
        return img, path.name

In [122]:
# Model
dim = 96
patch_size = 1
local_window_dim = patch_size * 7
residual_depth = 3
recursive_depth = 6
ckp_path = './pretrained/best_model.pt'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(device)

cuda


In [123]:
# Dataset
in_path = './datasets/Test100/input/'
out_path = './datasets/Test100/output/'
Path(out_path).mkdir(exist_ok=True)

test_data_loader = DataLoader(RainData(in_path), batch_size=1, shuffle=False)

In [124]:
for net_input, img_name in tqdm(test_data_loader):
    ### pad the image to make sure H == W fits the network requirements
    _, _, h_old, w_old = net_input.size()
    h_original = h_old
    w_original = w_old
    multiplier = max(h_old // local_window_dim + 1, w_old // local_window_dim + 1)
    h_pad = (multiplier) * local_window_dim - h_old
    w_pad = (multiplier) * local_window_dim - w_old
    net_input = torch.cat([net_input, torch.flip(net_input, [2])], 2)[:, :, :h_old + h_pad, :]
    net_input = torch.cat([net_input, torch.flip(net_input, [3])], 3)[:, :, :, :w_old + w_pad]
    
    ## pad again if h/w or w/h ratio is bigger than 2
    if h_pad > h_old or w_pad > w_old:
        _, _, h_old, w_old = net_input.size()
        multiplier = max(h_old // local_window_dim + 1, w_old // local_window_dim + 1)
        h_pad = (multiplier) * local_window_dim - h_old
        w_pad = (multiplier) * local_window_dim - w_old
        net_input = torch.cat([net_input, torch.flip(net_input, [2])], 2)[:, :, :h_old + h_pad, :]
        net_input = torch.cat([net_input, torch.flip(net_input, [3])], 3)[:, :, :, :w_old + w_pad]
        
    ### evaluate | load model with each image
    _, _, new_h, new_w = net_input.size()
    assert new_h == new_w, "Input image should have square dimension"
    eval_net = DeepRecursiveTransformer(dim, (new_h, new_w), patch_size, residual_depth, recursive_depth)
    eval_net.load_state_dict(torch.load(ckp_path)['state_dict'])
    eval_net.to(device)
    eval_net.eval()
    
    net_input = net_input.cuda()
    with torch.no_grad():
        net_output = eval_net(net_input)
        
    ### crop the output
    net_output = net_output[:, :, :h_original, :w_original]
    output_data = net_output.cpu().detach().numpy() # B C H W
    output_data = np.transpose(output_data, (0, 2, 3, 1)) # B H W C
    output_data = np.clip(output_data * 255, 0, 255).astype(np.uint8)
    
    # Save
    img = Image.fromarray(output_data[0])
    img.save(Path(out_path, img_name[0]))
    
    cuda.empty_cache()

100%|██████████████████████████████████████████████████████████████████████████████████| 98/98 [06:17<00:00,  3.85s/it]
