In [1]:
import numpy as np
from trax import layers as tl
from trax import shapes
from trax import fastmath

INFO:tensorflow:tokens_length=568 inputs_length=512 targets_length=114 noise_density=0.15 mean_noise_span_length=3.0 


In [2]:
!pip list | grep trax

trax                               1.3.4


## ReLU layer

In [3]:
relu=tl.Relu()

In [4]:
relu.name,relu.n_in,relu.n_out

('Relu', 1, 1)

In [5]:
x=np.array([-2,-1,0,1,2])

In [6]:
y=relu(x)



In [7]:
type(y)

jax.interpreters.xla.DeviceArray

In [8]:
print(y)

[0 0 0 1 2]


# Concatenate layer

In [9]:
concat=tl.Concatenate()

In [10]:
concat.name,concat.n_in,concat.n_out

('Concatenate', 2, 1)

In [11]:
x1=np.array([-10,-20,-30])

In [12]:
x2=x1/-10

In [13]:
y=concat([x1,x2])

In [14]:
print(y)

[-10. -20. -30.   1.   2.   3.]


# Layers & Configurable

In [15]:
concat_3=tl.Concatenate(n_items=3)

In [16]:
x1=fastmath.numpy.arange(0,3)
x2=fastmath.numpy.arange(3,6)
x3=fastmath.numpy.arange(6,9)

In [17]:
y=concat_3([x1,x2,x3])

In [18]:
concat_3.name,concat_3.n_in,concat_3.n_out

('Concatenate', 3, 1)

In [19]:
print(y)

[0 1 2 3 4 5 6 7 8]


In [20]:
help(tl.Concatenate)

Help on class Concatenate in module trax.layers.combinators:

class Concatenate(trax.layers.base.Layer)
 |  Concatenate(n_items=2, axis=-1)
 |  
 |  Concatenates n tensors into a single tensor.
 |  
 |  Method resolution order:
 |      Concatenate
 |      trax.layers.base.Layer
 |      builtins.object
 |  
 |  Methods defined here:
 |  
 |  __init__(self, n_items=2, axis=-1)
 |      Creates a partially initialized, unconnected layer instance.
 |      
 |      Args:
 |        n_in: Number of inputs expected by this layer.
 |        n_out: Number of outputs promised by this layer.
 |        name: Class-like name for this layer; for use when printing this layer.
 |        sublayers_to_print: Sublayers to display when printing out this layer;
 |          By default (when None) we display all sublayers.
 |  
 |  forward(self, xs)
 |      Computes this layer's output as part of a forward pass through the model.
 |      
 |      Authors of new layer subclasses should override this method to d

# Weights Layers

In [21]:
norm=tl.LayerNorm()

In [22]:
x=np.array([0,1,2,3],dtype='float')

In [23]:
norm.init(shapes.signature(x))



((DeviceArray([1., 1., 1., 1.], dtype=float32),
  DeviceArray([0., 0., 0., 0.], dtype=float32)),
 ())

In [24]:
type(shapes.signature(x))

trax.shapes.ShapeDtype

In [25]:
norm.name,norm.n_in,norm.n_out

('LayerNorm', 1, 1)

In [26]:
norm.weights[0]

DeviceArray([1., 1., 1., 1.], dtype=float32)

In [27]:
norm.weights[1]

DeviceArray([0., 0., 0., 0.], dtype=float32)

# Custom Layers

In [28]:
help(tl.Fn)

Help on function Fn in module trax.layers.base:

Fn(name, f, n_out=1)
    Returns a layer with no weights that applies the function `f`.
    
    `f` can take and return any number of arguments, and takes only positional
    arguments -- no default or keyword arguments. It often uses JAX-numpy (`jnp`).
    The following, for example, would create a layer that takes two inputs and
    returns two outputs -- element-wise sums and maxima:
    
        `Fn('SumAndMax', lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)), n_out=2)`
    
    The layer's number of inputs (`n_in`) is automatically set to number of
    positional arguments in `f`, but you must explicitly set the number of
    outputs (`n_out`) whenever it's not the default value 1.
    
    Args:
      name: Class-like name for the resulting layer; for use in debugging.
      f: Pure function from input tensors to output tensors, where each input
          tensor is a separate positional arg, e.g., `f(x0, x1) --> x0 + x1`.
          

In [29]:
def TimesTwo():
    return tl.Fn('TimesTwo!',lambda x:x*2)

In [30]:
times_two=TimesTwo()

In [31]:
x=tl.fastmath.numpy.arange(10)

In [32]:
times_two.n_in,times_two.n_out,times_two.name

(1, 1, 'TimesTwo!')

In [33]:
y=times_two(x)

In [34]:
print(y)

[ 0  2  4  6  8 10 12 14 16 18]


# Combinators

In [35]:
serial=tl.Serial(tl.LayerNorm()
                 ,tl.Relu()
                 ,times_two)

In [36]:
x=np.array(list(range(-2,3)))

In [37]:
print(x)

[-2 -1  0  1  2]


In [38]:
serial.init(shapes.signature(x))



(((DeviceArray([1, 1, 1, 1, 1], dtype=int32),
   DeviceArray([0, 0, 0, 0, 0], dtype=int32)),
  (),
  ()),
 ((), (), ()))

In [39]:
y=serial(x)

In [40]:
print(y)

[0.        0.        0.        1.4142132 2.8284264]
