## Try to write a decomposition for upsampling bilinear AA=True

In [1]:
import torch

In [2]:
n, c, in_h, in_w = 5, 3, 234, 345
out_h, out_w = 123, 234

align_corners = False

input_tensor = torch.arange(n * c * in_h * in_w, dtype=torch.float32).reshape(n, c, in_h, in_w)

In [3]:
# UpSample.h
# template <typename scalar_t>
# static inline scalar_t area_pixel_compute_scale(
#     int64_t input_size,
#     int64_t output_size,
#     bool align_corners,
#     const c10::optional<double> scale) {
#   // see Note [area_pixel_compute_scale]
#   if(align_corners) {
#     if(output_size > 1) {
#       return static_cast<scalar_t>(input_size - 1) / (output_size - 1);
#     } else {
#       return static_cast<scalar_t>(0);
#     }
#   } else {
#     return compute_scales_value<scalar_t>(scale, input_size, output_size);
#   }
# }
#
# Same as compute_scale in decompositions.py::upsample_bicubic2d_default
#
def _area_pixel_compute_scale(in_size, out_size, align_corners, scale=None):
    if align_corners:
        return (in_size - 1) / (out_size - 1) if out_size > 1 else 0
    else:
        return 1 / scale if scale is not None and scale > 0 else in_size / out_size        

In [4]:
def aa_linear_filter(x):
    x = torch.abs(x)
    return 1.0 - torch.clamp(x, max=1.0)


def _compute_indices_weights_aa(out_size, in_size, scale, interp_size, device):
    scale = _area_pixel_compute_scale(in_size, out_size, align_corners, scale=scale)
        
    support = torch.tensor((interp_size * 0.5) * scale if scale >= 1.0 else interp_size * 0.5, device=device)
    max_interp_size = torch.ceil(support).to(torch.long) * 2 + 1

    i = torch.arange(out_size, dtype=torch.long, device=device)

    center = scale * (i + 0.5)
    invscale = 1.0 / scale if scale >= 1.0 else 1.0

    # compute source indices as [xmin, xmin+1, ..., xmin+xsize-1]
    xmin = torch.clamp((center - support + 0.5).to(torch.long), min=0)
    xsize = torch.clamp((center + support + 0.5).to(torch.long), max=in_size) - xmin
    xsize = torch.clamp(xsize, 0, max_interp_size)
    
    # compute weights
    j = torch.arange(max_interp_size, dtype=torch.long, device=device).view(-1, 1)
    # TODO: use a generic function aa_filter defined for bilinear and bicubic
    weights = aa_linear_filter((j + xmin - center + 0.5) * invscale)
    weights = torch.where(j < xsize, weights, 0.0)
    total_weights = weights.sum(dim=0)
    weights = weights / total_weights

    return xmin, xsize, weights

In [5]:
from functools import partial, reduce
from typing import Callable, cast, Iterable, List, Optional, Tuple, Union
from torch import sym_float, sym_int, Tensor


def _sum_tensors(ts: Iterable[Tensor]) -> Tensor:
    return reduce(torch.add, ts)


In [9]:
def _separable_upsample_bilinear2d_aa_single_dim(in_tensor, out_size, interp_dim, align_corners, scale=None):
    # Assume that in_tensor dtype is float32
    
    assert interp_dim % 4 in (2, 3)
    
    n, c, in_h, in_w = in_tensor.shape
    interp_size = 2  # bilinear
    in_size = in_tensor.shape[interp_dim]
        
    n_idx = torch.arange(n, device=in_tensor.device).view(n, 1, 1, 1)
    c_idx = torch.arange(c, device=in_tensor.device).view(1, c, 1, 1)
    
    if interp_dim % 4 == 3:
        # horizontal pass
        xmin, xsize, weights = _compute_indices_weights_aa(out_size, in_size, scale, interp_size, device=in_tensor.device)
        in_y = torch.arange(in_h, device=in_tensor.device).view((1, 1, in_h, 1))
        xmin_idx = xmin.view(1, 1, 1, out_size)
        
        max_interp_size = len(weights)
        in_tensor_list = [in_tensor[n_idx, c_idx, in_y, torch.clamp(xmin_idx + k, max=in_w - 1)] for k in range(max_interp_size)]
        w_tensor_list = weights.unbind(dim=0)
        return _sum_tensors(in_t * w_t for in_t, w_t in zip(in_tensor_list, w_tensor_list))        
    else:
        # vertical pass
        ymin, ysize, weights = _compute_indices_weights_aa(out_size, in_size, scale, interp_size, device=in_tensor.device)

        ymin_idx = ymin.view(1, 1, out_size, 1)
        in_x = torch.arange(in_w, device=in_tensor.device).view((1, 1, 1, in_w))

        max_interp_size = len(weights)
        in_tensor_list = [in_tensor[n_idx, c_idx, torch.clamp(ymin_idx + k, max=in_h - 1), in_x] for k in range(max_interp_size)]
        w_tensor_list = weights.unsqueeze(-1).unbind(dim=0)
        return _sum_tensors(in_t * w_t for in_t, w_t in zip(in_tensor_list, w_tensor_list))        
    

In [10]:
def upsample_bilinear2d(
    input: Tensor,
    output_size: List[int],
    align_corners: bool,
    scales_h: Optional[float] = None,
    scales_w: Optional[float] = None,
) -> Tensor:

    # horizontal pass
    if output_size[1] != input.shape[-1]:
        output = _separable_upsample_bilinear2d_aa_single_dim(input, output_size[1], -1, align_corners=align_corners, scale=scales_w)
    else:
        output = input

    # vertical pass
    if output_size[0] != input.shape[-2]:
        output = _separable_upsample_bilinear2d_aa_single_dim(output, output_size[0], -2, align_corners=align_corners, scale=scales_h)

    return output    

In [11]:
output = upsample_bilinear2d(input_tensor, (out_h, out_w), align_corners=align_corners)

In [12]:
output.shape

torch.Size([5, 3, 123, 234])

In [13]:
from torch.nn import functional as F

expected = F.interpolate(input_tensor, size=(out_h, out_w), mode="bilinear", align_corners=align_corners, antialias=True)

print(expected.shape)
print(expected[0, 0, :5, :])
print(output[0, 0, :5, :])

torch.testing.assert_close(expected, output)

torch.Size([5, 3, 123, 234])
tensor([[ 225.4809,  226.8456,  228.2820,  ...,  565.9496,  567.3861,
          568.7507],
        [ 806.9192,  808.2839,  809.7203,  ..., 1147.3879, 1148.8245,
         1150.1890],
        [1459.6218, 1460.9866, 1462.4229,  ..., 1800.0905, 1801.5269,
         1802.8915],
        [2112.3245, 2113.6890, 2115.1255,  ..., 2452.7932, 2454.2295,
         2455.5938],
        [2771.9033, 2773.2686, 2774.7046,  ..., 3112.3726, 3113.8088,
         3115.1733]])
tensor([[ 225.4810,  226.8455,  228.2820,  ...,  565.9496,  567.3861,
          568.7506],
        [ 806.9191,  808.2837,  809.7202,  ..., 1147.3878, 1148.8243,
         1150.1890],
        [1459.6216, 1460.9863, 1462.4229,  ..., 1800.0906, 1801.5266,
         1802.8912],
        [2112.3242, 2113.6890, 2115.1255,  ..., 2452.7935, 2454.2292,
         2455.5938],
        [2771.9033, 2773.2681, 2774.7048,  ..., 3112.3721, 3113.8088,
         3115.1733]])


### Horizontal pass development

In [89]:
scale = _area_pixel_compute_scale(in_w, out_w, align_corners)
interp_size = 2  # bilinear

scale

6.25

In [156]:
support = torch.tensor((interp_size * 0.5) * scale if scale >= 1.0 else interp_size * 0.5, device=input_tensor.device)
max_interp_size = torch.ceil(support).to(torch.long) * 2 + 1

support, max_interp_size

(tensor(6.2500), tensor(15))

In [157]:
i = torch.arange(out_w, dtype=input_tensor.dtype, device=input_tensor.device)

In [158]:
center = scale * (i + 0.5)
invscale = 1.0 / scale if scale >= 1.0 else 1.0
center, invscale

(tensor([ 3.1250,  9.3750, 15.6250, 21.8750]), 0.16)

In [184]:
# xmin = std::max(
#         static_cast<int64_t>(center - support + 0.5 + align_corners_delta), static_cast<int64_t>(0));
# xsize = std::min(
#         static_cast<int64_t>(center + support + 0.5 + align_corners_delta), input_size) - xmin;

xmin = torch.clamp((center - support + 0.5).to(torch.long), min=0)
xsize = torch.clamp((center + support + 0.5).to(torch.long), max=in_w) - xmin
xsize = torch.clamp(xsize, 0, max_interp_size)


In [185]:
xmin, xsize

(tensor([ 0,  3,  9, 16]), tensor([ 9, 13, 13,  9]))

In [186]:
# template<typename scalar_t>
# static inline scalar_t aa_filter(scalar_t x) {
#   x = std::abs(x);
#   if (x < 1.0) {
#     return 1.0 - x;
#   }
#   return 0.0;
# }

# def aa_filter(x):
#     x = torch.abs(x)
#     return torch.where(x < 1, 1.0 - x, 0.0)

def aa_filter(x):
    x = torch.abs(x)
    return 1.0 - torch.clamp(x, max=1.0)

In [187]:
def aa_filter_scalar(x):
    x = abs(x)
    if x < 1.0:
        return 1.0 - x
    return 0.0


aa_filter_scalar(-1.1), aa_filter_scalar(-0.7), aa_filter_scalar(0.0), aa_filter_scalar(0.7), aa_filter_scalar(1.1), aa_filter_scalar(1.0)


(0.0, 0.30000000000000004, 1.0, 0.30000000000000004, 0.0, 0.0)

In [188]:
a = torch.tensor([-1.1, -0.7, 0.0, 0.7, 1.1, 1.0])

aa_filter(a)

tensor([0.0000, 0.3000, 1.0000, 0.3000, 0.0000, 0.0000])

In [189]:
xsize

tensor([ 9, 13, 13,  9])

In [201]:
j = torch.arange(max_interp_size, dtype=input_tensor.dtype, device=input_tensor.device).view(-1, 1)

In [202]:
weights = aa_filter((j + xmin - center + 0.5) * invscale)
weights = torch.where(j < xsize, weights, 0.0)

In [203]:
weights.shape, weights

(torch.Size([15, 4]),
 tensor([[0.5800, 0.0600, 0.0200, 0.1400],
         [0.7400, 0.2200, 0.1800, 0.3000],
         [0.9000, 0.3800, 0.3400, 0.4600],
         [0.9400, 0.5400, 0.5000, 0.6200],
         [0.7800, 0.7000, 0.6600, 0.7800],
         [0.6200, 0.8600, 0.8200, 0.9400],
         [0.4600, 0.9800, 0.9800, 0.9000],
         [0.3000, 0.8200, 0.8600, 0.7400],
         [0.1400, 0.6600, 0.7000, 0.5800],
         [0.0000, 0.5000, 0.5400, 0.0000],
         [0.0000, 0.3400, 0.3800, 0.0000],
         [0.0000, 0.1800, 0.2200, 0.0000],
         [0.0000, 0.0200, 0.0600, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000]]))

In [204]:
total_weights = weights.sum(dim=0)
total_weights

tensor([5.4600, 6.2600, 6.2600, 5.4600])

In [205]:
weights = weights / total_weights
weights

tensor([[0.1062, 0.0096, 0.0032, 0.0256],
        [0.1355, 0.0351, 0.0288, 0.0549],
        [0.1648, 0.0607, 0.0543, 0.0842],
        [0.1722, 0.0863, 0.0799, 0.1136],
        [0.1429, 0.1118, 0.1054, 0.1429],
        [0.1136, 0.1374, 0.1310, 0.1722],
        [0.0842, 0.1565, 0.1565, 0.1648],
        [0.0549, 0.1310, 0.1374, 0.1355],
        [0.0256, 0.1054, 0.1118, 0.1062],
        [0.0000, 0.0799, 0.0863, 0.0000],
        [0.0000, 0.0543, 0.0607, 0.0000],
        [0.0000, 0.0288, 0.0351, 0.0000],
        [0.0000, 0.0032, 0.0096, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000]])

In [195]:
xmin, xsize

(tensor([ 0,  3,  9, 16]), tensor([ 9, 13, 13,  9]))

In [196]:
n_idx = torch.arange(n, device=input_tensor.device).view(n, 1, 1, 1)
c_idx = torch.arange(c, device=input_tensor.device).view(1, c, 1, 1)
in_y = torch.arange(in_h, device=input_tensor.device).view((1, 1, in_h, 1))

In [197]:
xmin_idx = xmin.view(1, 1, 1, out_w)

In [199]:
in_tensors = [input_tensor[n_idx, c_idx, in_y, torch.clamp(xmin_idx + k, max=in_w - 1)] for k in range(max_interp_size.item())]

In [209]:
w_tensors = weights.unbind(dim=0)

In [210]:
from functools import partial, reduce
from typing import Callable, cast, Iterable, List, Optional, Tuple, Union
from torch import sym_float, sym_int, Tensor


def _sum_tensors(ts: Iterable[Tensor]) -> Tensor:
    return reduce(torch.add, ts)


In [214]:
output = _sum_tensors(in_t * w_t for (in_t, w_t) in zip(in_tensors, w_tensors))

In [215]:
output.shape

torch.Size([5, 3, 24, 4])

In [223]:
from torch.nn import functional as F

expected = F.interpolate(input_tensor, size=(out_h, out_w), mode="bilinear", align_corners=align_corners, antialias=True)

print(expected.shape)
print(expected[0, 0, :5, :])
print(output[0, 0, :5, :])

torch.testing.assert_close(expected, output)