In [1]:
import argparse
import time
import os
import torch
from PIL import Image
from torch.autograd import Variable
from torchvision.transforms import ToTensor, ToPILImage
import torch.nn.functional as F
from model import Generator
import numpy as np

In [2]:
UPSCALE_FACTOR = 4
TEST_IMAGE = 'frame63063.jpg'
TEST_PATH = './data/test/'
OUTPUT_PATH = './test_output/'
if not os.path.exists(OUTPUT_PATH):
    os.makedirs(OUTPUT_PATH)
MODEL = './epochs/netG_epoch_4_91.pth'

In [3]:
model = Generator(UPSCALE_FACTOR).eval()

model.load_state_dict(torch.load(MODEL, map_location="cuda:0"))
model = model.cuda()

In [8]:
image = Image.open(TEST_PATH + TEST_IMAGE)
with torch.no_grad():
    image = ToTensor()(image).unsqueeze(0)
    image = F.interpolate(image, size=(450, 800), mode='bicubic', align_corners=False)
    print(image.size())
image = image.cuda()

torch.Size([1, 3, 450, 800])


In [5]:
start = time.process_time()
out = model(image)
elapsed = time.process_time() - start
print('process time : ' + str(elapsed) + 's')
out = F.interpolate(out, size=(1440, 2560), mode='bicubic', align_corners=False)
out_img = ToPILImage()(out[0].data.cpu())
out_img.save(OUTPUT_PATH + TEST_IMAGE)

process time : 2.265625s


In [18]:
import matplotlib.pyplot as plt
import os
import cv2

os.environ['KMP_DUPLICATE_LIB_OK']='True'

image = Image.open(TEST_PATH + TEST_IMAGE)
image = np.array(image)
print(image.shape)
input_h = 720
input_w = 1280
scaled_h = (input_h//88+1)*88
scaled_w = (input_w//88+1)*88
pad_h = scaled_h-input_h
pad_w = scaled_w-input_w
image = np.pad(image, [(0, pad_h),(0, pad_w),(0,0)], 'constant')
print(scaled_h, scaled_w)
out_img = np.zeros((3,scaled_h*UPSCALE_FACTOR, scaled_w*UPSCALE_FACTOR))
print(out_img.shape)
for i in range(0, (720//88)+1):
    for j in range(0, (1280//88)+1):
        tmp_img = ToTensor()(image).unsqueeze(0)[:,:,i*88:(i+1)*88,j*88:(j+1)*88]
        tmp_img = tmp_img.cuda()
        tmp_out = model(tmp_img).detach().cpu().numpy().squeeze()
        out_img[:,i*88*UPSCALE_FACTOR:(i+1)*88*UPSCALE_FACTOR,j*88*UPSCALE_FACTOR:(j+1)*88*UPSCALE_FACTOR] = tmp_out
out_img = (out_img.transpose(1,2,0)*255).astype(np.uint8)[:input_h*4, :input_w*4, :]
out_img = cv2.resize(out_img, dsize=(2560, 1440), interpolation=cv2.INTER_CUBIC)
print(out_img.shape)
pil_img = Image.fromarray(out_img)
pil_img.show()


(720, 1280, 3)
792 1320
(3, 3168, 5280)
(1440, 2560, 3)
