In [1]:
import torch
import typing
import functorch

# 2.4 Broadcasting

In [14]:
a = [3]
b = [4]
c = [5]

tt = torch.Tensor

# F: a -> b
def F(Xa: tt) -> tt:
    return sum(Xa**2) + torch.ones(b)

# A raw definition,
# F c: a c -> b c
def Fc(Xac: tt) -> tt:
    Ybc = torch.zeros(b + c)
    for j in range(c[0]):
        Ybc[:,j] = F(Xac[:,j])
    return Ybc

# Broadcast using PyTorch,
# F *: a -> b *
Fs = torch.vmap(F, -1, -1)

# We can show this satisfies Definition 2.13 by,
Xac = torch.rand(a + c)

for k in range(c[0]):
    assert torch.allclose(F(Xac[:,k]), Fc(Xac)[:,k])
    assert torch.allclose(F(Xac[:,k]), Fs(Xac)[:,k])


# G: a + b -> c
def G(Xa: tt, Xb: tt) -> tt:
    return torch.sum(torch.sqrt(Xa**2)) + torch.sum(torch.sqrt(Xb ** 2))  + torch.ones(c)

d = [3]

# We can implement inner broadcasting by,
# G d1: a + (b d) -> c d
def Gd1(Xa: tt, Xbd : tt) -> tt:
    Ycd = torch.zeros(c + d)
    for k in range(d[0]):
        Ycd[:, k] = G(Xa, Xbd[:,k])
    return Ycd

# Broadcast using PyTorch,
# G *1: a + (b *) -> c *
Gs1 = torch.vmap(G, (None,1), 1)

# We can show this satisfies Definition 2.14 by,
Xa = torch.rand(a)
Xbd = torch.rand(b + d)

for k in range(d[0]):
    assert torch.allclose(G(Xa, Xbd[:,k]), Gd1(Xa, Xbd)[:,k])
    assert torch.allclose(G(Xa, Xbd[:,k]), Gs1(Xa, Xbd)[:,k])

ValueError: vmap(G, in_dims=[None, 1], ...)(<inputs>): expected `in_dims` to be int or a (potentially nested) tuple matching the structure of inputs, got: <class 'list'>.