In [1]:
import torch

## Parameters

NOTE: Only works for uneven kernel_sizes AND even strides (due to iterative requirement)

In [2]:
kernel_size = 5
stride = 2

## Setup Causal and Non-Causal Transposed Convolutions 


Padding is compute from the kernel size to keep the output size the same as the input size.

In [3]:
padding = (kernel_size - 1) // 2

Setup the non-causal transposed convolution as constant 1s and 0 bias

In [4]:
tconv1d = torch.nn.ConvTranspose1d(1, 1, kernel_size, stride=stride, padding=padding)

tconv1d.weight.data = torch.tensor([[[1.0] * kernel_size]])
tconv1d.bias.data = torch.tensor([0.0])



In [5]:

causal_padding = kernel_size - 1
tcausalconv1d = torch.nn.ConvTranspose1d(1, 1, kernel_size, stride=stride, padding=causal_padding)

tcausalconv1d.weight.data = torch.tensor([[[1.0] * kernel_size]])
tcausalconv1d.bias.data = torch.tensor([0.0])

buffer_size = (2 * padding) // stride

Test Input

In [6]:
x = torch.tensor([[[1.0, 2.0, 3.0, 4.0]]], requires_grad=True)
#x = torch.randn(1, 1, 4, requires_grad=True)

## Non-Causal Transposed Convolution

In [7]:
y = tconv1d(torch.cat([x, torch.zeros(1,1,buffer_size)], dim=-1))
y[:,:,:x.shape[-1]*stride]

tensor([[[3., 3., 6., 5., 9., 7., 7., 4.]]], grad_fn=<SliceBackward0>)

## Causal Transposed Convolution

In [8]:
buffer = torch.zeros((1, 1, buffer_size))
x_in = torch.cat([buffer, x, buffer], dim=-1)
y = tcausalconv1d(x_in)
y = y[:, :, :x.shape[-1]*stride]
y

tensor([[[1., 1., 3., 3., 6., 5., 9., 7.]]], grad_fn=<SliceBackward0>)

## Iterated Causal Convolution

In [9]:
buffer = torch.zeros((1, 1, buffer_size))

y_list = []

for t in range(x.shape[2]):
    x_in = torch.cat([buffer, x[:,:,t:t+1], torch.zeros(1,1,buffer_size)], dim=-1)
    y = tcausalconv1d(x_in)
    y = y[:,:,:stride]
    buffer = torch.cat([buffer[:,:,1:], x[:,:,t:t+1]], dim=-1)
    print(y)
    y_list.append(y)

torch.cat(y_list, dim=-1)


tensor([[[1., 1.]]], grad_fn=<SliceBackward0>)
tensor([[[3., 3.]]], grad_fn=<SliceBackward0>)
tensor([[[6., 5.]]], grad_fn=<SliceBackward0>)
tensor([[[9., 7.]]], grad_fn=<SliceBackward0>)


tensor([[[1., 1., 3., 3., 6., 5., 9., 7.]]], grad_fn=<CatBackward0>)

# Alternative implementation for special case of kernel_size = 2 * stride for strides >= 2

This case leads to well-behaved upsampling properties (e.g. avoids grid like artifacts) which is often needed in generative models.

In [10]:
kernel_size = 8
stride = 4

In [11]:
tconv1d = torch.nn.ConvTranspose1d(1, 1, kernel_size, stride=stride, padding=(kernel_size // stride) - 1)

tconv1d.weight.data = torch.tensor([[[1.0] * kernel_size]])
tconv1d.bias.data = torch.tensor([0.0])

In [12]:
tcausalconv1d = torch.nn.ConvTranspose1d(1, 1, kernel_size, stride=stride, padding=0)

tcausalconv1d.weight.data = torch.tensor([[[1.0] * kernel_size]])
tcausalconv1d.bias.data = torch.tensor([0.0])

buffer_size = (kernel_size // stride) - 1
buffer_size

1

In [13]:
y = tconv1d(x)
y = y[:,:,(kernel_size // stride) - 1:-((kernel_size // stride) - 1)]
y

tensor([[[1., 1., 3., 3., 3., 3., 5., 5., 5., 5., 7., 7., 7., 7., 4., 4.]]],
       grad_fn=<SliceBackward0>)

In [14]:
buffer = torch.zeros((1, 1, buffer_size))
x_in = torch.cat([buffer, x], dim=-1)
y = tcausalconv1d(x_in)
y = y[:,:,stride:-stride]
y

tensor([[[1., 1., 1., 1., 3., 3., 3., 3., 5., 5., 5., 5., 7., 7., 7., 7.]]],
       grad_fn=<SliceBackward0>)

In [15]:
buffer = torch.zeros((1, 1, buffer_size))

y_list = []

for t in range(x.shape[2]):
    x_in = torch.cat([buffer, x[:,:,t:t+1]], dim=-1)
    y = tcausalconv1d(x_in)
    y = y[:,:,stride:-stride]
    buffer = torch.cat([buffer[:,:,1:], x[:,:,t:t+1]], dim=-1)
    print(y)
    y_list.append(y)

torch.cat(y_list, dim=-1)

tensor([[[1., 1., 1., 1.]]], grad_fn=<SliceBackward0>)
tensor([[[3., 3., 3., 3.]]], grad_fn=<SliceBackward0>)
tensor([[[5., 5., 5., 5.]]], grad_fn=<SliceBackward0>)
tensor([[[7., 7., 7., 7.]]], grad_fn=<SliceBackward0>)


tensor([[[1., 1., 1., 1., 3., 3., 3., 3., 5., 5., 5., 5., 7., 7., 7., 7.]]],
       grad_fn=<CatBackward0>)