<a href="https://colab.research.google.com/github/rdspring1/Autopilot-TensorFlow/blob/master/primtorch_broadcast.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# PrimTorch and broadcasting
Broadcasting allows for example elementwise binary operations on arrays of different sizes, such as adding a vector to each column of a matrix.
PyTorch Eager supports implicit broadcasting and expansion of tensors for multi-tensor operations mimicking NumPy's behavior.

PyTorch's docs: https://pytorch.org/docs/stable/notes/broadcasting.html

NumPy's docs: https://numpy.org/doc/stable/user/basics.broadcasting.html

For PrimTorch we would like to avoid the situation when accelerating backends need to reimplement broadcasting rules themselves.

Initially, `broadcast_in_dim` was added to PrimTorch as the general broadcast+expand operation following JAX's specification: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.broadcast_in_dim.html

One problem with JAX is that its first backend, XLA, supports only static shapes and it was a necessary requirement from XLA to have concrete output shape of an expanded tensor:

In [None]:
# Let's print JAX graph with baked-in output shapes
import jax
from jax import make_jaxpr
import jax.numpy as jnp
a = jnp.ones((4, 1))
b = jnp.ones((3,))
print(make_jaxpr(lambda a, b: a + b)(a, b))

{ lambda ; a:f32[4,1] b:f32[3]. let
    c:f32[1,3] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 3)] b
    d:f32[4,3] = add a c
  in (d,) }


Now we inherit the same problem with baked into the graph concrete output shape:

In [None]:
import torch
print(torch.__version__)
from torch.fx.experimental.proxy_tensor import make_fx
import torch._refs
a = torch.ones((4, 1))
b = torch.ones((3,))
print(make_fx(lambda a, b: torch._refs.add(a, b))(a, b).graph)

1.12.1+cu113
graph():
    %a_1 : [#users=1] = placeholder[target=a_1]
    %b_1 : [#users=1] = placeholder[target=b_1]
    %broadcast_in_dim : [#users=1] = call_function[target=torch.ops.prims.broadcast_in_dim](args = (%a_1, [4, 3], [0, 1]), kwargs = {})
    %broadcast_in_dim_1 : [#users=1] = call_function[target=torch.ops.prims.broadcast_in_dim](args = (%b_1, [4, 3], [1]), kwargs = {})
    %add : [#users=1] = call_function[target=torch.ops.prims.add](args = (%broadcast_in_dim, %broadcast_in_dim_1), kwargs = {})
    return add


When we try to add a (3, 3) matrix and a (3,) vector broadcasting involves expanding the shape (torch.unsqueeze) and then expanding the broadcasted dimensions (torch.expand) to a larger size.

Vector's shape transformations:
(3,) -> unsqueeze -> (1, 3) -> expand -> (3, 3)

In [None]:
a = torch.randn(4, 1, device='cuda')
b = torch.randn(3, device='cuda')

In [None]:
# Broadcast and expand is implicit in PyTorch Eager
a + b

tensor([[-0.8126,  0.9479, -0.5042],
        [-0.8294,  0.9311, -0.5210],
        [-1.9368, -0.1763, -1.6283],
        [ 0.3101,  2.0706,  0.6186]], device='cuda:0')

In [None]:
# PrimTorch doesn't allow non-matching dims (implicit broadcasting)
try:
    torch._prims.add(a, b)
except RuntimeError as e:
    print(e)

Shape torch.Size([3]) is not the expected shape torch.Size([4, 1])!


In [None]:
# PrimTorch also doesn't allow implicit expand
try:
    torch._prims.add(a, b.unsqueeze(0))
except RuntimeError as e:
    print(e)

Shape torch.Size([1, 3]) is not the expected shape torch.Size([4, 1])!


nvFuser supports implicit expand (or "stretch" in NumPy's broadcasting page). We would like to modify PrimTorch's broadcasting primitive to have something that allows implicit expand of tensors since it enables support of dynamic shapes.

![image.png](https://numpy.org/doc/stable/_images/broadcasting_4.png)

Here's an example showing that implicit expand is well supported in nvFuser:

In [None]:
from torch._C._nvfuser import Fusion, FusionDefinition

fusion1 = Fusion()

with FusionDefinition(fusion1) as fd :
    t0 = fd.define_tensor(sizes=a.shape, strides=a.stride())
    t1 = fd.define_tensor(sizes=b.unsqueeze(0).shape, strides=b.unsqueeze(0).stride())

    fd.add_input(t0)
    fd.add_input(t1)

    t2 = fd.Ops.add(t0, t1)

    fd.add_output(t2)

# This doesn't print to the cell output
# Check Runtime > View Runtime Logs
fusion1.print_ir()

# Execute Fusion
print(fusion1.execute([a, b.unsqueeze(0)]))

[tensor([[-0.8126,  0.9479, -0.5042],
        [-0.8294,  0.9311, -0.5210],
        [-1.9368, -0.1763, -1.6283],
        [ 0.3101,  2.0706,  0.6186]], device='cuda:0')]


Now let's do one small modification to allow implicit expand for PrimTorch:

In [None]:
def is_same_shape(a, b, allow_one_size=True) -> bool:
    """
    Compares two shapes a and b, returning True if they are the same
    (their ranks and corresponding lengths match) and False otherwise.
    """
    if allow_one_size:
        return all((x == y or x == 1 or y == 1 for x, y in zip(a, b)))
    return tuple(a) == tuple(b)

In [None]:
# Monkey-patching
torch._prims.utils.is_same_shape = is_same_shape

In [None]:
# Now implicit expand is allowed
torch._prims.add(a, b.unsqueeze(0))

tensor([[-0.8126,  0.9479, -0.5042],
        [-0.8294,  0.9311, -0.5210],
        [-1.9368, -0.1763, -1.6283],
        [ 0.3101,  2.0706,  0.6186]], device='cuda:0')

\`broadcast_in_dim` can be split into special unsqueeze and normal expand operations. Let's define primitives for that. We call this special unsqueeze "broadcast". Given a list of bools that indicate whether the resulting tensor's dimension should be broadcasted or not. Effectively it inserts new size 1 dimensions at the specified positions.

In [None]:
prim_impl = torch._prims.prim_impl
torch._prims.prim.define("expand(Tensor input, int[] shape) -> Tensor")

'expand'

In [None]:
def broadcast(a, broadcast_dimensions):
    """Broadcasts the tensor to the given dimensions.
       `broadcast_dimensions` is a sequence of bools indicating whether given dimension of the result is a broadcasted dimension.
    """
    # We must have same number of False entries as a.ndim
    assert a.ndim == len([x for x in broadcast_dimensions if not x])
    for idx, is_broadcast_dim in enumerate(broadcast_dimensions):
        if is_broadcast_dim:
            a = a.unsqueeze(idx)
    return a

torch._prims.prim.define("broadcast(Tensor input, bool[] broadcast_dimensions) -> Tensor")
torch._prims.prim_impl.impl("broadcast", broadcast)

In [None]:
torch._prims.prim_impl.impl("expand", lambda inp, shape: inp.expand(shape))

Here we define a new `torch._refs._maybe_broadcast` function that can be lowered into primitives that are directly mappable to nvFuser's broadcast.

In [None]:
from functools import reduce

# This a new variant of "broadcast_in_dim".
# We don't necessarily need this. In the end all we'd like to see is explicit call to "prims.broadcast"
# in the graph
def refs_broadcast_in_dim(a, broadcast_dimensions, ndim, shape=None):
    """
    Similar to jax.lax.broadcast_in_dim but now the expand part is optional.
    We only unsqueeze to required `ndim` using `broadcast_dimensions`
    """
    if shape is not None:
        assert ndim == len(shape)

    is_broadcast_dims = [True] * ndim
    for broadcast_dimension in broadcast_dimensions:
        is_broadcast_dims[broadcast_dimension] = False

    a = torch.ops.prims.broadcast(a, is_broadcast_dims)

    if shape is not None:
        a = torch.ops.prims.expand(a, shape)
    return a

def _maybe_broadcast(*args, preserve_cpu_scalar_tensors=True):
    # Copied from torch._refs._maybe_broadcast
    common_shape = torch._refs._broadcast_shapes(
        *map(lambda t: t.shape if isinstance(t, torch._prims.utils.TensorLike) else None, args)
    )

    def __maybe_broadcast(x, shape):
        if x is None:
            return None
        elif isinstance(x, torch._prims.utils.Number):
            return x
        elif isinstance(x, torch._prims.utils.TensorLike):
            if preserve_cpu_scalar_tensors and torch._prims.utils.is_cpu_scalar_tensor(x):
                return x

            if tuple(x.shape) != common_shape:
                common_rank = len(common_shape) + 1
                start = common_rank - (len(x.shape) + 1)
                dims = tuple(range(start, len(x.shape) + start))
                # NOTE: This line was changed in this function
                return refs_broadcast_in_dim(x, dims, len(common_shape))
        else:
            raise RuntimeError(
                "Unexpected type when broadcasting: " + str(type(x)) + "!"
            )

    return tuple(__maybe_broadcast(x, common_shape) for x in args)

In [None]:
_maybe_broadcast(a, b)

(tensor([[-0.1531],
         [-0.1699],
         [-1.2773],
         [ 0.9696]], device='cuda:0'),
 tensor([[-0.6595,  1.1010, -0.3510]], device='cuda:0'))

In [None]:
# Monkey-patching
torch._refs._maybe_broadcast = _maybe_broadcast

Now let's run again the same example from the top of this notebook and see the resulting graph. We don't save concrete shape of the inputs to the add call anymore.

In [None]:
import torch
print(torch.__version__)
from torch.fx.experimental.proxy_tensor import make_fx
import torch._refs
a = torch.ones((4, 1))
b = torch.ones((3,))
add_fx = make_fx(lambda a, b: torch._refs.add(a, b))(a, b)
print(add_fx.graph)

1.12.1+cu113
graph():
    %a_1 : [#users=2] = placeholder[target=a_1]
    %b_1 : [#users=1] = placeholder[target=b_1]
    %broadcast : [#users=0] = call_function[target=torch.ops.prims.broadcast](args = (%a_1, [False, False]), kwargs = {})
    %broadcast_1 : [#users=1] = call_function[target=torch.ops.prims.broadcast](args = (%b_1, [True, False]), kwargs = {})
    %add : [#users=1] = call_function[target=torch.ops.prims.add](args = (%a_1, %broadcast_1), kwargs = {})
    return add


In [None]:
# Author: mruberry
# What would the grad transform of an add primitive that allowed for implicit 
# expansion be?
# I think it would have to call something like sum_to_size (reproduced below)
import torch._prims as prims
import torch._prims.utils as utils

def sum_to_size(
    a: torch.Tensor,
    *shape,
) -> torch.Tensor:
    shape = utils.extract_shape_from_varargs(shape, validate=False)
    utils.check(
        utils.is_expandable_to(shape, a.shape),
        lambda: f'sum_to_size: size "{shape}" is not expandable to size "{a.shape}"',
    )
    # In ATen scalar tensors are sent through sum and the result is returned as
    # type promoted
    if utils.is_same_shape(shape, a.shape) and len(shape) > 0:
        return prims.view_of(a)
    leading_dims = a.ndim - len(shape)
    reduce_dims = tuple(range(leading_dims)) + tuple(
        i
        for i in range(leading_dims, len(shape))
        if shape[i - leading_dims] == 1 and a.shape[i] != 1
    )
    return torch.sum(a, dim=reduce_dims, keepdim=True, dtype=None)

# I think sum_to_size will then encode the shape in the way we're hoping to avoid
# Alternatives might be to make sum_to_size itself a primitive, or define a 
# new "unary_elementwise_backward" primitive

# With symbolic shapes, however, I think expands can be not to a specific
# value but to a symbol representing the shape of the other tensor

In [None]:
!pip install functorch

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting functorch
  Downloading functorch-0.2.1-cp37-cp37m-manylinux1_x86_64.whl (20.6 MB)
[K     |████████████████████████████████| 20.6 MB 1.4 MB/s 
Installing collected packages: functorch
Successfully installed functorch-0.2.1


In [None]:
# The following won't work in this notebook because it doesn't have functorch
# UPD: fixed by !pip install functorch above
# This example is to illustrate the point about the distinct backwards

import functorch
import torch
import torch.fx as fx

from functorch.compile import aot_function#, make_boxed_func

def foo(a, b):
    return a + b

def print_graph(name):
    def f(fx_g: fx.GraphModule, inps):
        print(name)
        print(fx_g.code)
        return fx_g
    return f

# Pass on the compiler_fn to the aot_function API
aot_print_fn = aot_function(foo, fw_compiler=print_graph("forward"), bw_compiler=print_graph("backward"))

a = torch.randn(2, 1, requires_grad=True)
b = torch.randn(2, 2, requires_grad=True)
ref = aot_print_fn(a, b)
loss = ref.sum()
loss.backward()

# when a has shape (2, 2)
# def forward(self, primals_1, primals_2):
#     add = torch.ops.aten.add.Tensor(primals_1, primals_2);  primals_1 = primals_2 = None
#     return [add]

# def backward(self, tangents_1):
#     return [tangents_1, tangents_1]


# when a has shape (2, 1)
# def forward(self, primals_1, primals_2):
#     add = torch.ops.aten.add.Tensor(primals_1, primals_2);  primals_1 = primals_2 = None
#     return [add]

# def backward(self, tangents_1):
#     sum_1 = torch.ops.aten.sum.dim_IntList(tangents_1, [1], True)
#     return [sum_1, tangents_1]

# Alternative for graph pass:
# def backward(self, tangents_1, primals_1, primals_2):
#     if is_implicitly_expanded(primals_1):
#         sum_1 = torch.ops.aten.sum.dim_IntList(tangents_1, [1], True)
#     if is_implicitly_expanded(primals_2):
#         sum_2 = ...
#     return [sum_1, sum_2]

forward



def forward(self, primals_1, primals_2):
    add = torch.ops.aten.add(primals_1, primals_2);  primals_1 = primals_2 = None
    return [add]
    
backward



def forward(self, tangents_1):
    sum_1 = torch.ops.aten.sum(tangents_1, [1], True)
    return [sum_1, tangents_1]
    


When operations implicitly expand their forward is the same when implicit expansion occurs but their backwards is still distinct and suffers from explicit shapes. To address this issue we also have to redesign the backward and/or adopt symbolic shapes.