In [2]:
import sys
import jax.numpy as jnp
import numpy as np
from trax import layers as tl
from trax.shapes import signature
from trax.layers import combinators as cb
from trax.layers.assert_shape import assert_shape

sys.path.insert(0, "../..")
from src.models.build import summary


# 1 Convolutions
We can easily create new layers by combining the base layers from trax. You will find everything you know from torch.
For example, a Conv2D layer:

In [3]:
X = np.random.rand(32, 256, 256, 3)
conv = tl.Conv(filters=64, kernel_size=(3, 3), strides=(1, 1), padding="SAME")
conv.init_weights_and_state(signature(X))

yhat = conv(X)
signature(yhat)



ShapeDtype{shape:(32, 256, 256, 64), dtype:float32}

Or a model with three serial Conv2d layers as we have built in the convolutions lesson

In [5]:
model = cb.Serial(
    tl.Conv(filters=64, kernel_size=(3,3), strides=(2,2), padding="SAME"),
    tl.Relu(),
    tl.Conv(filters=128, kernel_size=(3,3), strides=(2,2), padding="SAME"),
    tl.Relu(),
    tl.Conv(filters=256, kernel_size=(3,3), strides=(2,2), padding="SAME"),
    tl.Relu(),
)
model.init_weights_and_state(X)
summary(model, X)

layer                   input                dtype     output               dtype 
(0) Conv                (32, 256, 256, 3)  (float64) | (32, 128, 128, 64) (float32)
(1) Relu                (32, 128, 128, 64) (float32) | (32, 128, 128, 64) (float32)
(2) Conv                (32, 128, 128, 64) (float32) | (32, 64, 64, 128)  (float32)
(3) Relu                (32, 64, 64, 128)  (float32) | (32, 64, 64, 128)  (float32)
(4) Conv                (32, 64, 64, 128)  (float32) | (32, 32, 32, 256)  (float32)
(5) Relu                (32, 32, 32, 256)  (float32) | (32, 32, 32, 256)  (float32)


ShapeDtype{shape:(32, 32, 32, 256), dtype:float32}

# 2 RNNs
Let's implement an Embedding + GRU + Linear model, like we built in the Attention lesson:

In [14]:
X = np.random.randint(0, 1000, size=(32, 100))

model = cb.Serial(
    tl.Embedding(vocab_size=1000, d_feature=128),
    tl.GRU(n_units=128),
    tl.Dense(2)
)
model.init_weights_and_state(signature(X))

summary(model, X)

layer                   input                dtype     output               dtype 
(0) Embedding_1000_128  (32, 100)          ( int64 ) | (32, 100, 128)     (float32)
(1) GRU_128             (32, 100, 128)     (float32) | (32, 100, 128)     (float32)
(2) Dense_2             (32, 100, 128)     (float32) | (32, 100, 2)       (float32)


ShapeDtype{shape:(32, 100, 2), dtype:float32}

We can wrap this in a function, so we can pass the vocab_size and the units as parameters

In [15]:
@assert_shape('bs->bsd')
def EmbGRU(vocab_size: int, d_feature: int, d_out: int):

    model = cb.Serial(
        tl.Embedding(vocab_size=vocab_size, d_feature=d_feature),
        tl.GRU(n_units=d_feature),
        tl.Dense(d_out)
    )
    return model


## 3 The assert_shape decorator
Note how we use an `@assert_shape` decorator. This is a very nice safety check.

From the [trax source code](https://github.com/google/trax/blob/master/trax/layers/assert_shape.py#L70):

```python
  Examples:
  # In Dense layer there is a single input and single output; the last dimension
  # may change in size, while the sizes of all previous dimensions, marked by
  # an ellipsis, will stay the same.
  @assert_shape('...a->...b')
  class Dense(base.Layer):
    (...)

  # DotProductCausalAttention takes three tensors as input: Queries, Keys, and
  # Values, and outputs a single tensor. Sizes of the first two dimensions in
  # all those tensors must match, while the last dimension must match only
  # between Queries and Keys, and separately between Values and output tensor.
  @assert_shape('blk,blk,bld->bld')
  class DotProductCausalAttention(base.Layer):
    (...)

  # assert_shape can also be placed before the function returning base.Layer.
  @assert_shape('...d->...')
  def ReduceSum():
    return Fn('ReduceSum', lambda x: jnp.sum(x, axis=-1, keepdims=False))
```

Our EmbGRU models expects as input a Tensor with dimensions (batch, sequencelenght) denoted with `bs`.
The output should be (batch, sequencelenght, dimension) denoted with `bsd`.
`@assert_shape` will check for us if this is correct.

So this works:


In [16]:
model = EmbGRU(vocab_size=1000, d_feature=128, d_out=2)
model.init_weights_and_state(signature(X))
summary(model, X)

layer                   input                dtype     output               dtype 
(0) Embedding_1000_128  (32, 100)          ( int64 ) | (32, 100, 128)     (float32)
(1) GRU_128             (32, 100, 128)     (float32) | (32, 100, 128)     (float32)
(2) Dense_2             (32, 100, 128)     (float32) | (32, 100, 2)       (float32)


ShapeDtype{shape:(32, 100, 2), dtype:float32}

But this fails! Can you see why?

In [9]:
X = np.random.randint(0, 1000, size=(32, 10, 3))
model = EmbGRU(vocab_size=1000, d_feature=128)
try:
    model.init_weights_and_state(X)
except Exception as e:
    print(e)

Exception passing through layer  (in _forward_abstract):
  layer created in file [...]/T/ipykernel_95071/1363168418.py, line 2
  layer input shapes: [[[691 843 671]
  [  3 738 178]
  [392 605 432]
  [346 714 483]
  [556 664 928]
  [366  49 579]
  [493 563 993]
  [736 418 599]
  [ 81 100  50]
  [488 421 448]]

 [[523 889 227]
  [528 514  46]
  [495 214 271]
  [457 234 916]
  [389 156 881]
  [464 524 231]
  [857 703  71]
  [  4 115 815]
  [830 340 192]
  [774 585 435]]

 [[784 933 481]
  [182 682 769]
  [530 392 896]
  [448 255 285]
  [907 120 907]
  [146 613 319]
  [457 378  41]
  [332 144 795]
  [ 16  71 878]
  [503 982 823]]

 [[215 908 415]
  [692  47 110]
  [250 769 275]
  [360  16 364]
  [304 384  74]
  [523 214 958]
  [998 849 322]
  [116 126  82]
  [ 55 739  66]
  [135 313 447]]

 [[607 343 187]
  [263 368 619]
  [382 152 872]
  [553 871 644]
  [699 366 249]
  [ 88 296 230]
  [178 401 501]
  [938 616  46]
  [382 937 444]
  [458 673 998]]

 [[907 180 546]
  [559 557 995]
  [293 22

# 4 Implementing your own layers
If we would like to implement our own layers, we could do that too.

Again, from the [source code of trax](https://github.com/google/trax/blob/master/trax/layers/base.py#L747):

```python
def Fn(name, f, n_out=1):  # pylint: disable=invalid-name
  """Returns a layer with no weights that applies the function `f`.

  `f` can take and return any number of arguments, and takes only positional
  arguments -- no default or keyword arguments. It often uses JAX-numpy (`jnp`).
  The following, for example, would create a layer that takes two inputs and
  returns two outputs -- element-wise sums and maxima:

      `Fn('SumAndMax', lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)), n_out=2)`

  The layer's number of inputs (`n_in`) is automatically set to number of
  positional arguments in `f`, but you must explicitly set the number of
  outputs (`n_out`) whenever it's not the default value 1.

  Args:
    name: Class-like name for the resulting layer; for use in debugging.
    f: Pure function from input tensors to output tensors, where each input
        tensor is a separate positional arg, e.g., `f(x0, x1) --> x0 + x1`.
        Output tensors must be packaged as specified in the `Layer` class
        docstring.
    n_out: Number of outputs promised by the layer; default value 1.

  Returns:
    Layer executing the function `f`.
  """
  ```

  Let's implement our own Hadamard product.
  You might remember from the RNN lesson, that we used this for the gates:

$$h_t = \Gamma \otimes h_{t-1}$$

where $\Gamma$ is the gate (which has the same size as $h$ and values between 0 and 1), $h$ is the hidden state, and $\otimes$ the Hadamard product. For example:

$$
\begin{bmatrix}
0.9 & 0.04 \\
0.25 & -0.48
\end{bmatrix}
=
\begin{bmatrix}
0.9 & 0.01 \\
0.5 & 0.2
\end{bmatrix}
\otimes
\begin{bmatrix}
1.0 & 2.0 \\
0.5 & -2.4
\end{bmatrix}
$$

In [10]:
softmax = tl.Softmax(axis=-1)
X = np.random.rand(32, 20)
gate = softmax(X)

out = jnp.multiply(X, gate)
out.shape

(32, 20)

The `jnp` function is kind of simple. But we want to make sure this is part of the backpropagation of our model.
We make it into a proper layer with `tl.Fn`

In [11]:
def Hadamard():
    def f(x0, x1):
        return jnp.multiply(x0, x1)
    
    return tl.Fn("Hadamard", f, n_out=1)

Now we can use is as part of a model.
For example, let's implement the Gated Linear Unit:
$$GLU(X) = \sigma(W_1 X + b_1) \otimes (W_2 X + b_2)$$

In [12]:
gate = cb.Serial(
    tl.Dense(128),
    tl.Softmax(axis=-1)
)

model = cb.Serial(
    cb.Branch(gate, tl.Dense(128)),
    Hadamard(),
)
model.init_weights_and_state(signature(X))
yhat = model(X)
signature(yhat)

ShapeDtype{shape:(32, 128), dtype:float32}

Or, wrap it inside a function.

In [13]:
@assert_shape('bd->bd')
def GLU(units: int):
    gate = cb.Serial(
    tl.Dense(units),
    tl.Softmax(axis=-1)
    )

    model = cb.Serial(
        cb.Branch(gate, tl.Dense(units)),
        Hadamard(),
    )
    return model