In [1]:
import torch
import numpy as np

In [2]:
class SubPixelConv(torch.nn.Module):
    def __init__(self, in_channels, r):
        super().__init__()
        chans_per_subsection  = in_channels // r**2
        ch_bounds = list(range(0, in_channels+1, chans_per_subsection))
        self.channel_divisions = list(zip(ch_bounds[:-1], ch_bounds[1:]))

    def forward_feature_map(self, inputs):
        rows = []
        for sub_divs in self.divisions:
            rows.append(torch.cat(
                [inputs[div[0]:div[1], ...] for div in sub_divs],
                dim=1,
            ))
        return torch.cat(rows, dim=2)

    def forward(self, inputs):
        if inputs.ndim == 3:
            return self.forward_feature_map(inputs)
        elif inputs.ndim == 4:
            return torch.stack([self.forward_feature_map(input) for input in inputs]) 

In [3]:
r = 3
in_chans = 18

In [4]:
chans_per_subsection  = in_chans // r**2
ch_bounds = list(range(0, in_chans+1, chans_per_subsection))
channel_divisions = list(zip(ch_bounds[:-1], ch_bounds[1:]))

filters = []
for i in range(r**2):
    tmp = np.zeros((1, 1, r, r), np.float32)
    tmp[..., i // r, i % r] = 1.
    filters.append(tmp)

expanding_convs = []
for f in filters:
    tmp = torch.nn.ConvTranspose2d(1, 1, r, r, bias=False)
    tmp.weight = torch.nn.Parameter(
        torch.from_numpy(f),
        requires_grad=False
    )
    expanding_convs.append(tmp)

In [5]:
inputs = torch.stack([
    torch.from_numpy(np.array([[1, 0], [0, 0]], np.float32)*(i+1))
    for i in range(in_chans)
]).unsqueeze(0)

In [6]:
subset = []
for exp, div in zip(expanding_convs, channel_divisions):
    subset.append(
        torch.cat([exp(inputs[:,[c],...]) for c in range(div[0], div[1])], dim=1)
    )

In [7]:
torch.stack([s for s in subset]).sum(dim=0)

tensor([[[[ 1.,  3.,  5.,  0.,  0.,  0.],
          [ 7.,  9., 11.,  0.,  0.,  0.],
          [13., 15., 17.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.,  0.]],

         [[ 2.,  4.,  6.,  0.,  0.,  0.],
          [ 8., 10., 12.,  0.,  0.,  0.],
          [14., 16., 18.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.,  0.]]]])