In [1]:
import trax
from trax import layers as tl
from trax.shapes import signature
import jax.numpy as jnp
import numpy as np


# 1 Input, output, Signatures
Let's start with some dummy data

In [2]:
X = np.random.normal(loc=0, scale=0.1, size=(5, 3))
signature(X)


ShapeDtype{shape:(5, 3), dtype:float64}

Trax has a `signature` function, which is almost similar to `.shape`, it's just more flexible.

In [3]:
def print_info(model, yhat):
    print(f"input: {model.n_in}")
    print(f"output: {model.n_out}")
    print(f"Signature: {signature(yhat)}")


We can create a trax layer, for example a `Relu` layer, like this

In [4]:
relu = tl.Relu()
yhat = relu(X)
X, yhat




(array([[ 0.0977031 ,  0.1033898 ,  0.00127924],
        [-0.15047795, -0.041198  ,  0.02144017],
        [-0.01992723, -0.06893442,  0.03013939],
        [-0.08105862, -0.12708502,  0.04827876],
        [-0.07423597, -0.054421  ,  0.19920343]]),
 DeviceArray([[0.0977031 , 0.10338981, 0.00127924],
              [0.        , 0.        , 0.02144017],
              [0.        , 0.        , 0.03013939],
              [0.        , 0.        , 0.04827876],
              [0.        , 0.        , 0.19920343]], dtype=float32))

This works as expected.

In [5]:
print_info(relu, yhat)


input: 1
output: 1
Signature: ShapeDtype{shape:(5, 3), dtype:float32}


This layer has one input, one output, and the shape is exactly the same.

In [6]:
concat = tl.Concatenate()
yhat = concat([X, X])
print_info(concat, yhat)


input: 2
output: 1
Signature: ShapeDtype{shape:(5, 6), dtype:float32}


Concatenate will take two inputs, and will merge them into one.
We can tell concatenate to take three inputs as well:

In [7]:
concat3 = tl.Concatenate(n_items=3)
yhat = concat3([X, X, X])
print_info(concat3, yhat)


input: 3
output: 1
Signature: ShapeDtype{shape:(5, 9), dtype:float32}


Or to use another axis

In [8]:
concat = tl.Concatenate(axis=0)
yhat = concat([X, X])
print_info(concat, yhat)


input: 2
output: 1
Signature: ShapeDtype{shape:(10, 3), dtype:float32}


# 2 Combinators
The most interesting part of trax are the combinators. 
## 2.1 Sequential
The serial layer is similar to the `sequential` from torch.

In [9]:
from trax.layers import combinators as cb


In [10]:
model1 = cb.Serial(
    tl.Dense(128),
    tl.Relu(),
)
model1.init_weights_and_state(signature(X))
yhat = model1(X)
print_info(model1, yhat)


input: 1
output: 1
Signature: ShapeDtype{shape:(5, 128), dtype:float32}


Note, we dont have to specify the size of the input. 
Calling `.init_weights_and_state` on a specific signature will infer the inputs sizes needed to make things work.

Expanding the model is as simple as adding layers.

In [11]:
model2 = cb.Serial(
    tl.Dense(64),
    tl.Relu(),
    tl.Dense(32),
    tl.Relu(),
)

model2.init_weights_and_state(signature(X))
yhat = model2(X)
print_info(model2, yhat)


input: 1
output: 1
Signature: ShapeDtype{shape:(5, 32), dtype:float32}


## 2.2 Branch combinator
With `torch` we have seen skip layers. With vanilla Python, could do that like this:

```python
...
def forward(self, x):
    # torch implementation
    skip = x
    x = self.neuralnetwork(x)
    out = skip + x
    return out
```

We have also seen parallel processing, e.g. with the GoogleNet architecture, sort of like this:

```python
...
def forward(self, x):
    # torch implementation
    x1 = self.conv1(x)
    x2 = self.conv2(x)
    out = self.concat(x1, x2)
    return out
```

However, with `trax`, we can use `Branch` to make parallel branches.
Trax uses a stack of inputs. With `Branch`, each layer consumes as much inputs from the stack as needed.

For example, suppose one has three layers:

    - F: 1 input, 1 output
    - G: 3 inputs, 1 output
    - H: 2 inputs, 2 outputs (h1, h2)

Branch(F, G) will take three inputs in parallel, and give 2 outputs.
Every function just takes from the stack what it needs, with a maximum of three inputs.
So with input (a, b, c), we will have F(a) and G(a, b, c).

Branch(F, G, H) will take 3 inputs and give 4 outputs:

    - inputs: a, b, c
    - outputs: F(a), G(a, b, c), H(a, b) -> f1, g1, h1, h2 

### Example of Branch
Above, we created two Neural Networks, model1 and model2.
Let's say we want to take one input, and branch it through the two models in parallel.

<img src="../../reports/figures/parallel.png"/>

We need a model that:
- takes 1 input x.
- This input is processed as model1(x), model2(x)
- model1 outputs (batch, 128) while model2 outputs (batch, 32)
- model1 and model2 both have one output, so the output is m1, m2


In [12]:
model = cb.Serial(cb.Branch(model1, model2))
model.init_weights_and_state(signature(X))
yhat = model(X)
print_info(model, yhat)


input: 1
output: 2
Signature: (ShapeDtype{shape:(5, 128), dtype:float32}, ShapeDtype{shape:(5, 32), dtype:float32})


And finally, we merge those two outputs by using a function like concatenate, that takes two inputs and outputs a single matrix.

In [14]:
model = cb.Serial(cb.Branch(model1, model2), cb.Concatenate())
model.init_weights_and_state(signature(X))
yhat = model(X)
print_info(model, yhat)


input: 1
output: 1
Signature: ShapeDtype{shape:(5, 160), dtype:float32}


When we leave one item empty, like `cb.Branch([], model)`, one copy of `x` is simply passed through without being processed.   
This is usefull for skiplayers, like the residual architecture.

In [13]:
X = np.random.rand(32, 10, 128)

dnn = cb.Serial(
    tl.Dense(128),
    tl.Relu(),
    tl.Dense(128),
    tl.Relu(),
)


residual = cb.Serial(cb.Branch([], dnn), cb.Add())

residual.init_weights_and_state(signature(X))
yhat = residual(X)
print_info(residual, yhat)


input: 1
output: 1
Signature: ShapeDtype{shape:(32, 10, 128), dtype:float32}


## 2.3 Parallel combinator
First, have a look at the embedding layer we already have seen in `torch`

In [86]:
emb = tl.Embedding(vocab_size=1000, d_feature=16)
X = np.random.randint(0, 1000, size=(32, 10))
emb.init_weights_and_state(signature(X))
X_ = emb(X)
signature(X_)


ShapeDtype{shape:(32, 10, 16), dtype:float32}

This is really similar to `torch`. From the torch documentation: 

> torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, device=None, dtype=None)

In Trax, `num_embeddings` and `embedding_dim` just have different names: `vocab_size` and `d_feature`.
Not much news here.

### parallel embeddings usecase
But let us have a medical example where we would need three embeddings.
Let there be three categorical inputs:
1. 20 different types of medication
2. 1000 different medical diagnoses
3. 128 different locations where patients are treated.

We would need three embedding layers, each with a different `vocab_size`. With `torch`, implementing this would be a bit more complex, especially if you want your model to be flexible enough to be able to use *any* number of embedding layers that you specify at the start of the model with a parameter.

One approach in `torch` would be to create a `ModuleDict`, where you can collect multiple layers with a name, keep track of every name-layer pair, and call the right layer when needed.

Just to let you appreciate the simplicity of `trax`, here is just a part of an implementation of a multiembedding in pytorch-forecasting. I removed parts of the code at the place of the dots `...`, for simplicity. The only thing I want you to take away from this example is that it is fairly complex and takes a lot of code.

```python
class MultiEmbedding(nn.Module):
    ...
    def init_embeddings(self):
            self.embeddings = nn.ModuleDict()
            for name in self.embedding_sizes.keys():
                embedding_size = self.embedding_sizes[name][1]
                ...
                # convert to list to become mutable
                self.embedding_sizes[name] = list(self.embedding_sizes[name])
                self.embedding_sizes[name][1] = embedding_size
                ...
                    self.embeddings[name] = nn.Embedding(
                        self.embedding_sizes[name][0],
                        embedding_size,
                        padding_idx=padding_idx,
                    )
    ...
    def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
    input_vectors = {}
            for name, emb in self.embeddings.items():
                if name in self.categorical_groups:
                    input_vectors[name] = emb(
                        x[
                            ...,
                            [self.x_categoricals.index(cat_name) for cat_name in self.categorical_groups[name]],
                        ]
                    )
                else:
                    input_vectors[name] = emb(x[..., self.x_categoricals.index(name)])
```

You can look up the full implementation at [github](https://github.com/jdb78/pytorch-forecasting/blob/master/pytorch_forecasting/models/nn/embeddings.py#L32), which is 163 lines long!


To do this in `trax`, we will use the `Parallel` layer. From the [source code](https://github.com/google/trax/blob/master/trax/layers/combinators.py#L138) of `trax`:
> For example, suppose one has three layers:
>    - F: 1 input, 1 output
>    - G: 3 inputs, 1 output
>    - H: 2 inputs, 2 outputs (h1, h2)
>
>  Then Parallel(F, G, H) will take 6 inputs and give 4 outputs:
>
>    - inputs: a, b, c, d, e, f
>    - outputs: F(a), G(b, c, d), h1, h2     where h1, h2 = H(e, f)

This is almost similar to `Branch`, but it does not duplicate inputs but will just consume the stack.

Now, have a look at the `trax` implementation of the multiembedding

In [97]:
# first, we set up vocab sizes and some random input
vocab_sizes = [20, 1000, 128]
input = [np.random.randint(0, v, size=(32, 10)) for v in vocab_sizes]


Now the full model

In [98]:
def multiembedding(vocab_sizes):
    embeddings = [tl.Embedding(vocab_size=vocab, d_feature=16) for vocab in vocab_sizes]

    model = cb.Serial(cb.Parallel(*embeddings))
    return model


That's it. let's test it:

In [99]:
model = multiembedding(vocab_sizes)
model.init_weights_and_state(signature(input))
yhat = model(input)
print_info(model, yhat)


input: 3
output: 3
Signature: (ShapeDtype{shape:(32, 10, 16), dtype:float32}, ShapeDtype{shape:(32, 10, 16), dtype:float32}, ShapeDtype{shape:(32, 10, 16), dtype:float32})


I hope to have convinced you that `trax` makes writing models simpler, more elegant, better to read and faster.