In [1]:
import torch
print(torch.__version__)
print(torch.cuda.is_available())

1.7.0+cu101
True


In [7]:
import torch
from torch import nn
from torch.utils.mobile_optimizer import optimize_for_mobile
import math
import time
from PIL import Image
from torch.autograd import Variable
from torchvision.transforms import ToTensor, ToPILImage

class Generator2X(torch.nn.Module):
    def __init__(self):

        super(Generator2X, self).__init__()

        self.conv1_1 = torch.nn.Conv2d(3, 8, kernel_size=3, padding=1).to(dtype=torch.float)
        self.bn1_1 = torch.nn.BatchNorm2d(8).to(dtype=torch.float)
        self.relu1_1 = torch.nn.ReLU()

        self.conv2_1 = torch.nn.Conv2d(8, 8, kernel_size=3, padding=1).to(dtype=torch.float)
        self.bn2_1 = torch.nn.BatchNorm2d(8).to(dtype=torch.float)
        self.relu2_1 = torch.nn.ReLU()

        self.conv3_1 = torch.nn.Conv2d(8, 32, kernel_size=3, padding=1).to(dtype=torch.float)
        self.relu3_1 = torch.nn.ReLU()
        self.pixel_shuffle3_1 = torch.nn.PixelShuffle(2)

        self.conv4_1 = torch.nn.Conv2d(8, 3, kernel_size=3, padding=1).to(dtype=torch.float)

        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x.contiguous(memory_format=torch.channels_last)
        x = self.quant(x)

        x = self.conv1_1(x)
        x = self.bn1_1(x)
        x = self.relu1_1(x)

        x = self.conv2_1(x)
        x = self.bn2_1(x)
        x = self.relu2_1(x)

        x = self.conv3_1(x)
        x = self.relu3_1(x)
        x = self.pixel_shuffle3_1(x)

        x = self.conv4_1(x)
        
        out = (torch.tanh(x) + 1) / 2

        out = self.dequant(out)

        return out


class Generator4X(torch.nn.Module):
    def __init__(self):

        super(Generator4X, self).__init__()

        self.conv1_1 = torch.nn.Conv2d(3, 8, kernel_size=3, padding=1).to(dtype=torch.float)
        self.bn1_1 = torch.nn.BatchNorm2d(8).to(dtype=torch.float)
        self.relu1_1 = torch.nn.ReLU()

        self.conv2_1 = torch.nn.Conv2d(8, 8, kernel_size=3, padding=1).to(dtype=torch.float)
        self.bn2_1 = torch.nn.BatchNorm2d(8).to(dtype=torch.float)
        self.relu2_1 = torch.nn.ReLU()

        self.conv3_1 = torch.nn.Conv2d(8, 32, kernel_size=3, padding=1).to(dtype=torch.float)
        self.relu3_1 = torch.nn.ReLU()
        self.pixel_shuffle3_1 = torch.nn.PixelShuffle(2)

        self.conv4_1 = torch.nn.Conv2d(8, 32, kernel_size=3, padding=1).to(dtype=torch.float)
        self.relu4_1 = torch.nn.ReLU()
        self.pixel_shuffle4_1 = torch.nn.PixelShuffle(2)

        self.conv5_1 = torch.nn.Conv2d(8, 3, kernel_size=3, padding=1).to(dtype=torch.float)

        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x.contiguous(memory_format=torch.channels_last)
        x = self.quant(x)

        x = self.conv1_1(x)
        x = self.bn1_1(x)
        x = self.relu1_1(x)

        x = self.conv2_1(x)
        x = self.bn2_1(x)
        x = self.relu2_1(x)

        x = self.conv3_1(x)
        x = self.relu3_1(x)
        x = self.pixel_shuffle3_1(x)

        x = self.conv4_1(x)
        x = self.relu4_1(x)
        x = self.pixel_shuffle4_1(x)

        x = self.conv5_1(x)
        
        out = (torch.tanh(x) + 1) / 2

        out = self.dequant(out)

        return out

class Generator8X(torch.nn.Module):
    def __init__(self):

        super(Generator8X, self).__init__()

        self.conv1_1 = torch.nn.Conv2d(3, 8, kernel_size=3, padding=1).to(dtype=torch.float)
        self.bn1_1 = torch.nn.BatchNorm2d(8).to(dtype=torch.float)
        self.relu1_1 = torch.nn.ReLU()

        self.conv2_1 = torch.nn.Conv2d(8, 8, kernel_size=3, padding=1).to(dtype=torch.float)
        self.bn2_1 = torch.nn.BatchNorm2d(8).to(dtype=torch.float)
        self.relu2_1 = torch.nn.ReLU()

        self.conv3_1 = torch.nn.Conv2d(8, 32, kernel_size=3, padding=1).to(dtype=torch.float)
        self.relu3_1 = torch.nn.ReLU()
        self.pixel_shuffle3_1 = torch.nn.PixelShuffle(2)

        self.conv4_1 = torch.nn.Conv2d(8, 32, kernel_size=3, padding=1).to(dtype=torch.float)
        self.relu4_1 = torch.nn.ReLU()
        self.pixel_shuffle4_1 = torch.nn.PixelShuffle(2)

        self.conv5_1 = torch.nn.Conv2d(8, 32, kernel_size=3, padding=1).to(dtype=torch.float)
        self.relu5_1 = torch.nn.ReLU()
        self.pixel_shuffle5_1 = torch.nn.PixelShuffle(2)

        self.conv6_1 = torch.nn.Conv2d(8, 3, kernel_size=3, padding=1).to(dtype=torch.float)

        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x.contiguous(memory_format=torch.channels_last)
        x = self.quant(x)

        x = self.conv1_1(x)
        x = self.bn1_1(x)
        x = self.relu1_1(x)

        x = self.conv2_1(x)
        x = self.bn2_1(x)
        x = self.relu2_1(x)

        x = self.conv3_1(x)
        x = self.relu3_1(x)
        x = self.pixel_shuffle3_1(x)

        x = self.conv4_1(x)
        x = self.relu4_1(x)
        x = self.pixel_shuffle4_1(x)

        x = self.conv5_1(x)
        x = self.relu5_1(x)
        x = self.pixel_shuffle5_1(x)

        x = self.conv6_1(x)
        
        out = (torch.tanh(x) + 1) / 2

        out = self.dequant(out)

        return out


model = Generator8X()
model.eval()
# model.cuda()

# image = Image.open('/content/drive/My Drive/data/test/HR.png')
# image = Variable(ToTensor()(image), volatile=True).unsqueeze(0)
# image = image.cuda()
# image = image.view(64,3,256,256)
# print(image.size())

# start = time.clock()
# out = model(image)
# elapsed = (time.clock() - start)
# print('cost' + str(elapsed) + 's')

# model.load_state_dict(torch.load("/content/drive/My Drive/data/test/netG8.pth"))

# 2X
# fused_model = torch.quantization.fuse_modules(model, [['conv1_1', 'bn1_1', 'relu1_1'],
#                             ['conv2_1', 'bn2_1', 'relu2_1'],
#                             ['conv3_1', 'relu3_1']],inplace=True)

# 4X
# fused_model = torch.quantization.fuse_modules(model, [['conv1_1', 'bn1_1', 'relu1_1'],
#                             ['conv2_1', 'bn2_1', 'relu2_1'],
#                             ['conv3_1', 'relu3_1'],
#                             ['conv4_1', 'relu4_1']],inplace=True)

# 8X
fused_model = torch.quantization.fuse_modules(model, [['conv1_1', 'bn1_1', 'relu1_1'],
                            ['conv2_1', 'bn2_1', 'relu2_1'],
                            ['conv3_1', 'relu3_1'],
                            ['conv4_1', 'relu4_1'],
                            ['conv5_1', 'relu5_1']],inplace=True)


# fused_model = torch.quantization.fuse_modules(model, [['conv1_1', 'bn1_1', 'relu1_1'],
#                             ['conv2_1', 'bn2_1', 'relu2_1'],
#                             ['conv3_1', 'bn3_1']],inplace=True)

# fused_model = torch.quantization.fuse_modules(model, [['conv1', 'bn1', 'relu1'],
#                             ['conv2', 'bn2', 'relu2'],
#                             ['conv3', 'bn3',]],inplace=True)
# print(fused_model)
type_to_quantize = {torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU}
quantized_model = torch.quantization.quantize_dynamic(fused_model, type_to_quantize, dtype = torch.qint8)
print(quantized_model)

torchscript_model = torch.jit.script(quantized_model)

torchscript_model_optimized = optimize_for_mobile(torchscript_model)

torch.jit.save(torchscript_model_optimized, "8xSRGANmodel.pt")



Generator8X(
  (conv1_1): ConvReLU2d(
    (0): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
  )
  (bn1_1): Identity()
  (relu1_1): Identity()
  (conv2_1): ConvReLU2d(
    (0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
  )
  (bn2_1): Identity()
  (relu2_1): Identity()
  (conv3_1): ConvReLU2d(
    (0): Conv2d(8, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
  )
  (relu3_1): Identity()
  (pixel_shuffle3_1): PixelShuffle(upscale_factor=2)
  (conv4_1): ConvReLU2d(
    (0): Conv2d(8, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
  )
  (relu4_1): Identity()
  (pixel_shuffle4_1): PixelShuffle(upscale_factor=2)
  (conv5_1): ConvReLU2d(
    (0): Conv2d(8, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
  )
  (relu5_1): Identity()
  (pixel_shuffle5_1): PixelShuffle(upscale_factor=2)
  (conv6_1): Conv2d(8, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1

In [13]:
pixel_shuffle = nn.PixelShuffle(8)
input = torch.randn(1, 3*64, 256, 256)


start = time.clock()
output = pixel_shuffle(input)
elapsed = (time.clock() - start)
print('cost' + str(elapsed) + 's')
print(output.size())

cost0.028972999999999693s
torch.Size([1, 3, 2048, 2048])


In [19]:
input = torch.randn(1, 3*4, 4, 4)

print(input)
print(torch.reshape(input, (1, 3, 8, 8)))

tensor([[[[-2.0880, -1.3108, -1.2320, -0.5744],
          [ 1.0028,  0.3629,  0.8612, -0.0603],
          [ 0.0400,  1.4703,  1.4256, -1.1825],
          [-0.2345,  0.9554, -0.9627,  0.2814]],

         [[-0.6836, -0.4423, -0.2498,  2.3824],
          [-0.8862, -1.2942,  1.0113,  0.3052],
          [-0.4126, -2.4467, -0.2619, -1.4474],
          [-1.0889, -2.3128,  0.5568, -0.4500]],

         [[-0.5391,  1.0020,  1.4213, -0.4633],
          [-0.2724, -0.6767,  1.3044,  0.1974],
          [-1.1362,  0.2133, -0.4271,  0.9641],
          [-1.7120,  1.6189, -1.1301, -1.7036]],

         [[-1.1017,  0.7903,  0.6104,  0.3426],
          [-0.1507, -0.6247, -1.6053,  0.2127],
          [ 0.0465, -0.6396,  1.5956, -0.4067],
          [-2.0572,  1.9096,  0.3948, -2.6080]],

         [[ 0.6629, -1.6017,  0.8885,  0.5681],
          [ 1.1213, -0.0567, -1.1208,  0.2856],
          [-1.0980,  1.1732,  0.5369, -1.9119],
          [-2.1718,  0.2090, -1.0048,  0.7340]],

         [[-0.3422, -1.6605,  