In [2]:
import torch
import torch.nn as nn
import onnx
import onnxruntime as rt
from onnx import helper, shape_inference

In [3]:
print(f'torch version: {torch.__version__}')
print(f'onnx version: {onnx.__version__}')
print(f'onnxruntime version: {rt.__version__}')

torch version: 1.9.0
onnx version: 1.9.0
onnxruntime version: 1.7.2


In [4]:
class TensorCache(nn.Module):
    def __init__(self, tensor):
        super(TensorCache, self).__init__()
        self.register_buffer('cache', tensor)
#         self.cache = tensor
    
    def forward(self, x):
        # assert x.size() == self.cache[:,:,0:1].size()
        cache_update = torch.cat((self.cache[:,:,1:], x.detach()), dim=2)
        self.cache[:,:,:] = cache_update
        return self.cache

In [12]:
tc = TensorCache(torch.zeros(1,1,10))
tc_script = torch.jit.script(tc)
tc_trace = torch.jit.trace(tc, torch.tensor([[[0.0]]]))

RuntimeError: outputs_[i]->uses().empty()INTERNAL ASSERT FAILED at "/opt/conda/conda-bld/pytorch_1623448238472/work/torch/csrc/jit/ir/ir.cpp":1226, please report a bug to PyTorch. 

In [6]:
print('Original:')
for inp in [1.,2.,3.,4.]:
    print(tc(torch.tensor([[[inp]]])))

print()
print('Convert using script:')
for inp in [1.,2.,3.,4.]:
    print(tc_script(torch.tensor([[[inp]]])))
    
print('')
print('Convert using trace:')
for inp in [1.,2.,3.,4.]:
    print(tc_trace(torch.tensor([[[inp]]])))

Original:
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]]])
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 1., 2.]]])
tensor([[[0., 0., 0., 0., 0., 0., 0., 1., 2., 3.]]])
tensor([[[0., 0., 0., 0., 0., 0., 1., 2., 3., 4.]]])

Convert using script:
tensor([[[0., 0., 0., 0., 0., 1., 2., 3., 4., 1.]]])
tensor([[[0., 0., 0., 0., 1., 2., 3., 4., 1., 2.]]])
tensor([[[0., 0., 0., 1., 2., 3., 4., 1., 2., 3.]]])
tensor([[[0., 0., 1., 2., 3., 4., 1., 2., 3., 4.]]])

Convert using trace:


NameError: name 'tc_trace' is not defined

In [7]:
ex_out = tc(torch.tensor([[[0.5]]]))

In [13]:
torch.onnx.export(
    tc_script,
    torch.tensor([[[0.0]]]),
    f'onnxrt_test.onnx',
    export_params=True,
    do_constant_folding=True,
    keep_initializers_as_inputs=True,
    opset_version=13,
    input_names = ['input'],
    output_names = ['output'],
    example_outputs=ex_out,
    verbose=True
)

graph(%input : Float(1, 1, 1, strides=[1, 1, 1], requires_grad=0, device=cpu),
      %cache : Float(1, 1, 10, strides=[10, 10, 1], requires_grad=0, device=cpu),
      %74 : Float(1, 1, 9, strides=[10, 10, 1], requires_grad=0, device=cpu),
      %75 : Long(3, strides=[1], requires_grad=0, device=cpu),
      %87 : Long(10, strides=[1], requires_grad=0, device=cpu),
      %88 : Long(1, 1, 1, strides=[1, 1, 1], requires_grad=0, device=cpu),
      %89 : Long(1, 1, strides=[1, 1], requires_grad=0, device=cpu),
      %92 : Long(3, strides=[1], requires_grad=0, device=cpu),
      %93 : Long(1, strides=[1], requires_grad=0, device=cpu),
      %94 : Long(1, strides=[1], requires_grad=0, device=cpu),
      %95 : Long(1, strides=[1], requires_grad=0, device=cpu),
      %98 : Long(3, strides=[1], requires_grad=0, device=cpu)):
  %7 : Float(1, 1, 10, strides=[10, 10, 1], device=cpu) = onnx::Concat[axis=2](%74, %input) # <ipython-input-4-95994830632f>:9:23
  %9 : Float(1, 1, 10, device=cpu) = onnx::E

In [14]:
onnx_model = onnx.load('onnxrt_test.onnx')
onnx.checker.check_model(onnx_model)
inferred_model = shape_inference.infer_shapes(onnx_model)
onnx.checker.check_model(inferred_model)

In [15]:
onnx_path = 'onnxrt_test.onnx'
ort_session = rt.InferenceSession(str(onnx_path))

In [16]:
def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

print('ONNX Runtime:')
for inp in [1.,2.,3.,4.]:
    ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(torch.tensor([[[inp]]]))}
    ort_outs = ort_session.run(None, ort_inputs)
    print(ort_outs[0])

ONNX Runtime:
[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]]]
[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 2.]]]
[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 3.]]]
[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 4.]]]


In [2]:
import torch
import torch.nn as nn
import onnx
import onnxruntime as rt
from onnx import helper, shape_inference

print(f'torch version: {torch.__version__}')
print(f'onnx version: {onnx.__version__}')
print(f'onnxruntime version: {rt.__version__}')

class InputCache(nn.Module):
    def __init__(self, tensor):
        super(InputCache, self).__init__()
        self.register_buffer('cache', tensor)
    
    def forward(self, x):
        # self.cache[:,:,:-1] = self.cache.clone()[:,:,1:]
        # self.cache[:,:,-1:] = x.detach()
        cache_update = torch.cat((self.cache[:,:,1:], x.detach()), dim=2)
        self.cache = cache_update
        return self.cache

tc = InputCache(torch.zeros(1,1,10))
tc_script = torch.jit.script(InputCache(torch.zeros(1,1,10)))

# Make sure torch code is doing what it should
print('Original:')
for inp in [1.,2.,3.,4.]:
    print(tc(torch.tensor([[[inp]]])))

print()
print('Convert using script:')
for inp in [1.,2.,3.,4.]:
    print(tc_script(torch.tensor([[[inp]]])))

# Reinitialize
tc = InputCache(torch.zeros(1,1,10))
tc_script = torch.jit.script(InputCache(torch.zeros(1,1,10)))

ex_out = tc(torch.tensor([[[0.5]]]))

print()
torch.onnx.export(
    tc_script,
    torch.tensor([[[0.0]]]),
    f'onnxrt_test.onnx',
    export_params=True,
    do_constant_folding=True,
    keep_initializers_as_inputs=True,
    opset_version=12,
    input_names = ['input'],
    output_names = ['output'],
    example_outputs=ex_out,
    verbose=True
)

# Use onnx tools to make sure model is valid
onnx_model = onnx.load('onnxrt_test.onnx')
onnx.checker.check_model(onnx_model)
inferred_model = shape_inference.infer_shapes(onnx_model)
onnx.checker.check_model(inferred_model)

# Run inference session
onnx_path = 'onnxrt_test.onnx'
ort_session = rt.InferenceSession(str(onnx_path))

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

print('ONNX Runtime:')
for inp in [1.,2.,3.,4.]:
    ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(torch.tensor([[[inp]]]))}
    ort_outs = ort_session.run(None, ort_inputs)
    print(ort_outs[0])

torch version: 1.9.0
onnx version: 1.9.0
onnxruntime version: 1.7.2
Original:
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]]])
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 1., 2.]]])
tensor([[[0., 0., 0., 0., 0., 0., 0., 1., 2., 3.]]])
tensor([[[0., 0., 0., 0., 0., 0., 1., 2., 3., 4.]]])

Convert using script:
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]]])
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 1., 2.]]])
tensor([[[0., 0., 0., 0., 0., 0., 0., 1., 2., 3.]]])
tensor([[[0., 0., 0., 0., 0., 0., 1., 2., 3., 4.]]])

graph(%input : Float(1, 1, 1, strides=[1, 1, 1], requires_grad=0, device=cpu),
      %8 : Float(1, 1, 9, strides=[10, 10, 1], requires_grad=0, device=cpu)):
  %output : Float(1, 1, 10, strides=[10, 10, 1], requires_grad=0, device=cpu) = onnx::Concat[axis=2](%8, %input)
  return (%output)

ONNX Runtime:
[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]]]
[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 2.]]]
[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 3.]]]
[[[0. 0. 0. 0. 0. 0. 0. 0. 0. 4.]]]
