# Introduction

In [2]:
from collections import OrderedDict
import functools

import math
import torch
from torch.distributions import constraints

%env FUNSOR_TYPECHECK=1
import funsor
from funsor.terms import Funsor, Variable, Number, Lambda, Slice
from funsor.tensor import Tensor
from funsor.domains import Array, Bint, Real, Reals
from funsor.factory import Bound, Fresh, Has, Value, make_funsor, to_funsor
import funsor.ops as ops
from funsor.cnf import Contraction
from funsor.testing import random_tensor
from funsor.interpretations import reflect, memoize
import funsor.torch.distributions as dist

funsor.set_backend("torch")
torch.set_default_dtype(torch.float32)

env: FUNSOR_TYPECHECK=1


# Examples

## Building blocks

In [32]:
class Layer:
    def __init__(self) -> None:
        pass
    
    def forward(self, x: Tensor) -> Tensor:
        raise NotImplementedError
        
    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

### Feedforward neural networks

\begin{aligned}
  X^0 &\in \mathbb{R}^{\mathsf{\vphantom{fg}input}} \\
  X^1 &= \sigma(W^1 \mathbin{\underset{\substack{\mathsf{\vphantom{fg}input}}}{\vphantom{fg}\odot}} X^0 + b^1) & W^1 &\in \mathbb{R}^{\mathsf{\vphantom{fg}hidden}_1 \times \mathsf{\vphantom{fg}input}} & b^1 &\in \mathbb{R}^{\mathsf{\vphantom{fg}hidden}_1} \\
  X^2 &= \sigma(W^2 \mathbin{\underset{\substack{\mathsf{\vphantom{fg}hidden}_1}}{\vphantom{fg}\odot}} X^1 + b^2) & W^2 &\in \mathbb{R}^{\mathsf{\vphantom{fg}hidden}_2 \times \mathsf{\vphantom{fg}hidden}_1} & b^2 &\in \mathbb{R}^{\mathsf{\vphantom{fg}hidden}_2} \\
  X^3 &= \sigma(W^3 \mathbin{\underset{\substack{\mathsf{\vphantom{fg}hidden}_2}}{\vphantom{fg}\odot}} X^2 + b^3) & W^3 &\in \mathbb{R}^{\mathsf{\vphantom{fg}out}\times \mathsf{\vphantom{fg}hidden}_2} & b^3 &\in \mathbb{R}^{\mathsf{\vphantom{fg}out}}
\end{aligned}

In [29]:
class FullConnLayer(Layer):
    def __init__(self, input_size: int, output_size: int) -> None:
        self.W = random_tensor(
            OrderedDict([
                ("input", Bint[input_size]),
                ("output", Bint[output_size])
            ])
        )
        self.W.data.requires_grad = True
        
        self.b = random_tensor(
            OrderedDict([("output", Bint[output_size])])
        )
        self.b.data.requires_grad = True
        
    def forward(self, x: Funsor) -> Funsor:
        out = ops.sigmoid((self.W * x).reduce(ops.add, "input") + self.b)
        return out(**{"output": "input"})

In [30]:
FullConn1 = FullConnLayer(100, 32)
FullConn2 = FullConnLayer(32, 16)
FullConn3 = FullConnLayer(16, 8)

In [31]:
X0 = random_tensor(OrderedDict([("input", Bint[100])]))
X1 = FullConn1(X0)
X2 = FullConn2(X1)
X3 = FullConn3(X2)
X3

Tensor(tensor([0.3556, 0.5652, 0.9982, 0.3874, 0.9996, 0.7776, 0.8879, 0.9356],
       grad_fn=<SigmoidBackward>), OrderedDict([('input', Bint[8, ])]), 'real')

### Recurrent neural networks

\begin{aligned}
x^{t} &\in \mathbb{R}^{\mathsf{\vphantom{fg}input}} & t &= 1, \ldots, n \\
W^{\text{h}} &\in \mathbb{R}^{\mathsf{\vphantom{fg}hidden}\times \mathsf{\vphantom{fg}hidden}^\prime} & |\mathsf{\vphantom{fg}hidden}| &= |\mathsf{\vphantom{fg}hidden}^\prime| \\
W^{\text{i}} &\in \mathbb{R}^{\mathsf{\vphantom{fg}input}\times \mathsf{\vphantom{fg}hidden}^\prime} \\
b &\in \mathbb{R}^{\mathsf{\vphantom{fg}hidden}^\prime} \\
h^{0} &\in \mathbb{R}^{\mathsf{\vphantom{fg}hidden}} \\
h^{t} &= \sigma\left( W^{\text{h}} \mathbin{\underset{\substack{\mathsf{\vphantom{fg}hidden}}}{\vphantom{fg}\odot}} h^{t-1} + W^{\text{i}} \mathbin{\underset{\substack{\mathsf{\vphantom{fg}input}}}{\vphantom{fg}\odot}} x^{t} + b \right)_{\mathsf{\vphantom{fg}hidden}^\prime\rightarrow\mathsf{\vphantom{fg}hidden}} & t &= 1, \ldots, n
\end{aligned}

In [33]:
class RecurrentLayer(Layer):
    def __init__(self, input_size: int, hidden_size: int) -> None:
        self.Wh = random_tensor(
            OrderedDict([
                ("hidden", Bint[hidden_size]),
                ("hidden2", Bint[hidden_size])
            ])
        )
        self.Wh.data.requires_grad = True
        
        self.Wi = random_tensor(
            OrderedDict([
                ("input", Bint[input_size]),
                ("hidden2", Bint[hidden_size])
            ])
        )
        self.Wi.data.requires_grad = True
        
        self.b = random_tensor(
            OrderedDict([("hidden2", Bint[hidden_size])])
        )
        self.b.data.requires_grad = True
        
    def forward(self, x: Funsor, h: Funsor) -> Funsor:
        out = ops.sigmoid(
            (self.Wh * h).reduce(ops.add, "hidden") + (self.Wi * x).reduce(ops.add, "input") + b
        )
        return out(**{"hidden2": "hidden"})

### Attention

In [34]:
@make_funsor
def Softmax(
    x: Funsor,
    ax: Bound,
    ax2: Fresh[lambda ax: ax]
) -> Fresh[lambda x: x]:
    x = x(**{ax.name: ax2.name})
    y = x - x.reduce(ops.logaddexp, ax2)
    return y.exp()

\begin{aligned}
  \text{Attention} \colon \mathbb{R}^{\mathsf{\vphantom{fg}key}} \times \mathbb{R}^{\mathsf{\vphantom{fg}seq}\times\mathsf{\vphantom{fg}key}} \times \mathbb{R}^{\mathsf{\vphantom{fg}seq}\times\mathsf{\vphantom{fg}val}} \times \mathbb{R}^{\mathsf{\vphantom{fg}seq}} &\rightarrow \mathbb{R}^{\mathsf{\vphantom{fg}val}} \\
\text{Attention}(Q, K, V, M) &= \mathop{\underset{\substack{\mathsf{\vphantom{fg}seq}}}{\vphantom{fg}\mathrm{softmax}}} \left( \frac{Q \mathbin{\underset{\substack{\mathsf{\vphantom{fg}key}}}{\vphantom{fg}\odot}} K}{\sqrt{|\mathsf{\vphantom{fg}key}|}} + M \right) \mathbin{\underset{\substack{\mathsf{\vphantom{fg}seq}}}{\vphantom{fg}\odot}} V.
\end{aligned}

In [35]:
@make_funsor
def Attention(
    Q: Has[{"key"}],
    K: Has[{"key", "seq"}],
    V: Has[{"seq2"}],
    M: Has[{"seq"}],
    key: Bound,
    seq: Bound,
    seq2: Bound
) -> Fresh[lambda Q: Q]:
    x = (Q * K).reduce(ops.add, key) / math.sqrt(key.output.size) + M
    return (Softmax(x, seq, seq2) * V).reduce(ops.add, seq2)

In [36]:
q = random_tensor(OrderedDict([("key", Bint[10])]))
k = random_tensor(OrderedDict([("key", Bint[10]), ("seq", Bint[3])]))
v = random_tensor(OrderedDict([("seq2", Bint[3]), ("val", Bint[5])]))
m = random_tensor(OrderedDict([("seq", Bint[3])]))
Attention(q, k, v, m, "key", "seq", "seq2")

Tensor(tensor([-0.2815,  1.1418, -0.1922, -1.3501, -1.0335]), OrderedDict([('val', Bint[5, ])]), 'real')

### Convolution

\begin{aligned}
  \mathop{\underset{\substack{\mathsf{\vphantom{fg}seq}\\ \mathsf{\vphantom{fg}kernel}}}{\vphantom{fg}\mathrm{unroll}}} \colon \mathbb{R}^{\mathsf{\vphantom{fg}seq}[n]} &\rightarrow \mathbb{R}^{\mathsf{\vphantom{fg}seq}[n-|\mathsf{\vphantom{fg}kernel}|+1], \mathsf{\vphantom{fg}kernel}} \\
  \mathop{\underset{\substack{\mathsf{\vphantom{fg}seq}\\ \mathsf{\vphantom{fg}kernel}}}{\vphantom{fg}\mathrm{unroll}}} X &= Y,\ \text{where} \\
  Y_{\mathsf{\vphantom{fg}seq}(i), \mathsf{\vphantom{fg}kernel}(j)} &= X_{\mathsf{\vphantom{fg}seq}(i+j - 1)}.
\end{aligned}

In [28]:
@make_funsor
def Unroll(
    x: Has[{"seq"}],
    seq: Bound,
    k: Value[int],
    kernel: Fresh[lambda k: Bint[k]],
    seq2: Fresh[lambda seq, k: Bint[seq.size - k + 1]]
) -> Fresh[lambda x: x]:
    return x(**{seq.name: seq2 + kernel})

\begin{aligned}
\text{Conv1d} \colon \mathbb{R}^{\mathsf{\vphantom{fg}chans}\times \mathsf{\vphantom{fg}seq}[n]} &\rightarrow \mathbb{R}^{\mathsf{\vphantom{fg}seq}[n^\prime]} \\
\text{Conv1d}(X; W, b) &= W \mathbin{\underset{\substack{\mathsf{\vphantom{fg}chans}\\ \mathsf{\vphantom{fg}kernel}}}{\vphantom{fg}\odot}} \mathop{\underset{\substack{\mathsf{\vphantom{fg}seq}\\ \mathsf{\vphantom{fg}kernel}}}{\vphantom{fg}\mathrm{unroll}}} X + b
\end{aligned}

\begin{aligned}
W &\in \mathbb{R}^{\mathsf{\vphantom{fg}chans}\times \mathsf{\vphantom{fg}kernel}} \\
b &\in \mathbb{R}\\
\end{aligned}

In [38]:
class Conv1d(Layer):
    def __init__(self, in_channels: int, out_channels: int, kernel_size: int) -> None:
        self.W = random_tensor(
            OrderedDict([
                ("chans", Bint[in_channels]),
                ("chans2", Bint[out_channels]),
                ("kernel", Bint[kernel_size])
            ])
        )
        self.W.data.requires_grad = True
        
        self.b = random_tensor(
            OrderedDict([
                ("chans2", Bint[out_channels])
            ])
        )
        self.b.data.requires_grad = True
        
        self.kernel_size = kernel_size
        
    def forward(self, x: Funsor) -> Funsor:
        unrolled_x = Unroll(x, "seq", self.kernel_size, "kernel", "seq2")
        out = (self.W * unrolled_x).reduce(ops.add, {"chans", "kernel"}) + self.b
        return out(**{"chans2": "chans", "seq2": "seq"})

In [39]:
class Conv2d(Layer):
    def __init__(self, in_channels: int, out_channels: int, kh_size: int, kw_size: int) -> None:
        self.W = random_tensor(
            OrderedDict([
                ("chans", Bint[in_channels]),
                ("chans2", Bint[out_channels]),
                ("kh", Bint[kh_size]),
                ("kw", Bint[Kw_size])
            ])
        )
        self.W.data.requires_grad = True
        
        self.b = random_tensor(
            OrderedDict([
                ("chans2", Bint[out_channels])
            ])
        )
        self.b.data.requires_grad = True
        
        self.kh_size = kh_size
        self.kw_size = kw_size
        
    def forward(self, x: Funsor) -> Funsor:
        unrolled_w_x = Unroll(x, "width", self.kw_size, "kw", "width2")
        unrolled_hw_x = Unroll(unrolled_w_x, "height", self.kh_size, "kh", "height2")
        out = (self.W * unrolled_hw_x).reduce(ops.add, {"chans", "kh", "kw"}) + self.b
        return out(**{"chans2": "chans", "height2": "height", "width2": "width"})

### Max pooling

In [40]:
@make_funsor
def Pool(
    x: Has[{"seq"}],
    seq: Bound,
    k: Value[int],
    kernel: Fresh[lambda k: Bint[k]],
    seq2: Fresh[lambda seq, k: Bint[seq.size // k]], # seq -> Bint[]
) -> Fresh[lambda x: x]: # x -> x.output (Bint[] or Real)
    assert not seq.output.size % k
    return x(**{seq.name: seq2 * Number(k, k+1) + kernel})

In [41]:
X = random_tensor(OrderedDict([("seq", Bint[10])]))
Y = Pool(X, "seq", 2, "kernel", "seq2")
Y

Tensor(tensor([[ 1.5662,  0.0141],
        [-0.2258,  0.8933],
        [ 0.0305, -0.3126],
        [-0.4944, -2.0142],
        [-0.3566,  0.6063]]), OrderedDict([('seq2', Bint[5, ]), ('kernel', Bint[2, ])]), 'real')

In [42]:
@make_funsor
def MaxPool1d(
    X: Has[{"seq"}],
    seq: Bound,
    k: Value[int],
    kernel: Fresh[lambda k: Bint[k]],
    seq2: Fresh[lambda seq, k: Bint[seq.size // k]]
) -> Fresh[lambda X: X]:
    return Pool(X, seq, k, kernel, seq2).reduce(ops.max, kernel)

In [43]:
X = random_tensor(OrderedDict([("seq", Bint[10])]))
Y = MaxPool1d(X, "seq", 2, "kernel", "seq2")
Y

Tensor(tensor([ 0.8018, -0.1132,  0.6756,  0.9458,  0.3273]), OrderedDict([('seq2', Bint[5, ])]), 'real')

In [44]:
@make_funsor
def MaxPool2d(
    X: Has[{"height", "width"}],
    height: Bound,
    kh_size: Value[int],
    kh: Fresh[lambda kh_size: Bint[kh_size]],
    height2: Fresh[lambda height, kh_size: Bint[height.size // kh_size]],
    width: Bound,
    kw_size: Value[int],
    kw: Fresh[lambda kw_size: Bint[kw_size]],
    width2: Fresh[lambda width, kw_size: Bint[width.size // kw_size]],
) -> Fresh[lambda X: X]:
    y = Pool(Pool(X, height, kh_size, kh, height2), width, kw_size, kw, width2)
    return y.reduce(ops.max, frozenset({kh, kw}))

In [45]:
X = random_tensor(OrderedDict([("width", Bint[9]), ("height", Bint[4])]))
Y = MaxPool2d(X, "height", 2, "kh", "height2", "width", 3, "kw", "width2")
Y

Tensor(tensor([[0.2411, 0.4493],
        [1.9085, 1.4201],
        [1.8092, 1.9324]]), OrderedDict([('width2', Bint[3, ]), ('height2', Bint[2, ])]), 'real')

### Normalization layers

In [53]:
# version 1
@make_funsor
def Mean(
    X: Has[{"ax"}],
    ax: Bound
) -> Fresh[lambda X: X]:
    return ops.mean(funsor.terms.Lambda(ax, X), 0)

@make_funsor
def Variance(
    X: Has[{"ax"}],
    ax: Bound
) -> Fresh[lambda X: X]:
    return ops.var(funsor.terms.Lambda(ax, X), 0)

In [52]:
# version 2
@make_funsor
def Mean(
    X: Has[{"ax"}],
    ax: Bound
) -> Fresh[lambda X: X]:
    return X.reduce(ops.add, ax) / ax.output.size

@make_funsor
def Mean2(
    X: Has[{"ax", "ax2"}],
    ax: Bound,
    ax2: Bound
) -> Fresh[lambda X: X]:
    return X.reduce(ops.add, frozenset({ax, ax2})) / (ax.output.size * ax2.output.size)

@make_funsor
def Variance(
    X: Has[{"ax"}],
    ax: Bound
) -> Fresh[lambda X: X]:
    return Mean((X - Mean(X, ax))**2, ax)


@make_funsor
def Variance2(
    X: Has[{"ax", "ax2"}],
    ax: Bound,
    ax2: Bound
) -> Fresh[lambda X: X]:
    return Mean2((X - Mean2(X, ax, ax2))**2, ax, ax2)

In [53]:
@make_funsor
def Standardize(
    X: Has[{"ax"}],
    ax: Bound,
    new_ax: Fresh[lambda ax: ax]
) -> Fresh[lambda X: X]:
    y = X(**{ax.name: new_ax})
    return (y - Mean(X, ax)) / (Variance(X, ax) + ops.finfo(X.data).eps).sqrt()

@make_funsor
def Standardize2(
    X: Has[{"ax", "ax2"}],
    ax: Bound,
    ax2: Bound,
    new_ax: Fresh[lambda ax: ax],
    new_ax2: Fresh[lambda ax2: ax2]
) -> Fresh[lambda X: X]:
    y = X(**{ax.name: new_ax, ax2.name: new_ax2})
    return (y - Mean2(X, ax, ax2)) / (Variance2(X, ax, ax2) + ops.finfo(X.data).eps).sqrt()

In [54]:
class BatchNorm(Layer):
    def __init__(self, num_channels: int) -> None:
        self.beta = random_tensor(
            OrderedDict([
                ("chans", Bint[num_channels])
            ])
        )
        self.beta.data.requires_grad = True
        
        self.gamma = random_tensor(
            OrderedDict([
                ("chans", Bint[num_channels])
            ])
        )
        self.gamma.data.requires_grad = True
        
    def forward(self, x: Funsor) -> Funsor:
        out = Standardize2(x, "batch", "layer", "batch2", "layer2") * self.gamma + self.beta
        return out(**{"batch2": "batch", "layer2": "layer"})

In [55]:
class InstanceNorm(Layer):
    def __init__(self, num_channels: int) -> None:
        self.beta = random_tensor(
            OrderedDict([
                ("chans", Bint[num_channels])
            ])
        )
        self.beta.data.requires_grad = True
        
        self.gamma = random_tensor(
            OrderedDict([
                ("chans", Bint[num_channels])
            ])
        )
        self.gamma.data.requires_grad = True
        
    def forward(self, x: Funsor) -> Funsor:
        out = Standardize(x, "layer", "layer2") * self.gamma + self.beta
        return out(**{"layer2": "layer"})

In [56]:
class LayerNorm(Layer):
    def __init__(self, num_channels: int, num_layers: int) -> None:
        self.beta = random_tensor(
            OrderedDict([
                ("chans", Bint[num_channels]),
                ("layer", Bint[num_layers])
            ])
        )
        self.beta.data.requires_grad = True
        
        self.gamma = random_tensor(
            OrderedDict([
                ("chans", Bint[num_channels]),
                ("layer", Bint[num_layers])
            ])
        )
        self.gamma.data.requires_grad = True
        
    def forward(self, x: Funsor) -> Funsor:
        out = Standardize2(x, "chans", "layer", "chans2", "layer2") * self.gamma + self.beta
        return out(**{"chans2": "chans", "layer2": "layer"})

In [57]:
x = random_tensor(OrderedDict([("batch", Bint[4]), ("chans", Bint[3]), ("layer", Bint[5])]))

BatchNorm(3)(x)

Tensor(tensor([[[-1.2064, -0.5113, -0.8763, -1.1434, -1.0650],
         [-0.7746, -1.4747,  0.1591,  0.2064,  1.0186],
         [ 1.2582,  1.2571,  1.2567,  1.2565,  1.2605]],

        [[-0.7646, -0.7682, -1.0567, -0.5184, -1.1284],
         [-0.9481, -0.8563, -0.9254, -0.9457, -1.7601],
         [ 1.2558,  1.2576,  1.2577,  1.2588,  1.2579]],

        [[-1.4398, -0.4638, -0.6619, -0.4812, -0.9846],
         [ 0.2047, -1.2126,  0.2145,  0.3439, -2.3063],
         [ 1.2588,  1.2566,  1.2576,  1.2594,  1.2609]],

        [[-0.6510, -0.7805, -0.6757, -0.4585, -1.0590],
         [-0.5686,  1.1107, -0.9664, -0.7939, -0.1247],
         [ 1.2577,  1.2584,  1.2576,  1.2582,  1.2606]]],
       grad_fn=<AddBackward0>), OrderedDict([('batch', Bint[4, ]), ('chans', Bint[3, ]), ('layer', Bint[5, ])]), 'real')

## Transformer

## LeNet

In [79]:
@make_funsor
def Relu(
    X: Funsor
) -> Fresh[lambda X: X]:
    return ops.max(X, Number(0.0))

In [80]:
W1 = random_tensor(
    OrderedDict([
        ("chans", Bint[3]),
        ("kh", Bint[3]),
        ("kw", Bint[4]),
        ("chans2", Bint[3])
    ]),
)
b1 = random_tensor(OrderedDict([("chans2", Bint[3])]))
W3 = random_tensor(
    OrderedDict([
        ("hidden", Bint[3]),
        ("height3", Bint[4]),
        ("width3", Bint[4]),
        ("chans2", Bint[3])
    ]),
)
b3 = random_tensor(OrderedDict([("hidden", Bint[3])]))
W4 = random_tensor(
    OrderedDict([
        ("hidden", Bint[3]),
        ("classes", Bint[5]),
    ]),
)
b4 = random_tensor(OrderedDict([("classes", Bint[5])]))

X0 = random_tensor(
    OrderedDict([
        ("batch", Bint[4]),
        ("chans", Bint[3]),
        ("height", Bint[14]),
        ("width", Bint[15])
    ])
)

T1 = Relu(
    Conv2d(X0, W1, b1, "chans", 3, "kh", "height", "height2", 4, "kw", "width", "width2")
)
X1 = MaxPool2d(T1, "height2", 3, "kh", "height3", "width2", 3, "kw", "width3")
X3 = (W3 * X1).reduce(ops.add, frozenset({"height3", "width3", "chans2"})) + b3
O = Softmax(((W4 * X3).reduce(ops.add, "hidden") + b4), "classes", "classes2")