# JAX for neural networks

Jax, like torch, has a lot of features that are useful for neural networks, such as auto differentiation. Below, we look at the univariate function and evaluate its gradient at a given point.

In [1]:
from jax import grad
from jax import random as jnr
import jax.numpy as jnp
from jax.scipy.stats import norm

def f(x):
    return x**2

grad_f = grad(f)
grad_f(1.0)

Array(2., dtype=float32, weak_type=True)

A free extension of this is that we can find the linear approximation of the function at a given point. We can also apply the gradient of f to another gradient to get the second derivative. We can do this for any number of derivatives, allowing us to find higher order derivatives and compute Taylor expansions up to any finite order.

Let's do the approximation of the square root function about x=43, granted we know that the square root of 49 is 7.

In [2]:
# First order approximation

def sqrt(x):
    return x**(1/2)

d1 = grad(sqrt)

# f(43 = 49 - 6) = sqrt(49) + d1(49)(-6)
true_val = float(jnp.sqrt(43))
approx_val = 7 + d1(49.0)*(-6)
print(f"True value: {true_val}, Approximate value: {approx_val}, Error: {true_val - approx_val}")


True value: 6.557438373565674, Approximate value: 6.5714287757873535, Error: -0.013990402221679688


Now let's see how the approximation changes with the order of the derivative. 

In [3]:
highest_order = 5
d = sqrt
approx_total = 7
factorial = 1
for i in range(1, highest_order+1):
    d = grad(d)
    approx_total += d(49.0)*((-6)**i) / factorial
    factorial *= i

    print(f"True value: {float(jnp.sqrt(43))}, Approximate value: {approx_total}, Error: {true_val - approx_total}")   

True value: 6.557438373565674, Approximate value: 6.5714287757873535, Error: -0.013990402221679688
True value: 6.557438373565674, Approximate value: 6.54518985748291, Error: 0.012248516082763672


True value: 6.557438373565674, Approximate value: 6.542779922485352, Error: 0.014658451080322266
True value: 6.557438373565674, Approximate value: 6.542533874511719, Error: 0.014904499053955078
True value: 6.557438373565674, Approximate value: 6.542507648468018, Error: 0.01493072509765625


Additionally, we can perform linear regression with JAX using the one layer neural network. This is full batch gradient descent.

In [4]:
key = jnr.PRNGKey(2)
X1_key, X2_key = jnr.split(key)
X1 = jnr.normal(X1_key, (1000, 1))
X2 = jnr.normal(X2_key, (1000, 1))
y = -11*X1 + 12*X2 + 3 + 3*jnr.normal(key, (1000, 1))

X = jnp.hstack([X1, X2])

def linear_model(in_size=2, out_size=1):
    W_key, b_key = jnr.split(key)
    w = jnr.normal(W_key, (in_size, out_size))
    b = jnr.normal(b_key, (out_size,))
    return w, b

def predict(w, b, x): 
    return jnp.dot(x, w) + b

def loss(w, b, x, y):
    return jnp.mean((predict(w, b, x) - y)**2)

W, b = linear_model(2, 1)

learning_rate = 0.01
epoch = 0   
while True:
    old_W, old_b = W, b
    loss_val = loss(W, b, X, y)
    grad_W, grad_b = grad(loss, (0, 1))(W, b, X, y)
    W -= learning_rate * grad_W
    b -= learning_rate * grad_b

    if epoch % 100 == 0:  # Print loss every 100 epochs
        print(f"Epoch {epoch}, Loss: {loss_val}. Parameters: {W.flatten()}, {b[0]}")
    
    if jnp.linalg.norm(W-old_W) < 1e-5 and jnp.linalg.norm(b-old_b) < 1e-5:
        break
    
    epoch += 1
    

# Final weights and bias
print(f"Final weights: {W.flatten()}, Final bias: {b[0]}")
print(f"Loss (variance estimate): {jnp.sqrt(loss_val)}")
sigma2 = loss_val

# With estimates, we can computer the p-values of the coefficients

X_ = jnp.hstack([jnp.ones((1000, 1)), X]) # to do statistical inference on the bias
W_ = jnp.hstack([b[0], W[0][0], W[1][0]])
S = (1/sigma2)*jnp.dot(X_.T, X_)
S = jnp.linalg.eigh(S)
eig_vals, eig_vecs = S[0], S[1]
W_std = sigma2 * eig_vecs @ jnp.diag(1/jnp.sqrt(eig_vals)) @ eig_vecs.T

W_pvals = 2 * (1-norm.cdf(abs(W_.T) / jnp.diag(W_std)))
print(f"W_pvals: {W_pvals}")


Epoch 0, Loss: 283.57696533203125. Parameters: [0.53832495 1.2374916 ], -0.11600104719400406
Epoch 100, Loss: 13.434798240661621. Parameters: [-9.48734 10.64589], 2.352383852005005
Epoch 200, Loss: 9.301318168640137. Parameters: [-10.71081   11.789786], 2.786647319793701
Epoch 300, Loss: 9.236554145812988. Parameters: [-10.860037  11.930452], 2.857306957244873
Epoch 400, Loss: 9.235513687133789. Parameters: [-10.87819   11.947992], 2.8682963848114014
Final weights: [-10.880371  11.950205], Final bias: 2.869945526123047
Loss (variance estimate): 3.0389962196350098
W_pvals: [0.00123334 0.         0.        ]


We can see that everything is significant at the 5% level. Our variance estimate is exactly our final loss value.

The parameters are very close to the true values of -11 and 12, and the bias is very close to 3.

We can also do minibatch gradient descent.

In [7]:
learning_rate = 0.001
batch_size = 16
num_samples = X.shape[0]
num_batches = 100  
epoch = 0

while True:
    key = jnr.PRNGKey(epoch)
    indices = jnr.randint(key=key, shape=(num_batches, batch_size), minval=0, maxval=num_samples)
    minibatch_loss = 0
    for i in range(num_batches):
        X_batch = X[indices[i]]
        y_batch = y[indices[i]]
        
        old_W, old_b = W, b
        loss_val = loss(W, b, X_batch, y_batch)
        minibatch_loss += loss_val
        grad_W, grad_b = grad(loss, (0, 1))(W, b, X_batch, y_batch)
        W -= learning_rate * grad_W / batch_size
        b -= learning_rate * grad_b / batch_size

    if epoch % 10 == 0:  # Print loss every 10 epochs
        print(f"Epoch {epoch}, Loss: {minibatch_loss/num_batches}. Parameters: {W.flatten()}, {b[0]}")
    
    if jnp.linalg.norm(W-old_W) < 1e-5 and jnp.linalg.norm(b-old_b) < 1e-5:
        break
    
    epoch += 1
    

# Final weights and bias
print(f"Final weights: {W.flatten()}, Final bias: {b[0]}")
print(f"Loss (variance estimate): {jnp.sqrt(loss_val)}")
sigma2 = loss_val

# With estimates, we can computer the p-values of the coefficients

X_ = jnp.hstack([jnp.ones((1000, 1)), X]) # to do statistical inference on the bias
W_ = jnp.hstack([b[0], W[0][0], W[1][0]])
S = (1/sigma2)*jnp.dot(X_.T, X_)
S = jnp.linalg.eigh(S)
eig_vals, eig_vecs = S[0], S[1]
W_std = sigma2 * eig_vecs @ jnp.diag(1/jnp.sqrt(eig_vals)) @ eig_vecs.T

W_pvals = 2 * (1-norm.cdf(abs(W_.T) / jnp.diag(W_std)))
print(f"W_pvals: {W_pvals}")

Epoch 0, Loss: 9.73923110961914. Parameters: [-10.867279  11.956   ], 2.8681633472442627
Epoch 10, Loss: 9.454119682312012. Parameters: [-10.868103  11.957705], 2.868723154067993
Epoch 20, Loss: 8.845235824584961. Parameters: [-10.866956  11.957055], 2.8659415245056152
Epoch 30, Loss: 8.964531898498535. Parameters: [-10.86666   11.955393], 2.8698999881744385
Epoch 40, Loss: 9.640743255615234. Parameters: [-10.867653  11.96116 ], 2.8705267906188965
Epoch 50, Loss: 9.333024978637695. Parameters: [-10.869248  11.963926], 2.870386838912964
Epoch 60, Loss: 9.223723411560059. Parameters: [-10.871053  11.964029], 2.8773386478424072
Epoch 70, Loss: 8.821696281433105. Parameters: [-10.873198  11.964625], 2.8794002532958984
Epoch 80, Loss: 9.098153114318848. Parameters: [-10.8722725  11.961653 ], 2.8798582553863525
Epoch 90, Loss: 9.080175399780273. Parameters: [-10.872027  11.959692], 2.8780481815338135
Epoch 100, Loss: 8.819097518920898. Parameters: [-10.870214  11.952191], 2.8761041164398193


Minibatch gradient descent in this case is slower, but the estimated parameters are still very close to the true values. Furthermore, the convergence is less smooth, which is expected from the stochastic nature of the algorithm. 