In [6]:
import argparse
import time
import math

import torch
from torch import nn

from PIL import Image
from torch.autograd import Variable
from torchvision.transforms import ToTensor, ToPILImage



class Generator32(nn.Module):
    def __init__(self, scale_factor):
        upsample_block_num = int(math.log(scale_factor, 2))

        super(Generator32, self).__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=9, padding=4),
            nn.PReLU()
        )
        self.block2 = ResidualBlock(32)
        self.block3 = ResidualBlock(32)
        self.block4 = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32)
        )
        block5 = [UpsampleBLock(32, 2) for _ in range(upsample_block_num)]
        block5.append(nn.Conv2d(32, 3, kernel_size=9, padding=4))
        self.block5 = nn.Sequential(*block5)

    def forward(self, x):
        block1 = self.block1(x)
        block2 = self.block2(block1)
        block3 = self.block3(block2)
        block4 = self.block4(block3)
        block5 = self.block5(block1 + block4)

        return (torch.tanh(block5) + 1) / 2

class Generator16(nn.Module):
    def __init__(self, scale_factor):
        upsample_block_num = int(math.log(scale_factor, 2))

        super(Generator16, self).__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=9, padding=4),
            nn.PReLU()
        )
        self.block2 = ResidualBlock(16)
        self.block3 = ResidualBlock(16)
        self.block4 = nn.Sequential(
            nn.Conv2d(16, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16)
        )
        block5 = [UpsampleBLock(16, 2) for _ in range(upsample_block_num)]
        block5.append(nn.Conv2d(16, 3, kernel_size=9, padding=4))
        self.block5 = nn.Sequential(*block5)

    def forward(self, x):
        block1 = self.block1(x)
        block2 = self.block2(block1)
        block3 = self.block3(block2)
        block4 = self.block4(block3)
        block5 = self.block5(block1 + block4)

        return (torch.tanh(block5) + 1) / 2

class Generator8(nn.Module):
    def __init__(self, scale_factor):
        upsample_block_num = int(math.log(scale_factor, 2))

        super(Generator8, self).__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(3, 8, kernel_size=9, padding=4),
            nn.PReLU()
        )
        self.block2 = ResidualBlock(8)
        self.block3 = nn.Sequential(
            nn.Conv2d(8, 8, kernel_size=3, padding=1),
            nn.BatchNorm2d(8)
        )
        block4 = [UpsampleBLock(8, 2) for _ in range(upsample_block_num)]
        block4.append(nn.Conv2d(8, 3, kernel_size=9, padding=4))
        self.block4 = nn.Sequential(*block4)

    def forward(self, x):
        block1 = self.block1(x)
        block2 = self.block2(block1)
        block3 = self.block3(block2)
        block4 = self.block4(block1 + block3)

        return (torch.tanh(block4) + 1) / 2


class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.prelu = nn.PReLU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = self.conv1(x)
        residual = self.bn1(residual)
        residual = self.prelu(residual)
        residual = self.conv2(residual)
        residual = self.bn2(residual)

        return x + residual

class UpsampleBLock(nn.Module):
    def __init__(self, in_channels, up_scale):
        super(UpsampleBLock, self).__init__()
        self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(up_scale)
        self.prelu = nn.PReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        x = self.prelu(x)
        return x

# parser = argparse.ArgumentParser(description='Test Single Image')
# parser.add_argument('--upscale_factor', default=4, type=int, help='super resolution upscale factor')
# parser.add_argument('--test_mode', default='GPU', type=str, choices=['GPU', 'CPU'], help='using GPU or CPU')
# parser.add_argument('--image_name', type=str, help='test low resolution image name')
# parser.add_argument('--model_name', default='netG_epoch_4_100.pth', type=str, help='generator model epoch name')
# opt = parser.parse_args()

# UPSCALE_FACTOR = opt.upscale_factor
# TEST_MODE = True if opt.test_mode == 'GPU' else False
# IMAGE_NAME = opt.image_name
# MODEL_NAME = opt.model_name

UPSCALE_FACTOR = 4
TEST_MODE = True
IMAGE_NAME = '/content/drive/My Drive/data/test/LR512.png'     # upload file rather than fetching files from google drvie will significantly improve inference speed
MODEL_NAME = '/content/drive/My Drive/data/test/netG8.pth'

model = Generator8(UPSCALE_FACTOR).eval()
if TEST_MODE:
    model.cuda()
    model.load_state_dict(torch.load(MODEL_NAME))
else:
    model.load_state_dict(torch.load(MODEL_NAME, map_location=lambda storage, loc: storage))

image = Image.open(IMAGE_NAME)
image = Variable(ToTensor()(image), volatile=True).unsqueeze(0)
if TEST_MODE:
    image = image.cuda()

start = time.clock()
out = model(image)
elapsed = (time.clock() - start)
print('cost' + str(elapsed) + 's')
out_img = ToPILImage()(out[0].data.cpu())
out_img.save("/content/drive/My Drive/data/test/8_SR.png")



cost0.0015089999999986503s


In [None]:
import torch

print(torch.__version__)

1.6.0+cu101
