In [1]:
import torch
import typing
import functorch

# 2.4 Broadcasting
![Alt text](https://mirrors.creativecommons.org/presskit/logos/cc.logo.svg)

In [16]:
a = [3]
b = [4]
c = [5]
d = [3]
tt = torch.Tensor

# F: a -> b
def F(Xa: tt) -> tt:
    return sum(Xa**2) + torch.ones(b)

# A raw definition,
# F c: a c -> b c
def Fc(Xac: tt) -> tt:
    Ybc = torch.zeros(b + c)
    for j in range(c[0]):
        Ybc[:,j] = F(Xac[:,j])
    return Ybc

# Broadcast using PyTorch,
# F *: a -> b *
Fs = torch.vmap(F, -1, -1)

# We can show this satisfies Definition 2.13 by,
Xac = torch.rand(a + c)

for k in range(c[0]):
    assert torch.allclose(F(Xac[:,k]), Fc(Xac)[:,k])
    assert torch.allclose(F(Xac[:,k]), Fs(Xac)[:,k])


# G: a + b -> c
def G(Xa: tt, Xb: tt) -> tt:
    return torch.sum(torch.sqrt(Xa**2)) + torch.sum(torch.sqrt(Xb ** 2))  + torch.ones(c)



# We can implement inner broadcasting by,
# G d1: a + (b d) -> c d
def Gd1(Xa: tt, Xbd : tt) -> tt:
    Ycd = torch.zeros(c + d)
    for k in range(d[0]):
        Ycd[:, k] = G(Xa, Xbd[:,k])
    return Ycd

# Broadcast using PyTorch,
# G *1: a + (b *) -> c *
Gs1 = torch.vmap(G, (None,1), 1)

# We can show this satisfies Definition 2.14 by,
Xa = torch.rand(a)
Xbd = torch.rand(b + d)

for k in range(d[0]):
    assert torch.allclose(G(Xa, Xbd[:,k]), Gd1(Xa, Xbd)[:,k])
    assert torch.allclose(G(Xa, Xbd[:,k]), Gs1(Xa, Xbd)[:,k])

# Elementwise operations are implemented as usual ie
# f: 1 -> 1
def f(x):
    return x ** 2

# f *: * -> *
fs = torch.vmap(f)

Xa = torch.rand(a)
for i in range(a[0]):
    assert torch.allclose(f(Xa[i]), fs(Xa)[i])

# The different forms of addition line up with PyTorch broadcasting, 
# with slight modifications.



In [21]:
# Addition
# Variables are implicitly deleted if not copied. This sequence of variables
# therefore gives an idea of what the first part of Figure 5 is implying
def give_xy():
    x, y = torch.rand(1), torch.rand(1)
    print(locals())
    return x, y

def give_z():
    z = sum(give_xy())
    print(locals())
    return z

give_z()

# The other addition operators are analogous to PyTorch broadcasts
i = [3]
j = [4]

def distribute_and_broadcast(X0i, X1i):
    return X0i + X1i
def inner_broadcast(Xi : tt, Xj : tt):
    Xi.view()


{'x': tensor([0.1909]), 'y': tensor([0.9223])}
{'z': tensor([1.1132])}


tensor([1.1132])