In [65]:
import trax
from trax import layers as tl
from trax.shapes import signature
import jax.numpy as jnp
import numpy as np

In [66]:
X = np.random.normal(loc=0, scale=0.1, size=(5,3))
signature(X)

ShapeDtype{shape:(5, 3), dtype:float64}

In [79]:
def print_info(model, yhat):
    print(f"input: {model.n_in}")
    print(f"output: {model.n_out}")
    print(f"Signature: {signature(yhat)}")

In [80]:
relu = tl.Relu()
yhat = relu(X)
X, yhat

(array([[ 0.01561314, -0.01694871, -0.039871  ],
        [ 0.07181149, -0.04606745, -0.1203104 ],
        [-0.0495206 , -0.03870095,  0.0915289 ],
        [ 0.06205908, -0.03711139, -0.02789194],
        [ 0.1220953 , -0.08110519, -0.12254368]]),
 DeviceArray([[0.01561314, 0.        , 0.        ],
              [0.07181149, 0.        , 0.        ],
              [0.        , 0.        , 0.0915289 ],
              [0.06205908, 0.        , 0.        ],
              [0.12209529, 0.        , 0.        ]], dtype=float32))

In [81]:
print_info(relu, yhat)

input: 1
output: 1
Signature: ShapeDtype{shape:(5, 3), dtype:float32}


In [82]:
concat = tl.Concatenate()
yhat = concat([X, X])
print_info(concat, yhat)

input: 2
output: 1
Signature: ShapeDtype{shape:(5, 6), dtype:float32}


In [83]:
concat3 = tl.Concatenate(n_items=3)
yhat = concat3([X, X, X])
print_info(concat3, yhat)

input: 3
output: 1
Signature: ShapeDtype{shape:(5, 9), dtype:float32}


In [84]:
from trax.layers import combinators as cb

In [86]:
model1 = cb.Serial(
    tl.Dense(128),
    tl.Relu(),
)
model1.init_weights_and_state(signature(X))
yhat = model1(X)
print_info(model1, yhat)

input: 1
output: 1
Signature: ShapeDtype{shape:(5, 128), dtype:float32}


In [87]:
model2 = cb.Serial(
    tl.Dense(64),
    tl.Relu(),
    tl.Dense(32),
    tl.Relu(),
) 
model2.init_weights_and_state(signature(X))
yhat = model2(X)
print_info(model2, yhat)

input: 1
output: 1
Signature: ShapeDtype{shape:(5, 32), dtype:float32}


In [88]:
model = cb.Serial(
    cb.Branch(model1, model2)
)
model.init_weights_and_state(signature(X))
yhat = model(X)
print_info(model, yhat)

input: 1
output: 2
Signature: (ShapeDtype{shape:(5, 128), dtype:float32}, ShapeDtype{shape:(5, 32), dtype:float32})


In [89]:
model = cb.Serial(
    cb.Branch(model1, model2),
    cb.Concatenate()
)
model.init_weights_and_state(signature(X))
yhat = model(X)
print_info(model, yhat)

input: 1
output: 1
Signature: ShapeDtype{shape:(5, 160), dtype:float32}


In [90]:
X = np.random.rand(32, 10, 128)

dnn = cb.Serial(
    tl.Dense(128),
    tl.Relu(),
    tl.Dense(128),
    tl.Relu(),
) 


residual = cb.Serial(
    cb.Branch([], dnn),
    cb.Add()
)

residual.init_weights_and_state(signature(X))
yhat = residual(X)
print_info(residual, yhat)

input: 1
output: 1
Signature: ShapeDtype{shape:(32, 10, 128), dtype:float32}
