In [1]:
import torch
from torch.nn.functional import conv1d
from torch.nn.functional import conv2d
from torch.nn.functional import conv3d
from torch.nn.functional import conv_transpose3d

# Combining 1D Convolutions

In [2]:
x = torch.rand(10).unsqueeze(0).unsqueeze(0).float()
f = torch.tensor([1,1,1]).unsqueeze(0).unsqueeze(0).float()

In [3]:
conv1d(conv1d(x, f), f)

tensor([[[4.6860, 5.4231, 5.3521, 3.8716, 2.3654, 2.1638]]])

In [4]:
conv1d(x, conv1d(f, f, padding=f.shape[2]-1))

tensor([[[4.6860, 5.4231, 5.3521, 3.8716, 2.3654, 2.1638]]])

In [5]:
torch.isclose(conv1d(conv1d(x, f), f),
              conv1d(x, conv1d(f, f, padding=f.shape[2]-1))).all()

tensor(True)

In [6]:
conv1d(f, f, padding=f.shape[2]-1)

tensor([[[1., 2., 3., 2., 1.]]])

# Combining 2D Convolutions

In [7]:
x = torch.rand(1, 1, 10,10).double()
f = torch.ones(1, 1, 3, 3).double()

In [8]:
conv2d(conv2d(x, f), f)

tensor([[[[47.7161, 45.6197, 41.2246, 39.8701, 39.9359, 42.3967],
          [46.2592, 45.2626, 39.2659, 36.7648, 35.2137, 37.0153],
          [45.5694, 45.4108, 39.5216, 34.7581, 31.5526, 33.3274],
          [45.8617, 47.7183, 43.8074, 39.3699, 33.7974, 34.8290],
          [48.3470, 52.2104, 50.7168, 46.0268, 37.6444, 35.5286],
          [50.3107, 53.9540, 53.8947, 49.7427, 39.4608, 33.1865]]]],
       dtype=torch.float64)

In [9]:
conv2d(x, conv2d(f, f, padding=2))

tensor([[[[47.7161, 45.6197, 41.2246, 39.8701, 39.9359, 42.3967],
          [46.2592, 45.2626, 39.2659, 36.7648, 35.2137, 37.0153],
          [45.5694, 45.4108, 39.5216, 34.7581, 31.5526, 33.3274],
          [45.8617, 47.7183, 43.8074, 39.3699, 33.7974, 34.8290],
          [48.3470, 52.2104, 50.7168, 46.0268, 37.6444, 35.5286],
          [50.3107, 53.9540, 53.8947, 49.7427, 39.4608, 33.1865]]]],
       dtype=torch.float64)

In [10]:
torch.isclose(conv2d(conv2d(x, f), f),
              conv2d(x, conv2d(f, f, padding=2))).all()

tensor(True)

In [11]:
conv2d(f, f, padding=f.shape[2]-1)

tensor([[[[1., 2., 3., 2., 1.],
          [2., 4., 6., 4., 2.],
          [3., 6., 9., 6., 3.],
          [2., 4., 6., 4., 2.],
          [1., 2., 3., 2., 1.]]]], dtype=torch.float64)

# Combining 3D Convolutions

In [12]:
x = torch.rand(1, 1, 6, 6, 6).float()
f = torch.ones(1, 1, 3, 3, 3).float()

In [13]:
conv3d(conv3d(x, f), f)

tensor([[[[[322.2564, 337.0458],
           [319.8718, 341.2558]],

          [[300.5420, 317.6799],
           [303.3722, 323.1547]]]]])

In [14]:
conv3d(x, conv3d(f, f, padding=f.shape[2]-1))

tensor([[[[[322.2565, 337.0458],
           [319.8718, 341.2557]],

          [[300.5420, 317.6798],
           [303.3721, 323.1546]]]]])

In [15]:
torch.isclose(conv3d(conv3d(x, f), f),
              conv3d(x, conv3d(f, f, padding=f.shape[2]-1))).all()

tensor(True)

In [16]:
conv3d(f, f, padding=f.shape[2]-1).shape

torch.Size([1, 1, 5, 5, 5])

# Combining 3d conv and deconv

In [17]:
x = torch.rand(1, 2, 4, 4, 4).float()
f = torch.rand(2, 1, 2, 2, 2).float()

In [18]:
conv3d(conv_transpose3d(x, f), f)

tensor([[[[[ 6.7947,  9.3210,  8.7730,  5.7996],
           [ 6.6157, 10.5132, 10.2960,  5.9712],
           [ 8.3911, 10.6076,  8.9293,  5.5777],
           [ 8.6277, 10.9423,  8.8907,  5.8671]],

          [[ 7.6218,  9.6929,  9.1714,  6.3557],
           [11.0009, 14.5086, 12.7852,  8.4129],
           [13.5989, 17.7788, 16.5410, 10.8110],
           [10.7948, 14.7808, 14.5563,  9.8212]],

          [[ 9.7662, 10.2473,  8.3275,  5.9850],
           [13.6592, 16.9809, 16.4182, 11.1959],
           [13.9259, 18.3572, 19.3461, 13.1436],
           [ 9.0436, 12.7738, 14.1969,  9.3066]],

          [[ 8.0103,  8.8138,  7.8530,  5.8490],
           [10.3978, 12.7006, 12.1097,  8.2535],
           [ 8.9190, 12.1912, 13.5279,  8.0708],
           [ 4.9424,  7.1872,  9.5022,  6.9289]]],


         [[[ 4.9157,  8.5273,  8.5266,  6.6892],
           [ 4.6624,  9.7062, 10.6973,  7.6970],
           [ 5.9475,  9.8349, 10.3466,  6.8353],
           [ 5.4772,  9.6031,  8.9292,  6.4148]],

        

In [19]:
combined_f = conv3d(f, f, padding=1)
conv_transpose3d(x, combined_f, padding=1)

tensor([[[[[ 6.7947,  9.3210,  8.7730,  5.7996],
           [ 6.6157, 10.5132, 10.2960,  5.9712],
           [ 8.3911, 10.6076,  8.9293,  5.5777],
           [ 8.6277, 10.9423,  8.8907,  5.8671]],

          [[ 7.6218,  9.6929,  9.1714,  6.3557],
           [11.0009, 14.5086, 12.7852,  8.4129],
           [13.5989, 17.7789, 16.5410, 10.8110],
           [10.7948, 14.7808, 14.5563,  9.8212]],

          [[ 9.7662, 10.2473,  8.3276,  5.9850],
           [13.6592, 16.9809, 16.4182, 11.1959],
           [13.9259, 18.3572, 19.3461, 13.1436],
           [ 9.0436, 12.7737, 14.1969,  9.3066]],

          [[ 8.0103,  8.8138,  7.8530,  5.8490],
           [10.3978, 12.7006, 12.1097,  8.2535],
           [ 8.9190, 12.1912, 13.5279,  8.0708],
           [ 4.9424,  7.1872,  9.5022,  6.9289]]],


         [[[ 4.9157,  8.5273,  8.5266,  6.6892],
           [ 4.6624,  9.7062, 10.6973,  7.6970],
           [ 5.9475,  9.8349, 10.3466,  6.8353],
           [ 5.4772,  9.6031,  8.9292,  6.4148]],

        

In [20]:
torch.isclose(conv3d(conv_transpose3d(x, f), f), conv_transpose3d(x, combined_f, padding=1)).all()

tensor(True)

# Timing combined 3D conv and deconv

In [22]:
x = torch.rand(1, 1, 20, 20, 20).float()
f = torch.rand(1, 1, 8, 8, 8).float()

In [23]:
%timeit -n3 -r10 conv3d(conv_transpose3d(x, f), f)

176 ms ± 45.1 ms per loop (mean ± std. dev. of 10 runs, 3 loops each)


In [24]:
combined_f = conv3d(f, f, padding=f.shape[3]-1)

In [25]:
%timeit -n3 -r10 conv_transpose3d(x, combined_f, padding=f.shape[3]-1)

99.9 ms ± 2.6 ms per loop (mean ± std. dev. of 10 runs, 3 loops each)


# Timing combined 3d conv and deconv with many filters

In [26]:
x = torch.rand(1, 1, 4, 400, 700).float()
f = torch.rand(10, 1, 4, 8, 8).float()
act = conv3d(x, f)

In [27]:
act.shape

torch.Size([1, 10, 1, 393, 693])

In [28]:
%timeit -n3 -r10 conv3d(conv_transpose3d(act, f), f)

407 ms ± 75.9 ms per loop (mean ± std. dev. of 10 runs, 3 loops each)


In [29]:
conv3d(conv_transpose3d(act, f), f).shape

torch.Size([1, 10, 1, 393, 693])

In [30]:
combined_f = conv3d(f, f, padding=f.shape[3]-1)
combined_f.shape

torch.Size([10, 10, 15, 15, 15])

In [31]:
%timeit -n3 -r10 conv_transpose3d(act, combined_f, padding=f.shape[3]-1)

11.2 s ± 535 ms per loop (mean ± std. dev. of 10 runs, 3 loops each)
Compiler time: 0.40 s


In [32]:
conv_transpose3d(act, combined_f, padding=f.shape[3]-1).shape

torch.Size([1, 10, 1, 393, 693])

In [33]:
torch.isclose(conv3d(conv_transpose3d(act, f), f), conv_transpose3d(act, combined_f, padding=f.shape[3]-1)).all()

tensor(True)