In [1]:
import jax.numpy as np
from jax import grad, jit, vmap
from jax import random

## `np.sort`

In [5]:
x = np.array([[2,3,4],
              [3,2,1]])
print(x.shape)
np.sort(x)

(2, 3)


DeviceArray([[2, 3, 4],
             [1, 2, 3]], dtype=int32)

In [4]:
np.sort(x, axis=0)

DeviceArray([[2, 2, 1],
             [3, 3, 4]], dtype=int32)

In [6]:
np.sort(x, axis=1)

DeviceArray([[2, 3, 4],
             [1, 2, 3]], dtype=int32)

# can `jit` take care of this?

In [2]:
batched_vdot = vmap(np.vdot)

In [38]:
def test(x, y):
    n = 2 * 10**4
    xtiled = np.tile(x, (n, 1))
    ytiled = np.tile(y, (n, 1))
    out = batched_vdot(xtiled, ytiled)
    return out[0]

jt = jit(test)

In [39]:
x = np.array([1,1,1])
y = np.array([1,2,3])

In [40]:
%timeit test(x, y)

1.87 s ± 392 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [43]:
%timeit jt(x, y)

146 µs ± 2.66 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


jup, it can

## kernel computation

In [None]:
from utils import single_rbf, ard
for x in np.linspace(-10, 10, 30):
    x = np.array([x])
    for y in np.linspace(-10, 10, 30):
        y = np.array([y])
        for h in np.linspace(1, 100, 5):
#             print("x = ", x)
#             print("y = ", y)
#             print()
#             print("rbf(x, y): ", single_rbf(x, y, h))
#             print("ard(x, y): ", ard(x, y, h))
#             print()
#             print("-----------")
            assert single_rbf(x, y, h) == ard(x, y, h)

## misc jax:

In [None]:
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)

In [None]:
# multiply matrices
size = 3000
x = random.normal(key, (size, size), dtype=np.float32)
%timeit np.dot(x, x.T).block_until_ready()  # runs on the GPU

In [None]:
import numpy as onp  # original CPU-backed NumPy
x = onp.random.normal(size=(size, size)).astype(onp.float32)
%timeit np.dot(x, x.T).block_until_ready()

In [None]:
from jax import device_put

x = onp.random.normal(size=(size, size)).astype(onp.float32)
x = device_put(x)
%timeit np.dot(x, x.T).block_until_ready()

Using `jit`

In [None]:
def selu(x, alpha=1.67, lmbda=1.05):
    return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha)

x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()

In [None]:
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()

take diff


In [None]:
@jit
def sum_logistic(x):
    return np.sum(1.0 / (1.0 + np.exp(-x)))

x = np.arange(3.)
print(sum_logistic(x))
print(grad(sum_logistic)(x))

In [None]:
x = random.normal(key, (10, 3))
batched_sum = vmap(sum_logistic)
batched_sum(x)

In [None]:
def test(x, y):
    return np.sum(x**2 + y**2)

In [None]:
x, y  = [random.normal(key, (10,3)), random.normal(key + 1, (10,3))]
print('single argument:', test(x[0], y[0]), '\n')
print('batch output shape:', vmap(test)(x, y).shape)
print('batch output:', vmap(test)(x, y))

In [None]:
np.append(np.array([1,2,3]), 4)

# pytorch distance matrix

In [None]:
import torch

In [None]:
x = torch.rand(10, 2)
x

In [None]:
x.split(1)

In [None]:
row = x.split(1)[0]
row

In [None]:
r_v = row.expand_as(x)
r_v

In [None]:
sq_dist = torch.sum((r_v - x) ** 2, 1)
print(sq_dist.shape)
sq_dist

In [None]:
sq_dist.view(1, -1).shape

In [None]:
def row_pairwise_distances(x, y=None, dist_mat=None):
    if y is None:
        y = x
    if dist_mat is None:
        dtype = x.data.type()
        dist_mat = Variable(torch.Tensor(x.size()[0], y.size()[0]).type(dtype))

    for i, row in enumerate(x.split(1)):
        r_v = row.expand_as(y)
        sq_dist = torch.sum((r_v - y) ** 2, 1)
        dist_mat[i] = sq_dist.view(1, -1)
    return dist_mat

# random stuff

## first question
does `jit` cache results if it needs them again? that is, does it skip over repeated computations?

In [None]:
import jax.numpy as np
from jax import grad, jit, vmap
from jax import random

In [None]:
m = 10**4
def h(x):
    for i in range(m):
        x += i
    return x

In [None]:
h(2)

In [None]:
def f1(x):
    """
    computation of h is repeated needlessly
    """
    out = 0
    for i in range(10):
        out += h(x) + i
    return out

In [None]:
def f2(x):
    """
    h(x) is computed only once
    """
    out = 0
    hx = h(x)
    for i in range(10):
        out += hx + i
    return out

In [None]:
s = 4

In [None]:
assert f1(s) == f2(s)

In [None]:
%timeit f1(s)
%timeit f2(s)
%timeit jit(f1)(s).block_until_ready()
%timeit jit(f2)(s).block_until_ready()

## next question:
does `lax.fori_loop` compile more quickly?

In [None]:
from jax import lax

In [None]:
m = 10**4
s = 4

In [None]:
def h1(x):
    body = lambda i, val: val + i
    for i in range(m):
        x = body(i, x)
    return x

In [None]:
def h2(x):
    body = lambda i, val: val + i
    return lax.fori_loop(0, m, body, init_val=x)

In [None]:
jit(h1)(s)

In [None]:
jit(h2)(s)

yeees, it does!

## Next question:
when we use `lax.fori_loop`, do we get the same speedup for repeated computations?

In [None]:
assert h(s) == h1(s)

In [None]:
# now we use the lax fori loop h1
def f1_lax(x):
    """
    computation of h is repeated needlessly
    """
    out = 0
#     for i in range(10):
#         out += h1(x) + i
        
    out = lax.fori_loop(0, 10, lambda i, val: val + h1(x) + i, init_val=out)
    return out

def f2_lax(x):
    """
    h(x) is computed only once
    """
    out = 0
    hx = h1(x)
#     for i in range(10):
#         out += hx + i
        
    out = lax.fori_loop(0, 10, lambda i, val: val + hx + i, init_val=out)
    return out

In [None]:
assert f1_lax(s) == f1(s)
assert f2_lax(s) == f2(s)
assert f1(s) == f2(s)

In [None]:
# still compiles fast:
jit(f1_lax)(s)

In [None]:
%timeit f1_lax(s)
%timeit f2_lax(s)
%timeit jit(f1_lax)(s).block_until_ready()
%timeit jit(f2_lax)(s).block_until_ready()

we do indeed have a speedup.

## How to plot 3D

In [None]:
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter

In [None]:
# Make data.
X = np.linspace(-5, 5, 50)
Y = np.linspace(-5, 5, 50)
X, Y = np.meshgrid(X, Y) # both shape (40, 40)

Z = X**2 + Y**2

# plot
fig = plt.figure()
ax = fig.gca(projection='3d')

surf = ax.plot_surface(X, Y, Z, cmap=cm.coolwarm,
                       linewidth=0, antialiased=False)