In [1]:
print('ahoj')

ahoj


In [2]:
import numpy as np  # regular ol' numpy

from trax import layers as tl  # core building block
from trax import shapes  # data signatures: dimensionality and type
from trax import fastmath  # uses jax, offers numpy on steroids

In [4]:
!pip list | grep trax



trax                     1.3.7                


In [6]:
relu = tl.Relu()


# Inspect properties
print("-- Properties --")
print("name :", relu.name)
print("expected inputs :", relu.n_in)
print("promised outputs :", relu.n_out, "\n")

# Inputs
x = np.array([-2, -1, 0, 1, 2])
print("-- Inputs --")
print("x :", x, "\n")

# Outputs
y = relu(x)
print("-- Outputs --")
print("y :", y)


print(relu)


-- Properties --
name : Serial
expected inputs : 1
promised outputs : 1 

-- Inputs --
x : [-2 -1  0  1  2] 





-- Outputs --
y : [0 0 0 1 2]
Serial[
  Relu
]


In [7]:
# Create a concatenate trax layer
concat = tl.Concatenate()
print("-- Properties --")
print("name :", concat.name)
print("expected inputs :", concat.n_in)
print("promised outputs :", concat.n_out, "\n")

# Inputs
x1 = np.array([-10, -20, -30])
x2 = x1 / -10
print("-- Inputs --")
print("x1 :", x1)
print("x2 :", x2, "\n")

# Outputs
y = concat([x1, x2])
print("-- Outputs --")
print("y :", y)

-- Properties --
name : Concatenate
expected inputs : 2
promised outputs : 1 

-- Inputs --
x1 : [-10 -20 -30]
x2 : [1. 2. 3.] 

-- Outputs --
y : [-10. -20. -30.   1.   2.   3.]


In [8]:

# Configure a concatenate layer
concat_3 = tl.Concatenate(n_items=3)  # configure the layer's expected inputs
print("-- Properties --")
print("name :", concat_3.name)
print("expected inputs :", concat_3.n_in)
print("promised outputs :", concat_3.n_out, "\n")

# Inputs
x1 = np.array([-10, -20, -30])
x2 = x1 / -10
x3 = x2 * 0.99
print("-- Inputs --")
print("x1 :", x1)
print("x2 :", x2)
print("x3 :", x3, "\n")

# Outputs
y = concat_3([x1, x2, x3])
print("-- Outputs --")
print("y :", y)

-- Properties --
name : Concatenate
expected inputs : 3
promised outputs : 1 

-- Inputs --
x1 : [-10 -20 -30]
x2 : [1. 2. 3.]
x3 : [0.99 1.98 2.97] 

-- Outputs --
y : [-10.   -20.   -30.     1.     2.     3.     0.99   1.98   2.97]


In [9]:
help(tl.Concatenate)


Help on class Concatenate in module trax.layers.combinators:

class Concatenate(trax.layers.base.Layer)
 |  Concatenate(n_items=2, axis=-1)
 |  
 |  Concatenates a number of tensors into a single tensor.
 |  
 |  For example::
 |  
 |      x = np.array([1, 2])
 |      y = np.array([3, 4])
 |      z = np.array([5, 6])
 |      concat3 = tl.Concatenate(n_items=3)
 |      z = concat3((x, y, z))  # z = [1, 2, 3, 4, 5, 6]
 |  
 |  Use the `axis` argument to specify on which axis to concatenate the tensors.
 |  By default it's the last axis, `axis=-1`, and `n_items=2`.
 |  
 |  Method resolution order:
 |      Concatenate
 |      trax.layers.base.Layer
 |      builtins.object
 |  
 |  Methods defined here:
 |  
 |  __init__(self, n_items=2, axis=-1)
 |      Creates a partially initialized, unconnected layer instance.
 |      
 |      Args:
 |        n_in: Number of inputs expected by this layer.
 |        n_out: Number of outputs promised by this layer.
 |        name: Class-like name for thi

In [10]:
help(tl.LayerNorm)
help(shapes.signature)

Help on class LayerNorm in module trax.layers.normalization:

class LayerNorm(LayerNorm)
 |  LayerNorm(epsilon=1e-06)
 |  
 |  Layer normalization.
 |  
 |  Method resolution order:
 |      LayerNorm
 |      LayerNorm
 |      trax.layers.base.Layer
 |      builtins.object
 |  
 |  Methods defined here:
 |  
 |  __init__(self, epsilon=1e-06)
 |  
 |  ----------------------------------------------------------------------
 |  Methods inherited from LayerNorm:
 |  
 |  forward(self, x)
 |      Computes this layer's output as part of a forward pass through the model.
 |      
 |      A layer subclass overrides this method to define how the layer computes
 |      outputs from inputs. If the layer depends on weights, state, or randomness
 |      as part of the computation, the needed information can be accessed as
 |      properties of the layer object: `self.weights`, `self.state`, and
 |      `self.rng`. (See numerous examples in `trax.layers.core`.)
 |      
 |      Args:
 |        inputs:

In [14]:
# Layer initialization
norm = tl.LayerNorm()
# You first must know what the input data will look like
x = np.array([0, 1, 2, 3], dtype="float32")

# Use the input data signature to get shape and type for initializing weights and biases
norm.init(shapes.signature(x)) # We need to convert the input datatype from usual tuple to trax ShapeDtype

print("Normal shape:",x.shape, "Data Type:",type(x.shape))
print("Shapes Trax:",shapes.signature(x),"Data Type:",type(shapes.signature(x)))

# Inspect properties
print("-- Properties --")
print("name :", norm.name)
print("expected inputs :", norm.n_in)
print("promised outputs :", norm.n_out)
# Weights and biases
print("weights :", norm.weights[0])
print("biases :", norm.weights[1], "\n")

# Inputs
print("-- Inputs --")
print("x :", x)

# Outputs
y = norm(x)
print("-- Outputs --")
print("y :", y)

Normal shape: (4,) Data Type: <class 'tuple'>
Shapes Trax: ShapeDtype{shape:(4,), dtype:float32} Data Type: <class 'trax.shapes.ShapeDtype'>
-- Properties --
name : LayerNorm
expected inputs : 1
promised outputs : 1
weights : [1. 1. 1. 1.]
biases : [0. 0. 0. 0.] 

-- Inputs --
x : [0. 1. 2. 3.]
-- Outputs --
y : [-1.3416404  -0.44721344  0.44721344  1.3416404 ]


In [15]:
help(tl.Fn)

Help on function Fn in module trax.layers.base:

Fn(name, f, n_out=1)
    Returns a layer with no weights that applies the function `f`.
    
    `f` can take and return any number of arguments, and takes only positional
    arguments -- no default or keyword arguments. It often uses JAX-numpy (`jnp`).
    The following, for example, would create a layer that takes two inputs and
    returns two outputs -- element-wise sums and maxima:
    
        `Fn('SumAndMax', lambda x0, x1: (x0 + x1, jnp.maximum(x0, x1)), n_out=2)`
    
    The layer's number of inputs (`n_in`) is automatically set to number of
    positional arguments in `f`, but you must explicitly set the number of
    outputs (`n_out`) whenever it's not the default value 1.
    
    Args:
      name: Class-like name for the resulting layer; for use in debugging.
      f: Pure function from input tensors to output tensors, where each input
          tensor is a separate positional arg, e.g., `f(x0, x1) --> x0 + x1`.
          

In [34]:
def TimesByTwo():
    func_name = "TimesTwoFunction"
    
    def func(x):
        
        print(x)
        
        
        return x * 2
    
    return tl.Fn(func_name, func)
    
    
    
    
def TimesByTwoMoreArguments():
    func_name = "TimesTwoFunctionMoreArgs"
    
    def func(x0, x1):     
        print(x0, x1)
        return x * x
    return tl.Fn(func_name, func)
        

times_two = TimesByTwo()
times_two_more_arguments = TimesByTwoMoreArguments()




print(times_two)
print(times_two.name)
print(times_two.n_in)    
print(times_two.n_out)


print(times_two_more_arguments)
print(times_two_more_arguments.name)
print(times_two_more_arguments.n_in)    
print(times_two_more_arguments.n_out)



# x_tmp = [1,2,3]
# y_tmp = times_two(x_tmp)



x = np.array([1,2,3])
y = times_two(x)


_y = times_two_more_arguments((x, x))




print(x)
print(y)
print(_y)




    

TimesTwoFunction
TimesTwoFunction
1
1
TimesTwoFunctionMoreArgs_in2
TimesTwoFunctionMoreArgs
2
1
[1 2 3]
[1 2 3] [1 2 3]
[1 2 3]
[2 4 6]
[1 4 9]


In [35]:
# Serial combinator
serial = tl.Serial(
    tl.LayerNorm(),         # normalize input
    tl.Relu(),              # convert negative values to zero
    times_two,              # the custom layer you created above, multiplies the input recieved from above by 2
    
    ### START CODE HERE
#     tl.Dense(n_units=2),  # try adding more layers. eg uncomment these lines
#     tl.Dense(n_units=1),  # Binary classification, maybe? uncomment at your own peril
#     tl.LogSoftmax()       # Yes, LogSoftmax is also a layer
    ### END CODE HERE
)

# Initialization
x = np.array([-2, -1, 0, 1, 2]) #input
serial.init(shapes.signature(x)) #initialising serial instance

print("-- Serial Model --")
print(serial,"\n")
print("-- Properties --")
print("name :", serial.name)
print("sublayers :", serial.sublayers)
print("expected inputs :", serial.n_in)
print("promised outputs :", serial.n_out)
print("weights & biases:", serial.weights, "\n")

# Inputs
print("-- Inputs --")
print("x :", x, "\n")

# Outputs
y = serial(x)
print("-- Outputs --")
print("y :", y)

Traced<ShapedArray(float32[5])>with<DynamicJaxprTrace(level=1/0)>
-- Serial Model --
Serial[
  LayerNorm
  Serial[
    Relu
  ]
  TimesTwoFunction
] 

-- Properties --
name : Serial
sublayers : [LayerNorm, Serial[
  Relu
], TimesTwoFunction]
expected inputs : 1
promised outputs : 1
weights & biases: ((DeviceArray([1, 1, 1, 1, 1], dtype=int32), DeviceArray([0, 0, 0, 0, 0], dtype=int32)), ((), (), ()), ()) 

-- Inputs --
x : [-2 -1  0  1  2] 

[0.        0.        0.        0.7071066 1.4142132]
-- Outputs --
y : [0.        0.        0.        1.4142132 2.8284264]


  lax._check_user_dtype_supported(dtype, "ones")
  lax._check_user_dtype_supported(dtype, "zeros")


In [36]:

# Numpy vs fastmath.numpy have different data types
# Regular ol' numpy
x_numpy = np.array([1, 2, 3])
print("good old numpy : ", type(x_numpy), "\n")

# Fastmath and jax numpy
x_jax = fastmath.numpy.array([1, 2, 3])
print("jax trax numpy : ", type(x_jax))

good old numpy :  <class 'numpy.ndarray'> 

jax trax numpy :  <class 'jax.interpreters.xla._DeviceArray'>
