In [1]:
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm

In [2]:
class TensorCache(nn.Module):
    def __init__(self, tensor):
        super(TensorCache, self).__init__()
        self.register_buffer('cache', tensor)
    
    def forward(self, x):
        # assert x.size() == self.cache[:,:,0:1].size()
        cache_update = torch.cat((self.cache[:,:,1:], x.detach()), dim=2)
        self.cache = cache_update
        return self.cache

In [3]:
tc = TensorCache(torch.zeros(1,1,10))
tc_script = torch.jit.script(tc)
tc_trace = torch.jit.trace(tc, torch.tensor([[[0]]]))


In [4]:
print('Original:')
for inp in [1,2,3,4]:
    print(tc(torch.tensor([[[inp]]])))

print()
print('Convert using script:')
for inp in [1,2,3,4]:
    print(tc_script(torch.tensor([[[inp]]])))
    
print('')
print('Convert using trace:')
for inp in [1,2,3,4]:
    print(tc_trace(torch.tensor([[[inp]]])))

Original:
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]]])
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 1., 2.]]])
tensor([[[0., 0., 0., 0., 0., 0., 0., 1., 2., 3.]]])
tensor([[[0., 0., 0., 0., 0., 0., 1., 2., 3., 4.]]])

Convert using script:
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]]])
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 1., 2.]]])
tensor([[[0., 0., 0., 0., 0., 0., 0., 1., 2., 3.]]])
tensor([[[0., 0., 0., 0., 0., 0., 1., 2., 3., 4.]]])

Convert using trace:
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]]])
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 2.]]])
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 3.]]])
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 4.]]])


In [20]:
class TemporalInferenceBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, batch_size=1):
        super(TemporalInferenceBlock, self).__init__()
        self.in_ch, self.k, self.d = n_inputs, kernel_size, dilation

        self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size,
                                           stride=stride, padding=0, dilation=dilation))
        self.relu1 = nn.ReLU()

        self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size,
                                           stride=stride, padding=0, dilation=dilation))
        self.relu2 = nn.ReLU()

        self.batch_size = batch_size

        self.cache1 = torch.jit.script(TensorCache(torch.zeros(
            batch_size, 
            self.conv1.in_channels, 
            (self.conv1.kernel_size[0]-1)*self.conv1.dilation[0] + 1
            )))
        
        self.cache2 = torch.jit.script(TensorCache(torch.zeros(
            batch_size, 
            self.conv2.in_channels, 
            (self.conv2.kernel_size[0]-1)*self.conv2.dilation[0] + 1
            )))
        
        self.stage1 = nn.Sequential(self.conv1, self.relu1)
        self.stage2 = nn.Sequential(self.conv2, self.relu2)


        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else nn.Identity()
        self.relu = nn.ReLU()
        self.init_weights()
    
    def init_weights(self):
        self.conv1.weight.data.normal_(0, 0.01)
        self.conv2.weight.data.normal_(0, 0.01)
        if isinstance(self.downsample, nn.modules.conv.Conv1d):
            self.downsample.weight.data.normal_(0, 0.01)

    def reset_cache(self):
        device = next(self.parameters()).device
        self.cache1.cache = torch.zeros(self.cache1.cache.size()).to(device)
        self.cache2.cache = torch.zeros(self.cache2.cache.size()).to(device)
    
    def forward(self, x):
        '''
        x is of shape (B, CH, 1)
        '''
        # out = self.stage1(self.cache1(x)[:x.size()[0], :, :])
        # out = self.stage2(self.cache2(out)[:x.size()[0], :, :])
        out1 = self.stage1(self.cache1(x))
        out2 = self.stage2(self.cache2(out1))

        res = self.downsample(x)
        out = self.relu(out2 + res)
        # print(f'\t res shape: {res.size()}')
        #         print(f'x: {x} \n c1: {self.cache1.cache} \n out1: {out1} \n c2: {self.cache2.cache} \n out2: {out2} \n \n')

        return x, self.cache1.cache, out1, self.cache2.cache, out2, res, out

In [21]:
tblock = TemporalInferenceBlock(1,1,3,1,1,1)
tblock_script = torch.jit.script(tblock)
tblock_trace = torch.jit.trace(tblock, torch.tensor([[[0]]]))

In [22]:
for inp in [1,2,3,4]:
    x, c1, out1, c2, out2, res, out = tblock(torch.tensor([[[inp]]]))
    print(f'x: {x} \n c1: {c1} \n out1: {out1} \n c2: {c2} \n out2: {out2}, res: {res} \n out: {out} \n \n')

x: tensor([[[1]]]) 
 c1: tensor([[[0., 0., 1.]]]) 
 out1: tensor([[[0.5241]]], grad_fn=<ReluBackward0>) 
 c2: tensor([[[0.2088, 0.2088, 0.5241]]]) 
 out2: tensor([[[0.0385]]], grad_fn=<ReluBackward0>), res: tensor([[[1]]]) 
 out: tensor([[[1.0385]]], grad_fn=<ReluBackward0>) 
 

x: tensor([[[2]]]) 
 c1: tensor([[[0., 1., 2.]]]) 
 out1: tensor([[[1.0486]]], grad_fn=<ReluBackward0>) 
 c2: tensor([[[0.2088, 0.5241, 1.0486]]]) 
 out2: tensor([[[0.1257]]], grad_fn=<ReluBackward0>), res: tensor([[[2]]]) 
 out: tensor([[[2.1257]]], grad_fn=<ReluBackward0>) 
 

x: tensor([[[3]]]) 
 c1: tensor([[[1., 2., 3.]]]) 
 out1: tensor([[[1.6069]]], grad_fn=<ReluBackward0>) 
 c2: tensor([[[0.5241, 1.0486, 1.6069]]]) 
 out2: tensor([[[0.2229]]], grad_fn=<ReluBackward0>), res: tensor([[[3]]]) 
 out: tensor([[[3.2229]]], grad_fn=<ReluBackward0>) 
 

x: tensor([[[4]]]) 
 c1: tensor([[[2., 3., 4.]]]) 
 out1: tensor([[[2.1652]]], grad_fn=<ReluBackward0>) 
 c2: tensor([[[1.0486, 1.6069, 2.1652]]]) 
 out2: tenso

In [23]:
for inp in [1,2,3,4]:
    x, c1, out1, c2, out2, res, out = tblock_script(torch.tensor([[[inp]]]))
    print(f'x: {x} \n c1: {c1} \n out1: {out1} \n c2: {c2} \n out2: {out2}, res: {res} \n out: {out} \n \n')

x: tensor([[[1]]]) 
 c1: tensor([[[3., 4., 1.]]]) 
 out1: tensor([[[0.2188]]], grad_fn=<ReluBackward0>) 
 c2: tensor([[[1.6069, 2.1652, 0.2188]]]) 
 out2: tensor([[[0.]]], grad_fn=<ReluBackward0>), res: tensor([[[1]]]) 
 out: tensor([[[1.]]], grad_fn=<ReluBackward0>) 
 

x: tensor([[[2]]]) 
 c1: tensor([[[4., 1., 2.]]]) 
 out1: tensor([[[0.2243]]], grad_fn=<ReluBackward0>) 
 c2: tensor([[[2.1652, 0.2188, 0.2243]]]) 
 out2: tensor([[[0.]]], grad_fn=<DifferentiableGraphBackward>), res: tensor([[[2]]]) 
 out: tensor([[[2.]]], grad_fn=<DifferentiableGraphBackward>) 
 

x: tensor([[[3]]]) 
 c1: tensor([[[1., 2., 3.]]]) 
 out1: tensor([[[0.2335]]], grad_fn=<ReluBackward0>) 
 c2: tensor([[[0.2188, 0.2243, 0.2335]]]) 
 out2: tensor([[[0.]]], grad_fn=<DifferentiableGraphBackward>), res: tensor([[[3]]]) 
 out: tensor([[[3.]]], grad_fn=<DifferentiableGraphBackward>) 
 

x: tensor([[[4]]]) 
 c1: tensor([[[2., 3., 4.]]]) 
 out1: tensor([[[0.2418]]], grad_fn=<ReluBackward0>) 
 c2: tensor([[[0.2243, 

In [24]:
for inp in [1,2,3,4]:
    x, c1, out1, c2, out2, res, out = tblock_trace(torch.tensor([[[inp]]]))
    print(f'x: {x} \n c1: {c1} \n out1: {out1} \n c2: {c2} \n out2: {out2}, res: {res} \n out: {out} \n \n')

x: tensor([[[1]]]) 
 c1: tensor([[[3., 4., 1.]]]) 
 out1: tensor([[[1.4625]]], grad_fn=<ReluBackward0>) 
 c2: tensor([[[0.2335, 0.2418, 1.4625]]]) 
 out2: tensor([[[0.3777]]], grad_fn=<ReluBackward0>), res: tensor([[[1]]]) 
 out: tensor([[[1.3777]]], grad_fn=<ReluBackward0>) 
 

x: tensor([[[2]]]) 
 c1: tensor([[[4., 1., 2.]]]) 
 out1: tensor([[[1.1836]]], grad_fn=<ReluBackward0>) 
 c2: tensor([[[0.2418, 1.4625, 1.1836]]]) 
 out2: tensor([[[0.]]], grad_fn=<ReluBackward0>), res: tensor([[[2]]]) 
 out: tensor([[[2.]]], grad_fn=<ReluBackward0>) 
 

x: tensor([[[3]]]) 
 c1: tensor([[[1., 2., 3.]]]) 
 out1: tensor([[[1.6069]]], grad_fn=<ReluBackward0>) 
 c2: tensor([[[1.4625, 1.1836, 1.6069]]]) 
 out2: tensor([[[0.3796]]], grad_fn=<ReluBackward0>), res: tensor([[[3]]]) 
 out: tensor([[[3.3796]]], grad_fn=<ReluBackward0>) 
 

x: tensor([[[4]]]) 
 c1: tensor([[[2., 3., 4.]]]) 
 out1: tensor([[[2.1652]]], grad_fn=<ReluBackward0>) 
 c2: tensor([[[1.1836, 1.6069, 2.1652]]]) 
 out2: tensor([[[0.3