<a href="https://colab.research.google.com/github/prateekchandrajha/ir-mini-project/blob/main/IR_project_google_brain_trax_library_intro.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction to Trax Library Maintained by Google Brain Team

In this code walk-through we'll cover the fundamentals of the Trax framework and learn about its basic building blocks including classes, sub-classes and data generators.



## Background

### Why Should Someone Use Trax and not TensorFlow, Keras, Caffe or PyTorch? (Also explains why we used it for our IR Mini-Project)

TensorFlow and PyTorch are both extensive frameworks that can do almost anything in deep learning. They offer a lot of flexibility, but that often means verbosity of syntax and extra time to code.

Trax is much more concise. It runs on a TensorFlow backend but allows you to train models with 1 line commands. Trax also runs end to end, allowing you to get data, model and train all with a single terse statements. This means you can focus on learning, instead of spending hours on the idiosyncrasies of big framework implementation. 

### Why not Keras then? Elaborate More?

Keras is now part of Tensorflow itself from 2.0 onwards. Also, trax is good for implementing new state of the art algorithms like Transformers, Reformers, BERT because it is actively maintained by Google Brain Team for advanced deep learning tasks. It runs smoothly on CPUs,GPUs and TPUs as well with comparatively lesser modifications in code.

### How to Code Neural Networks in Trax?
Building models in Trax relies on 2 key concepts:- **layers** and **combinators**.

Trax layers are simple objects that process data and perform computations. They can be chained together into composite layers using Trax combinators, allowing you to build layers and models of any complexity.

### Trax, JAX, TensorFlow and Tensor2Tensor

You already know that Trax uses Tensorflow as a backend, but it also uses the JAX library to speed up computation too. You can view JAX as an enhanced and optimized version of numpy. 

**Sometimes you'll find us importing `import trax.fastmath.numpy as np`. So when we call `np` we are really calling Trax’s version of numpy that is compatible with JAX.**

As a result of this, where one used to encounter the type `numpy.ndarray` now you will find the type `jax.interpreters.xla.DeviceArray`.

Tensor2Tensor is another name we have heard. It started as an end to end solution much like how Trax is designed, but it grew unwieldy and complicated. So once can view Trax as the new improved version that operates much faster and simpler.

### More Resources & References:

- Trax source code can be found on Github: [Trax](https://github.com/google/trax)
- Read more about JAX library over here: [JAX](https://jax.readthedocs.io/en/latest/index.html)


## Installing Trax

Trax has dependencies on JAX and some libraries like JAX which are yet to be supported in [Windows](https://github.com/google/jax/blob/1bc5896ee4eab5d7bb4ec6f161d8b2abb30557be/README.md#installation).

Official maintained documentation - [trax-ml](https://trax-ml.readthedocs.io/en/latest/).

In [None]:
# !pip install trax==1.3.1 


## Imports

In [None]:
import warnings
warnings.filterwarnings('ignore')

import numpy as np  # regular ol' numpy

from trax import layers as tl  # core building block
from trax import shapes  # data signatures: dimensionality and type
from trax import fastmath  # uses jax, offers numpy on steroids

In [None]:
# Trax version 1.3.1 or better 
!pip list | grep trax

trax                          1.3.1                


## Layers
Layers are the core building blocks in Trax so they are the base classes.

They take inputs, compute functions/custom calculations and return outputs.

We can also inspect layer properties. Let us do a code walk-through.


### Relu Layer
First we'll see how to build a Relu activation function as a layer. A layer like this is one of the simplest types. Notice there is no object initialization so it works just like a math function.

**Note: Activation functions are also layers in Trax, which might look odd if you have been using other frameworks for a longer time. Traditionally layers in deep learning refer to layers of neural networks. In Trax, on the other hand, everything can be considered and implemented as a layer.**

In [None]:
# Layers
# Create a relu trax layer
relu = tl.Relu()

# Inspect properties
print("-- Properties --")
print("name :", relu.name)
print("expected inputs :", relu.n_in)
print("promised outputs :", relu.n_out, "\n")

# Inputs
x = np.array([-2, -1, 0, 1, 2])
print("-- Inputs --")
print("x :", x, "\n")

# Outputs
y = relu(x)
print("-- Outputs --")
print("y :", y)

-- Properties --
name : Relu
expected inputs : 1
promised outputs : 1 

-- Inputs --
x : [-2 -1  0  1  2] 

-- Outputs --
y : [0 0 0 1 2]


### Concatenate Layer
Now we see how to build a layer that takes 2 inputs. Notice the change in the expected inputs property from 1 to 2. 

In [None]:
# Create a concatenate trax layer
concat = tl.Concatenate()
print("-- Properties --")
print("name :", concat.name)
print("expected inputs :", concat.n_in)
print("promised outputs :", concat.n_out, "\n")

# Inputs
x1 = np.array([-10, -20, -30])
x2 = x1 / -10
print("-- Inputs --")
print("x1 :", x1)
print("x2 :", x2, "\n")

# Outputs
y = concat([x1, x2])
print("-- Outputs --")
print("y :", y)

-- Properties --
name : Concatenate
expected inputs : 2
promised outputs : 1 

-- Inputs --
x1 : [-10 -20 -30]
x2 : [1. 2. 3.] 

-- Outputs --
y : [-10. -20. -30.   1.   2.   3.]


## Layers are Configurable
We can change the default settings of layers. For example, we can change the expected inputs for a concatenate layer from 2 to 3 using the optional parameter `n_items`. This is very useful in practical world.

In [None]:
# Configure a concatenate layer
concat_3 = tl.Concatenate(n_items=3)  # configure the layer's expected inputs
print("-- Properties --")
print("name :", concat_3.name)
print("expected inputs :", concat_3.n_in)
print("promised outputs :", concat_3.n_out, "\n")

# Inputs
x1 = np.array([-10, -20, -30])
x2 = x1 / -10
x3 = x2 * 0.99
print("-- Inputs --")
print("x1 :", x1)
print("x2 :", x2)
print("x3 :", x3, "\n")

# Outputs
y = concat_3([x1, x2, x3])
print("-- Outputs --")
print("y :", y)

-- Properties --
name : Concatenate
expected inputs : 3
promised outputs : 1 

-- Inputs --
x1 : [-10 -20 -30]
x2 : [1. 2. 3.]
x3 : [0.99 1.98 2.97] 

-- Outputs --
y : [-10.   -20.   -30.     1.     2.     3.     0.99   1.98   2.97]


## Layers can have Weights
Some layer types include mutable weights and biases that are used in computation and training. Layers of this type require initialization before use.

For example the `LayerNorm` layer calculates normalized data, that is also scaled by weights and biases. During initialization you pass the data shape and data type of the inputs, so the layer can initialize compatible arrays of weights and biases.

In [None]:
# Layer initialization
norm = tl.LayerNorm()
# You first must know what the input data will look like
x = np.array([0, 1, 2, 3], dtype="float")

# Use the input data signature to get shape and type for initializing weights and biases
norm.init(shapes.signature(x)) # We need to convert the input datatype from usual tuple to trax ShapeDtype

print("Normal shape:",x.shape, "Data Type:",type(x.shape))
print("Shapes Trax:",shapes.signature(x),"Data Type:",type(shapes.signature(x)))

# Inspect properties
print("-- Properties --")
print("name :", norm.name)
print("expected inputs :", norm.n_in)
print("promised outputs :", norm.n_out)
# Weights and biases
print("weights :", norm.weights[0])
print("biases :", norm.weights[1], "\n")

# Inputs
print("-- Inputs --")
print("x :", x)

# Outputs
y = norm(x)
print("-- Outputs --")
print("y :", y)

Normal shape: (4,) Data Type: <class 'tuple'>
Shapes Trax: ShapeDtype{shape:(4,), dtype:float64} Data Type: <class 'trax.shapes.ShapeDtype'>
-- Properties --
name : LayerNorm
expected inputs : 1
promised outputs : 1
weights : [1. 1. 1. 1.]
biases : [0. 0. 0. 0.] 

-- Inputs --
x : [0. 1. 2. 3.]
-- Outputs --
y : [-1.3416404  -0.44721344  0.44721344  1.3416404 ]


## Custom Layers & Functions
This is where things start getting much more interesting than in Keras & Pytorch! Here we can create our own custom layers too and define custom functions for computations by using `tl.Fn`. Let us see how.

In [None]:
help(tl.Fn)

Help on function Fn in module trax.layers.base:

Fn(name, f, n_out=1)
    Returns a layer with no weights that applies the function `f`.
    
    `f` can take and return any number of arguments, and takes only positional
    arguments -- no default or keyword arguments. It often uses JAX-numpy (`jnp`).
    The following, for example, would create a layer that takes two inputs and
    returns two outputs -- element-wise sums and maxima:
    
        `Fn('SumAndMax', lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)), n_out=2)`
    
    The layer's number of inputs (`n_in`) is automatically set to number of
    positional arguments in `f`, but you must explicitly set the number of
    outputs (`n_out`) whenever it's not the default value 1.
    
    Args:
      name: Class-like name for the resulting layer; for use in debugging.
      f: Pure function from input tensors to output tensors, where each input
          tensor is a separate positional arg, e.g., `f(x0, x1) --> x0 + x1`.
          

In [None]:
# Define a custom layer
# In this example we will create a simple layer to just calculate the input times 2

def TimesTwo():
    layer_name = "TimesTwo" # Trax uses this to identify the layer

    # Custom function for the custom layer
    def func(x):
        return x * 2

    return tl.Fn(layer_name, func)


# Testing
times_two = TimesTwo()

# Inspect properties
print("-- Properties --")
print("name :", times_two.name)
print("expected inputs :", times_two.n_in)
print("promised outputs :", times_two.n_out, "\n")

# Inputs
x = np.array([1, 2, 3])
print("-- Inputs --")
print("x :", x, "\n")

# Outputs
y = times_two(x)
print("-- Outputs --")
print("y :", y)

-- Properties --
name : TimesTwo
expected inputs : 1
promised outputs : 1 

-- Inputs --
x : [1 2 3] 

-- Outputs --
y : [2 4 6]


## Combinators
We can also combine separate simple layers to build more complex layers. As we highlighted above, Trax provides a set of objects named combinator layers to make this happen. Combinators are themselves layers, so behavior commutes in a way.



### Serial Combinator
This is the most common and easiest to use combinator layer. For example we could build a simple neural network by combining layers into a single layer using the `Serial` combinator. This new layer then acts just like a single layer, so you can inspect intputs, outputs and weights. Or even combine it into another layer! Combinators can then be used as trainable neural network models.


In [None]:
# help(tl.Serial)
# help(tl.Parallel)

In [None]:
# Serial combinator
serial = tl.Serial(
    tl.LayerNorm(),         # normalize input values
    tl.Relu(),              # convert negative values to zero using Relu activation function
    times_two,              # the custom layer we created above, multiplies the input recieved from above by 2

#     tl.Dense(n_units=2),  # we can add more layers
#     tl.Dense(n_units=1),  # Binary classification
#     tl.LogSoftmax()       # Yes, LogSoftmax is also a layer
)

# Initialization
x = np.array([-2, -1, 0, 1, 2]) #input
serial.init(shapes.signature(x)) #initialising serial instance

print("-- Serial Model --")
print(serial,"\n")
print("-- Properties --")
print("name :", serial.name)
print("sublayers :", serial.sublayers)
print("expected inputs :", serial.n_in)
print("promised outputs :", serial.n_out)
print("weights & biases:", serial.weights, "\n")

# Inputs
print("-- Inputs --")
print("x :", x, "\n")

# Outputs
y = serial(x)
print("-- Outputs --")
print("y :", y)

-- Serial Model --
Serial[
  LayerNorm
  Relu
  TimesTwo
] 

-- Properties --
name : Serial
sublayers : [LayerNorm, Relu, TimesTwo]
expected inputs : 1
promised outputs : 1
weights & biases: [(DeviceArray([1, 1, 1, 1, 1], dtype=int32), DeviceArray([0, 0, 0, 0, 0], dtype=int32)), (), ()] 

-- Inputs --
x : [-2 -1  0  1  2] 

-- Outputs --
y : [0.        0.        0.        1.4142132 2.8284264]


## JAX
When working with `trax` we should be careful regarding which numpy have we been using, the regular old numpy or Trax's JAX compatible numpy. Both tend to use the alias np for importing so we will mention whenever we use one kind of numpy and not another.

**Note:There are certain things which are still not possible in fastmath.numpy which can be done in normal old numpy so we will sometimes switch between them to get our work done.**

In [None]:
# Numpy vs fastmath.numpy have different data types
# Regular old numpy

x_numpy = np.array([1, 2, 3])
print("good old numpy : ", type(x_numpy), "\n")

# Fastmath and jax numpy
x_jax = fastmath.numpy.array([1, 2, 3])
print("jax trax numpy : ", type(x_jax))

good old numpy :  <class 'numpy.ndarray'> 

jax trax numpy :  <class 'jax.interpreters.xla._DeviceArray'>



# Data generators in Trax



In Python, a generator is a function that behaves like an iterator. It will return the next item. Here is a [link](https://wiki.python.org/moin/Generators) to delve deeper into python generators. In many AI applications, it is advantageous to have a data generator to handle loading and transforming data for different applications. Every deep learnign framework has to have some sort of a dataloader mechanism.

We will now implement a custom data generator. In the following code walk-through, we use a set of samples `a`, to derive a new set of samples, with more elements than the original set. This demonstrates how a data generator can be manipulated.

**Note: Due attention should be paid to the usage of list `lines_index` and variable `index` to traverse the original list.**

In [1]:
import random 
import numpy as np

# Example of traversing a list of indexes to create a circular list
a = [1, 2, 3, 4]
b = [0] * 10

a_size = len(a)
b_size = len(b)

lines_index = [*range(a_size)] # is equivalent to list comprehension [i for i in range(0, a_size)], the difference being the advantage of using * to pass values of range iterator to list directly
index = 0                      # similar to index in data_generator below
for i in range(b_size):        # `b` is longer than `a` forcing a wrap
    
    # We wrap by resetting index to 0 so the sequences circle back at the end to point to the first index
    if index >= a_size:
        index = 0
    
    b[i] = a[lines_index[index]]     #  `indexes_list[index]` point to a index of a. Store the result in b
    index += 1
    
print(b)
# so as you see in the output below, after traversing the list from 1 to 4, it starts back at 1 and carries on till we exhaust the traversal

[1, 2, 3, 4, 1, 2, 3, 4, 1, 2]


## Shuffling the data order

In the next code walk-through, we will do the same as before, but shuffling the order of the elements in the output list. We should note that here, our strategy of traversing using `lines_index` and `index` becomes very important, because we can simulate a shuffle in the input data, without doing that in reality. So simply put, data is shuffled in output without the original data order being shuffled.

**So why does it matter at all?**

**In deep learning parlance, we call it 1 epoch each time a deep learning training algorithm passes over all the training examples. Shuffling the examples for each epoch is known to reduce variance, making the models more generalizable and less prone to overfitting.**





In [2]:
# Example of traversing a list of indexes to create a circular list
a = [1, 2, 3, 4]
b = []

a_size = len(a)
b_size = 10

lines_index = [*range(a_size)]
print("Original order of index:", lines_index)

# if we shuffle the index_list we can change the order of our circular list
# without modifying the order or our original data

random.shuffle(lines_index) # Shuffle the order
print("Shuffled order of index:", lines_index)

print("New value order for first batch:", [a[index] for index in lines_index]) # using list comprehension

batch_counter = 1
index = 0                # similar to index in data_generator below

for i in range(b_size):  # `b` is longer than `a` forcing a wrap
    # We wrap by resetting index to 0
    if index >= a_size:
        index = 0
        batch_counter += 1
        random.shuffle(lines_index) # Re-shuffle the order
        print("\nShuffled Indexes for Batch No.{} :{}".format(batch_counter,lines_index))
        print("Values for Batch No.{} :{}".format(batch_counter,[a[index] for index in lines_index]))
    
    b.append(a[lines_index[index]])     #  `indexes_list[index]` point to a index of a. Store the result in b
    index += 1
print()    
print("Final value of b:", b)

Original order of index: [0, 1, 2, 3]
Shuffled order of index: [3, 2, 1, 0]
New value order for first batch: [4, 3, 2, 1]

Shuffled Indexes for Batch No.2 :[3, 1, 2, 0]
Values for Batch No.2 :[4, 2, 3, 1]

Shuffled Indexes for Batch No.3 :[1, 3, 2, 0]
Values for Batch No.3 :[2, 4, 3, 1]

Final value of b: [4, 3, 2, 1, 4, 2, 3, 1, 2, 4]


## Final Custom Data Generator Code Walk-Through

So here we try and implement a data generator function that takes in `batch_size, x, y, shuffle` where x could be a large list of samples, and y is a list of the tags associated with those samples. Our generator should return a subset of those inputs in a tuple of two arrays `(X,Y)`. Each is an array of dimension (`batch_size`). If `shuffle=True`, the data will be traversed in a random form in each subsequent epoch (defined above).

**Implementation Details:**

This code as an outer loop looks like this: 
```
while True:  
...  
yield((X,Y))  
```

Which runs continuously in the fashion of generators, pausing when yielding the next values. We will generate a batch_size output on each pass of this loop.    

It has an inner loop that stores in temporal lists `(X, Y)` i.e. the data samples to be included in the next batch of yielding.

There are 3 slightly out of the ordinary features which should be given due thought:

1. The first is the use of a list of a predefined size to store the data for each batch. Using a predefined size list reduces the computation time if the elements in the array are of a fixed size, like numbers. If the elements are of different sizes, it is better to use an empty array and append one element at a time during the loop. When you know things in advance, go for a list of predefined size.

2. The second is tracking the current location in the incoming lists of samples. Generator's variables hold their values between invocations/yielding, so we create an `index` variable, initialize to zero, and increment by one for each sample included in a batch. However, we do not use the `index` to access the positions of the list of sentences directly. Instead, we use it to select one index from a list of indexes. In this way, we can change the order in which we traverse our original list, keeping untouched/unblemished our original list data.  

3. The third relates to wrapping. Because `batch_size` and the length of the input lists are not aligned, gathering a batch_size group of inputs may involve wrapping back to the beginning of the input loop. In our approach, it is just enough to reset the `index` to 0. We can re-shuffle the list of indexes to produce different batches in each subsequent epoch.

In [3]:
def data_generator(batch_size, data_x, data_y, shuffle=True):
    '''
      Input: 
        batch_size - integer describing the batch size
        data_x - list containing samples
        data_y - list containing labels
        shuffle - Shuffle the data order
      Output:
        a tuple containing 2 elements:
        X - list of dim (batch_size) of samples
        Y - list of dim (batch_size) of labels
    '''
    
    data_lng = len(data_x) # len(data_x) must be equal to len(data_y)
    index_list = [*range(data_lng)] # Create a list with the ordered indexes of sample data
    
    
    # If shuffle is set to true, we traverse the list in a random way
    if shuffle:
        random.shuffle(index_list) # Inplace shuffle of the list
    
    index = 0 # Start with the first element
      
    while True:
        X = [0] * batch_size # We can create a list with batch_size elements. 
        Y = [0] * batch_size # We can create a list with batch_size elements. 
        
        for i in range(batch_size):
            
            # Wrap the index each time that we reach the end of the list
            if index >= data_lng:
                index = 0
                # Shuffle the index_list if shuffle is true
                if shuffle:
                    rnd.shuffle(index_list) # re-shuffle the order
            
            X[i] = data_x[index_list[index]] # We set the corresponding element in x
            Y[i] = data_y[index_list[index]] # We set the corresponding element in x
           
            index += 1
        
        yield((X, Y))
    

In [4]:
def testing_custom_data_generator():
    x = [1, 2, 3, 4]
    y = [xi ** 2 for xi in x]
    
    generator = data_generator(3, x, y, shuffle=False)

    assert np.allclose(next(generator), ([1, 2, 3], [1, 4, 9])),  "First batch does not match"
    assert np.allclose(next(generator), ([4, 1, 2], [16, 1, 4])), "Second batch does not match"
    assert np.allclose(next(generator), ([3, 4, 1], [9, 16, 1])), "Third batch does not match"
    assert np.allclose(next(generator), ([2, 3, 4], [4, 9, 16])), "Fourth batch does not match"

    print("\033[92mAll tests passed!")

testing_custom_data_generator()

[92mAll tests passed!


## Summary
Trax is a concise deep learning framework, built on TensorFlow, for end to end machine/deep learning. The key building blocks are layers and combinators. This code walk-through was to give a taste of the library our project is based upon. It also sets us up with some key custom layer and data generator intuitions to take forward into the next few submissions where we have built an end to end Neural Machine Translation model.

## Code References:
- [Trax Layers Introduction](https://trax-ml.readthedocs.io/en/latest/notebooks/layers_intro.html)
- [JAX, AKA NUMPY ON STEROIDS](https://iaml.it/blog/jax-intro-english)
- [JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html)
- [You don't know JAX](https://colinraffel.com/blog/you-don-t-know-jax.html)
- [Trax — Deep Learning with Clear Code and Speed](https://github.com/google/trax)