In [9]:
import torch
from torch import nn

In [10]:
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet34', pretrained=True)
resnet = nn.Sequential(*list(model.children())[:-2])


class BasicBlockDec(nn.Module):

    def __init__(self, shape):
        super().__init__()
        if shape == 512:
            shape2 = 512
        else:
            shape2 = int(shape * 2)
        
        self.convtrans1 = nn.ConvTranspose2d(shape2, shape, kernel_size=4, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(shape)
        self.convtrans2 = nn.ConvTranspose2d(shape, shape, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(shape)
        
    def forward(self, x):
        out1 = self.convtrans1(x)
        out2 = torch.relu(self.bn1(out1))
        out2 = self.convtrans2(out2)
        out2 = torch.relu(self.bn2(out2))
        final = torch.add(out1, out2)
        
        return final


class ResNet18Dec(nn.Module):

    def __init__(self, num_Blocks=[2,2,2,2], nc=3):
        super().__init__()
        self.layer1 = BasicBlockDec(512)
        self.layer2 = BasicBlockDec(256)
        self.layer3 = BasicBlockDec(128)
        self.layer4 = BasicBlockDec(64)
        self.conv1 = nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1)

#         self.conv1 = ResizeConv2d(64, nc, kernel_size=3, scale_factor=2)

    def _make_layer(self, BasicBlockDec, shape):
        return 
        strides = [stride] + [1]*(num_Blocks-1)
        layers = []
        for stride in reversed(strides):
            layers += [BasicBlockDec(self.in_planes, stride)]
        self.in_planes = planes
        return nn.Sequential(*layers)

    def forward(self, x):

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = torch.sigmoid(self.conv1(x))
        return x

class AutoEncoder(nn.Module):

    def __init__(self):
        super().__init__()
        self.encoder = resnet
        self.decoder = ResNet18Dec()
        self.conv1 = nn.Conv2d(512, 512, kernel_size=1, stride=1)
        self.conv2 = nn.Conv2d(512, 512, kernel_size=1, stride=1)

    def forward(self, x):
#         mean, logvar = self.encoder(x)
        x = self.encoder(x)
        x = torch.relu(self.conv1(x))
#         x = torch.relu(self.conv2(x))
#         z = self.reparameterize(mean, logvar)
        x = self.decoder(x)
        return x
    


Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


In [11]:
prefix = "resnet34AE_04211523_100epoch"
pkl_path = prefix + ".pkl"
model = AutoEncoder().cuda()
model.load_state_dict(torch.load(pkl_path))
model.eval()

AutoEncoder(
  (encoder): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=T

In [12]:
dummy_input = torch.randn(1, 3, 256, 256, device="cuda")

torch.onnx.export(model, dummy_input, f"{prefix}.onnx", verbose=True, input_names=["input"], output_names=["output"])

graph(%input : Float(1, 3, 256, 256, strides=[196608, 65536, 256, 1], requires_grad=0, device=cuda:0),
      %decoder.layer1.convtrans1.weight : Float(512, 512, 4, 4, strides=[8192, 16, 4, 1], requires_grad=1, device=cuda:0),
      %decoder.layer1.convtrans1.bias : Float(512, strides=[1], requires_grad=1, device=cuda:0),
      %decoder.layer1.bn1.weight : Float(512, strides=[1], requires_grad=1, device=cuda:0),
      %decoder.layer1.bn1.bias : Float(512, strides=[1], requires_grad=1, device=cuda:0),
      %decoder.layer1.bn1.running_mean : Float(512, strides=[1], requires_grad=0, device=cuda:0),
      %decoder.layer1.bn1.running_var : Float(512, strides=[1], requires_grad=0, device=cuda:0),
      %decoder.layer1.convtrans2.weight : Float(512, 512, 3, 3, strides=[4608, 9, 3, 1], requires_grad=1, device=cuda:0),
      %decoder.layer1.convtrans2.bias : Float(512, strides=[1], requires_grad=1, device=cuda:0),
      %decoder.layer1.bn2.weight : Float(512, strides=[1], requires_grad=1, devic