In [1]:
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm

In [2]:
class TensorCache(nn.Module):
    def __init__(self, tensor):
        super(TensorCache, self).__init__()
        self.register_buffer('cache', tensor)
    
    def forward(self, x):
        # assert x.size() == self.cache[:,:,0:1].size()
        cache_update = torch.cat((self.cache[:,:,1:], x.detach()), dim=2)
        self.cache = cache_update
        return self.cache

In [3]:
tc = TensorCache(torch.zeros(1,1,10))
tc_script = torch.jit.script(tc)
tc_trace = torch.jit.trace(tc, torch.tensor([[[0]]]))


In [4]:
print('Original:')
for inp in [1,2,3,4]:
    print(tc(torch.tensor([[[inp]]])))

print()
print('Convert using script:')
for inp in [1,2,3,4]:
    print(tc_script(torch.tensor([[[inp]]])))
    
print('')
print('Convert using trace:')
for inp in [1,2,3,4]:
    print(tc_trace(torch.tensor([[[inp]]])))

Original:
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]]])
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 1., 2.]]])
tensor([[[0., 0., 0., 0., 0., 0., 0., 1., 2., 3.]]])
tensor([[[0., 0., 0., 0., 0., 0., 1., 2., 3., 4.]]])

Convert using script:
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]]])
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 1., 2.]]])
tensor([[[0., 0., 0., 0., 0., 0., 0., 1., 2., 3.]]])
tensor([[[0., 0., 0., 0., 0., 0., 1., 2., 3., 4.]]])

Convert using trace:
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]]])
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 2.]]])
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 3.]]])
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 4.]]])


In [19]:
class TemporalInferenceBlock(nn.Module):
    def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, batch_size=1):
        super(TemporalInferenceBlock, self).__init__()
        self.in_ch, self.k, self.d = n_inputs, kernel_size, dilation

        self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size,
                                           stride=stride, padding=0, dilation=dilation))
        self.relu1 = nn.ReLU()

        self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size,
                                           stride=stride, padding=0, dilation=dilation))
        self.relu2 = nn.ReLU()

        self.batch_size = batch_size

        self.cache1 = TensorCache(torch.zeros(
            batch_size, 
            self.conv1.in_channels, 
            (self.conv1.kernel_size[0]-1)*self.conv1.dilation[0] + 1
            ))
        
        self.cache2 = TensorCache(torch.zeros(
            batch_size, 
            self.conv2.in_channels, 
            (self.conv2.kernel_size[0]-1)*self.conv2.dilation[0] + 1
            ))
        
        self.stage1 = nn.Sequential(self.conv1, self.relu1)
        self.stage2 = nn.Sequential(self.conv2, self.relu2)


        self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else nn.Identity()
        self.relu = nn.ReLU()
        self.init_weights()
    
    def init_weights(self):
        self.conv1.weight.data.normal_(0, 0.01)
        self.conv2.weight.data.normal_(0, 0.01)
        if isinstance(self.downsample, nn.modules.conv.Conv1d):
            self.downsample.weight.data.normal_(0, 0.01)

    def reset_cache(self):
        device = next(self.parameters()).device
        self.cache1.cache = torch.zeros(self.cache1.cache.size()).to(device)
        self.cache2.cache = torch.zeros(self.cache2.cache.size()).to(device)
    
    def forward(self, x):
        '''
        x is of shape (B, CH, 1)
        '''
        # out = self.stage1(self.cache1(x)[:x.size()[0], :, :])
        # out = self.stage2(self.cache2(out)[:x.size()[0], :, :])
        out1 = self.stage1(self.cache1(x))
        out2 = self.stage2(self.cache2(out1))

        res = self.downsample(x)
        out = self.relu(out2 + res)
        # print(f'\t res shape: {res.size()}')
        #         print(f'x: {x} \n c1: {self.cache1.cache} \n out1: {out1} \n c2: {self.cache2.cache} \n out2: {out2} \n \n')

        return x, self.cache1.cache, out1, self.cache2.cache, out2, res, out

RuntimeError: 
Tried to access nonexistent attribute or method 'stage1' of type 'Tensor (inferred)'.:
  File "<ipython-input-19-c5fa523e36a1>", line 54
        # out = self.stage1(self.cache1(x)[:x.size()[0], :, :])
        # out = self.stage2(self.cache2(out)[:x.size()[0], :, :])
        out1 = self.stage1(self.cache1(x))
               ~~~~~~~~~~~ <--- HERE
        out2 = self.stage2(self.cache2(out1))


In [15]:
tblock = TemporalInferenceBlock(1,1,3,1,1,1)
tblock_script = torch.jit.script(tblock)
tblock_trace = torch.jit.trace(tblock, torch.tensor([[[0]]]))

With rtol=1e-05 and atol=1e-05, found 2 element(s) (out of 3) whose difference(s) exceeded the margin of error (including 0 nan comparisons). The greatest difference was 0.5023746490478516 (0.0 vs. 0.5023746490478516), which occurred at index (0, 0, 0).
  _module_class,


In [16]:
for inp in [1,2,3,4]:
    x, c1, out1, c2, out2, res, out = tblock(torch.tensor([[[inp]]]))
    print(f'x: {x} \n c1: {c1} \n out1: {out1} \n c2: {c2} \n out2: {out2}, res: {res} \n out: {out} \n \n')

x: tensor([[[1]]]) 
 c1: tensor([[[0., 0., 1.]]]) 
 out1: tensor([[[0.8534]]], grad_fn=<ReluBackward0>) 
 c2: tensor([[[0.5024, 0.5024, 0.8534]]]) 
 out2: tensor([[[0.]]], grad_fn=<ReluBackward0>), res: tensor([[[1]]]) 
 out: tensor([[[1.]]], grad_fn=<ReluBackward0>) 
 

x: tensor([[[2]]]) 
 c1: tensor([[[0., 1., 2.]]]) 
 out1: tensor([[[1.4260]]], grad_fn=<ReluBackward0>) 
 c2: tensor([[[0.5024, 0.8534, 1.4260]]]) 
 out2: tensor([[[0.]]], grad_fn=<ReluBackward0>), res: tensor([[[2]]]) 
 out: tensor([[[2.]]], grad_fn=<ReluBackward0>) 
 

x: tensor([[[3]]]) 
 c1: tensor([[[1., 2., 3.]]]) 
 out1: tensor([[[2.0169]]], grad_fn=<ReluBackward0>) 
 c2: tensor([[[0.8534, 1.4260, 2.0169]]]) 
 out2: tensor([[[0.]]], grad_fn=<ReluBackward0>), res: tensor([[[3]]]) 
 out: tensor([[[3.]]], grad_fn=<ReluBackward0>) 
 

x: tensor([[[4]]]) 
 c1: tensor([[[2., 3., 4.]]]) 
 out1: tensor([[[2.6079]]], grad_fn=<ReluBackward0>) 
 c2: tensor([[[1.4260, 2.0169, 2.6079]]]) 
 out2: tensor([[[0.]]], grad_fn=<Rel

In [17]:
for inp in [1,2,3,4]:
    x, c1, out1, c2, out2, res, out = tblock_script(torch.tensor([[[inp]]]))
    print(f'x: {x} \n c1: {c1} \n out1: {out1} \n c2: {c2} \n out2: {out2}, res: {res} \n out: {out} \n \n')

x: tensor([[[1]]]) 
 c1: tensor([[[0., 0., 1.]]]) 
 out1: tensor([[[0.5109]]], grad_fn=<ReluBackward0>) 
 c2: tensor([[[0.0000, 0.0000, 0.5109]]]) 
 out2: tensor([[[0.1303]]], grad_fn=<ReluBackward0>), res: tensor([[[1]]]) 
 out: tensor([[[1.1303]]], grad_fn=<ReluBackward0>) 
 

x: tensor([[[2]]]) 
 c1: tensor([[[0., 1., 2.]]]) 
 out1: tensor([[[0.5139]]], grad_fn=<ReluBackward0>) 
 c2: tensor([[[0.0000, 0.5109, 0.5139]]]) 
 out2: tensor([[[0.1276]]], grad_fn=<DifferentiableGraphBackward>), res: tensor([[[2]]]) 
 out: tensor([[[2.1276]]], grad_fn=<DifferentiableGraphBackward>) 
 

x: tensor([[[3]]]) 
 c1: tensor([[[1., 2., 3.]]]) 
 out1: tensor([[[0.5055]]], grad_fn=<ReluBackward0>) 
 c2: tensor([[[0.5109, 0.5139, 0.5055]]]) 
 out2: tensor([[[0.1293]]], grad_fn=<DifferentiableGraphBackward>), res: tensor([[[3]]]) 
 out: tensor([[[3.1293]]], grad_fn=<DifferentiableGraphBackward>) 
 

x: tensor([[[4]]]) 
 c1: tensor([[[2., 3., 4.]]]) 
 out1: tensor([[[0.4972]]], grad_fn=<ReluBackward0>) 

In [18]:
for inp in [1,2,3,4]:
    x, c1, out1, c2, out2, res, out = tblock_trace(torch.tensor([[[inp]]]))
    print(f'x: {x} \n c1: {c1} \n out1: {out1} \n c2: {c2} \n out2: {out2}, res: {res} \n out: {out} \n \n')

x: tensor([[[1]]]) 
 c1: tensor([[[0., 0., 1.]]]) 
 out1: tensor([[[0.8534]]], grad_fn=<ReluBackward0>) 
 c2: tensor([[[0.0000, 0.0000, 0.8534]]]) 
 out2: tensor([[[0.]]], grad_fn=<ReluBackward0>), res: tensor([[[1]]]) 
 out: tensor([[[1.]]], grad_fn=<ReluBackward0>) 
 

x: tensor([[[2]]]) 
 c1: tensor([[[0., 0., 2.]]]) 
 out1: tensor([[[1.2045]]], grad_fn=<ReluBackward0>) 
 c2: tensor([[[0.0000, 0.0000, 1.2045]]]) 
 out2: tensor([[[0.]]], grad_fn=<ReluBackward0>), res: tensor([[[2]]]) 
 out: tensor([[[2.]]], grad_fn=<ReluBackward0>) 
 

x: tensor([[[3]]]) 
 c1: tensor([[[0., 0., 3.]]]) 
 out1: tensor([[[1.5556]]], grad_fn=<ReluBackward0>) 
 c2: tensor([[[0.0000, 0.0000, 1.5556]]]) 
 out2: tensor([[[0.]]], grad_fn=<ReluBackward0>), res: tensor([[[3]]]) 
 out: tensor([[[3.]]], grad_fn=<ReluBackward0>) 
 

x: tensor([[[4]]]) 
 c1: tensor([[[0., 0., 4.]]]) 
 out1: tensor([[[1.9066]]], grad_fn=<ReluBackward0>) 
 c2: tensor([[[0.0000, 0.0000, 1.9066]]]) 
 out2: tensor([[[0.]]], grad_fn=<Rel