In [None]:
import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt

## Automatic differentiation
The key ingredient for optimize neural network is the ability to compute gradient with respect to the parameters of the mode. This is achived with automatic differentiation.

### In Tensorflow

In [None]:
x = tf.Variable(3.0)

def f(x):
    return (x - 1) ** 2

with tf.GradientTape() as tape:
    y = f(x)

tape.gradient(y, x)

### In Jax

In [None]:
f_dx = grad(f)
f_dx(3.)

In [None]:
def f(x):
    return ((x[0] - 0.) + (x[1] - 1.) + (x[2] - 2.) + (x[3] - 3.)) ** 2

In [None]:
import jax.numpy as jnp
import jax

f_dx = grad(f)
f_dx([1., 1., 1., 1.])

### Minimize the function with and without the exact gradient
Start the minimization very far from the minimum

In [None]:
from scipy.optimize import minimize

In [None]:
minimize(f, x0=[0., 2, 2, 2], method='BFGS')

In [None]:
minimize(f, x0=(2, 2, 2, 2), jac=jax.jit(f_dx), method="Newton-CG")

If we now the gradient we can optimize a function in less steps

In [None]:
def f(x):
    if x > 2:
        for i in range(10):
            x += jnp.sqrt(x)
        return x
    else:
        return jnp.cos(x ** 3)
    
f_dx = jax.grad(f)
    
xspace = jnp.linspace(-2, 5, 200)
yi = np.asarray([f_dx(xx) for xx in xspace])
plt.plot(xspace, yi)
plt.show()

In [None]:
f_dx(np.linspace(0., 1., 100))

In [None]:
minimize_scalar?

## Not only ML

### Statistics

Let define the likeilhood of a counting experiments, one category, one signal, background uncertainty. The parameters are the POI (signal strenght) and the NP about the background uncertainty.

In [None]:
import pyhf
pyhf.set_backend('jax')

# make a counting experiment
model = pyhf.simplemodels.uncorrelated_background(signal=[5.], bkg=[10.], bkg_uncertainty=[3.5])
pars = jnp.array(m.config.suggested_init())

# generate an Asimov dataset (e.g. 15 events observed)
data = jnp.array(model.expected_data(model.config.suggested_init()))

bestfit = pyhf.infer.mle.fit(data, m)  # not really needed since it is an Asimov
bestfit

In [None]:
H = -2 * jax.hessian(model.logpdf)(bestfit, data)[0]
np.linalg.inv(H)

We are able to compute the expected errros without any minimization!

Plot the likelihood as a function of the parameters ***the gradient***

In [None]:
grid = x, y = np.meshgrid(np.linspace(0.5, 1.5, 101), np.linspace(0.5, 1.5, 101))

points = np.swapaxes(grid,0,-1).reshape(-1,2)
v = jax.vmap(model.logpdf, in_axes = (0,None))(points,data)
v = np.swapaxes(v.reshape(101,101),0,-1)
plt.contourf(x,y,v, levels = 100)
plt.contour(x,y,v, levels = 20, colors = 'w')


grid = x,y = np.meshgrid(np.linspace(0.5, 1.5, 11), np.linspace(0.5, 1.5, 11))
points = np.swapaxes(grid,0,-1).reshape(-1,2)
values, gradients = jax.vmap(
    jax.value_and_grad(
        lambda p,d: model.logpdf(p,d)[0]
    ), in_axes = (0,None)
)(points,data)

plt.quiver(
    points[:,0],
    points[:,1],
    gradients[:,0],
    gradients[:,1],
    angles = 'xy',
    scale = 75
)
plt.scatter(bestfit[0],bestfit[1], c = 'r')

plt.xlim(0.5,1.5)
plt.ylim(0.5,1.5)
plt.gcf().set_size_inches(5,5)

## Heavy number crunching

In [None]:
ymin, ymax = -1.5, 1.5
xmin, xmax = -1.5, 1.5

nx, ny = 500, 500

X, Y = np.meshgrid(np.linspace(xmin, xmax, nx), np.linspace(ymin, ymax, ny))
Z = X + 1j * Y

# Grid of complex numbers
xs = tf.constant(Z.astype(np.complex64))

# Z-values for determining divergence; initialized at zero
zs = tf.zeros_like(xs)

# N-values store the number of iterations taken before divergence
ns = tf.Variable(tf.zeros_like(xs, tf.float32))

def step(c, z, n):
    z = z * z + c
    
    not_diverged = tf.abs(z) < 4
    n = tf.add(n, tf.cast(not_diverged, tf.float32))
    
    return c, z, n

fig, axs = plt.subplots(1, 2, figsize=(15, 7))
iterations = 1000

# mandelbrot
for _ in range(iterations): 
    xs, zs, ns = step(xs, zs, ns)

def shade_fractal(fractal):
    fractal = np.where(fractal == 0, iterations, fractal)
    fractal = fractal / fractal.max()
    fractal = np.log10(fractal)  
    return fractal

axs[0].pcolormesh(X, Y, shade_fractal(ns), shading='gouraud')    

#julia
zs = tf.zeros_like(xs)
ns = tf.Variable(tf.zeros_like(xs, tf.float32))

for _ in range(iterations): 
    zs, xs, ns = step(-0.7269 + 0.1889j, xs, ns)
    
axs[1].pcolormesh(X, Y, shade_fractal(ns), shading='gouraud')    

for ax in axs:
    ax.set_aspect('equal')