In [1]:
import jax.numpy as jnp

In [9]:
from jax import random

key = random.key(42)

x = random.normal(key, (3, 3))
y = random.normal(key, (3, 3))

print(x == y)

[[ True  True  True]
 [ True  True  True]
 [ True  True  True]]


In [13]:
new_key, subkey = random.split(key)
x = random.normal(subkey, (3, 3))
y = random.normal(subkey, (3, 3))

[[False False False]
 [False False False]
 [False False False]]


In [2]:
x = jnp.arange(10)
x[0] = 10

TypeError: JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html

In [3]:
y = x.at[0].set(10)

In [4]:
import jax

In [5]:
x.devices()

{CpuDevice(id=0)}

In [6]:
from jax import grad

In [7]:
def logit(x):
    logits = 1/(1+jnp.exp(-x))
    return jnp.sum(logits)

In [8]:
x_ = jnp.arange(3, dtype=jnp.float32)

In [9]:
y = logit(x_)
dy = grad(logit)
dy(x_)

Array([0.25      , 0.19661197, 0.10499357], dtype=float32)

In [10]:
from jax import jacobian

jacobian(logit)(x_)


Array([0.25      , 0.19661197, 0.10499357], dtype=float32)

In [11]:
from jax import random

seed = 1

key = random.key(seed)

layers = 3
input_dim = 15
hidden_dims = [15, 15]
output_dim = 1

dims = [input_dim] + hidden_dims + [output_dim]

# initialize Ws
Ws = []
Bs = []

for layer in range(layers):

    b = random.normal(key, shape=(dims[layer+1]))*1/10

    W = random.normal(key, shape=(dims[layer],dims[layer+1]))*1/10

    Bs.append(b)
    Ws.append(W)

X = jnp.arange(input_dim)
Y = jnp.ones(shape=(1,))

def forward(X, weights, biases):

    n_layers = len(weights)

    hidden_layers = n_layers-1

    s = X.copy()

    # hidden layers
    for i_layer in range(hidden_layers):

        # unpack weights and biases
        b = biases[i_layer]
        W = weights[i_layer]

        s = jax.nn.relu(b + W[i_layer].T @ X)

        print(s)

    # final layer
    
    y = jax.nn.sigmoid(b[-1] + W[-1].T @ s) # shapes (inputdim x layerdim).T x (inputdim,) = (layerdim,)

    return y

def compute_loss(X, Y, weights, biases):

    return jnp.sum((forward(X, weights, biases)-Y)**2)

compute_loss(X, Y, Ws, Bs)

[4.9282985 4.952213  4.9301443 4.928239  5.070409  4.958572  5.1578984
 5.0440097 4.9147086 4.9795766 4.8729496 4.9191866 5.0323005 5.022354
 5.032635 ]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]


Array(0.22828464, dtype=float32)

In [12]:
jacobian(compute_loss, argnums=2)(X, Y, Ws, Bs)

LinearizeTracer<float32[15]>
LinearizeTracer<float32[15]>


[Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0.,

In [35]:
Bs

[Array([-0.15443718,  0.08470728, -0.13598049, -0.15503626,  1.2666674 ,
         0.14829758,  2.1415603 ,  1.0026742 , -0.29033586,  0.3583448 ,
        -0.70792735, -0.24555527,  0.8855825 ,  0.7861191 ,  0.88892716],      dtype=float32),
 Array([-0.15443718,  0.08470728, -0.13598049, -0.15503626,  1.2666674 ,
         0.14829758,  2.1415603 ,  1.0026742 , -0.29033586,  0.3583448 ,
        -0.70792735, -0.24555527,  0.8855825 ,  0.7861191 ,  0.88892716],      dtype=float32),
 Array([-0.15443718,  0.08470728, -0.13598049, -0.15503626,  1.2666674 ,
         0.14829758,  2.1415603 ,  1.0026742 , -0.29033586,  0.3583448 ,
        -0.70792735, -0.24555527,  0.8855825 ,  0.7861191 ,  0.88892716],      dtype=float32)]