In [1]:
import torch
import typing
import functorch

# 2.4 Broadcasting
![Alt text](https://mirrors.creativecommons.org/presskit/logos/cc.logo.svg)

In [16]:
a = [3]
b = [4]
c = [5]
d = [3]
tt = torch.Tensor

# F: a -> b
def F(Xa: tt) -> tt:
    return sum(Xa**2) + torch.ones(b)

# A raw definition,
# F c: a c -> b c
def Fc(Xac: tt) -> tt:
    Ybc = torch.zeros(b + c)
    for j in range(c[0]):
        Ybc[:,j] = F(Xac[:,j])
    return Ybc

# Broadcast using PyTorch,
# F *: a * -> b *
Fs = torch.vmap(F, -1, -1)

# We can show this satisfies Definition 2.13 by,
Xac = torch.rand(a + c)

for k in range(c[0]):
    assert torch.allclose(F(Xac[:,k]), Fc(Xac)[:,k])
    assert torch.allclose(F(Xac[:,k]), Fs(Xac)[:,k])


# G: a + b -> c
def G(Xa: tt, Xb: tt) -> tt:
    return torch.sum(torch.sqrt(Xa**2)) + torch.sum(torch.sqrt(Xb ** 2))  + torch.ones(c)



# We can implement inner broadcasting by,
# G d1: a + (b d) -> c d
def Gd1(Xa: tt, Xbd : tt) -> tt:
    Ycd = torch.zeros(c + d)
    for k in range(d[0]):
        Ycd[:, k] = G(Xa, Xbd[:,k])
    return Ycd

# Broadcast using PyTorch,
# G *1: a + (b *) -> c *
Gs1 = torch.vmap(G, (None,1), 1)

# We can show this satisfies Definition 2.14 by,
Xa = torch.rand(a)
Xbd = torch.rand(b + d)

for k in range(d[0]):
    assert torch.allclose(G(Xa, Xbd[:,k]), Gd1(Xa, Xbd)[:,k])
    assert torch.allclose(G(Xa, Xbd[:,k]), Gs1(Xa, Xbd)[:,k])

# Elementwise operations are implemented as usual ie
# f: 1 -> 1
def f(x):
    return x ** 2

# f *: * -> *
fs = torch.vmap(f)

Xa = torch.rand(a)
for i in range(a[0]):
    assert torch.allclose(f(Xa[i]), fs(Xa)[i])

# The different forms of addition line up with PyTorch broadcasting, 
# with slight modifications.



In [21]:
# Addition
# Variables are implicitly deleted if not copied. This sequence of variables
# therefore gives an idea of what the first part of Figure 5 is implying.
def give_xy():
    x, y = torch.rand(1), torch.rand(1)
    print(locals())
    return x, y

def give_z():
    z = sum(give_xy())
    print(locals())
    return z

give_z()

# The other addition operators are analogous to PyTorch broadcasts.
i = [3]
j = [4]

def distribute_and_broadcast(X0i, X1i):
    return X0i + X1i
def inner_broadcast_0(Xi : tt, Xj : tt):
    return Xi.view(Xi.shape + [1]) + Xj.view([1] + Xj.shape)
def adding_tuples(X0ij : tt, X1ij : tt):
    return X0ij + X1ij
def inner_broadcast_1(Xi: tt, Xij: tt):
    return Xi.view(Xi.shape + [1]) + Xij

# Copying, this automatically broadcasts.
def copy(x : tt):
    return x, torch.clone(x)


{'x': tensor([0.1909]), 'y': tensor([0.9223])}
{'z': tensor([1.1132])}


tensor([1.1132])

In [49]:
import math
import einops
# Summing and Rearranging Data
# The best example for this is scaled dot product attention. We start with,
y, k, heads, x = [16], [32], [8], [12]

# Following memory *exactly*,
Q, K, V = torch.rand(y+k+heads), torch.rand(x+k+heads), torch.rand(x + k + heads)
X, V    = einops.einsum(Q, K, 'y k h,x k h -> y x h'), V
del Q, K
X, V    = torch.nn.Softmax(-2)(X / math.sqrt(k[0])), V
X       = einops.einsum(X, V, 'y x h,x k h -> y k h')
del V

# Or, we can define,
def MultiHeadDotProductAttention(q: tt, k: tt, v: tt) -> tt:
    ''' ykh, xkh, xkh -> ykh '''
    klength = k.size()[-2]
    x = einops.einsum(q, k, 'y k h, x k h -> y x h')
    x = torch.nn.Softmax(-2)(x / math.sqrt(klength))
    x = einops.einsum(x, v, 'y x h, x k h -> y k h')
    return x


In [54]:
es = einops.einsum

# Linear Algebra
a = [3]
b = [4]
c = [5]
d = [3]

# F: a b c
F_matrix = torch.rand(a + b + c)

# As in Figure 8,
# F: a -> b c
def F_func(Xa: tt):
    return es(Xa,F_matrix,'a,a b c->b c',)

# Transpoing by linearity
# We take the outer product of Id(b) and F, and follow up with a dot product.
# Flt: b a c
F_linear_transpose = es(torch.eye(b[0]),F_matrix,'b B, a B c->b a c',)

# We contend this this is the same as broadcasting F, and following
# with a dot product.
# * F: * a -> * b c
F_broadcast = torch.vmap(F_func, 0, 0)

# b a -> b b c -> c
def F_broadcast_transpose(Xba: tt):
    Xbbc = F_broadcast(Xba)
    return es(Xbbc, 'b b c -> c')

Xba = torch.rand(b + a)

assert torch.allclose(
    F_broadcast_transpose(Xba), 
    es(Xba,F_linear_transpose, 'b a,b a c -> c'))

# The Kronecker-Dot identity,
# The first step is an outer product with the Kronecker delta,
outerKronecker = lambda Xb: es(Xb, torch.eye(b[0]), 'b0,b1 b2->b0 b1 b2')
# The next is a dot product over the first two axes,
dotOuter = lambda Xbbb: es(Xbbb, 'b0 b0 b1 -> b1')

Xb = torch.rand(b)
assert torch.allclose(
    Xb,
    dotOuter(outerKronecker(Xb)))

# Therefore, we can confidently use the expressions in Figure 8 to transpose expressions.


In [55]:
import torch.nn as nn

# Basic Image Recogniser
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )
    def forward(self, x):
        x = self.flatten(x)
        x = self.linear_relu_stack(x)
        y_pred = nn.Softmax(x)
        return y_pred
    
my_NeuralNetwork = NeuralNetwork()
my_NeuralNetwork.forward(torch.rand([1,28,28]))

Softmax(
  dim=tensor([[-0.1206, -0.0320,  0.0639, -0.0608, -0.1142, -0.1160, -0.0625, -0.0675,
           -0.0875,  0.0338]], grad_fn=<AddmmBackward0>)
)

In [61]:
# Multi-Head Attention (Figure 11)

class MultiHeadAttention(nn.Module):
    def __init__(self, m, k, h):
        super().__init__()

        self.m, self.k, self.h = m, k, h
        # Set up all the boldface, learned components
        # Note how they bind axes we want to split, we do this with
        # einops.
        self.Lq = nn.Linear(m, k*h, False)
        self.Lk = nn.Linear(m, k*h, False)
        self.Lv = nn.Linear(m, k*h, False)
        self.Lo = nn.Linear(k*h, m, False)

    # Defined previously, closely follows the diagram.
    def MultiHeadDotProductAttention(q: tt, k: tt, v: tt) -> tt:
        ''' ykh, xkh, xkh -> ykh '''
        klength = k.size()[-2]
        x = einops.einsum(q, k, 'y k h, x k h -> y x h')
        x = torch.nn.Softmax(-2)(x / math.sqrt(klength))
        x = einops.einsum(x, v, 'y x h, x k h -> y k h')
        return x

    # We have endogenous data (Eym) and external / injected data (Xxm)
    def forward(self, Eym, Xxm):
        # query, key, and value vectors.
        # Linear layers are automatically broadcast.
        # However, the k and h axes are bound.
        unbind = lambda x: einops.rearrange(x, '... (k h)->... k h', h=self.h)
        q = unbind(self.Lq(Eym))
        k = unbind(self.Lk(Xxm))
        v = unbind(self.Lv(Xxm))
        o = MultiHeadDotProductAttention(q, k, v)
        # Rebind to feed to Lo
        o = einops.rearrange(o, '... k h-> ... (k h)', h=self.h)
        return self.Lo(o)

y, x, m, k, heads = [20], [22], [128], [16], [8]

# Internal Data
Eym = torch.rand(y + m)
# External Data
Xxm = torch.rand(x + m)

# We can now run the algorithm,
mha = MultiHeadAttention(m[0],k[0],heads[0])
mha.forward(Eym, Xxm).size()


torch.Size([20, 128])

In [None]:
# For Figure 15, we will interpret the ``fenced off'' regions as seperate modules.

class NormActivate(nn.Sequential):
    def __init__(self, nf, Norm=nn.BatchNorm2d, Activation=nn.ReLU):
        super().__init__(Norm(nf), Activation())

class IdentityResNet(nn.Sequential):
    def __init__(self, N=3, n_mu=[16,64,128,256]):
        super().__init__(
            nn.Conv2d(3, n_mu[0], 3, padding=1),
            Block(1, N, n_mu[0], n_mu[1]),
            Block(2, N, n_mu[1], n_mu[2]),
            Block(2, N, n_mu[2], n_mu[3]),
            NormActivate(n_mu[3]),
            nn.AvgPool2d(8),
            nn.Linear(n_mu[3]),
            nn.Softmax(-1),
            )

class Block(nn.Sequential):
    def __init__(self, s, N, n0, n1):
        """ n0 and n1 as inputs to the initializer are implicit from having them in the domain and codomain in the diagram. """
        nb = n1 / 4
        super().__init__(
            ([
            NormActivate(n0),
            ResidualConnection(
                [
                    nn.Conv2d(n0, nb, 1, s),
                    NormActivate(nb),
                    nn.Conv2d(n0, nb, 3, padding=1),
                    NormActivate(nb),
                    nn.Conv2d(nb, n1),
                ],
                [
                    nn.Conv2d(n0, nb, 1, s),
                ]
            )
            ] + [
            ResidualConnection([
                    NormActivate(n1),
                    nn.Conv2d(n1, n1, 1, stride=2),
                    NormActivate(nb),
                    nn.Conv2d(nb, nb, 3, padding=1),
                    NormActivate(nb),
                    nn.Conv2d(nb, n1, 1)
                ],)
            ] * N
            )
            )   

class ResidualConnection(nn.Module):
    def __init__(self, mainline : list[nn.Module], connection : list[nn.Module] = []) -> None:
        super().__init__()
        self.main = nn.Sequential(mainline)
        self.secondary = nn.Identity if connection == [] else nn.Sequential(connection)
    def forward(self, x):
        return self.main(x) + self.secondary(x)

In [None]:
# For Figure 15, we will interpret the ``fenced off'' regions as seperate modules.

class DoubleConvolution(nn.Sequential):
    def __init__(self, c0, c1, Activation=nn.ReLU):
        super().__init__(
            nn.Conv2d(c0, c1, 3, padding=1),
            Activation(),
            nn.Conv2d(c0, c1, 3, padding=1),
            Activation(),
            )

class UNet(nn.Module):
    def __init__(self, y=2):
        super().__init__()
        c = [1 if i == 0 else 64 * 2 ** i for i in range(6)]

        # Set up the components
        self.DownScaleBlocks = [
            DownScaleBlock(c[i],c[i+1])
            for i in range(0,4)
        ]
        self.middleDoubleConvolution = DoubleConvolution(c[4], c[5])
        self.middleUpscale = nn.ConvTranspose2d(c[5], c[4], 2, 2, 1)
        self.upScaleBlocks = [
            UpScaleBlock(c[5-i],c[4-i])
            for i in range(1,4)
        ]
        self.finalConvolution = nn.Conv2d(c[1], y)

    def forward(self, x):
        cLambdas = []
        for dsb in self.DownScaleBlocks:
            x, cLambda = dsb(x)
            cLambdas.append(cLambda)
        x = self.middleDoubleConvolution(x)
        x = self.middleUpscale(x)
        for usb in self.upScaleBlocks:
            cLambda = cLambdas.pop()
            x = usb(x, cLambda)
        x = self.finalConvolution(x)

class DownScaleBlock(nn.Module):
    def __init__(self, c0, c1) -> None:
        super().__init__()
        self.doubleConvolution = DoubleConvolution(c0, c1)
        self.downScaler = nn.MaxPool2d(2, 2, 1)
    def forward(self, x):
        cLambda = self.doubleConvolution(x)
        x = self.downScaler(cLambda)
        return x, cLambda

class UpScaleBlock(nn.Module):
    def __init__(self, c1, c0) -> None:
        super().__init__()
        self.doubleConvolution = DoubleConvolution(2*c1, c1)
        self.upScaler = nn.ConvTranspose2d(c1,c0,2,2,1)
    def forward(self, x, cLambda):
        # Concatenation occurs over the C channel axis (dim=1)
        x = torch.concat(x, cLambda, 1)
        x = self.doubleConvolution(x)
        x = self.upScaler(x)
        return x

In [85]:
# Visual Attention

class VisualAttention(nn.Module):
    def __init__(self, c, k, heads = 1, kernel = 1, stride = 1):
        super().__init__()
        
        # w gives the kernel size, which we make adjustable.
        self.c, self.k, self.h, self.w = c, k, heads, kernel
        # Set up all the boldface, learned components
        # Note how they bind axes we want to split, we do this with
        # einops.

        # The learned layers form convolutions
        self.Cq = nn.Conv2d(c, k * heads, kernel, stride)
        self.Ck = nn.Conv2d(c, k * heads, kernel, stride)
        self.Cv = nn.Conv2d(c, k * heads, kernel, stride)
        self.Co = nn.ConvTranspose2d(
                            k * heads, c, kernel, stride)

    # Defined previously, closely follows the diagram.
    def MultiHeadDotProductAttention(self, q: tt, k: tt, v: tt) -> tt:
        ''' ykh, xkh, xkh -> ykh '''
        klength = k.size()[-2]
        x = einops.einsum(q, k, '... y k h, ... x k h -> ... y x h')
        x = torch.nn.Softmax(-2)(x / math.sqrt(klength))
        x = einops.einsum(x, v, '... y x h, ... x k h -> ... y k h')
        return x

    # We have endogenous data (EYc) and external / injected data (XXc)
    def forward(self, EcY, XcX):
        """ cY, cX -> cY 
        The visual attention algorithm. Injects information from Xc into Yc. """
        # query, key, and value vectors.
        # We unbind the k h axes which were produced by the convolutions, and feed them
        # in the normal manner to MultiHeadDotProductAttention.
        unbind = lambda x: einops.rearrange(x, 'N (k h) H W -> N (H W) k h', h=self.h)
        # Save size to recover it later
        q = self.Cq(EcY)
        W = q.size()[-1]
        q = unbind(q)
        k = unbind(self.Ck(XcX))
        v = unbind(self.Cv(XcX))
        o = self.MultiHeadDotProductAttention(q, k, v)
        # Rebind to feed to the transposed convolution layer.
        o = einops.rearrange(o, 'N (H W) k h -> N (k h) H W', 
                             h=self.h, W=W)
        return self.Co(o)

# Single batch element
b = [1]
Y, X, c, k = [16, 16], [16, 16], [33], 8
# The additional configurations
heads, kernel, stride = 4, 3, 3

# Internal Data
EYc = torch.rand(b + c + Y)
# External Data
XXc = torch.rand(b + c + X)

# We can now run the algorithm,
visualAttention = VisualAttention(c[0], k, heads, kernel, stride)

# Interestingly, the height/width reduces by 1 for stride
# values above 1.
visualAttention.forward(EYc, XXc).size()

torch.Size([1, 33, 18, 18])