In [1]:
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm

In [2]:
class TensorCache(nn.Module):
    def __init__(self, tensor):
        super(TensorCache, self).__init__()
        # cache = nn.Parameter(tensor)
        # self.register_parameter('cache', cache)
        self.register_buffer('cache', tensor)
    
    def forward(self, x):
        # assert x.size() == self.cache[:,:,0:1].size()
        cache_update = torch.cat((self.cache[:,:,1:], x.detach()), dim=2)
        self.cache[:,:,:] = cache_update
        return self.cache
    
#     def forward(self):
#         return self.cache
    
    @torch.jit.export
    def zero_cache(self):
        self.cache[:,:,:] = torch.ones(self.cache.size())

In [3]:
@torch.jit.script
def update_cache(cache, x):
#     return torch.cat((cache[:,:,1:], x.detach()), dim=2)
#     cache_update = torch.cat((cache[:,:,1:], x.detach()), dim=2)
#     cache[:,:,:] = cache_update
    cache[:,:,-1] = x[:,:,0].detach()

In [4]:
tc = TensorCache(torch.zeros(1,1,10))
tc_script = torch.jit.script(tc)
tc_trace = torch.jit.trace(tc, torch.tensor([[[0]]]))


RuntimeError: outputs_[i]->uses().empty() INTERNAL ASSERT FAILED at "/pytorch/torch/csrc/jit/ir/ir.cpp":1176, please report a bug to PyTorch. 

In [268]:
print('Original:')
for inp in [1,2,3,4]:
    print(tc(torch.tensor([[[inp]]])))

print()
print('Convert using script:')
for inp in [1,2,3,4]:
    print(tc_script(torch.tensor([[[inp]]])))
    
print('')
print('Convert using trace:')
for inp in [1,2,3,4]:
    print(tc_trace(torch.tensor([[[inp]]])))

Original:
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]]])
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 1., 2.]]])
tensor([[[0., 0., 0., 0., 0., 0., 0., 1., 2., 3.]]])
tensor([[[0., 0., 0., 0., 0., 0., 1., 2., 3., 4.]]])

Convert using script:
tensor([[[0., 0., 0., 0., 0., 1., 2., 3., 4., 1.]]])
tensor([[[0., 0., 0., 0., 1., 2., 3., 4., 1., 2.]]])
tensor([[[0., 0., 0., 1., 2., 3., 4., 1., 2., 3.]]])
tensor([[[0., 0., 1., 2., 3., 4., 1., 2., 3., 4.]]])

Convert using trace:
tensor([[[0., 1., 2., 3., 4., 1., 2., 3., 4., 1.]]])
tensor([[[1., 2., 3., 4., 1., 2., 3., 4., 1., 2.]]])
tensor([[[2., 3., 4., 1., 2., 3., 4., 1., 2., 3.]]])
tensor([[[3., 4., 1., 2., 3., 4., 1., 2., 3., 4.]]])


In [416]:
class TemporalInferenceBlock(nn.Module):
# class TemporalInferenceBlock(torch.jit.ScriptModule):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, batch_size=1):
        super(TemporalInferenceBlock, self).__init__()
        self.in_ch, self.k, self.d = n_inputs, kernel_size, dilation

        self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size,
                                           stride=stride, padding=0, dilation=dilation))
        self.relu1 = nn.ReLU()

        self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size,
                                           stride=stride, padding=0, dilation=dilation))
        self.relu2 = nn.ReLU()

        self.batch_size = batch_size

        self.cache1 = torch.jit.script(TensorCache(torch.zeros(
            batch_size, 
            self.conv1.in_channels, 
            (self.conv1.kernel_size[0]-1)*self.conv1.dilation[0] + 1
            )))
        
        self.cache2 = torch.jit.script(TensorCache(torch.zeros(
            batch_size, 
            self.conv2.in_channels, 
            (self.conv2.kernel_size[0]-1)*self.conv2.dilation[0] + 1
            )))

#         self.cache1 = torch.zeros(
#             batch_size, 
#             self.conv1.in_channels, 
#             (self.conv1.kernel_size[0]-1)*self.conv1.dilation[0] + 1
#             )
            
#         self.cache2 = torch.zeros(
#             batch_size, 
#             self.conv2.in_channels, 
#             (self.conv2.kernel_size[0]-1)*self.conv2.dilation[0] + 1
#             )
        
        self.stage1 = nn.Sequential(self.conv1, self.relu1)
        self.stage2 = nn.Sequential(self.conv2, self.relu2)


        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else nn.Identity()
        self.relu = nn.ReLU()
        self.init_weights()
    
    def init_weights(self):
        self.conv1.weight.data.normal_(0, 0.01)
        self.conv2.weight.data.normal_(0, 0.01)
        if isinstance(self.downsample, nn.modules.conv.Conv1d):
            self.downsample.weight.data.normal_(0, 0.01)

    def reset_cache(self):
        self.cache1.zero_cache()
        self.cache2.zero_cache()
    
    def forward(self, x):
        '''
        x is of shape (B, CH, 1)
        '''
        # out = self.stage1(self.cache1(x)[:x.size()[0], :, :])
        # out = self.stage2(self.cache2(out)[:x.size()[0], :, :])
        # out1 = self.stage1(self.cache1(x))
        # out2 = self.stage2(self.cache2(out1))
        
        out1 = self.relu1(self.conv1(self.cache1(x)))
        out2 = self.relu2(self.conv2(self.cache2(out1)))
        # self.cache1.zero_cache()
        # out1 = self.cache1()
        # self.cache2.zero_cache()
        # out2 = self.cache2()

        res = self.downsample(x)
        out = self.relu(out2 + res)
        # print(f'\t res shape: {res.size()}')
        #         print(f'x: {x} \n c1: {self.cache1.cache} \n out1: {out1} \n c2: {self.cache2.cache} \n out2: {out2} \n \n')

        return x, self.cache1.cache, out1, self.cache2.cache, out2, res, out
    
#     @torch.jit.script_method
#     def forward(self, x):
#         '''
#         x is of shape (B, CH, 1)
#         '''
#         # out = self.stage1(self.cache1(x)[:x.size()[0], :, :])
#         # out = self.stage2(self.cache2(out)[:x.size()[0], :, :])
#         update_cache(self.cache1, x)
#         out1 = self.stage1(self.cache1)
#         update_cache(self.cache2, out1)
#         out2 = self.stage2(self.cache2)

#         res = self.downsample(x)
#         out = self.relu(out2 + res)
#         # print(f'\t res shape: {res.size()}')
#         #         print(f'x: {x} \n c1: {self.cache1.cache} \n out1: {out1} \n c2: {self.cache2.cache} \n out2: {out2} \n \n')

#         return x, self.cache1, out1, self.cache2, out2, res, out

In [417]:
tblock = TemporalInferenceBlock(1,1,7,1,1,1)
tblock.eval()
tblock.cuda()

tblock_script = torch.jit.script(tblock)
tblock_trace = torch.jit.trace(tblock, torch.tensor([[[0]]]).cuda())


In [389]:
for inp in [1,2,3,4]:
    x, c1, out1, c2, out2, res, out = tblock(torch.tensor([[[inp]]]).cuda())
    print(f'x: {x} \n c1: {c1} \n out1: {out1} \n c2: {c2} \n out2: {out2}, res: {res} \n out: {out} \n \n')

x: tensor([[[1]]], device='cuda:0') 
 c1: tensor([[[0., 0., 0., 0., 0., 0., 1.]]], device='cuda:0') 
 out1: tensor([[[0.]]], device='cuda:0', grad_fn=<ReluBackward0>) 
 c2: tensor([[[0., 0., 0., 0., 0., 0., 0.]]], device='cuda:0') 
 out2: tensor([[[0.1345]]], device='cuda:0', grad_fn=<ReluBackward0>), res: tensor([[[1]]], device='cuda:0') 
 out: tensor([[[1.1345]]], device='cuda:0', grad_fn=<ReluBackward0>) 
 

x: tensor([[[2]]], device='cuda:0') 
 c1: tensor([[[0., 0., 0., 0., 0., 1., 2.]]], device='cuda:0') 
 out1: tensor([[[0.]]], device='cuda:0', grad_fn=<ReluBackward0>) 
 c2: tensor([[[0., 0., 0., 0., 0., 0., 0.]]], device='cuda:0') 
 out2: tensor([[[0.1345]]], device='cuda:0', grad_fn=<ReluBackward0>), res: tensor([[[2]]], device='cuda:0') 
 out: tensor([[[2.1345]]], device='cuda:0', grad_fn=<ReluBackward0>) 
 

x: tensor([[[3]]], device='cuda:0') 
 c1: tensor([[[0., 0., 0., 0., 1., 2., 3.]]], device='cuda:0') 
 out1: tensor([[[0.]]], device='cuda:0', grad_fn=<ReluBackward0>) 
 c

In [390]:
for inp in [1,2,3,4]:
    x, c1, out1, c2, out2, res, out = tblock_script(torch.tensor([[[inp]]]).cuda())
    print(f'x: {x} \n c1: {c1} \n out1: {out1} \n c2: {c2} \n out2: {out2}, res: {res} \n out: {out} \n \n')

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "<ipython-input-326-ecfa1ebd866b>", line 65, in forward
        # out = self.stage1(self.cache1(x)[:x.size()[0], :, :])
        # out = self.stage2(self.cache2(out)[:x.size()[0], :, :])
        out1 = self.stage1(self.cache1(x))
               ~~~~~~~~~~~ <--- HERE
        out2 = self.stage2(self.cache2(out1))
        # self.cache1.zero_cache()
  File "/home/s1bhavsa/.local/lib/python3.7/site-packages/torch/nn/modules/container.py", line 117, in forward
    def forward(self, input):
        for module in self:
            input = module(input)
                    ~~~~~~ <--- HERE
        return input
  File "/home/s1bhavsa/.local/lib/python3.7/site-packages/torch/nn/modules/container.py", line 117, in forward
    def forward(self, input):
        for module in self:
            input = module(input)
                    ~~~~~~ <--- HERE
        return input
  File "/home/s1bhavsa/.local/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 258, in forward
                            self.weight, self.bias, self.stride,
                            _single(0), self.dilation, self.groups)
        return F.conv1d(input, self.weight, self.bias, self.stride,
               ~~~~~~~~ <--- HERE
                        self.padding, self.dilation, self.groups)
RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same


In [391]:
for inp in [1,2,3,4]:
    x, c1, out1, c2, out2, res, out = tblock_trace(torch.tensor([[[inp]]]).cuda())
    print(f'x: {x} \n c1: {c1} \n out1: {out1} \n c2: {c2} \n out2: {out2}, res: {res} \n out: {out} \n \n')

x: tensor([[[1]]], device='cuda:0') 
 c1: tensor([[[0., 1., 2., 3., 4., 1., 1.]]], device='cuda:0') 
 out1: tensor([[[0.2580]]], device='cuda:0', grad_fn=<ReluBackward0>) 
 c2: tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2580]]],
       device='cuda:0') 
 out2: tensor([[[0.0769]]], device='cuda:0', grad_fn=<ReluBackward0>), res: tensor([[[1]]], device='cuda:0') 
 out: tensor([[[1.0769]]], device='cuda:0', grad_fn=<ReluBackward0>) 
 

x: tensor([[[2]]], device='cuda:0') 
 c1: tensor([[[1., 2., 3., 4., 1., 1., 2.]]], device='cuda:0') 
 out1: tensor([[[0.5571]]], device='cuda:0', grad_fn=<ReluBackward0>) 
 c2: tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2580, 0.5571]]],
       device='cuda:0') 
 out2: tensor([[[0.]]], device='cuda:0', grad_fn=<ReluBackward0>), res: tensor([[[2]]], device='cuda:0') 
 out: tensor([[[2.]]], device='cuda:0', grad_fn=<ReluBackward0>) 
 

x: tensor([[[3]]], device='cuda:0') 
 c1: tensor([[[2., 3., 4., 1., 1., 2., 3.]]], device='cuda:0') 

In [418]:
test_out = tblock_trace(torch.tensor([[[0.0]]]).cuda())
test_out

(tensor([[[0.]]], device='cuda:0'),
 tensor([[[0., 0., 0., 0., 0., 0., 0.]]], device='cuda:0'),
 tensor([[[0.]]], device='cuda:0', grad_fn=<ReluBackward0>),
 tensor([[[0., 0., 0., 0., 0., 0., 0.]]], device='cuda:0'),
 tensor([[[0.1974]]], device='cuda:0', grad_fn=<ReluBackward0>),
 tensor([[[0.]]], device='cuda:0'),
 tensor([[[0.1974]]], device='cuda:0', grad_fn=<ReluBackward0>))

In [419]:
torch.onnx.export(
    tblock_trace,
    torch.tensor([[[0.0]]]).cuda(),
    f'tblock_test.onnx',
    export_params=True,
    do_constant_folding=True,
    keep_initializers_as_inputs=True,
    opset_version=12,
    input_names = ['input'],
    output_names = ['output'],
#     dynamic_axes={
#                  'input' : {0 : 'batch_size'}, 
#                  'output' : {0 : 'batch_size'}
#                  },
    example_outputs=test_out,
    verbose=True
)

graph(%output : Float(1:1, 1:1, 1:1, requires_grad=0, device=cuda:0),
      %1 : Float(1:1, requires_grad=1, device=cuda:0),
      %4 : Float(1:1, requires_grad=1, device=cuda:0),
      %7 : Float(1:7, 1:7, 7:1, requires_grad=0, device=cuda:0),
      %8 : Float(1:7, 1:7, 7:1, requires_grad=0, device=cuda:0),
      %169 : Float(1:7, 1:7, 6:1, requires_grad=0, device=cuda:0),
      %170 : Long(3:1, requires_grad=0, device=cpu),
      %173 : Long(1:1, requires_grad=0, device=cpu),
      %176 : Long(1:1, requires_grad=0, device=cpu),
      %179 : Long(1:1, requires_grad=0, device=cpu),
      %181 : Long(0:1, requires_grad=0, device=cpu),
      %184 : Float(1:7, 1:7, 7:1, requires_grad=1, device=cuda:0),
      %185 : Float(1:7, 1:7, 6:1, requires_grad=0, device=cuda:0),
      %186 : Long(3:1, requires_grad=0, device=cpu),
      %189 : Long(1:1, requires_grad=0, device=cpu),
      %192 : Long(1:1, requires_grad=0, device=cpu),
      %195 : Long(1:1, requires_grad=0, device=cpu),
      %197 :

In [420]:
import onnx
import onnxruntime
from onnx import helper, shape_inference

In [421]:
onnx_model = onnx.load('tblock_test.onnx')
onnx.checker.check_model(onnx_model)

In [422]:
inferred_model = shape_inference.infer_shapes(onnx_model)
onnx.checker.check_model(inferred_model)

In [423]:
onnx_path = 'tblock_test.onnx'
ort_session = onnxruntime.InferenceSession(str(onnx_path))

Fail: [ONNXRuntimeError] : 1 : FAIL : Exception during loading: /onnxruntime_src/onnxruntime/core/graph/function.cc:391 onnxruntime::FunctionImpl::FunctionImpl(const onnxruntime::Graph&, const NodeIndex&, const onnx::FunctionProto&, const onnxruntime::logging::Logger&) status.IsOK() was false. Resolve subgraph failed:Node (0x55678227b6c0_2) Op (Loop) [TypeInferenceError] Graph attribute inferencing failed: Node:0x55678227b6c0_2 Output:cond [ShapeInferenceError] Mismatch between number of source and target dimensions. Source=1 Target=0


In [357]:
dir(onnx_model.graph)

['ByteSize',
 'Clear',
 'ClearExtension',
 'ClearField',
 'CopyFrom',
 'DESCRIPTOR',
 'DiscardUnknownFields',
 'Extensions',
 'FindInitializationErrors',
 'FromString',
 'HasExtension',
 'HasField',
 'IsInitialized',
 'ListFields',
 'MergeFrom',
 'MergeFromString',
 'ParseFromString',
 'RegisterExtension',
 'SerializePartialToString',
 'SerializeToString',
 'SetInParent',
 'UnknownFields',
 'WhichOneof',
 '_CheckCalledFromGeneratedFile',
 '_SetListener',
 '__class__',
 '__deepcopy__',
 '__delattr__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__slots__',
 '__str__',
 '__subclasshook__',
 '__unicode__',
 '_extensions_by_name',
 '_extensions_by_number',
 'doc_string',
 'initializer',
 'input',
 'name',
 'node',
 'output',

In [362]:
onnx_model.graph.node

[input: "161"
input: "output"
output: "12"
name: "Concat_0"
op_type: "Concat"
attribute {
  name: "axis"
  i: 2
  type: INT
}
, input: "12"
input: "162"
output: "14"
name: "Expand_1"
op_type: "Expand"
, output: "19"
name: "Constant_2"
op_type: "Constant"
attribute {
  name: "value"
  t {
    data_type: 7
    raw_data: "\000\000\000\000\000\000\000\000"
  }
  type: TENSOR
}
, output: "20"
name: "Constant_3"
op_type: "Constant"
attribute {
  name: "value"
  t {
    data_type: 7
    raw_data: "\001\000\000\000\000\000\000\000"
  }
  type: TENSOR
}
, input: "19"
input: "165"
input: "20"
output: "21"
name: "Range_4"
op_type: "Range"
, output: "26"
name: "Constant_5"
op_type: "Constant"
attribute {
  name: "value"
  t {
    data_type: 7
    raw_data: "\000\000\000\000\000\000\000\000"
  }
  type: TENSOR
}
, output: "27"
name: "Constant_6"
op_type: "Constant"
attribute {
  name: "value"
  t {
    data_type: 7
    raw_data: "\001\000\000\000\000\000\000\000"
  }
  type: TENSOR
}
, input: "26"


torch.float32