In [None]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [None]:
import jax

In [None]:
jax.__version__

In [None]:
key = random.PRNGKey(0)
x = random.normal(key, (10,))

In [None]:
x

In [None]:
size = 3000
x = random.normal(key, shape=(size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()

In [None]:
import numpy as np
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x,x.T).block_until_ready()

In [None]:
from jax import device_put

x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)
%timeit jnp.dot(x, x.T).block_until_ready()

In [None]:
def selu(x, alpha=1.67, lmda = 1.05):
    return lmda * jnp.where(x > 0, x, 
                            alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()

In [None]:
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()

In [None]:
def sum_log(x):
    return jnp.sum(1. / (1. + jnp.exp(-x)))
x_small = jnp.arange(3.)
der_fn = grad(sum_log)
print(der_fn(x_small))

In [None]:
import matplotlib.pyplot as plt
plt.plot(x_small, [sum_log(z) for z in x_small])

In [None]:
mat = random.normal(key, (150, 100))
bx = random.normal(key, (10, 100))

def apply_matrix(v):
    return jnp.dot(mat, v)

In [None]:
jnp.dot(mat, bx[0]).shape

In [None]:
def batched_apply_matrix(bv):
    return jnp.stack([apply_matrix(v) for v in bv])

In [None]:
print('Naively batched')
%timeit batched_apply_matrix(bx).block_until_ready()

In [None]:
@jit
def batched_apply_matrix(v_batched):
    return jnp.dot(v_batched, mat.T)

print('Manually batched')
%timeit batched_apply_matrix(bx).block_until_ready()

In [None]:
batched_apply_matrix(bx).shape

How to JAX off

In [None]:
bx[0] = 1

In [None]:
bx.at[0].set(1)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

In [None]:
x_np = np.linspace(0, 10,1000)
y_np = 2 * np.sin(x_np) * np.cos(x_np)
plt.plot(x_np, y_np)

In [None]:
import torch

In [None]:

x_jnp = jnp.linspace(0, 10, 1000)
y_jnp = 2 * jnp.sin(x_jnp) * jnp.cos(x_jnp)
plt.plot(x_jnp, y_jnp);

In [None]:
torch.add(1, 1.)

In [None]:
x = jnp.array([1,2,1])
y = jnp.ones(10)
jnp.convolve(x,y)

In [None]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [None]:
def random_layer_params(m, n, key, scale=1e-2):
    w, b = random.split(key)
    return scale * random.normal(w, (n,m)), scale * random.normal(b, (n,))
random_layer_params(2,3,key)

In [None]:
import torch.nn as nn

In [None]:
import tensorflow as tf

In [None]:
def init_network_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [random_layer_params(m, n, k) 
               for m,n,k in zip(sizes[:-1], sizes[1:], keys)]

In [None]:
layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 8
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.PRNGKey(0))
params

In [None]:
from jax.scipy.special import logsumexp

def relu(x):
    return jnp.maximum(0,x)
def predict(params, image):
    
    activations = image
    for w, b in params[:-1]:
        outputs = jnp.dot(w, activations) + b
        activations = relu(outputs)

    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits - logsumexp(logits)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

In [None]:
x_np = np.linspace(0, 10, 1000)
y_np = 2 * np.sin(x_np) * np.cos(x_np)
plt.plot(x_np, y_np)

In [None]:
import jax.numpy as jnp

x_jnp = jnp.linspace(0, 10, 1000)
y_jnp = 2 * jnp.sin(x_jnp) * jnp.cos(x_jnp)
plt.plot(x_jnp, y_jnp);

In [None]:
x = jnp.arange(10)
x = x.at[0].set(10)
x

In [None]:
import jax.numpy as jnp
jnp.add(1, 1.0)  # jax.numpy API implicitly promotes mixed types.


In [None]:
from jax import lax

In [None]:
def norm(x):
    x = (x - x.mean(0)) / x.std(0)
    return x
norm(jnp.arange(10))

In [None]:
norm_c = jit(norm)
# norm_c(jnp.arange(10))
norm_c(np.array([1,2,3]))

In [None]:
X = jnp.array(np.random.rand(10000,10))
%timeit norm(X).block_until_ready()
%timeit norm_c(X).block_until_ready()

In [None]:
def get_negatives(x):
    return x[x < 0]
x = jnp.array(np.random.randn(10))
get_negatives(x)

In [None]:
jit(get_negatives)(x)

In [None]:
@jit
def f(x,y):
    print('run f')
    print(f'x = {x}')
    print(f"y = {y}")
    result = jnp.dot(x,y)
    print(f"result = {result}")
    return result

x = np.random.rand(3,4)
y = np.random.rand(4)
f(x,y)

## Jax grads

JAX grad recipe:
- Get a python function that does your computation
- Transform it with `grad()` -> get a gradent function
- Evaluate that grad function to get a gradient w.r.t. the first param

In [None]:
## Torch exaple
import torch
w = torch.tensor(13.,requires_grad=True)
x = torch.tensor(42.,requires_grad=True)
y = x * w # 42 * w -> dy/dw -> 42
y.backward()
grad_w = w.grad
grad_w

In [None]:
## Jax example

def f(w,x):
#     print(repr(x))
    return w * x
dfdw = jax.grad(f,(0, 1))
w = jnp.array(13.)
x = jnp.array(42.)
grad_w = dfdw(w,x)
grad_w

In [None]:
import torch
import torch.nn as nn
class LSTMCell(torch.nn.Module):
    def __init__(self, in_dim, out_dim):
        super(LSTMCell, self).__init__()
        self.weight_ih = torch.nn.Parameter(torch.rand(4*out_dim, in_dim))
        self.weight_hh = torch.nn.Parameter(torch.rand(4*out_dim, out_dim))
        self.bias = torch.nn.Parameter(torch.zeros(4*out_dim,))
        
    def forward(self, inputs, h, c):
        ifgo = self.weight_ih @ inputs + self.weight_hh @ h + self.bias
        i, f, g, o = torch.chunk(ifgo, 4)
        i = torch.sigmoid(i)
        f = torch.sigmoid(f)
        g = torch.tanh(g)
        o = torch.sigmoid(o)
        new_c = f * c + i * g
        new_h = o * torch.tanh(new_c)
        return (new_h, new_c)

In [None]:
class LSTMLM(torch.nn.Module):
    def __init__(self, vocab_size, dim=17):
        super().__init__()
        self.cell = LSTMCell(dim, dim)
        self.embeddings = torch.nn.Parameter(torch.rand(vocab_size, dim))
        self.c_0 = torch.nn.Parameter(torch.zeros(dim))
    
    @property
    def hc_0(self):
        return (torch.tanh(self.c_0), self.c_0)

    def forward(self, seq, hc):
        loss = torch.tensor(0.)
        for idx in seq:
            loss -= torch.log_softmax(self.embeddings @ hc[0], dim=-1)[idx]
            hc = self.cell(self.embeddings[idx,:], *hc)
        return loss, hc
    
    def greedy_argmax(self, hc, length=6):
        with torch.no_grad():
            idxs = []
            for i in range(length):
                idx = torch.argmax(self.embeddings @ hc[0])
                idxs.append(idx.item())
                hc = self.cell(self.embeddings[idx,:], *hc)
        return idxs

In [None]:
import jax.numpy as jnp
vocab_size = 43  # prime trick! :)
training_data = jnp.array([4, 8, 15, 16, 23, 42])

lm = LSTMLM(vocab_size=vocab_size)
print("Sample before:", lm.greedy_argmax(lm.hc_0))

bptt_length = 3  # to illustrate hc.detach-ing

for epoch in range(101):
    hc = lm.hc_0
    totalloss = 0.
    for start in range(0, len(training_data), bptt_length):
        batch = training_data[start:start+bptt_length]
        loss, (h, c) = lm(batch, hc)
        hc = (h.detach(), c.detach())
        if epoch % 50 == 0:
            totalloss += loss.item()
        loss.backward()
        for name, param in lm.named_parameters():
            if param.grad is not None:
                param.data -= 0.1 * param.grad
                del param.grad
    if totalloss:
        print("Loss:", totalloss)

print("Sample after:", lm.greedy_argmax(lm.hc_0))
# Sample before: [42, 34, 34, 34, 34, 34]
# Loss: 25.953862190246582
# Loss: 3.7642268538475037
# Loss: 1.9537211656570435
# Sample after: [4, 8, 15, 16, 23, 42]