# Automatic differentiation and jax

In [27]:
import jax
from jax import numpy as np
from jax import scipy as sp

In [28]:
seed = jax.random.PRNGKey(228)

In [125]:
def compare_differentiation_methods(f, gradf, shape, repeat_times=5):
    for i in range(repeat_times):
        x = jax.random.uniform(seed, shape=shape)
        isclose = np.isclose(gradf(x), jax.grad(f)(x), atol=1e-6, rtol=1e-4)
        if len(shape) > 1:
            isclose = isclose.flatten()

        print(f"Iteration {i}: ", end='')
        if np.all(isclose == True):
            print("all components are close")
        else:
            print("some components differ")
            print(f"Relative error is: {(gradf(x) - jax.grad(f)(x)).norm() / gradf(x).norm()}")

## Task 1

In [16]:
def f(x, y):
    return np.exp(-(np.sin(x) - np.cos(y))**2)

In [17]:
graph = jax.xla_computation(f)(np.ones(1337), np.ones(1337))
with open("graph.dot", "w") as file:
    file.write(graph.as_hlo_dot_graph())

In [21]:
!dot graph.dot -Tpng > graph.png

1336.80s - pydevd: Sending message related to process being replaced timed-out after 5 seconds


![](graph.png)

## Task 2

$ f(A) = \operatorname{tr}(e^A),\, A \in \mathbb{R}^{n \times n} $

$ \nabla f(A) = \exp(A^{\top}) $ from Matrix calculus task 4

In [117]:
def f(A):
    return np.trace(sp.linalg.expm(A))

In [118]:
def gradf(A):
    return sp.linalg.expm(A.T)

In [126]:
compare_differentiation_methods(f, gradf, (20, 20))

Iteration 0: some components differ


AttributeError: 'ArrayImpl' object has no attribute 'norm'

## Task 3

## Task 4

$ f(x) = -\log \det X,\, X \in \mathbb{R}^{n \times n} $

$ df(x) = \frac{\det X \cdot \langle X^{-\top},\, dX \rangle}{\det X} = \langle X^{-\top},\, dX \rangle$

$ \nabla f(x) = X^{-\top} $

In [122]:
def f(X):
    return -np.log(np.linalg.det(X))

In [123]:
def gradf(X):
    return np.linalg.inv(X).T

In [124]:
compare_differentiation_methods(f, gradf, (10, 10))

Iteration 0: some components differ
Values are:
105.81229
Iteration 1: some components differ
Values are:
105.81229
Iteration 2: some components differ
Values are:
105.81229
Iteration 3: some components differ
Values are:
105.81229
Iteration 4: some components differ
Values are:
105.81229


## Task 5

$ f(x) = x^{\top} x x^{\top} x,\, x \in \mathbb{R}^n $

$ f(x) = \langle x,\, x \rangle^2 $

$ df(x) = 4 \cdot \langle x,\, x \rangle \cdot \langle x,\, dx \rangle = \big\langle 4 \cdot \langle x,\, x \rangle \cdot x,\, dx \big\rangle $

$ \nabla f(x) = 4 \cdot \langle x,\, x \rangle \cdot x $

In [109]:
def f(x):
    return (x.T @ x) * (x.T @ x)

In [110]:
def gradf(x):
    return 4 * (x.T @ x) * x

In [111]:
compare_differentiation_methods(f, gradf, (100,))

Iteration 0: all components are close
Iteration 1: all components are close
Iteration 2: all components are close
Iteration 3: all components are close
Iteration 4: all components are close
