# Training

In [None]:
# require two of 3090, 4~5 days

# ! python3 ../Uformer-RSBlur/train/train_RealisticGoProABMEDeblur.py --arch Uformer_B --batch_size 8 --gpu '0' \
# --train_ps 256 --train_dir ../Uformer-RSBlur/datasets/GOPRO_INTER_ABME \
# --val_ps 256 --val_dir ../Uformer-RSBlur/datasets/RealBlurJ_test --env _RealisticGoProABMEDeblur \
# --mode deblur --nepoch 1500 --checkpoint 100 --dataset GoPro --warmup --train_workers 12

# ! python3 ../Uformer-RSBlur/train/train_NaiveGoProABMEDeblur.py --arch Uformer_B --batch_size 8 --gpu '0' \
# --train_ps 256 --train_dir ../Uformer-RSBlur/datasets/GOPRO_INTER_ABME \
# --val_ps 256 --val_dir ../Uformer-RSBlur/datasets/RealBlurJ_test --env _NaiveGoProABMEDeblur \
# --mode deblur --nepoch 1500 --checkpoint 100 --dataset GoPro --warmup --train_workers 12

# Test trained models

| Models | Train set | Realistic Pipeline | PSNR / SSIM    |
| :---:|:---:  |  :---:|:---:|
| Uformer-B |   GoPro |     | 00.00 / 0.0000 |
| Uformer-B |  GoPro  |✓     | 30.98 / 0.9067 |
| Uformer-B | GoPro_U |     | 31.19 / 0.9143 |
| Uformer-B | GoPro_U |  ✓   | 28.93 / 0.8673 |


In [None]:
import os
import sys
import math
import matplotlib.pyplot as plt

sys.path.insert(0, "../Uformer-RSBlur")
from dataset.dataset_motiondeblur import *
import utils

def expand2square_reflect(timg, factor=128):
    _, _, h, w = timg.size()
    X = int(math.ceil(max(h, w) / float(factor)) * factor)

    pad_w = X - w
    pad_h = X - h

    img = F.pad(timg, (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2), 'reflect')
    mask = torch.zeros(1, 1, X, X).type_as(timg)

    mask[:, :, ((X - h) // 2):((X - h) // 2 + h), ((X - w) // 2):((X - w) // 2 + w)].fill_(1)

    return img, mask

class Args():
    
    def __init__(self):
        pass

# 시각화 코드
def viz_two_images(img1, img2, title1, title2):
    fig = plt.figure(figsize=(16, 16))
    rows = 1; cols = 2

    ax1 = fig.add_subplot(rows, cols, 1)
    ax1.set_title(title1)
    
    if len(img1.shape) == 2:
        image1 = ax1.imshow(img1, 'gray')
    else:
        image1 = ax1.imshow(img1)

    ax2 = fig.add_subplot(rows, cols, 2)
    ax2.set_title(title2)
    
    if len(img2.shape) == 2:
        image2 = ax2.imshow(img2, 'gray')
    else:
        image2 = ax2.imshow(img2)
    fig.tight_layout()
    plt.show()
    
rgb_dir_test = os.path.join('../Uformer-RSBlur/datasets', 'RealBlurJ_test', 'test', 'input')
test_dataset = get_test_data(rgb_dir_test, img_options={})

In [None]:
opt = Args()
opt.arch = 'Uformer_B'
opt.train_ps = 256
opt.dd_in = 3

# training with our pipeline, 31.19 PSNR
weight_path = '../Uformer-RSBlur/logs/Uformer_B_RealisticGoProUDeblur.pth'
model_restoration = utils.get_arch(opt)
utils.load_checkpoint(model_restoration, weight_path)
model_restoration.cuda()
model_restoration.eval()

# naive training, 28.93 PSNR
weight_path = '../Uformer-RSBlur/logs/Uformer_B_NaiveGoProUDeblur.pth'
model_restoration_naive = utils.get_arch(opt)
utils.load_checkpoint(model_restoration_naive, weight_path)
model_restoration_naive.cuda()
model_restoration_naive.eval()

In [None]:
data_test = test_dataset[18]

with torch.no_grad():
    torch.cuda.ipc_collect()
    torch.cuda.empty_cache()

    input_    = data_test[0].cuda().unsqueeze(0)
    _, _, h, w = input_.shape
    filenames = data_test[1]

    input_, mask = expand2square_reflect(input_, factor=128)

    restored = model_restoration(input_)
    restored = torch.masked_select(restored,mask.bool()).reshape(1,3,h,w)
    restored = torch.clamp(restored,0,1).cpu().numpy().squeeze().transpose((1,2,0))
    
    restored_naive = model_restoration_naive(input_)
    restored_naive = torch.masked_select(restored_naive,mask.bool()).reshape(1,3,h,w)
    restored_naive = torch.clamp(restored_naive,0,1).cpu().numpy().squeeze().transpose((1,2,0))
    
viz_two_images(restored, restored_naive, 'restored', 'restored_naive')