# Automatic differentiation and jax

In [232]:
import random

import jax
from jax import numpy as np
from jax import scipy as sp

In [233]:
def seed():
    return jax.random.PRNGKey(random.randint(0, 228))

In [272]:
def compare_differentiation_methods(f, gradf, shape, repeat_times=5):
    for i in range(repeat_times):
        x = jax.random.uniform(seed(), shape=shape)
        isclose = np.isclose(gradf(x), jax.grad(f)(x)).flatten()

        print(f"Iteration {i}: ", end='')
        if np.all(isclose == True):
            print("all components are close")
        else:
            print("some components differ")
            print(f"Max componentwise relative error is: {((gradf(x) - jax.grad(f)(x)) / gradf(x)).max()}")

## Task 1

In [16]:
def f(x, y):
    return np.exp(-(np.sin(x) - np.cos(y))**2)

In [17]:
graph = jax.xla_computation(f)(np.ones(1337), np.ones(1337))
with open("graph.dot", "w") as file:
    file.write(graph.as_hlo_dot_graph())

In [21]:
!dot graph.dot -Tpng > graph.png

1336.80s - pydevd: Sending message related to process being replaced timed-out after 5 seconds


![](graph.png)

## Task 2

$$ f(A) = \operatorname{tr}(e^A),\, A \in \mathbb{R}^{n \times n} $$

$$ \nabla f(A) = \exp(A^{\top}) $$ proved in task 4 in Matrix calculus hw

In [221]:
def f(A):
    return np.trace(sp.linalg.expm(A))

In [222]:
def gradf(A):
    return sp.linalg.expm(A.T)

In [223]:
compare_differentiation_methods(f, gradf, (20, 20))

Iteration 0: some components differ
Max componentwise relative error is: -6.508145452244207e-05
Iteration 1: some components differ
Max componentwise relative error is: -9.928335202857852e-06
Iteration 2: some components differ
Max componentwise relative error is: -4.052296208101325e-05
Iteration 3: some components differ
Max componentwise relative error is: -3.57690078089945e-05
Iteration 4: some components differ
Max componentwise relative error is: -1.8703463865676895e-05


As we can see, different approaches give us different results with maximum relative error about $$10^{-5}$$ or less

## Task 3

$$ f(x) = \frac{1}{2} \| x \|^2,\, x \in \mathbb{R}^n $$

$$ f(x) = \frac{1}{2} \langle x,\, x \rangle $$

$$ df(x) = \langle x,\, dx \rangle $$

$$ \nabla f(x) = x $$

In [244]:
def L(x_0):
    def wrapper(alpha):
        nonlocal x_0
        x = x_0
        for i in range(10):
            x = x - alpha[i] * x
        return np.linalg.norm(x) / 2
    
    return wrapper

In [271]:
x_0 = float(jax.random.uniform(seed(), shape=(1,))[0])
alpha_1 = jax.random.uniform(seed(), maxval=0.1, shape=(10,))

In [257]:
L(x_0)(alpha_1)

Array(0.04536309, dtype=float32)

... :(
    
Gradient descent doesn't work

## Task 4

$$ f(x) = -\log \det X,\, X \in \mathbb{R}^{n \times n} $$

$$ df(x) = -\frac{\det X \cdot \langle X^{-\top},\, dX \rangle}{\det X} = \langle X^{-\top},\, dX \rangle$$

$$ \nabla f(x) = -X^{-\top} $$

In [224]:
def f(X):
    return -np.log(np.linalg.det(X))

In [225]:
def gradf(X):
    return -np.linalg.inv(X).T

In [226]:
compare_differentiation_methods(f, gradf, (20, 20))

Iteration 0: some components differ
Max componentwise relative error is: 0.00013778700667899102
Iteration 1: some components differ
Max componentwise relative error is: 9.875793330138549e-05
Iteration 2: some components differ
Max componentwise relative error is: 5.780803348898189e-06
Iteration 3: some components differ
Max componentwise relative error is: 1.0148980436497368e-05
Iteration 4: some components differ
Max componentwise relative error is: 0.00010453537106513977


As we can see, different approaches give us different results with maximum relative error about $$10^{-4}$$ or less

## Task 5

$$ f(x) = x^{\top} x x^{\top} x,\, x \in \mathbb{R}^n $$

$$ f(x) = \langle x,\, x \rangle^2 $$

$$ df(x) = 4 \cdot \langle x,\, x \rangle \cdot \langle x,\, dx \rangle = \big\langle 4 \cdot \langle x,\, x \rangle \cdot x,\, dx \big\rangle $$

$$ \nabla f(x) = 4 \cdot \langle x,\, x \rangle \cdot x $$

In [227]:
def f(x):
    return (x.T @ x) * (x.T @ x)

In [228]:
def gradf(x):
    return 4 * (x.T @ x) * x

In [229]:
compare_differentiation_methods(f, gradf, (100,))

Iteration 0: all components are close
Iteration 1: all components are close
Iteration 2: all components are close
Iteration 3: all components are close
Iteration 4: all components are close
