In [None]:
# 参考 https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html

In [8]:
import io
import numpy as np

from torch import nn
import torch.onnx

class Generator(nn.Module):
    def __init__(self):

        super(Generator, self).__init__()

        self.conv1_1 = torch.nn.Conv2d(3, 8, kernel_size=3, padding=1).to(dtype=torch.float)
        self.bn1_1 = torch.nn.BatchNorm2d(8).to(dtype=torch.float)
        self.relu1_1 = torch.nn.ReLU()

        self.conv2_1 = torch.nn.Conv2d(8, 8, kernel_size=3, padding=1).to(dtype=torch.float)
        self.bn2_1 = torch.nn.BatchNorm2d(8).to(dtype=torch.float)
        self.relu2_1 = torch.nn.ReLU()

        self.conv3_1 = torch.nn.Conv2d(8, 32, kernel_size=3, padding=1).to(dtype=torch.float)
        self.relu3_1 = torch.nn.ReLU()
        self.pixel_shuffle3_1 = torch.nn.PixelShuffle(2)

        self.conv4_1 = torch.nn.Conv2d(8, 32, kernel_size=3, padding=1).to(dtype=torch.float)
        self.relu4_1 = torch.nn.ReLU()
        self.pixel_shuffle4_1 = torch.nn.PixelShuffle(2)

        self.conv5_1 = torch.nn.Conv2d(8, 32, kernel_size=3, padding=1).to(dtype=torch.float)
        self.relu5_1 = torch.nn.ReLU()
        self.pixel_shuffle5_1 = torch.nn.PixelShuffle(2)

        self.conv6_1 = torch.nn.Conv2d(8, 3, kernel_size=3, padding=1).to(dtype=torch.float)

    def forward(self, x):
        x = self.conv1_1(x)
        x = self.bn1_1(x)
        x = self.relu1_1(x)

        x = self.conv2_1(x)
        x = self.bn2_1(x)
        x = self.relu2_1(x)

        x = self.conv3_1(x)
        x = self.relu3_1(x)
        x = self.pixel_shuffle3_1(x)

        x = self.conv4_1(x)
        x = self.relu4_1(x)
        x = self.pixel_shuffle4_1(x)

        x = self.conv5_1(x)
        x = self.relu5_1(x)
        x = self.pixel_shuffle5_1(x)

        x = self.conv6_1(x)
        
        out = (torch.tanh(x) + 1) / 2

        return out

torch_model = Generator()
torch_model.load_state_dict(torch.load("/content/drive/My Drive/data/epochs/netG_epoch_8_90.pth", map_location=lambda storage, loc: storage))
torch_model.eval()
x = torch.randn(1, 3, 256, 256, requires_grad=True)
torch_out = torch_model(x)
print(torch_out.size())

# Export the model
torch.onnx.export(torch_model,               # model being run
        x,                         # model input (or a tuple for multiple inputs)
        "/content/drive/My Drive/data/test/8xSRGANmodel.onnx",   # where to save the model (can be a file or file-like object)
        export_params=True,        # store the trained parameter weights inside the model file
        opset_version=10,          # the ONNX version to export the model to
        do_constant_folding=True,  # whether to execute constant folding for optimization
        input_names = ['input'],   # the model's input names
        output_names = ['output'], # the model's output names
        dynamic_axes={'input' : {0 : 'batch_size'},    # variable lenght axes
              'output' : {0 : 'batch_size'}})



torch.Size([1, 3, 2048, 2048])
