In [2]:
import sys
import jax.numpy as jnp
import numpy as np
from trax import layers as tl
from trax.shapes import signature
from trax.layers import combinators as cb
from trax.layers.assert_shape import assert_shape

sys.path.insert(0, "../..")
from src.models.build import summary


# 1 Convolutions
We can easily create new layers by combining the base layers from trax. You will find everything you know from torch.
For example, a [Conv2D layer](https://github.com/google/trax/blob/master/trax/layers/convolution.py):

In [3]:
X = np.random.rand(32, 128, 128, 3)
conv = tl.Conv(filters=64, kernel_size=(3, 3), strides=(1, 1), padding="SAME")
conv.init_weights_and_state(signature(X))

yhat = conv(X)
signature(yhat)




ShapeDtype{shape:(32, 128, 128, 64), dtype:float32}

Trax in general implements the TensorFlow API, so you can always read a bit more on the excellent [TensorFlow documentation](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Conv2D)

We can extent this one-layers model into a model with three serial Conv2d layers as we have built in the convolutions lesson

In [4]:
model = cb.Serial(
    tl.Conv(filters=64, kernel_size=(3, 3), strides=(2, 2), padding="SAME"),
    tl.Relu(),
    tl.Conv(filters=128, kernel_size=(3, 3), strides=(2, 2), padding="SAME"),
    tl.Relu(),
    tl.Conv(filters=256, kernel_size=(3, 3), strides=(2, 2), padding="SAME"),
    tl.Relu(),
)
model.init_weights_and_state(X)
summary(model, X)


layer                   input                dtype     output               dtype 
(0) Conv                (32, 128, 128, 3)  (float64) | (32, 64, 64, 64)   (float32)
(1) Relu                (32, 64, 64, 64)   (float32) | (32, 64, 64, 64)   (float32)
(2) Conv                (32, 64, 64, 64)   (float32) | (32, 32, 32, 128)  (float32)
(3) Relu                (32, 32, 32, 128)  (float32) | (32, 32, 32, 128)  (float32)
(4) Conv                (32, 32, 32, 128)  (float32) | (32, 16, 16, 256)  (float32)
(5) Relu                (32, 16, 16, 256)  (float32) | (32, 16, 16, 256)  (float32)


ShapeDtype{shape:(32, 16, 16, 256), dtype:float32}

I have written a custom "summary" function for Trax models. It does not work (yet) for parallel layers, but I thought it would help some of you with building models.

# 2 RNNs
Let's implement an Embedding + GRU model, like we have built in the Attention lesson #5:

In [5]:
X = np.random.randint(0, 1000, size=(32, 100))

model = cb.Serial(
    tl.Embedding(vocab_size=1000, d_feature=128), tl.GRU(n_units=128), tl.Dense(2)
)
model.init_weights_and_state(signature(X))

summary(model, X)


layer                   input                dtype     output               dtype 
(0) Embedding_1000_128  (32, 100)          ( int64 ) | (32, 100, 128)     (float32)
(1) GRU_128             (32, 100, 128)     (float32) | (32, 100, 128)     (float32)
(2) Dense_2             (32, 100, 128)     (float32) | (32, 100, 2)       (float32)


ShapeDtype{shape:(32, 100, 2), dtype:float32}

We can wrap this in a function, so we can pass the vocab_size and the units as parameters

In [6]:
@assert_shape("bs->bsd")
def EmbGRU(vocab_size: int, d_feature: int, d_out: int):

    model = cb.Serial(
        tl.Embedding(vocab_size=vocab_size, d_feature=d_feature),
        tl.GRU(n_units=d_feature),
        tl.Dense(d_out),
    )
    return model


## 3 The assert_shape decorator
Note how we use an `@assert_shape` decorator. This is a very nice safety check.

From the [trax source code](https://github.com/google/trax/blob/master/trax/layers/assert_shape.py#L70):

```python
  Examples:
  # In Dense layer there is a single input and single output; the last dimension
  # may change in size, while the sizes of all previous dimensions, marked by
  # an ellipsis, will stay the same.
  @assert_shape('...a->...b')
  class Dense(base.Layer):
    (...)

  # DotProductCausalAttention takes three tensors as input: Queries, Keys, and
  # Values, and outputs a single tensor. Sizes of the first two dimensions in
  # all those tensors must match, while the last dimension must match only
  # between Queries and Keys, and separately between Values and output tensor.
  @assert_shape('blk,blk,bld->bld')
  class DotProductCausalAttention(base.Layer):
    (...)

  # assert_shape can also be placed before the function returning base.Layer.
  @assert_shape('...d->...')
  def ReduceSum():
    return Fn('ReduceSum', lambda x: jnp.sum(x, axis=-1, keepdims=False))
```

Our EmbGRU models expects as input a Tensor with dimensions (batch, sequencelenght) denoted with `bs`.
The output should be (batch, sequencelenght, dimension) denoted with `bsd`.
`@assert_shape` will check for us if this is correct.

We can use any letter we like. Trax will simple check that the input has two dimensions, and that the output has exactly the same dimensions plus an additional third dimension.

So this works:


In [7]:
model = EmbGRU(vocab_size=1000, d_feature=128, d_out=2)
model.init_weights_and_state(signature(X))
summary(model, X)


layer                   input                dtype     output               dtype 
(0) Embedding_1000_128  (32, 100)          ( int64 ) | (32, 100, 128)     (float32)
(1) GRU_128             (32, 100, 128)     (float32) | (32, 100, 128)     (float32)
(2) Dense_2             (32, 100, 128)     (float32) | (32, 100, 2)       (float32)


ShapeDtype{shape:(32, 100, 2), dtype:float32}

But this fails! Can you see why?

In [8]:
X_ = np.random.randint(0, 1000, size=(32, 10, 3))
model = EmbGRU(vocab_size=1000, d_feature=128, d_out=2)
try:
    model.init_weights_and_state(X_)
except Exception as e:
    print(e)


Exception passing through layer  (in _forward_abstract):
  layer created in file [...]/T/ipykernel_75043/461658037.py, line 2
  layer input shapes: [[[171 621 716]
  [ 46 744 995]
  [990 738 611]
  [838 528 468]
  [325 312 770]
  [928 973 349]
  [221 892 235]
  [893 932 221]
  [844 736 905]
  [773 715 248]]

 [[570 658 219]
  [946 555 602]
  [745 718  57]
  [918 716 864]
  [ 62 979 781]
  [ 69  12 673]
  [792 740  14]
  [408 840 683]
  [358 257 464]
  [ 55 982 493]]

 [[610 526 319]
  [986   0 711]
  [ 79 652 677]
  [546 138 792]
  [475 720  60]
  [347 975 751]
  [774 230 994]
  [828 145 111]
  [177  54 380]
  [325  42 442]]

 [[169 387 481]
  [705 277  83]
  [160 635 304]
  [957 227 121]
  [ 99 505   4]
  [939 861  12]
  [775 309 361]
  [342 816 368]
  [934 245 205]
  [951 361 475]]

 [[613  55 568]
  [  0 690 950]
  [726 335 100]
  [440 855 475]
  [884 312 258]
  [847 873 255]
  [829 586 699]
  [149  69 637]
  [734 326 819]
  [739 772 270]]

 [[297 938 358]
  [963 213 392]
  [174 234

# 4 Implementing your own layers
If we would like to implement our own layers, we could do that too.

Again, from the [source code of trax](https://github.com/google/trax/blob/master/trax/layers/base.py#L747):

```python
def Fn(name, f, n_out=1):  # pylint: disable=invalid-name
  """Returns a layer with no weights that applies the function `f`.

  `f` can take and return any number of arguments, and takes only positional
  arguments -- no default or keyword arguments. It often uses JAX-numpy (`jnp`).
  The following, for example, would create a layer that takes two inputs and
  returns two outputs -- element-wise sums and maxima:

      `Fn('SumAndMax', lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)), n_out=2)`

  The layer's number of inputs (`n_in`) is automatically set to number of
  positional arguments in `f`, but you must explicitly set the number of
  outputs (`n_out`) whenever it's not the default value 1.

  Args:
    name: Class-like name for the resulting layer; for use in debugging.
    f: Pure function from input tensors to output tensors, where each input
        tensor is a separate positional arg, e.g., `f(x0, x1) --> x0 + x1`.
        Output tensors must be packaged as specified in the `Layer` class
        docstring.
    n_out: Number of outputs promised by the layer; default value 1.

  Returns:
    Layer executing the function `f`.
  """
  ```

  Let's implement our own Hadamard product.
  You might remember from the RNN lesson, that we used this for the gates:

$$h_t = \Gamma \otimes h_{t-1}$$

where $\Gamma$ is the gate (which has the same size as $h$ and values between 0 and 1), $h$ is the hidden state, and $\otimes$ the Hadamard product. For example:

$$
\begin{bmatrix}
0.9 & 0.04 \\
0.25 & -0.48
\end{bmatrix}
=
\begin{bmatrix}
0.9 & 0.01 \\
0.5 & 0.2
\end{bmatrix}
\otimes
\begin{bmatrix}
1.0 & 2.0 \\
0.5 & -2.4
\end{bmatrix}
$$

In [9]:
# input
X = np.random.rand(32, 20)

# the gate is created with a softmax
softmax = tl.Softmax(axis=-1)
gate = softmax(X)

# the gate is applied with a hadamard prodcut, implemented in numpy/jaxnumpy with `.multiply`
out = jnp.multiply(X, gate)
out.shape


(32, 20)

The `jnp` function is kind of simple. But we want to make sure this is part of the backpropagation of our model.
We make it into a proper layer with `tl.Fn`

In [10]:
def Hadamard():
    def f(x0, x1):
        return jnp.multiply(x0, x1)

    return tl.Fn("Hadamard", f, n_out=1)


Now we can use is as part of a model.
For example, let's implement the Gated Linear Unit:
$$GLU(X) = \sigma(W_1 X + b_1) \otimes (W_2 X + b_2)$$

Let's break this down into smaller pieces:
the gate is created with:
$$gate(X) = \sigma(W_1 X + b_1)$$

and the input is branched into both the gate and a basic linear layer, then combined with a hadamard product.
$$gate(X) \otimes Linear(X)$$

<img src="../../reports/figures/GLU.png" />

In [11]:
gate = cb.Serial(tl.Dense(128), tl.Softmax(axis=-1))

model = cb.Serial(
    cb.Branch(gate, tl.Dense(128)),
    Hadamard(),
)
model.init_weights_and_state(signature(X))
yhat = model(X)
signature(yhat)


ShapeDtype{shape:(32, 128), dtype:float32}

Or, wrap it inside a function.

In [12]:
@assert_shape("bd->bd")
def GLU(units: int):
    gate = cb.Serial(tl.Dense(units), tl.Softmax(axis=-1))

    model = cb.Serial(
        cb.Branch(gate, tl.Dense(units)),
        Hadamard(),
    )
    return model


I hope this notebook convinces you that:

- Implementing PyTorch models we have built before in Trax is not that hard
- Adding custom layers is pretty simple, once you figured out how to implement it in numpy/jaxnumpy. It helps that you will find numerous examples of numpy on the internet
- Trax is really excellent for implementing parallel architectures, and it will save you a lot of code complexity
- Building complex architectures can be done by recombining smaller units, a bit like building with LEGO.
