In [1]:
import ncd
from ncd import shape

## Dot-Product Attention
Using *Neural Circuit Diagrams* we can represent dot-product attention by;

<img src="Graphics/attention.png" width="700">

Each vertical section of a diagram corresponds to a shape. Columns either represent the data type we are working with, or an operation between data types. Solid lines represent axes and dashed lines separating terms represent Cartesian products. Placing a solid line adjacent to an operation, without separation, lifts it. Wirings represent Einstein operations, which includes linear contractions and rearrangements.

In [2]:
def display_columns(target: ncd.Shape, name = None):
    columns = '\n'.join([
        f'Domain:   {target.dom}',
        *(f'Morphism: {x}\n' + 
          f'Object:   {x.cod}' 
        for x in ncd.Composed.get_content(target))
    ])
    if name:
        columns = name + '\n' + columns
    return columns

In [3]:
from ncd import Duplicate
from ncd.nn import Einops, Linear, Addition, SoftMax
from itertools import starmap

L = Linear

x = shape('x')
m = shape('m^')
# Copying is implicit
linears = x >> (m @ (L('q') + L('k') + L('v')) @ 'k^')
# The axis names for einops are simply used as tags, which are attached
# to configurations.
ein1 = Einops('y k, x k -> y x')
softmax = x >> SoftMax()
ein2 = Einops('y x, x k -> y k')
linOut = L('o') @ m

# Note, the printout will contain tagged axes. These are configured
# upon composition.
section_names = ['Linears', 'Einops', 'SoftMax', 'Einops', 'Linear']
sections = [linears, ein1, softmax, ein2, linOut]
print("Printout of Individual Sections;")
print('\n'.join(starmap(display_columns, zip(sections, section_names))))

Printout of Individual Sections;
Linears
Domain:   [[36m[4mx[0m [36m[4mm[0m→]
Morphism: [[36m[4mx[0m [36m[4mm[0m→Δ3]
Object:   ([[36m[4mx[0m [36m[4mm[0m→], [[36m[4mx[0m [36m[4mm[0m→], [[36m[4mx[0m [36m[4mm[0m→])
Morphism: ([[36m[4mx[0m→Lq], [[36m[4mx[0m→Lk], [[36m[4mx[0m→Lv])
Object:   ([[36m[4mx[0m [36m[4mk[0m→], [[36m[4mx[0m [36m[4mk[0m→], [[36m[4mx[0m [36m[4mk[0m→])
Einops
Domain:   ([[34my[0m=[33my.CC[0m [34mk[0m=[33mk.DC[0m→], [[34mx[0m=[33mx.F7[0m [34mk[0m=[33mk.DC[0m→])
Morphism: [32m[4my k, x k -> y x[0m
Object:   [[34my[0m=[33my.CC[0m [34mx[0m=[33mx.F7[0m→]
SoftMax
Domain:   [[36m[4mx[0m [34m*[0m=[33m*.B1[0m→]
Morphism: [[36m[4mx[0m→[32m[4m◁[0m]
Object:   [[36m[4mx[0m [34m*[0m=[33m*.B1[0m→]
Einops
Domain:   ([[34my[0m=[33my.45[0m [34mx[0m=[33mx.DA[0m→], [[34mx[0m=[33mx.DA[0m [34mk[0m=[33mk.F1[0m→])
Morphism: [32m[4my x, x k -> y k[0m
Object:   [[34my[0m=

In [4]:
attention = linears @ (ein1 @ softmax + '*') @ ein2 @ linOut
print("\nPrintout of Composed Expression;")
print(display_columns(attention))


Printout of Composed Expression;
Domain:   [[36m[4mx[0m [36m[4mm[0m→]
Morphism: [[36m[4mx[0m [36m[4mm[0m→Δ3]
Object:   ([[36m[4mx[0m [36m[4mm[0m→], [[36m[4mx[0m [36m[4mm[0m→], [[36m[4mx[0m [36m[4mm[0m→])
Morphism: ([[36m[4mx[0m→Lq], [[36m[4mx[0m→Lk], [[36m[4mx[0m→Lv])
Object:   ([[36m[4mx[0m [36m[4mk[0m→], [[36m[4mx[0m [36m[4mk[0m→], [[36m[4mx[0m [36m[4mk[0m→])
Morphism: ([32m[4my k, x k -> y x[0m, [[36m[4mx[0m [36m[4mk[0m→])
Object:   ([[36m[4mx[0m [36m[4mx[0m→], [[36m[4mx[0m [36m[4mk[0m→])
Morphism: ([[36m[4mx[0m→[32m[4m◁[0m], [[36m[4mx[0m [36m[4mk[0m→])
Object:   ([[36m[4mx[0m [36m[4mx[0m→], [[36m[4mx[0m [36m[4mk[0m→])
Morphism: [32m[4my x, x k -> y k[0m
Object:   [[36m[4mx[0m [36m[4mk[0m→]
Morphism: Lo
Object:   [[36m[4mm[0m→]


In [5]:
# We can use the marches package to disassemble an algebraic expression into
#   a graph, and to then compile it into code. Currently, PyTorch is supported.
import ncd.marches

# We redefine the expression using configurable axes so that the '__init__'
# function knows which configuration parameters are required.
x_conf = ncd.Conf('x')
m_conf = ncd.Conf('m')
k_conf = ncd.Conf('k')

# We use a functor which remaps objects to make them configurable.
make_configurable = ncd.DictFunctor({
    shape('x'): x_conf,
    shape('m'): m_conf,
    shape('k'): k_conf})

# See if our functor worked ie succesfully mapped set objects to
# configurable objects.
print(display_columns(make_configurable(attention)))

Domain:   [[34mx[0m=[33mx.F9[0m [34mm[0m=[33mm[0m→]
Morphism: [[34mx[0m=[33mx.F9[0m [34mm[0m=[33mm[0m→Δ3]
Object:   ([[34mx[0m=[33mx.F9[0m [34mm[0m=[33mm[0m→], [[34mx[0m=[33mx.F9[0m [34mm[0m=[33mm[0m→], [[34mx[0m=[33mx.F9[0m [34mm[0m=[33mm[0m→])
Morphism: ([[34mx[0m=[33mx.F9[0m→Lq], [[34mx[0m=[33mx.F9[0m→Lk], [[34mx[0m=[33mx.F9[0m→Lv])
Object:   ([[34mx[0m=[33mx.F9[0m [34mk[0m=[33mk.23[0m→], [[34mx[0m=[33mx.F9[0m [34mk[0m=[33mk.23[0m→], [[34mx[0m=[33mx.F9[0m [34mk[0m=[33mk.23[0m→])
Morphism: ([32m[4my k, x k -> y x[0m, [[34mx[0m=[33mx.F9[0m [34mk[0m=[33mk.23[0m→])
Object:   ([[34mx[0m=[33mx.F9[0m [34mx[0m=[33mx.F9[0m→], [[34mx[0m=[33mx.F9[0m [34mk[0m=[33mk.23[0m→])
Morphism: ([[34mx[0m=[33mx.F9[0m→[32m[4m◁[0m], [[34mx[0m=[33mx.F9[0m [34mk[0m=[33mk.23[0m→])
Object:   ([[34mx[0m=[33mx.F9[0m [34mx[0m=[33mx.F9[0m→], [[34mx[0m=[33mx.F9[0m [34mk[0m=[33mk.23[0m

In [6]:
# It did! So we can compile it, with a correct __init__ function.
# 'Multilinear' is found in ncd.torch_utilities
print(ncd.marches.to_torch(make_configurable(attention), "Attention"))

class Attention(nn.Module):
    def __init__(self, k, m, x):
        self.Lq = Multilinear((m),(k))
        self.Lk = Multilinear((m),(k))
        self.Lv = Multilinear((m),(k))
        self.Lo = Multilinear((x, k),(m))
    def forward(self, a):
        a, b, c = a, a, a
        a = Lq(a)
        b = Lk(b)
        c = Lv(c)
        a = einops.einsum(a, b, "y k, x k -> y x")
        a = torch.softmax(a, dim=-1)
        a = einops.einsum(a, c, "y x, x k -> y k")
        a = Lo(a)
        return a


## Multi-Head Dot Product Attention
We represent the more intricate multi-head dot-product attention by;

<img src="Graphics/multihead.png" width="700">

This diagram has an additional $h$ axis. The linear layers output data of size ``k h``, there is additional wiring for the Einops, and the SoftMax is lifted below. We can implement these changes using our algebraic tools.

In [7]:
from ncd import Duplicate, shape
from ncd.nn import Einops, Linear, Addition, SoftMax

# Multi-Headed Attention defined symbolically.
# We piece together individual sections.
# Upon composition, axes sizes are aligned!
L = Linear
m = shape('*m^')
x = shape('x')

linears = x >> (m @ (L('q') + L('k') + L('v')) @ '*k *h')
einops = (Einops('q k h, x k h -> q x h') + '*')
softs = ((x >> SoftMax() << '*') + '*') @ Einops('q x h, x k h -> q k h')
linout = (shape('*k *h') @ L('o') @ m)

# Note, the printout will contain tagged axes. These are configured
# upon composition.
section_names = ['Linears', 'Einops', 'SoftMax + Einops', 'Linear']
sections = [linears, einops, softs, linout]
print("Printout of Individual Sections;")
print('\n'.join(starmap(display_columns, zip(sections, section_names))))

Printout of Individual Sections;
Linears
Domain:   [[36m[4mx[0m [34mm[0m=[33mm.8E[0m→]
Morphism: [[36m[4mx[0m [34mm[0m=[33mm.8E[0m→Δ3]
Object:   ([[36m[4mx[0m [34mm[0m=[33mm.8E[0m→], [[36m[4mx[0m [34mm[0m=[33mm.8E[0m→], [[36m[4mx[0m [34mm[0m=[33mm.8E[0m→])
Morphism: ([[36m[4mx[0m→Lq], [[36m[4mx[0m→Lk], [[36m[4mx[0m→Lv])
Object:   ([[36m[4mx[0m [34mk[0m=[33mk.25[0m [34mh[0m=[33mh.3B[0m→], [[36m[4mx[0m [34mk[0m=[33mk.25[0m [34mh[0m=[33mh.3B[0m→], [[36m[4mx[0m [34mk[0m=[33mk.25[0m [34mh[0m=[33mh.3B[0m→])
Einops
Domain:   ([[34mq[0m=[33mq[0m [34mk[0m=[33mk.7C[0m [34mh[0m=[33mh.ED[0m→], [[34mx[0m=[33mx.E7[0m [34mk[0m=[33mk.7C[0m [34mh[0m=[33mh.ED[0m→], [34m[0m=[33m.D4[0m)
Morphism: ([32m[4mq k h, x k h -> q x h[0m, [34m[0m=[33m.D4[0m)
Object:   ([[34mq[0m=[33mq[0m [34mx[0m=[33mx.E7[0m [34mh[0m=[33mh.ED[0m→], [34m[0m=[33m.D4[0m)
SoftMax + Einops
Domain:   ([[36m[4

In [8]:
multihead = linears @ einops @ softs @ (shape('*k *h') @ L('o') @ m)

print(display_columns(multihead))

Domain:   [[36m[4mx[0m [34mm[0m=[33mm.8E[0m→]
Morphism: [[36m[4mx[0m [34mm[0m=[33mm.8E[0m→Δ3]
Object:   ([[36m[4mx[0m [34mm[0m=[33mm.8E[0m→], [[36m[4mx[0m [34mm[0m=[33mm.8E[0m→], [[36m[4mx[0m [34mm[0m=[33mm.8E[0m→])
Morphism: ([[36m[4mx[0m→Lq], [[36m[4mx[0m→Lk], [[36m[4mx[0m→Lv])
Object:   ([[36m[4mx[0m [34mk[0m=[33mk.E5[0m [34mh[0m=[33mh.ED[0m→], [[36m[4mx[0m [34mk[0m=[33mk.E5[0m [34mh[0m=[33mh.ED[0m→], [[36m[4mx[0m [34mk[0m=[33mk.E5[0m [34mh[0m=[33mh.ED[0m→])
Morphism: ([32m[4mq k h, x k h -> q x h[0m, [[36m[4mx[0m [34mk[0m=[33mk.E5[0m [34mh[0m=[33mh.ED[0m→])
Object:   ([[36m[4mx[0m [36m[4mx[0m [34mh[0m=[33mh.ED[0m→], [[36m[4mx[0m [34mk[0m=[33mk.E5[0m [34mh[0m=[33mh.ED[0m→])
Morphism: ([[36m[4mx[0m→[32m[4m◁[0m←[34m[0m=[33mh.ED[0m], [[36m[4mx[0m [34mk[0m=[33mk.E5[0m [34mh[0m=[33mh.ED[0m→])
Object:   ([[36m[4mx[0m [36m[4mx[0m [34m[0m=[33mh.ED[0m

In [9]:
# We can use the "GetConfig" functor to accumulate the
# unassigned variables in its internal state. This allows
# us to quickly generate configuration parameters from an
# expression.
config = ncd.GetConfig()
config(multihead)
print(config.configs)

{[33mk.E5[0m, [33mm.8E[0m, [33mh.ED[0m}


In [10]:
# Marches is a package for compiling code.
# Here, it generate code for multi-headed attention.
import ncd.marches

print(ncd.marches.to_torch(multihead, "MultiHeadAttention"))

class MultiHeadAttention(nn.Module):
    def __init__(self, k, m, h):
        self.Lq = Multilinear((m),(k, h))
        self.Lk = Multilinear((m),(k, h))
        self.Lv = Multilinear((m),(k, h))
        self.Lo = Multilinear((k, h),(m))
    def forward(self, a):
        a, b, c = a, a, a
        a = Lq(a)
        b = Lk(b)
        c = Lv(c)
        a = einops.einsum(a, b, "q k h, x k h -> q x h")
        a = torch.softmax(a, dim=-2)
        a = einops.einsum(a, c, "q x h, x k h -> q k h")
        a = Lo(a)
        return a
