In [4]:
import torch
from torch import nn
import torch.onnx
import tensorflow as tf 
from onnx_tf.backend import prepare
import onnx

In [5]:
class ResBlock(nn.Module):
    def __init__(self, channels):
        super(ResBlock, self).__init__()
        
        self.resblock = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, kernel_size=3),
            nn.InstanceNorm2d(channels, affine=True),
            nn.ReLU(),
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, kernel_size=3),
            nn.InstanceNorm2d(channels, affine=True),
        )
        
    def forward(self, x):
        out = self.resblock(x)
        return out + x
    
class Upsample2d(nn.Module):
    def __init__(self, scale_factor):
        super(Upsample2d, self).__init__()
        
        self.interp = nn.functional.interpolate
        self.scale_factor = scale_factor
        
    def forward(self, x):
        x = self.interp(x, scale_factor=self.scale_factor, mode='nearest')
        return x

class MicroResNet(nn.Module):
    def __init__(self):
        super(MicroResNet, self).__init__()
        
        self.downsampler = nn.Sequential(
            nn.ReflectionPad2d(4),
            nn.Conv2d(3, 8, kernel_size=9, stride=4),
            nn.InstanceNorm2d(8, affine=True),
            nn.ReLU(),
            nn.ReflectionPad2d(1),
            nn.Conv2d(8, 16, kernel_size=3, stride=2),
            nn.InstanceNorm2d(16, affine=True),
            nn.ReLU(),
            nn.ReflectionPad2d(1),
            nn.Conv2d(16, 32, kernel_size=3, stride=2),
            nn.InstanceNorm2d(32, affine=True),
            nn.ReLU(),
        )
        
        self.residual = nn.Sequential(
            ResBlock(32),
            nn.Conv2d(32, 64, kernel_size=1, bias=False, groups=32),
            ResBlock(64),
        )
        
        self.segmentator = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(64, 16, kernel_size=3),
            nn.InstanceNorm2d(16, affine=True),
            nn.ReLU(),
            Upsample2d(scale_factor=2),
            nn.ReflectionPad2d(4),
            nn.Conv2d(16, 1, kernel_size=9),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        out = self.downsampler(x)
        out = self.residual(out)
        out = self.segmentator(out)
        return out

In [7]:
model = MicroResNet()

checkpoint = torch.load('../models/model_fast.pt')
model.load_state_dict(checkpoint)
model.eval()

MicroResNet(
  (downsampler): Sequential(
    (0): ReflectionPad2d((4, 4, 4, 4))
    (1): Conv2d(3, 8, kernel_size=(9, 9), stride=(4, 4))
    (2): InstanceNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (3): ReLU()
    (4): ReflectionPad2d((1, 1, 1, 1))
    (5): Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2))
    (6): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (7): ReLU()
    (8): ReflectionPad2d((1, 1, 1, 1))
    (9): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2))
    (10): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (11): ReLU()
  )
  (residual): Sequential(
    (0): ResBlock(
      (resblock): Sequential(
        (0): ReflectionPad2d((1, 1, 1, 1))
        (1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
        (2): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        (3): ReLU()
        (4): ReflectionPad2d((1, 1, 1

In [8]:
class MicroResNetIOS(nn.Module):
    def __init__(self):
        super(MicroResNetIOS, self).__init__()
        self.micro_res_net = model
        
    def forward(self, x):
        x = x.permute(0,3,1,2)
        out = self.micro_res_net(x)
        return out.permute(0,2,3,1)

model_ios = MicroResNetIOS()

In [9]:
with torch.no_grad():
    x = torch.randn(1, 3, 240, 320, requires_grad=True)
    torch_out = torch.onnx.export(model, x, "../models/model_fast.onnx", export_params=True)

In [10]:
onnx_model = onnx.load("../models/model_fast.onnx")  # load onnx model
tf_rep = prepare(onnx_model)  # prepare tf representation
tf_rep.export_graph("../models/model_fast.pb")  # export the model

  handler.ONNX_OP, handler.DOMAIN or "ai.onnx"))
  handler.ONNX_OP, handler.DOMAIN or "ai.onnx"))


Instructions for updating:
Colocations handled automatically by placer.


In [11]:
g = tf.GraphDef()
g.ParseFromString(open("../models/model_fast.pb", "rb").read())

# g.node
[str(idx)+" "+n.name for (idx, n) in enumerate(g.node)] #if n.name == 'PyFunc']

['0 Const',
 '1 Const_1',
 '2 Const_2',
 '3 Const_3',
 '4 Const_4',
 '5 Const_5',
 '6 Const_6',
 '7 Const_7',
 '8 Const_8',
 '9 Const_9',
 '10 Const_10',
 '11 Const_11',
 '12 Const_12',
 '13 Const_13',
 '14 Const_14',
 '15 Const_15',
 '16 Const_16',
 '17 Const_17',
 '18 Const_18',
 '19 Const_19',
 '20 Const_20',
 '21 Const_21',
 '22 Const_22',
 '23 Const_23',
 '24 Const_24',
 '25 Const_25',
 '26 Const_26',
 '27 Const_27',
 '28 Const_28',
 '29 Const_29',
 '30 Const_30',
 '31 Const_31',
 '32 Const_32',
 '33 Const_33',
 '34 Const_34',
 '35 input.1',
 '36 Const_35',
 '37 MirrorPad',
 '38 transpose/perm',
 '39 transpose',
 '40 Const_36',
 '41 Pad',
 '42 Const_37',
 '43 split/split_dim',
 '44 split',
 '45 transpose_1/perm',
 '46 transpose_1',
 '47 Const_38',
 '48 split_1/split_dim',
 '49 split_1',
 '50 convolution/dilation_rate',
 '51 convolution',
 '52 concat/concat_dim',
 '53 concat',
 '54 Add',
 '55 transpose_2/perm',
 '56 transpose_2',
 '57 Reshape/shape',
 '58 Reshape',
 '59 Reshape_1/s

In [12]:
# Tensorflow to Tensorflow Lite      IMPORTANT --input_arrays=0 --output_arrays=Sigmoid
!tflite_convert --output_file=../models/model_fast.tflite --graph_def_file=../models/model_fast.pb --input_arrays=input.1 --output_arrays=Sigmoid --inference_type=FLOAT --post_training_quantize


2019-02-21 21:05:59.890808: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2019-02-21 21:06:00.139431: I tensorflow/core/grappler/devices.cc:53] Number of eligible GPUs (core count >= 8): 0 (Note: TensorFlow was not compiled with CUDA support)
2019-02-21 21:06:00.141560: I tensorflow/core/grappler/clusters/single_machine.cc:359] Starting new session
2019-02-21 21:06:00.521412: I tensorflow/core/grappler/optimizers/meta_optimizer.cc:586] Optimization results for grappler item: graph_to_optimize
2019-02-21 21:06:00.521439: I tensorflow/core/grappler/optimizers/meta_optimizer.cc:588]   model_pruner: Graph size after: 407 nodes (-88), 516 edges (-18), time = 24.064ms.
2019-02-21 21:06:00.521448: I tensorflow/core/grappler/optimizers/meta_optimizer.cc:588]   function_optimizer: Graph size after: 407 nodes (0), 516 edges (0), time = 2.411ms.
2019-02-21 21:06:00.521455: I tensorflow/core/grapple