# Installing and using JAX

JAX is an auto-differentiation library for native Python and Numpy code which does gradient-based optimization. Auto-differentiation forms the backbone of deep learning libraries like PyTorch.

Activate your standard environment from Assignment-3. Then   
```
pip install --upgrade pip     
pip install --upgrade jax jaxlib 
```

[See this](https://github.com/google/jax#installation) for more information. (CPU version should be enough for this project.)

In [2]:
pip install --upgrade pip

Collecting pip
  Downloading pip-23.2.1-py3-none-any.whl (2.1 MB)
[K     |████████████████████████████████| 2.1 MB 904 kB/s eta 0:00:01
[?25hInstalling collected packages: pip
Successfully installed pip-23.2.1
Note: you may need to restart the kernel to use updated packages.


In [3]:
pip install --upgrade jax jaxlib 

Defaulting to user installation because normal site-packages is not writeable
[0mCollecting jax
  Downloading jax-0.4.13.tar.gz (1.3 MB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m951.3 kB/s[0m eta [36m0:00:00[0meta [36m0:00:01[0m0:01[0m:01[0m
[?25h  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[?25hCollecting jaxlib
  Obtaining dependency information for jaxlib from https://files.pythonhosted.org/packages/a8/f2/44921ec03f7e051ccd8831efa40ff3efb04f99b71411ab48ca48028a4e09/jaxlib-0.4.13-cp38-cp38-manylinux2014_x86_64.whl.metadata
  Downloading jaxlib-0.4.13-cp38-cp38-manylinux2014_x86_64.whl.metadata (2.1 kB)
Collecting ml-dtypes>=0.1.0 (from jax)
  Obtaining dependency information for ml-dtypes>=0.1.0 from https://files.pythonhosted.org/packages/e7/db/16992470d8adc93e5230f01b0be8fe32a4eb25cd1c306a2efd1349d3

In [1]:
import jax.numpy as jnp
"""
Use "jnp" instead of using "np", our favourite numpy library. 
All functions work as it is (at least that are required for this project).
Be careful though:

JAX works on python functions that are "functionally pure": 
For the sake of our project, that just means using array datatype everywhere 
(or 'jnp.array()' in particular) instead of using other datatype, say lists for
storing arrays or matrices. Whenever you face some datatype issue with jax, 
first try to convert it to jax numpy array using `jnp.array()`.

Tip: jnp's errors don't seem very readable as compared to np.
So use "np" first for most of the code and the moment the necessity for "jnp" starts, 
replace all np's with jnp's. Directly replacing should work fine. This is only a tip for easier 
debugging.
"""
from jax import jacfwd

In [2]:
# Define some simple function.
def sigmoid(x):
    return 0.5 * (jnp.tanh(x / 2) + 1)

# Note that here, I want a derivative of a "vector" output function (inputs*a + b is a vector) wrt a input 
# "vector" a at a0: Derivative of vector wrt another vector is a matrix: The Jacobian
def simpleJ(a, b, inputs): #inputs is a matrix, a & b are vectors
    return sigmoid(jnp.dot(inputs, a) + b)

inputs = jnp.array([[0.52, 1.12,  0.77],
                   [0.88, -1.08, 0.15],
                   [0.52, 0.06, -1.30],
                   [0.74, -2.49, 1.39]])

b = jnp.array([0.2, 0.1, 0.3, 0.2])
a0 = jnp.array([0.1,0.7,0.7])

# Isolate the function: variables to be differentiated from the constant parameters
f = lambda a: simpleJ(a, b, inputs) # Now f is just a function of variable to be differentiated

J = jacfwd(f)
# Till now I have only calculated the derivative, it still needs to be evaluated at a0.
J(a0)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Array([[ 0.07388726,  0.1591418 ,  0.10940997],
       [ 0.20861849, -0.2560318 ,  0.03555997],
       [ 0.12171669,  0.01404423, -0.30429173],
       [ 0.17407255, -0.58573055,  0.3269741 ]], dtype=float32)