In [1]:
import jax
import jax.numpy as np
import jax.random as random
import jax.nn as nn

print(f'jax version: {jax.__version__}')


# for a dense neural network layer
def random_layer_params(nin, nout, key, scale=1e-2):
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (nout, nin)), scale * random.normal(b_key, (nout, 1))


# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [random_layer_params(nin, nout, k) for nin, nout, k in zip(sizes[:-1], sizes[1:], keys)]


def forward(params, x):
    act = x 

    for (w, b) in params[:-1]:
        out = np.dot(w, act) + b
        act = nn.relu(out)

    final_w, final_b = params[-1]

    out = np.dot(final_w, act) + final_b

    return out


def test():
    
    D = 2
    x = np.arange(1, D+1)
    x = np.expand_dims(x, axis=1)
    x = np.float32(x)
    
    hidden_sizes = [6]
    sizes = [D] + hidden_sizes + [D]
    
    params = init_network_params(sizes, random.PRNGKey(0))
    
    def get_forward_dim(*args):
        return forward(*args)[0, 2]
    
    forward_output = forward(params, x)
    
    print(forward_output.shape)
    
    # This should raise an error 
    # because we're accessing a dimension that is out of bound
    # but no error is raised
    print(forward_output)
    print(forward_output[0][10])
    
    # This will raise an error
    print(x[0][10])
        
test()

jax version: 0.1.46
(2, 1)
[[0.00296329]
 [0.00434136]]
0.0029632905


IndexError: index 10 is out of bounds for axis 0 with size 1