In [1]:
import jax.numpy as np
from jax import grad, jit, vmap
from jax import random

In [2]:
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)



[-0.372111    0.2642311  -0.18252774 -0.7368198  -0.44030386 -0.15214427
 -0.6713536  -0.59086424  0.73168874  0.56730247]


In [3]:
# multiply matrices
size = 3000
x = random.normal(key, (size, size), dtype=np.float32)
%timeit np.dot(x, x.T).block_until_ready()  # runs on the GPU

332 ms ± 63.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [4]:
import numpy as onp  # original CPU-backed NumPy
x = onp.random.normal(size=(size, size)).astype(onp.float32)
%timeit np.dot(x, x.T).block_until_ready()

294 ms ± 59.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [5]:
from jax import device_put

x = onp.random.normal(size=(size, size)).astype(onp.float32)
x = device_put(x)
%timeit np.dot(x, x.T).block_until_ready()

254 ms ± 36.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


Using `jit`

In [6]:
def selu(x, alpha=1.67, lmbda=1.05):
    return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha)

x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()

5.04 ms ± 728 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [7]:
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()

584 µs ± 106 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


take diff


In [8]:
@jit
def sum_logistic(x):
    return np.sum(1.0 / (1.0 + np.exp(-x)))

x = np.arange(3.)
print(sum_logistic(x))
print(grad(sum_logistic)(x))

2.1118555
[0.25       0.19661197 0.10499357]


In [9]:
x = random.normal(key, (10, 3))
batched_sum = vmap(sum_logistic)
batched_sum(x)

DeviceArray([1.4159583, 2.2595613, 1.4873992, 1.2502856, 2.0268779,
             1.7852714, 1.1136012, 1.4845107, 0.8522337, 1.1135309],            dtype=float32)

In [10]:
def test(x, y):
    return np.sum(x**2 + y**2)

In [11]:
x, y  = [random.normal(key, (10,3)), random.normal(key + 1, (10,3))]
print('single argument:', test(x[0], y[0]), '\n')
print('batch output shape:', vmap(test)(x, y).shape)
print('batch output:', vmap(test)(x, y))

single argument: 2.4101882 

batch output shape: (10,)
batch output: [2.4101882 5.7268    4.604978  5.98869   3.2325125 4.0696073 3.6179295
 4.550433  7.862453  7.104902 ]


In [12]:
np.append(np.array([1,2,3]), 4)

DeviceArray([1, 2, 3, 4], dtype=int32)

# pytorch distance matrix

In [13]:
import torch

In [14]:
x = torch.rand(10, 2)
x

tensor([[0.1648, 0.2830],
        [0.7454, 0.9028],
        [0.1389, 0.5697],
        [0.9894, 0.2126],
        [0.2389, 0.9959],
        [0.4940, 0.5586],
        [0.5405, 0.6170],
        [0.6420, 0.8481],
        [0.6479, 0.1027],
        [0.6203, 0.7241]])

In [15]:
x.split(1)

(tensor([[0.1648, 0.2830]]),
 tensor([[0.7454, 0.9028]]),
 tensor([[0.1389, 0.5697]]),
 tensor([[0.9894, 0.2126]]),
 tensor([[0.2389, 0.9959]]),
 tensor([[0.4940, 0.5586]]),
 tensor([[0.5405, 0.6170]]),
 tensor([[0.6420, 0.8481]]),
 tensor([[0.6479, 0.1027]]),
 tensor([[0.6203, 0.7241]]))

In [16]:
row = x.split(1)[0]
row

tensor([[0.1648, 0.2830]])

In [17]:
r_v = row.expand_as(x)
r_v

tensor([[0.1648, 0.2830],
        [0.1648, 0.2830],
        [0.1648, 0.2830],
        [0.1648, 0.2830],
        [0.1648, 0.2830],
        [0.1648, 0.2830],
        [0.1648, 0.2830],
        [0.1648, 0.2830],
        [0.1648, 0.2830],
        [0.1648, 0.2830]])

In [18]:
sq_dist = torch.sum((r_v - x) ** 2, 1)
print(sq_dist.shape)
sq_dist

torch.Size([10])


tensor([0.0000, 0.7212, 0.0829, 0.6849, 0.5137, 0.1843, 0.2527, 0.5471, 0.2658,
        0.4021])

In [19]:
sq_dist.view(1, -1).shape

torch.Size([1, 10])

In [20]:
def row_pairwise_distances(x, y=None, dist_mat=None):
    if y is None:
        y = x
    if dist_mat is None:
        dtype = x.data.type()
        dist_mat = Variable(torch.Tensor(x.size()[0], y.size()[0]).type(dtype))

    for i, row in enumerate(x.split(1)):
        r_v = row.expand_as(y)
        sq_dist = torch.sum((r_v - y) ** 2, 1)
        dist_mat[i] = sq_dist.view(1, -1)
    return dist_mat

# random stuff

## first question
does `jit` cache results if it needs them again? that is, does it skip over repeated computations?

In [21]:
import jax.numpy as np
from jax import grad, jit, vmap
from jax import random

In [22]:
m = 10**4
def h(x):
    for i in range(m):
        x += i
    return x

In [23]:
h(2)

49995002

In [24]:
def f1(x):
    """
    computation of h is repeated needlessly
    """
    out = 0
    for i in range(10):
        out += h(x) + i
    return out

In [25]:
def f2(x):
    """
    h(x) is computed only once
    """
    out = 0
    hx = h(x)
    for i in range(10):
        out += hx + i
    return out

In [26]:
s = 4

In [27]:
assert f1(s) == f2(s)

In [28]:
%timeit f1(s)
%timeit f2(s)
%timeit jit(f1)(s).block_until_ready()
%timeit jit(f2)(s).block_until_ready()

4.21 ms ± 639 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
424 µs ± 62 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
448 µs ± 112 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
324 µs ± 47.1 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


## next question:
does `lax.fori_loop` compile more quickly?

In [29]:
from jax import lax

In [30]:
m = 10**4
s = 4

In [31]:
def h1(x):
    body = lambda i, val: val + i
    for i in range(m):
        x = body(i, x)
    return x

In [32]:
def h2(x):
    body = lambda i, val: val + i
    return lax.fori_loop(0, m, body, init_val=x)

In [33]:
jit(h1)(s)

DeviceArray(49995004, dtype=int32)

In [34]:
jit(h2)(s)

DeviceArray(49995004, dtype=int32)

yeees, it does!

## Next question:
when we use `lax.fori_loop`, do we get the same speedup for repeated computations?

In [35]:
assert h(s) == h1(s)

In [36]:
# now we use the lax fori loop h1
def f1_lax(x):
    """
    computation of h is repeated needlessly
    """
    out = 0
#     for i in range(10):
#         out += h1(x) + i
        
    out = lax.fori_loop(0, 10, lambda i, val: val + h1(x) + i, init_val=out)
    return out

def f2_lax(x):
    """
    h(x) is computed only once
    """
    out = 0
    hx = h1(x)
#     for i in range(10):
#         out += hx + i
        
    out = lax.fori_loop(0, 10, lambda i, val: val + hx + i, init_val=out)
    return out

In [37]:
assert f1_lax(s) == f1(s)
assert f2_lax(s) == f2(s)
assert f1(s) == f2(s)

In [38]:
# still compiles fast:
jit(f1_lax)(s)

DeviceArray(499950085, dtype=int32)

In [39]:
%timeit f1_lax(s)
%timeit f2_lax(s)
%timeit jit(f1_lax)(s).block_until_ready()
%timeit jit(f2_lax)(s).block_until_ready()

12.9 ms ± 1.32 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
12.2 ms ± 179 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
116 µs ± 3.87 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
The slowest run took 4.48 times longer than the fastest. This could mean that an intermediate result is being cached.
389 µs ± 233 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


we do indeed have a speedup.