In [1]:
import torch
import typing
import functorch
import itertools

# 2.3 Tensored Shapes
<img src="SVG/tensoredshapes.svg" width="700">

In [2]:
# Tensored shapes actually express a rather simple relationship. The definition attached to the diagram can be shown as;
a, b = [3], [4]
Xab = torch.rand(a + b)
for i, j in itertools.product(range(a[0]), range(b[0])):
    assert Xab[i, j] == Xab[i, :][j]
    assert Xab[i, j] == Xab[:, j][i]

# 2.4 Broadcasting
## Definition 2.13 (Broadcasting with indexes)
<img src="SVG/broadcasting_indexes.svg" width="700">

In [3]:
a, b, c, d = [3], [4], [5], [3]
tt = torch.Tensor

# We have some function from a to b;
def F(Xa: tt) -> tt:
    """ F: a -> b """
    return sum(Xa**2) + torch.ones(b)

# We could bootstrap a definition of broadcasting,
# Note that we are using spaces to indicate tensoring. 
# We will use commas for tupling, which is in line with standard notation while writing code.
def Fc(Xac: tt) -> tt:
    """ F c : a c -> b c """
    Ybc = torch.zeros(b + c)
    for j in range(c[0]):
        Ybc[:,j] = F(Xac[:,j])
    return Ybc

# Or use a PyTorch command,
# F *: a * -> b *
Fs = torch.vmap(F, -1, -1)

# We feed a random input, and see whether Definition 2.13 is satisfied,
Xac = torch.rand(a + c)
for k in range(c[0]):
    assert torch.allclose(F(Xac[:,k]), Fc(Xac)[:,k])
    assert torch.allclose(F(Xac[:,k]), Fs(Xac)[:,k])

# This shows how our definition of broadcasting lines up with that used by PyTorch vmap.

## Definition 2.14 (Inner broadcasting)
<img src="SVG/inner_broadcasting.svg" width="700">

In [4]:
# We have some function which can be inner broadcast,
def G(Xa: tt, Xb: tt) -> tt:
    """ G: a, b -> c"""
    return torch.sum(torch.sqrt(Xa**2)) + torch.sum(torch.sqrt(Xb ** 2))  + torch.ones(c)

# We can bootstrap inner broadcasting,
def Gd1(Xa: tt, Xbd : tt) -> tt:
    """ G d1: a, b d -> c d"""
    Ycd = torch.zeros(c + d)
    for k in range(d[0]):
        Ycd[:, k] = G(Xa, Xbd[:,k])
    return Ycd

# But vmap offers a clear way of doing it,
# G *1: a + (b *) -> c *
Gs1 = torch.vmap(G, (None,1), 1)

# We can show this satisfies Definition 2.14 by,
Xa = torch.rand(a)
Xbd = torch.rand(b + d)
for k in range(d[0]):
    assert torch.allclose(G(Xa, Xbd[:,k]), Gd1(Xa, Xbd)[:,k])
    assert torch.allclose(G(Xa, Xbd[:,k]), Gs1(Xa, Xbd)[:,k])


# 2.5 Elementwise Operations
<img src="SVG/elementwise.svg" width="700">

In [5]:

# Elementwise operations are implemented as usual ie
def f(x):
    "f : 1 -> 1"
    return x ** 2

# We broadcast an elementwise operation,
# f *: * -> *
fs = torch.vmap(f)

Xa = torch.rand(a)
for i in range(a[0]):
    # And see that it aligns with the index before = index after framework.
    assert torch.allclose(f(Xa[i]), fs(Xa)[i])
    # But, elementwise operations are implied, so no special implementation is needed. 
    assert torch.allclose(f(Xa[i]), f(Xa)[i])

# 2.6 Addition and Copying
<img src="SVG/utility_morphisms.svg" width="700">

In [6]:
# The different forms of addition line up with PyTorch broadcasting, with slight modifications.

# The first diagrams - adding a tuple of singular values - is clear.
# More precisely, local memory *only* contains the variables indicated by each shape.
# Therefore, the initial x, y values are deleted when z is calculated.
def give_xy():
    """ -> 1, 1 """
    x, y = torch.rand(1), torch.rand(1)
    print(locals())
    return x, y
def give_z():
    """ 1, 1 -> 1 """
    z = sum(give_xy())
    print(locals())
    return z
give_z()

{'x': tensor([0.4905]), 'y': tensor([0.9016])}
{'z': tensor([1.3921])}


tensor([1.3921])

In [7]:
# The other addition operators are analogous to PyTorch broadcasts.
i, j = [3], [4]

def distribute_and_broadcast(X0i, X1i):
    """ i, i -> i """
    return X0i + X1i
def inner_broadcast_0(Xi : tt, Xj : tt):
    """ i, j -> i j """
    return Xi.view(Xi.shape + [1]) + Xj.view([1] + Xj.shape)
def adding_tuples(X0ij : tt, X1ij : tt):
    """ (i, j), (i, j) -> i j """
    return X0ij + X1ij
def inner_broadcast_1(Xi: tt, Xij: tt):
    """ i, i j -> i j """
    return Xi.view(Xi.shape + [1]) + Xij

# Copying automatically broadcasts. 
# Other than copying, everything we do is inplace. 
# In PyTorch, this is indicated by functions starting with _
# Copying is properly given by the clone operation.
def copy(x : tt):
    """ a -> a, a """
    return x, torch.clone(x)


In [8]:
import math
import einops
x, y, k, h = [5], [3], [4], [2]
Q = torch.rand(y + k + h)
K = torch.rand(x + k + h)
k = 4

# Local memory contains,
# Q: y k h # K: x k h
# Outer products, transposes, inner products, and
# diagonalization reduce to einops expressions.
# Transpose K,
K = einops.einsum(K, 'x k h -> k x h')
# Outer product and diagonalize,
X = einops.einsum(Q, K, 'y k1 h, k2 x h -> y k1 k2 x h')
# Inner product,
X = einops.einsum(X, 'y k k x h -> y x h')
# Scale,
X = X / math.sqrt(k)

x, y, k, h = [5], [3], [4], [2]
Q = torch.rand(y + k + h)
K = torch.rand(x + k + h)
k = 4

# Local memory contains,
# Q: y k h # K: x k h
X = einops.einsum(Q, K, 'y k h, x k h -> y x h')
X = X / math.sqrt(k)


# 2.10 Summing Over and Rearranging Data
<img src="SVG/common_operations.svg" width="700">

In [9]:
import math
import einops

# Summing over and rearranging data are common operations. 
# As they are linear, they can be simultaneously applied.
# Scaled dot product attention is an operation that is difficult to express with standard notation,
# implementing a diagram as code will teach us about using the einsum package in conjunction with
# neural circuit diagrams.
y, k, heads, x = [16], [32], [8], [12]

def MultiHeadDotProductAttention(q: tt, k: tt, v: tt) -> tt:
    ''' y k h, x k h, x k h -> y k h '''
    klength = k.size()[-2]
    x = einops.einsum(q, k, '... y k h, ... x k h -> ... y x h')
    x = torch.nn.Softmax(-2)(x / math.sqrt(klength))
    x = einops.einsum(x, v, '... y x h, ... x k h -> ... y k h')
    return x

In [10]:
# As an exercise, we can track the memory usage *exactly*;
Q, K, V = torch.rand(y+k+heads), torch.rand(x+k+heads), torch.rand(x + k + heads)
X, V    = einops.einsum(Q, K, 'y k h,x k h -> y x h'), V
del Q, K
X, V    = torch.nn.Softmax(-2)(X / math.sqrt(k[0])), V
X       = einops.einsum(X, V, 'y x h,x k h -> y k h')
del V

In [11]:
# Local memory contains,
# Q: y k h # K: x k h
# Transpose K
Q, K = Q, einops.einsum(K, 'x k h -> k x h')
# Implicit outer product and diagonalize
X = einops.einsum(Q, K, 'y k h, k x h -> y k1 k2 x h')
# Inner product
X = einops.einsum(X, 'y k k x h -> y x h')
# Scale
X = X / math.sqrt(k)

NameError: name 'Q' is not defined

In [None]:
# Local memory contains,
# Q: y k h # K: x k h
X = einops.einsum(Q, K, 'y k h,x k h->y x h')
X = X / math.sqrt(k)

# 2.11 Associated Tensors and Linear Algebra
<img src="SVG/associated_tensors.svg" width="700">

In [None]:
a, b, c, d = [3], [4], [5], [3]

# We will be using this guy *a lot*
es = einops.einsum

# F: a b c
F_matrix = torch.rand(a + b + c)

# As an exericse we will show that F: a -> b c can be transposed in two ways.
# Either, we can broadcast, or take an outer product. We will show these are the same.

# Transposing by broadcasting
# 
def F_func(Xa: tt):
    """ F: a -> b c """
    return es(Xa,F_matrix,'a,a b c->b c',)
# * F: * a -> * b c
F_broadcast = torch.vmap(F_func, 0, 0)

# We then reduce it, as in the diagram,
# b a -> b b c -> c
def F_broadcast_transpose(Xba: tt):
    """ (b F) (.b c): b a -> c """
    Xbbc = F_broadcast(Xba)
    return es(Xbbc, 'b b c -> c')

# Transpoing by linearity
#
# We take the outer product of Id(b) and F, and follow up with a dot product.
# This gives us,
F_outerproduct = es(torch.eye(b[0]), F_matrix,'b0 b1, a b2 c->b0 b1 a b2 c',)
# Think of this as Id(b) F: b0 a -> b1 b2 c
# We then take the dot product
F_linear_transpose = es(F_outerproduct,'b B a B c->b a c',)

# We contend that these are the same.
#
Xba = torch.rand(b + a)
assert torch.allclose(
    F_broadcast_transpose(Xba), 
    es(Xba,F_linear_transpose, 'b a, b a c -> c'))

# Furthermore, lets prove the Kronecker Delta-Dot product identity.
#
# The first step is an outer product with the Kronecker delta,
outerKronecker = lambda Xb: es(Xb, torch.eye(b[0]), 'b0, b1 b2 -> b0 b1 b2')
# The next is a dot product over the first two axes,
dotOuter = lambda Xbbb: es(Xbbb, 'b0 b0 b1 -> b1')
# Applying both of these *should* be the identity, and hence leave any input unchanged.
Xb = torch.rand(b)
assert torch.allclose(
    Xb,
    dotOuter(outerKronecker(Xb)))

# Therefore, we can confidently use the expressions in Figure 8 to manipulate expressions.

# 3.1 Basic Multi-Layer Perceptron
<img src="SVG/imagerec.svg" width="700">

In [None]:
import torch.nn as nn

# Basic Image Recogniser
# This is a close copy of an introductory PyTorch tutorial:
# https://pytorch.org/tutorials/beginner/basics/buildmodel_tutorial.html
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )
    def forward(self, x):
        x = self.flatten(x)
        x = self.linear_relu_stack(x)
        y_pred = nn.Softmax(x)
        return y_pred
    
my_NeuralNetwork = NeuralNetwork()
my_NeuralNetwork.forward(torch.rand([1,28,28]))

Softmax(
  dim=tensor([[ 0.0282,  0.0299, -0.0271, -0.0193, -0.0753,  0.1084, -0.0097,  0.0123,
           -0.0920,  0.1043]], grad_fn=<AddmmBackward0>)
)

# 3.2 Multi-Head Attention
<img src="SVG/multihead_attention.svg" width="700">

We will be implementing this algorithm. This shows us how we go from diagrams to implementations, and begins to give an idea of how organized diagrams leads to organized code.

In [None]:
# Defined previously, closely follows the diagram.
def MultiHeadDotProductAttention(q: tt, k: tt, v: tt) -> tt:
    ''' ykh, xkh, xkh -> ykh '''
    klength = k.size()[-2]
    x = einops.einsum(q, k, '... y k h, ... x k h -> ... y x h')
    x = torch.nn.Softmax(-2)(x / math.sqrt(klength))
    x = einops.einsum(x, v, '... y x h, ... x k h -> ... y k h')
    return x

# We implement this component as a neural network model.
# This is necessary when there are bold, learned components that need to be initialized.
class MultiHeadAttention(nn.Module):
    # Multi-Head attention has various settings, which become variables
    # for the initializer.
    def __init__(self, m, k, h):
        super().__init__()
        self.m, self.k, self.h = m, k, h
        # Set up all the boldface, learned components
        # Note how they bind axes we want to split, which we do later with einops.
        self.Lq = nn.Linear(m, k*h, False)
        self.Lk = nn.Linear(m, k*h, False)
        self.Lv = nn.Linear(m, k*h, False)
        self.Lo = nn.Linear(k*h, m, False)


    # We have endogenous data (Eym) and external / injected data (Xxm)
    def forward(self, Eym, Xxm):
        """ y m, x m -> y m """
        # We first generate query, key, and value vectors.
        # Linear layers are automatically broadcast.

        # However, the k and h axes are bound. We define an unbinder to handle the outputs,
        unbind = lambda x: einops.rearrange(x, '... (k h)->... k h', h=self.h)
        q = unbind(self.Lq(Eym))
        k = unbind(self.Lk(Xxm))
        v = unbind(self.Lv(Xxm))

        # We feed q, k, and v to standard Multi-Head Dot Product Attention
        o = MultiHeadDotProductAttention(q, k, v)

        # Rebind to feed to the final learned layer,
        o = einops.rearrange(o, '... k h-> ... (k h)', h=self.h)
        return self.Lo(o)

# Now we can run it on fake data;
y, x, m, k, heads = [20], [22], [128], [16], 4
# Internal Data
Eym = torch.rand(y + m)
# External Data
Xxm = torch.rand(x + m)

mha = MultiHeadAttention(m[0],k[0],heads)
assert list(mha.forward(Eym, Xxm).size()) == y + m


# 3.4 Computer Vision
## Figure 15: Identity Residual Network
<img src="SVG/Identity ResNet.svg" width="700">

Here, we really start to understand why splitting diagrams into ``fenced off'' blocks aids implementation. 
In addition to making diagrams easier to understand and patterns more clearn, blocks indicate how code can structured and organized.

In [None]:
# For Figure 15, every fenced off region is its own module.

# Batch norm and then activate is a repeated motif,
class NormActivate(nn.Sequential):
    def __init__(self, nf, Norm=nn.BatchNorm2d, Activation=nn.ReLU):
        super().__init__(Norm(nf), Activation())

def size_to_string(size):
    return " ".join(map(str,list(size)))

# The Identity ResNet block breaks down into a manageable sequence of components.
class IdentityResNet(nn.Sequential):
    def __init__(self, N=3, n_mu=[16,64,128,256], y=10):
        super().__init__(
            nn.Conv2d(3, n_mu[0], 3, padding=1),
            Block(1, N, n_mu[0], n_mu[1]),
            Block(2, N, n_mu[1], n_mu[2]),
            Block(2, N, n_mu[2], n_mu[3]),
            NormActivate(n_mu[3]),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(n_mu[3], y),
            nn.Softmax(-1),
            )

<img src="SVG/Block.svg" width="700">

The Block can be defined in a seperate model, keeping the code manageable and closely connected to the diagram.

In [None]:
# We then follow how diagrams define each ``block''
class Block(nn.Sequential):
    def __init__(self, s, N, n0, n1):
        """ n0 and n1 as inputs to the initializer are implicit from having them in the domain and codomain in the diagram. """
        nb = n1 // 4
        super().__init__(
            *[
            NormActivate(n0),
            ResidualConnection(
                nn.Sequential(
                    nn.Conv2d(n0, nb, 1, s),
                    NormActivate(nb),
                    nn.Conv2d(nb, nb, 3, padding=1),
                    NormActivate(nb),
                    nn.Conv2d(nb, n1, 1),
                ),
                nn.Conv2d(n0, n1, 1, s),
            )
            ] + [
            ResidualConnection(
                nn.Sequential(
                    NormActivate(n1),
                    nn.Conv2d(n1, nb, 1),
                    NormActivate(nb),
                    nn.Conv2d(nb, nb, 3, padding=1),
                    NormActivate(nb),
                    nn.Conv2d(nb, n1, 1)
                ),
                )
            ] * N
            
            )   
# Residual connections are a repeated pattern in the diagram. So, we are motivated to encapsulate them
# as a seperate module.
class ResidualConnection(nn.Module):
    def __init__(self, mainline : nn.Module, connection : nn.Module | None = None) -> None:
        super().__init__()
        self.main = mainline
        self.secondary = nn.Identity() if connection == None else connection
    def forward(self, x):
        return self.main(x) + self.secondary(x)

In [None]:
# A standard image processing algorithm has inputs shaped b c h w.
b, c, hw = [3], [3], [16, 16]

idresnet = IdentityResNet()
Xbchw = torch.rand(b + c + hw)

# And we see if the overall size is maintained,
assert list(idresnet.forward(Xbchw).size()) == b + [10]

## Figure 16: The UNet architecture
<img src="SVG/unet.svg" width="700">

The UNet is a more complicated algorithm than residual networks. This makes its diagram, and hence implementation, more complicated.

In [None]:
# We notice that double convolution where the numbers of channels change is a repeated motif.
# We denote the input with c0 and output with c1. 
# This can also be done for subsequent members of an iteration.
# When we go down an iteration eg. 5, 4, etc. we may have the input be c1 and the output c0.
class DoubleConvolution(nn.Sequential):
    def __init__(self, c0, c1, Activation=nn.ReLU):
        super().__init__(
            nn.Conv2d(c0, c1, 3, padding=1),
            Activation(),
            nn.Conv2d(c0, c1, 3, padding=1),
            Activation(),
            )

# The model is specified for a very specific number of layers,
# so we will not make it very flexible.
class UNet(nn.Module):
    def __init__(self, y=2):
        super().__init__()
        # Set up the channel sizes;
        c = [1 if i == 0 else 64 * 2 ** i for i in range(6)]

        # Saving and loading from memory means we can not use a single,
        # sequential chain.

        # Set up and initialize the components;
        self.DownScaleBlocks = [
            DownScaleBlock(c[i],c[i+1])
            for i in range(0,4)
        ] # Note how this imitates the lambda operators in the diagram.
        self.middleDoubleConvolution = DoubleConvolution(c[4], c[5])
        self.middleUpscale = nn.ConvTranspose2d(c[5], c[4], 2, 2, 1)
        self.upScaleBlocks = [
            UpScaleBlock(c[5-i],c[4-i])
            for i in range(1,4)
        ]
        self.finalConvolution = nn.Conv2d(c[1], y)

    def forward(self, x):
        cLambdas = []
        for dsb in self.DownScaleBlocks:
            x, cLambda = dsb(x)
            cLambdas.append(cLambda)
        x = self.middleDoubleConvolution(x)
        x = self.middleUpscale(x)
        for usb in self.upScaleBlocks:
            cLambda = cLambdas.pop()
            x = usb(x, cLambda)
        x = self.finalConvolution(x)

class DownScaleBlock(nn.Module):
    def __init__(self, c0, c1) -> None:
        super().__init__()
        self.doubleConvolution = DoubleConvolution(c0, c1)
        self.downScaler = nn.MaxPool2d(2, 2, 1)
    def forward(self, x):
        cLambda = self.doubleConvolution(x)
        x = self.downScaler(cLambda)
        return x, cLambda

class UpScaleBlock(nn.Module):
    def __init__(self, c1, c0) -> None:
        super().__init__()
        self.doubleConvolution = DoubleConvolution(2*c1, c1)
        self.upScaler = nn.ConvTranspose2d(c1,c0,2,2,1)
    def forward(self, x, cLambda):
        # Concatenation occurs over the C channel axis (dim=1)
        x = torch.concat(x, cLambda, 1)
        x = self.doubleConvolution(x)
        x = self.upScaler(x)
        return x

## Figure 17: Visual Attention
<img src="SVG/visual_attention.svg" width="700">

We adapt our code for Multi-Head Attention to apply it to the vision case.
This is a good exercise in how Neural Circuit Diagrams allow code to be easily adapted for new modalities.

In [None]:
class VisualAttention(nn.Module):
    def __init__(self, c, k, heads = 1, kernel = 1, stride = 1):
        super().__init__()
        
        # w gives the kernel size, which we make adjustable.
        self.c, self.k, self.h, self.w = c, k, heads, kernel
        # Set up all the boldface, learned components
        # Note how standard components may not have axes bound in 
        # the same way as diagrams. This requires us to rearrange
        # using the einops package.

        # The learned layers form convolutions
        self.Cq = nn.Conv2d(c, k * heads, kernel, stride)
        self.Ck = nn.Conv2d(c, k * heads, kernel, stride)
        self.Cv = nn.Conv2d(c, k * heads, kernel, stride)
        self.Co = nn.ConvTranspose2d(
                            k * heads, c, kernel, stride)

    # Defined previously, closely follows the diagram.
    def MultiHeadDotProductAttention(self, q: tt, k: tt, v: tt) -> tt:
        ''' ykh, xkh, xkh -> ykh '''
        klength = k.size()[-2]
        x = einops.einsum(q, k, '... y k h, ... x k h -> ... y x h')
        x = torch.nn.Softmax(-2)(x / math.sqrt(klength))
        x = einops.einsum(x, v, '... y x h, ... x k h -> ... y k h')
        return x

    # We have endogenous data (EYc) and external / injected data (XXc)
    def forward(self, EcY, XcX):
        """ cY, cX -> cY 
        The visual attention algorithm. Injects information from Xc into Yc. """
        # query, key, and value vectors.
        # We unbind the k h axes which were produced by the convolutions, and feed them
        # in the normal manner to MultiHeadDotProductAttention.
        unbind = lambda x: einops.rearrange(x, 'N (k h) H W -> N (H W) k h', h=self.h)
        # Save size to recover it later
        q = self.Cq(EcY)
        W = q.size()[-1]

        # By appropriately managing the axes, minimal changes to our previous code
        # is necessary.
        q = unbind(q)
        k = unbind(self.Ck(XcX))
        v = unbind(self.Cv(XcX))
        o = self.MultiHeadDotProductAttention(q, k, v)

        # Rebind to feed to the transposed convolution layer.
        o = einops.rearrange(o, 'N (H W) k h -> N (k h) H W', 
                             h=self.h, W=W)
        return self.Co(o)

# Single batch element,
b = [1]
Y, X, c, k = [16, 16], [16, 16], [33], 8
# The additional configurations,
heads, kernel, stride = 4, 3, 3

# Internal Data,
EYc = torch.rand(b + c + Y)
# External Data,
XXc = torch.rand(b + c + X)

# We can now run the algorithm,
visualAttention = VisualAttention(c[0], k, heads, kernel, stride)

# Interestingly, the height/width reduces by 1 for stride
# values above 1. Otherwise, it stays the same.
visualAttention.forward(EYc, XXc).size()

torch.Size([1, 33, 15, 15])

# Appendix

In [None]:
# A container to track the size of modules,
class Tracker(nn.Module):
    def __init__(self, module: nn.Module, name : str = ""):
        super().__init__()
        self.module = module
        if name:
            self.name = name
        else:
            self.name = self.module._get_name()
    def forward(self, x):
        x_size = size_to_string(x.size())
        x = self.module.forward(x)
        y_size = size_to_string(x.size())
        print(f"{self.name}: \t {x_size} -> {y_size}")
        return x