# JAX for neural networks

Jax, like torch, has a lot of features that are useful for neural networks, such as auto differentiation. Below, we look at the univariate function and evaluate its gradient at a given point.

In [9]:
from jax import grad
import jax.numpy as jnp

def f(x):
    return x**2

grad_f = grad(f)
grad_f(1.0)

Array(2., dtype=float32, weak_type=True)

A free extension of this is that we can find the linear approximation of the function at a given point. We can also apply the gradient of f to another gradient to get the second derivative. We can do this for any number of derivatives, allowing us to find higher order derivatives and compute Taylor expansions up to any finite order.

Let's do the approximation of the square root function about x=43, granted we know that the square root of 49 is 7.

In [15]:
# First order approximation

def sqrt(x):
    return x**(1/2)

d1 = grad(sqrt)

# f(43 = 49 - 6) = sqrt(49) + d1(49)(-6)
true_val = float(jnp.sqrt(43))
approx_val = 7 + d1(49.0)*(-6)
print(f"True value: {true_val}, Approximate value: {approx_val}, Error: {true_val - approx_val}")


True value: 6.557438373565674, Approximate value: 6.5714287757873535, Error: -0.013990402221679688


Now let's see how the approximation changes with the order of the derivative. 

In [25]:
highest_order = 5
d = sqrt
approx_total = 7
factorial = 1
for i in range(1, highest_order+1):
    d = grad(d)
    approx_total += d(49.0)*((-6)**i) / factorial
    factorial *= i

    print(f"True value: {float(jnp.sqrt(43))}, Approximate value: {approx_total}, Error: {true_val - approx_total}")   

True value: 6.557438373565674, Approximate value: 6.5714287757873535, Error: -0.013990402221679688
True value: 6.557438373565674, Approximate value: 6.54518985748291, Error: 0.012248516082763672
True value: 6.557438373565674, Approximate value: 6.542779922485352, Error: 0.014658451080322266
True value: 6.557438373565674, Approximate value: 6.542533874511719, Error: 0.014904499053955078
True value: 6.557438373565674, Approximate value: 6.542507648468018, Error: 0.01493072509765625


Additionally, we can perform linear regression with JAX.

In [94]:
from jax import random

key = random.PRNGKey(42) 
X1 = random.normal(key, (1000,1))
X2 = 3*random.normal(key, (1000,1))
y = 3*X1 + 2*X2 + 1 + 0.001*random.normal(key, (1000,1))

X0 = jnp.ones((1000,1))
X = jnp.concatenate([X0, X1, X2], axis=-1)

key = random.PRNGKey(43)
b = random.normal(key, (3,))

def loss(params_dict):
    b0, b1, b2 = params_dict['b0'], params_dict['b1'], params_dict['b2']
    yhat = b0 + b1 * X1 + b2 * X2
    return jnp.mean((y - yhat)**2)

counter = 0
steps = 0
while counter < 10:
    loss_val = loss({'b0': b[0], 'b1': b[1], 'b2': b[2]})
    dloss = grad(loss)({'b0': b[0], 'b1': b[1], 'b2': b[2]})
    b_old = jnp.array([b[0], b[1], b[2]])
    b0_new = b[0] - 0.005*dloss['b0']   
    b1_new = b[1] - 0.005*dloss['b1']
    b2_new = b[2] - 0.005*dloss['b2']
    b = jnp.array([b0_new, b1_new, b2_new])
    if jnp.linalg.norm(b - b_old) < 1e-27:
        counter += 1
    elif counter and jnp.linalg.norm(b - b_old) > 1e-5:
        counter = 0
    steps += 1
print(f'Took {steps} steps to converge with loss {loss_val}', b)

print(jnp.linalg.inv(X.T@X)@X.T@y)


Took 1301 steps to converge with loss 1.1249474442254215e-11 [ 0.999997  -0.4451895  3.1487293]
[[0.9018224]
 [3.5300593]
 [1.7035173]]
