# Sub-Pixel Convolution

An implementation of the Sub-Pixel Convolution operation.

Reference Links:
- https://arxiv.org/pdf/1609.05158.pdf

In [1]:
import torch
import numpy as np

In [2]:
DEVICE = "cuda"

In [3]:
class SubPixelConv(torch.nn.Module):
    def __init__(self, in_chans, r):
        super().__init__()
        self.in_chans = in_chans
        self.r = r

        # divide channels out
        self.chans_per_subsection  = in_chans // r**2
        ch_bounds = list(range(0, in_chans+1, self.chans_per_subsection))
        self.channel_divisions = list(zip(ch_bounds[:-1], ch_bounds[1:]))

        # build "upsampling" filters to expand that'll be used
        # to expand single input features into a rxr space
        filters = []
        for i in range(r**2):
            exp_filter = np.zeros((1, 1, r, r), np.float32)
            exp_filter[..., i // r, i % r] = 1.
            filters.append(exp_filter)
        
        self.expanding_transposed_convs = []
        for f in filters:
            exp_conv = torch.nn.ConvTranspose2d(1, 1, r, r, bias=False)
            exp_conv.weight = torch.nn.Parameter(
                torch.from_numpy(f),
                requires_grad=False
            )
            self.expanding_transposed_convs.append(exp_conv)
        self.expanding_transposed_convs = torch.nn.ModuleList(self.expanding_transposed_convs)

    def forward_feature_map(self, inputs):
        if inputs.shape[1] != self.in_chans:
            raise ValueError("input channels does not expected `in_chans`")

        # for each subset of channels, use expanding tranposed conv
        # and track all subset in list
        by_channel_division = []
        for exp, div in zip(self.expanding_transposed_convs, self.channel_divisions):
            by_channel_division.append(
                torch.cat([exp(inputs[:,[c],...]) for c in range(div[0], div[1])], dim=1)
            )
        
        # stack the mutually exclusive expanded by-channel subsets
        # then sum them up (this gathers the sub-divided up)
        return torch.stack([t for t in by_channel_division]).sum(dim=0) 

    def forward(self, inputs):
        if inputs.ndim == 3:
            # could throw error here instead :shrug:
            return self.forward_feature_map(inputs.unsqueeze(0))
        elif inputs.ndim == 4:
            return self.forward_feature_map(inputs)

In [4]:
r = 2
in_chans = 100
spc = SubPixelConv(in_chans, r)
inputs = torch.ones(1, in_chans, 256, 256)

In [5]:
spc.to(DEVICE)
inputs = inputs.to(DEVICE)

In [6]:
%%timeit
_ = spc.forward(inputs)

7.39 ms ± 23.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [7]:
%%timeit
_ = torch.nn.functional.interpolate(
    inputs,
    size = (inputs.shape[-2]*r, inputs.shape[-1]*r),
    mode = "bilinear",
    align_corners=False,
)

432 µs ± 31 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
