# Gradient computation in jax

Jax comes with the benefit of having general automatic differentiation systems. This notebook should give a quick overview what options you have to caculate a gradient in a pure jax pipeline.

We will have a look on the gradient and the jacobian. There are four different functions supported in jax: grad, hessian, jacrev and jacfwd.

In [None]:
import jax
import jax.numpy as jnp
from jax import grad, hessian, jacrev, jacfwd

## Pipeline with one input and one output

We define us a simple pipeline that takes one element as input and returns per input one output element.

In [None]:
def pipeline(x):
    # Example pipeline function
    y = jnp.sin(x)
    z = jnp.sum(y ** 2)
    return z

We can bild the gradient of the pipeline easily by applying grad(pipeline) to the pipeline. This returns us the derivative of the functions in the pipeline. If you apply grad(grad(pipeline)) you receive the second order derivative

In [None]:
# Define the gradient function
pipeline_grad = grad(pipeline)

# Example input
x = jnp.array([1.0, 2.0, 3.0])

# Compute the gradient
grad_result = pipeline_grad(x)
print("Gradient:", grad_result)

## Pipeline with multiple outputs

In [None]:
def pipeline_multi_output(x):
    # Example pipeline function with multiple outputs
    y = jnp.sin(x)
    z = jnp.cos(x)
    return y, z

# Define the Jacobian function
pipeline_jacobian = jacrev(pipeline_multi_output)

# Compute the Jacobian
jacobian_result = pipeline_jacobian(x)
print("Jacobian:", jacobian_result)

## Pipeline with multiple inputs

In [None]:
def pipeline_mulit_input(x, y):
    # Example pipeline function with multiple inputs
    z = x ** 2 + (y-1) ** 2
    return z

x = 5.0
y = 2.0

If we have a function with multiple inputs, we can build the gradient with respect to one of the inputs.

In [None]:
x_grad = grad(pipeline_mulit_input, argnums=0)(x,y)
y_grad = grad(pipeline_mulit_input, argnums=1)(x,y)

print("x_grad:", x_grad)
print("y_grad:", y_grad)

In [None]:


pipeline_grad_x = jacrev(pipeline_mulit_input, argnums=0)
pipeline_grad_y = jacrev(pipeline_mulit_input, argnums=1)
pipeline_grad_xy = jacrev(pipeline_mulit_input, argnums=(0, 1))

grad_x_result = pipeline_grad_x(x, y)
grad_y_result = pipeline_grad_y(x, y)
grad_xy_result = pipeline_grad_xy(x, y)

print("Gradient with respect to x:", grad_x_result)
print("Gradient with respect to y:", grad_y_result)
print("Gradient with respect to x and y:", grad_xy_result)


In [None]:
x = x - 0.1 * grad_x_result
y = y - 0.1 * grad_y_result

print("Updated x:", x)
print("Updated y:", y)

grad_x_result = pipeline_grad_x(x, y)
grad_y_result = pipeline_grad_y(x, y)

print("Gradient with respect to x:", grad_x_result)
print("Gradient with respect to y:", grad_y_result)