In [None]:
%run ../../DataFiles_and_Notebooks/talktools.py

<img src="https://jax.readthedocs.io/en/latest/_static/jax_logo_250px.png">

Jax is accelerated numpy (and more): https://jax.readthedocs.io/en/latest/jax-101/01-jax-basics.html


- JAX provides a NumPy-inspired interface for convenience.

- Through duck-typing, JAX arrays can often be used as drop-in replacements of NumPy arrays.

- Unlike NumPy arrays, JAX arrays are always immutable.

Python code can be converted to highly efficient compiled code in real-time, using XLA (Accelerated Linear Algebra). Note not all JAX code can be JIT compiled, as it requires array shapes to be static & known at compile time.

In [None]:
import jax
import jax.numpy as jnp

In [None]:
jnp.int32, jnp.float64

In [None]:
x = jnp.arange(10)
print(x)

In [None]:
type(x)

In [None]:
n_devices = jax.local_device_count() 
n_devices

JAX uses the XLA compiler under the hood, and enables you to just-in-time (jit) compile your code to make it faster and more efficient. This is the purpose of the @jit annotation. 

In [None]:
from jax import jit

@jit
def bar(a, b, c):
    return a + b  * c

@jit
def foo(a, b, c):
    return a + b  * c

print(foo)
print(foo(1, 2, 3))

In [None]:
a = jnp.sqrt((1+2j).real**2 + (1+2j).imag**2)

foo(2,a,0j)

In [None]:
@jit
def square(x):
    return x ** 2

@jit
def hypot(x, y):
    return jnp.sqrt(square(x) + square(y))

In [None]:
hypot(4,5)

In [None]:
@jit
def f2(x, y):
    return x + y

print(f2(1, 2))
print(f2("a", "b"))

In [None]:
def na_var(data):
    sample_mean = 0.0
    
    # 1st loop
    for x in data:
        sample_mean =  sample_mean + x
    
    sample_mean = sample_mean / len(data)
    
    # second loop
    sum_of_squared_errors = 0.0
    for x in data:
        sum_of_squared_errors += (x - sample_mean) ** 2
    
    ret =  sum_of_squared_errors / (len(data) - 1.0)
    return ret

In [None]:
import numpy as np
%timeit na_var(np.arange(1000))

In [None]:
jax_na_var = jit(na_var)

In [None]:
%timeit jax_na_var(jnp.arange(1000))

In [None]:
from jax import numpy as jnp, random

def selu(x, alpha=1.67, lmbda=1.05):
    return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

v = random.normal(random.PRNGKey(42), (1000000,))
%timeit selu(v).block_until_ready()

In [None]:
import jax

selu_jit = jax.jit(selu)
%timeit selu_jit(v).block_until_ready()

In [None]:
long_vector = jnp.arange(int(1e7))

%timeit jnp.dot(long_vector, long_vector).block_until_ready()

In [None]:
def filter2d(image, filt):
    M, N = image.shape
    Mf, Nf = filt.shape
    Mf2 = Mf // 2
    Nf2 = Nf // 2
    result = jnp.zeros_like(image)
    for i in range(Mf2, M - Mf2):
        for j in range(Nf2, N - Nf2):
            num = 0.0
            for ii in range(Mf):
                for jj in range(Nf):
                    num += (filt[Mf-1-ii, Nf-1-jj] * image[i-Mf2+ii, j-Nf2+jj])
            # result[i, j] = num
            result = result.at[i, j].set(num)
    return result

### JAX Gotchas

Looping and flow control (e.g. if..then statements) are non-trivial in JAX. You cannot simply `@jit` any python function. See:

https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html


## Numba  ##

LLVM compiler for python (brought to you by Continuum). Should be on everyone's conda installation already. But you might want to update (it's been changed recently):

   `conda update numba`

Docs: http://numba.pydata.org/numba-doc/0.54.1/index.html

In [None]:
import numba
numba.__version__

In [None]:
import numpy

def filter2d(image, filt):
    M, N = image.shape
    Mf, Nf = filt.shape
    Mf2 = Mf // 2
    Nf2 = Nf // 2
    result = numpy.zeros_like(image)
    for i in range(Mf2, M - Mf2):
        for j in range(Nf2, N - Nf2):
            num = 0.0
            for ii in range(Mf):
                for jj in range(Nf):
                    num += (filt[Mf-1-ii, Nf-1-jj] * image[i-Mf2+ii, j-Nf2+jj])
            result[i, j] = num
    return result

In [None]:
from numba import double, jit

numbafilter_2d = jit(double[:,:](double[:,:], double[:,:]))(filter2d)

# Now numbafilter_2d runs at speeds as if you had first translated
# it to C, compiled the code and wrapped it with Python
image = numpy.random.random((100, 100))
filt = numpy.random.random((10, 10))
res = numbafilter_2d(image, filt)

In [None]:
%timeit numbafilter_2d(image, filt)

In [None]:
%timeit filter2d(image, filt)

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
plt.imshow(res)

In [None]:
plt.imshow(image)

Numba also works with GPUs (and JAX with with GPU and TPUs).

```python
try:
    import jax.tools.colab_tpu
    jax.tools.colab_tpu.setup_tpu()
except:
    pass
```

# Breakout

a. Write a nested for loop that does dot produce multiplication (on `A` and `B`, 2D arrays) and prints the results. 

In [None]:
import numpy as np

def my_dot(A, B):
    # FIXME
    
    return outarray

b. Get the runtime speed for A.shape = (30,50) and B.shape = (50,15)

In [None]:
A = np.ones((30,50))
B = np.ones((50,15))

rez = my_dot(A, B)

# make sure that your code gives the right answers
np.alltrue(A.dot(B) == rez)

In [None]:
%timeit my_dot(A, B)

c. Try using numba to make it faster.

In [None]:
from numba import jit

# FIXME

d. How does the numba speed compare to the native matrix multiplication in numpy (`numpy.dot`)?

In [None]:
%timeit A.dot(B)