In [73]:
import torch
import torch.nn as nn
from torch2trt import torch2trt
from torch.nn.utils import weight_norm

In [65]:
class TensorCache(nn.Module):
    def __init__(self, tensor):
        super(TensorCache, self).__init__()
        # self.register_buffer('cache', tensor)
        self.cache = tensor
    
    def forward(self, x):
        # assert x.size() == self.cache[:,:,0:1].size()
        cache_update = torch.cat((self.cache[:,:,1:], x.detach()), dim=2)
        self.cache[:,:,:] = cache_update
        return self.cache

In [66]:
tc = TensorCache(torch.zeros(1,1,10))
tc_script = torch.jit.script(tc)
# tc_trace = torch.jit.trace(tc, torch.tensor([[[0.0]]))

In [67]:
print('Original:')
for inp in [1.,2.,3.,4.]:
    print(tc(torch.tensor([[[inp]]])))

print()
print('Convert using script:')
for inp in [1.,2.,3.,4.]:
    print(tc_script(torch.tensor([[[inp]]])))


# print('Original:')
# for inp in [1.,2.,3.,4.]:
#     print(tc(torch.tensor([[[inp]]])))

# print()
# print('Convert using script:')
# for inp in [1.,2.,3.,4.]:
#     print(tc_script(torch.tensor([[[inp]]])))
    
# print('')
# print('Convert using trace:')
# for inp in [1.,2.,3.,4.]:
#     print(tc_trace(torch.tensor([[[inp]]])))

Original:
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]]])
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 1., 2.]]])
tensor([[[0., 0., 0., 0., 0., 0., 0., 1., 2., 3.]]])
tensor([[[0., 0., 0., 0., 0., 0., 1., 2., 3., 4.]]])

Convert using script:
tensor([[[0., 0., 0., 0., 0., 1., 2., 3., 4., 1.]]])
tensor([[[0., 0., 0., 0., 1., 2., 3., 4., 1., 2.]]])
tensor([[[0., 0., 0., 1., 2., 3., 4., 1., 2., 3.]]])
tensor([[[0., 0., 1., 2., 3., 4., 1., 2., 3., 4.]]])


In [68]:
ex_in = torch.tensor([[[0.0]]])
ex_in.size()
tc_trt = torch2trt(tc, ex_in)

RuntimeError: torch.cat(): Tensors must have same number of dimensions: got 3 and 2

In [33]:
t = torch.zeros(10)
torch.cat((t[1:], torch.tensor([1])))

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1.])

In [76]:
class TemporalInferenceBlock(nn.Module):
# class TemporalInferenceBlock(torch.jit.ScriptModule):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, batch_size=1):
        super(TemporalInferenceBlock, self).__init__()
        self.in_ch, self.k, self.d = n_inputs, kernel_size, dilation

        self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size,
                                           stride=stride, padding=0, dilation=dilation))
        self.relu1 = nn.ReLU()

        self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size,
                                           stride=stride, padding=0, dilation=dilation))
        self.relu2 = nn.ReLU()

        self.batch_size = batch_size

#         self.cache1 = torch.jit.script(TensorCache(torch.zeros(
#             batch_size, 
#             self.conv1.in_channels, 
#             (self.conv1.kernel_size[0]-1)*self.conv1.dilation[0] + 1
#             )))
        
#         self.cache2 = torch.jit.script(TensorCache(torch.zeros(
#             batch_size, 
#             self.conv2.in_channels, 
#             (self.conv2.kernel_size[0]-1)*self.conv2.dilation[0] + 1
#             )))

        self.cache1 = torch.zeros(
            batch_size, 
            self.conv1.in_channels, 
            (self.conv1.kernel_size[0]-1)*self.conv1.dilation[0] + 1
            )
            
        self.cache2 = torch.zeros(
            batch_size, 
            self.conv2.in_channels, 
            (self.conv2.kernel_size[0]-1)*self.conv2.dilation[0] + 1
            )
        
        self.stage1 = nn.Sequential(self.conv1, self.relu1)
        self.stage2 = nn.Sequential(self.conv2, self.relu2)


        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else nn.Identity()
        self.relu = nn.ReLU()
        self.init_weights()
    
    def init_weights(self):
        self.conv1.weight.data.normal_(0, 0.01)
        self.conv2.weight.data.normal_(0, 0.01)
        if isinstance(self.downsample, nn.modules.conv.Conv1d):
            self.downsample.weight.data.normal_(0, 0.01)

    def reset_cache(self):
        self.cache1.zero_cache()
        self.cache2.zero_cache()
    
#     def forward(self, x):
#         '''
#         x is of shape (B, CH, 1)
#         '''
#         # out = self.stage1(self.cache1(x)[:x.size()[0], :, :])
#         # out = self.stage2(self.cache2(out)[:x.size()[0], :, :])
#         # out1 = self.stage1(self.cache1(x))
#         # out2 = self.stage2(self.cache2(out1))
        
#         out1 = self.relu1(self.conv1(self.cache1(x)))
#         out2 = self.relu2(self.conv2(self.cache2(out1)))
#         # self.cache1.zero_cache()
#         # out1 = self.cache1()
#         # self.cache2.zero_cache()
#         # out2 = self.cache2()

#         res = self.downsample(x)
#         out = self.relu(out2 + res)
#         # print(f'\t res shape: {res.size()}')
#         #         print(f'x: {x} \n c1: {self.cache1.cache} \n out1: {out1} \n c2: {self.cache2.cache} \n out2: {out2} \n \n')

#         return x, self.cache1.cache, out1, self.cache2.cache, out2, res, out
    
#     @torch.jit.script_method
    def forward(self, x):
        '''
        x is of shape (B, CH, 1)
        '''
        # out = self.stage1(self.cache1(x)[:x.size()[0], :, :])
        # out = self.stage2(self.cache2(out)[:x.size()[0], :, :])
        cache_update = torch.cat((self.cache1[:,:,1:], x.detach()), dim=2)
        self.cache1[:,:,:] = cache_update
        out1 = self.stage1(self.cache1)
        
        cache_update = torch.cat((self.cache2[:,:,1:], out1), dim=2)
        self.cache2[:,:,:] = cache_update
        
        out2 = self.stage2(self.cache2)

        res = self.downsample(x)
        out = self.relu(out2 + res)
        # print(f'\t res shape: {res.size()}')
        #         print(f'x: {x} \n c1: {self.cache1.cache} \n out1: {out1} \n c2: {self.cache2.cache} \n out2: {out2} \n \n')

        return x, self.cache1, out1, self.cache2, out2, res, out

In [75]:
tblock = TemporalInferenceBlock(1,1,7,1,1,1)
tblock.eval()
tblock.cuda()

tblock_script = torch.jit.script(tblock)
tblock_trace = torch.jit.trace(tblock, torch.tensor([[[0]]]).cuda())


IndentationError: unexpected indent (<unknown>, line 1)

In [71]:
print('Original:')
for inp in [1,2,3,4]:
    print(tc(torch.tensor([[[inp]]])))

print()
print('Convert using script:')
for inp in [1,2,3,4]:
    print(tc_script(torch.tensor([[[inp]]])))

Original:
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]]])
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 1., 2.]]])
tensor([[[0., 0., 0., 0., 0., 0., 0., 1., 2., 3.]]])
tensor([[[0., 0., 0., 0., 0., 0., 1., 2., 3., 4.]]])

Convert using script:
tensor([[[0., 0., 0., 0., 0., 1., 2., 3., 4., 1.]]])
tensor([[[0., 0., 0., 0., 1., 2., 3., 4., 1., 2.]]])
tensor([[[0., 0., 0., 1., 2., 3., 4., 1., 2., 3.]]])
tensor([[[0., 0., 1., 2., 3., 4., 1., 2., 3., 4.]]])
