In [1]:
import torch
from torch import nn
import torch.onnx
import tensorflow as tf 
from onnx_tf.backend import prepare



In [2]:
class ResBlock(nn.Module):
    def __init__(self, channels):
        super(ResBlock, self).__init__()
        
        self.resblock = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.InstanceNorm2d(channels, affine=True),
            nn.ReLU(),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.InstanceNorm2d(channels, affine=True),
        )
        
    def forward(self, x):
        out = self.resblock(x)
        return out + x

class MicroResNet(nn.Module):
    def __init__(self):
        super(MicroResNet, self).__init__()
        
        self.downsampler = nn.Sequential(
            nn.Conv2d(3, 8, kernel_size=9, padding=4),
            nn.InstanceNorm2d(8, affine=True),
            nn.ReLU(),
            nn.Conv2d(8, 16, kernel_size=3, padding=1, stride=2),
            nn.InstanceNorm2d(16, affine=True),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, padding=1, stride=4),
            nn.InstanceNorm2d(32, affine=True),
            nn.ReLU()
        )
        
        self.residual = nn.Sequential(
            ResBlock(32),
            ResBlock(32)
        )
        
        self.segmentator = nn.Sequential(
            nn.Conv2d(32, 16, kernel_size=3, padding=1),
            nn.InstanceNorm2d(16, affine=True),
            nn.ReLU(),
            nn.Conv2d(16, 1, kernel_size=9, padding=4),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        out = self.downsampler(x)
        for i in range(2): # recursively use the same simple block 2 times
            out = self.residual(out)
        out = self.segmentator(out)
        return out

In [3]:
model = MicroResNet()

checkpoint = torch.load('saliency_model_v4.pt')
model.load_state_dict(checkpoint)
model.eval()

MicroResNet(
  (downsampler): Sequential(
    (0): Conv2d(3, 8, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
    (1): InstanceNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (2): ReLU()
    (3): Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (5): ReLU()
    (6): Conv2d(16, 32, kernel_size=(3, 3), stride=(4, 4), padding=(1, 1))
    (7): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (8): ReLU()
  )
  (residual): Sequential(
    (0): ResBlock(
      (resblock): Sequential(
        (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        (2): ReLU()
        (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): InstanceNorm2d(32, eps=1e-05, momentum=0.1,

In [20]:
with torch.no_grad():
    x = torch.randn(1, 3, 240, 320, requires_grad=True)
    torch_out = torch.onnx._export(model, x, "saliency.onnx", export_params=True)

In [47]:
onnx_model = onnx.load("saliency.onnx")  # load onnx model
tf_rep = prepare(onnx_model)  # prepare tf representation
tf_rep.export_graph("saliency.pb")  # export the model

In [4]:
g = tf.GraphDef()
g.ParseFromString(open("saliency.pb", "rb").read())

# g.node
[n for n in g.node if n.name == '0']

[name: "0"
 op: "Placeholder"
 attr {
   key: "dtype"
   value {
     type: DT_FLOAT
   }
 }
 attr {
   key: "shape"
   value {
     shape {
       dim {
         size: 1
       }
       dim {
         size: 3
       }
       dim {
         size: 240
       }
       dim {
         size: 320
       }
     }
   }
 }]

In [13]:
# IMPORTANT --input_arrays=0 --output_arrays=Sigmoid
!tflite_convert --output_file=saliency.tflite --graph_def_file=saliency.pb --input_arrays=0 --output_arrays=Sigmoid


Traceback (most recent call last):
  File "/usr/local/bin/tflite_convert", line 6, in <module>
    from tensorflow.lite.python.tflite_convert import main
  File "/usr/local/lib/python3.6/site-packages/tensorflow/__init__.py", line 34, in <module>
    from tensorflow._api.v1 import compat
  File "/usr/local/lib/python3.6/site-packages/tensorflow/_api/v1/compat/__init__.py", line 21, in <module>
    from tensorflow._api.v1.compat import v1
  File "/usr/local/lib/python3.6/site-packages/tensorflow/_api/v1/compat/v1/__init__.py", line 71, in <module>
    from tensorflow._api.v1.compat.v1 import test
  File "/usr/local/lib/python3.6/site-packages/tensorflow/_api/v1/compat/v1/test/__init__.py", line 24, in <module>
    from tensorflow.python.platform.googletest import mock
ImportError: cannot import name 'mock'
