In [11]:
import numpy as np
from trax import layers as tl
from trax import shapes

In [8]:
# Trax version 1.3.1 or better 
!pip list | grep trax

trax                     1.3.7


In [9]:
relu = tl.Relu()

print("--Properties--")
print("Name: ",relu.name)
print(f"# of Input: {relu.n_in}")
print(f"# of Output : {relu.n_out}")


# Inputs
x = np.array([-2, -1, 0, 1, 2])
print("-- Inputs --")
print("x :", x, "\n")

# Outputs
y = relu(x)
print("-- Outputs --")
print("y :", y)

--Properties--
Name:  Serial
# of Input: 1
# of Output : 1
-- Inputs --
[-2 -1  0  1  2] 

-- Outputs --
y : [0 0 0 1 2]


In [10]:
# Create a concatenate trax layer
concat = tl.Concatenate()
print("-- Properties --")
print("name :", concat.name)
print("expected inputs :", concat.n_in)
print("promised outputs :", concat.n_out, "\n")

# Inputs
x1 = np.array([-10, -20, -30])
x2 = x1 / -10
print("-- Inputs --")
print("x1 :", x1)
print("x2 :", x2, "\n")

# Outputs
y = concat([x1, x2])
print("-- Outputs --")
print("y :", y)

-- Properties --
name : Concatenate
expected inputs : 2
promised outputs : 1 

-- Inputs --
x1 : [-10 -20 -30]
x2 : [1. 2. 3.] 

-- Outputs --
y : [-10. -20. -30.   1.   2.   3.]


In [12]:
help(shapes.signature)

Help on function signature in module trax.shapes:

signature(obj)
    Returns a `ShapeDtype` signature for the given `obj`.
    
    A signature is either a `ShapeDtype` instance or a tuple of `ShapeDtype`
    instances. Note that this function is permissive with respect to its inputs
    (accepts lists or tuples or dicts, and underlying objects can be any type
    as long as they have shape and dtype attributes) and returns the corresponding
    nested structure of `ShapeDtype`.
    
    Args:
      obj: An object that has `shape` and `dtype` attributes, or a list/tuple/dict
          of such objects.
    
    Returns:
      A corresponding nested structure of `ShapeDtype` instances.



In [16]:
x = np.array([0,1,2,3], dtype = 'float32')
norm = tl.LayerNorm()

norm.init(shapes.signature(x))


print("Normal shape:",x.shape, "Data Type:",type(x.shape))
print("Shapes Trax:",shapes.signature(x),"Data Type:",type(shapes.signature(x)))

# Inspect properties
print("-- Properties --")
print("name :", norm.name)
print("expected inputs :", norm.n_in)
print("promised outputs :", norm.n_out)
# Weights and biases
print("weights :", norm.weights[0])
print("biases :", norm.weights[1], "\n")

# Inputs
print("-- Inputs --")
print("x :", x)

# Outputs
y = norm(x)
print("-- Outputs --")
print("y :", y)

Normal shape: (4,) Data Type: <class 'tuple'>
Shapes Trax: ShapeDtype{shape:(4,), dtype:float32} Data Type: <class 'trax.shapes.ShapeDtype'>
-- Properties --
name : LayerNorm
expected inputs : 1
promised outputs : 1
weights : [1. 1. 1. 1.]
biases : [0. 0. 0. 0.] 

-- Inputs --
x : [0. 1. 2. 3.]
-- Outputs --
y : [-1.3416404  -0.44721344  0.44721344  1.3416404 ]


In [21]:
def Multiplication():
    layer_name = 'Multiplication'

    def func(x, y):
        return x * y

    return tl.Fn(layer_name, func)


multiply = Multiplication()
x = np.array([2])
y = np.array([3])


print("--Properties--")
print(f"Name: {multiply.name}")
print(f"# of input : {multiply.n_in}")
print(f"# of Output: {multiply.n_out}")
print("---------------\n")

print("---------Inputs--------")
print(f"x : {x}")
print(f"y:  {y}")
print("-----------------------\n")


answer = multiply((x,y))
print(f"Outputs: {answer}")

--Properties--
Name: Multiplication
# of input : 2
# of Output: 1
---------------

---------Inputs--------
x : [2]
y:  [3]
-----------------------

Outputs: [6]


In [31]:
def TimesTwo():
    layer_name = 'Times Two'

    def func(x):
        return x*2


    return tl.Fn(layer_name, func)


times_two = TimesTwo()
x = np.array([2])
y = np.array([3])


print("--Properties--")
print(f"Name: {times_two.name}")
print(f"# of input : {times_two.n_in}")
print(f"# of Output: {times_two.n_out}")
print("---------------\n")

print("---------Inputs--------")
print(f"x : {x}")
print("-----------------------\n")


answer = times_two(x)
print(f"Outputs: {answer}")


--Properties--
Name: Times Two
# of input : 1
# of Output: 1
---------------

---------Inputs--------
x : [2]
-----------------------

Outputs: [4]


In [32]:
serial = tl.Serial(
  tl.LayerNorm(),
  tl.Relu(),
  times_two, 
)

x = np.array([-2,-1,0,1,2])
serial.init(shapes.signature(x))

print("--Properties--")
print(f"Name: {serial.name}")
print(f"Sublayers : {serial.sublayers}")
print(f"# of input : {serial.n_in}")
print(f"# of Output: {serial.n_out}")
print("weights & biases:", serial.weights, "\n")
print("---------------\n")

print("---------Inputs--------")
print(f"x : {x}")
print("-----------------------\n")


answer = serial(x)
print(f"Outputs: {answer}")




--Properties--
Name: Serial
Sublayers : [LayerNorm, Serial[
  Relu
], Times Two]
# of input : 1
# of Output: 1
weights & biases: ((DeviceArray([1, 1, 1, 1, 1], dtype=int32), DeviceArray([0, 0, 0, 0, 0], dtype=int32)), ((), (), ()), ()) 

---------------

---------Inputs--------
x : [-2 -1  0  1  2]
-----------------------

Outputs: [0.        0.        0.        1.4142132 2.8284264]


In [35]:
from trax import fastmath

In [36]:
# Numpy vs fastmath.numpy have different data types
# Regular ol' numpy
x_numpy = np.array([1, 2, 3])
print("good old numpy : ", type(x_numpy), "\n")

# Fastmath and jax numpy
x_jax = fastmath.numpy.array([1, 2, 3])
print("jax trax numpy : ", type(x_jax))

good old numpy :  <class 'numpy.ndarray'> 

jax trax numpy :  <class 'jax.interpreters.xla._DeviceArray'>
