In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [2]:
key = random.PRNGKey(0)

In [3]:
print(random.normal(key, (5,)))

[ 0.18784378 -1.2833427  -0.27109176  1.2490592   0.24446994]


In [4]:
size = 3000
x = random.normal(key, (size, size), jnp.float32)

In [5]:
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU

347 ms ± 13 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [6]:
%timeit jnp.dot(x, x.T)  # runs on the GPU

345 ms ± 5.68 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [7]:
from jax import grad, jit, vmap

In [8]:
def random_layer_params(m,n,key,scale=1e-2):
    wk,bk = random.split(key)
    return scale * random.normal(wk, (n,m)), \
            scale * random.normal(bk, (n,))
random_layer_params(3,4,key)

(DeviceArray([[ 0.01037525,  0.00798739, -0.00124815],
              [ 0.00078822,  0.00026833,  0.00265739],
              [-0.00680038,  0.0122535 , -0.003527  ],
              [-0.01284488,  0.00135566,  0.00207909]], dtype=float32),
 DeviceArray([ 0.01137878, -0.00143314, -0.00591536,  0.00794662], dtype=float32))

In [9]:
# Init all layers for a dense NN with `sizes`
# Units at each layer
def init_network_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [
        random_layer_params(m,n,k)
        for m,n,k in zip(sizes[:-1],sizes[1:], keys)
    ]

params = init_network_params([2,3,4],key)

In [10]:
params

[(DeviceArray([[-0.00263652, -0.0033948 ],
               [-0.00245806,  0.00532352],
               [-0.00156567, -0.0001147 ]], dtype=float32),
  DeviceArray([ 0.00967246, -0.00562784,  0.00379132], dtype=float32)),
 (DeviceArray([[ 0.00208515, -0.01319962, -0.01186628],
               [ 0.00837928,  0.01667681, -0.01895897],
               [ 0.00806379, -0.00242121,  0.00733277],
               [ 0.00142338,  0.00047958,  0.0125117 ]], dtype=float32),
  DeviceArray([-0.01917566, -0.02784671, -0.00641505, -0.00611423], dtype=float32))]

In [11]:
from jax.scipy.special import logsumexp

def relu(x):
    return jnp.maximum(0,x)

def predict(params, image):
    activations = image
    for w,b in params[:-1]:
        print(w.shape, activations.shape,  b.shape,)
        outputs = jnp.dot(w, activations) + b
        print(outputs.shape)
        activations = relu(outputs)
    final_w, final_b = params[-1]
    print(final_w.shape, final_b.shape, activations.shape)
    logits = jnp.dot(final_w, activations) + final_b
    return logits - logsumexp(logits)

import numpy as np
# predict(params, np.random.normal(size=(28*28,)))

In [12]:
layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 10
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.PRNGKey(0))

In [13]:

random_flattened_image = random.normal(key,(28 * 28,))
preds = predict(params, random_flattened_image)
print(preds.shape)

(512, 784) (784,) (512,)
(512,)
(512, 512) (512,) (512,)
(512,)
(10, 512) (10,) (512,)
(10,)


The error was due to a typo wherein the `activations` variable wasn't getting updated and invalid shape exception was raised.

In [14]:
batched_predict = vmap(fun=predict, in_axes=[None, 0])
bs = random.normal(key,[10,28 * 28])
batched_predict(params, bs).shape

(512, 784) (784,) (512,)
(512,)
(512, 512) (512,) (512,)
(512,)
(10, 512) (10,) (512,)


(10, 10)

In [15]:
def one_hot(x, k, dtype=jnp.float32):
    """Create onehot encoding of x of size k"""
    return jnp.array(x[:,None] == jnp.arange(k), dtype)

one_hot(np.arange(3),3)

DeviceArray([[1., 0., 0.],
             [0., 1., 0.],
             [0., 0., 1.]], dtype=float32)

In [16]:
def accuracy(params, images, targets):
    target_class = jnp.argmax(targets, axis=1)
    predicted_class = jnp.argmax(batch)

In [17]:
def mult_stoopid(a,b):
    r = 0
    for i in range(a):
        for j in range(b):
            r += 1
    return r

mult_stoopid(2,3)

6

In [18]:
from jax import grad

mul_g = grad(mult_stoopid,allow_int=True)

In [19]:
# mul_g(2,3)

In [20]:
import jax
def mapping(v):
    x,y,z=v
    return jnp.array([x*x, y*z])
f = jax.jacfwd(mapping)
v = jnp.array([4.,5.,6.,])
print(f(v))

[[8. 0. 0.]
 [0. 6. 5.]]



## VMAP auto-vectorization

In [21]:
lsa = [1,2,3]
lsb = [3,2,1]
result = []
for i in range(3):
    result.append(lsa[i] + lsb[i])
result

[4, 4, 4]

In [22]:
import itertools
import jax
import jax.numpy as np

In [23]:
import numpy.random as random

In [24]:
def sigmoid(x):
    return 1/(1 + np.exp(-x))
import matplotlib.pyplot as plt

In [25]:
# x = np.linspace(-3,3,20)
# # y = sigmoid(x)
# sigmoid_prime = jax.grad(sigmoid,)
# plt.plot(x,sigmoid(x))
# plt.plot(x, sigmoid_prime(x,))

In [26]:
def net(params, x):
    w1,b1,w2,b2 = params
    hidden = np.tanh(np.dot(w1, x) + b1)
    return sigmoid(np.dot(w2,hidden) + b2)

def loss(params, x, y):
    out = net(params, x)
    cross_entropy = - y * np.log(out) \
        - (1 - y) * np.log(1 - out)
    return cross_entropy
params = [
    random.normal(size=(3,2)),
    random.normal(size=(3,)),
    random.normal(size=(1,3)),
    random.normal(size=(1,)),
]
from IPython.display import display

input = np.array([0,1])
target = np.array(0.)
display(net(params, input))
loss(params, input, target)

DeviceArray([0.50114065], dtype=float32)

DeviceArray([0.6954311], dtype=float32)

In [27]:

# Utility function for testing whether the net produces the correct
# output for all possible inputs
def test_all_inputs(inputs, params):
    predictions = [int(net(params, inp) > 0.5) for inp in inputs]
    for inp, out in zip(inputs, predictions):
        print(inp, '->', out)
    return (predictions == [np.bitwise_xor(*inp) for inp in inputs])

In [28]:
def initial_params():
    return [
        random.randn(3, 2),  # w1
        random.randn(3),  # b1
        random.randn(3),  # w2
        random.randn(),  #b2
    ]
loss_grad = grad(loss, argnums=0)
learning_rate = 1
inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])

# Initialize parameters randomly
params = initial_params()

for n in itertools.count():
    # Grab a single random input
    x = inputs[random.choice(inputs.shape[0])]
    # Compute the target output
    y = np.bitwise_xor(*x)
    # Get the gradient of the loss for this input/output pair
    grads = loss_grad(params, x, y)
    # Update parameters via gradient descent
    params = [param - learning_rate * grad
              for param, grad in zip(params, grads)]
    # Every 100 iterations, check whether we've solved XOR
    if not n % 100:
        print('Iteration {}'.format(n))
        if test_all_inputs(inputs, params):
            break

Iteration 0
[0 0] -> 1
[0 1] -> 0
[1 0] -> 1
[1 1] -> 1
Iteration 100
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 200
[0 0] -> 0
[0 1] -> 1
[1 0] -> 1
[1 1] -> 0


In [29]:
%timeit loss_grad(params, x, y)
loss_grad_jit = jit(loss_grad)
loss_grad_jit(params, x, y)
%timeit loss_grad_jit(params, x, y)

9.86 ms ± 1.35 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
10.9 µs ± 72.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [33]:
import numpy as onp
params = initial_params()

for n in itertools.count():
    x = inputs[onp.random.choice(inputs.shape[0])]
    y = onp.bitwise_xor(*x)
    grads = loss_grad(params, x, y)
    params = [param - learning_rate * grad
              for param, grad in zip(params, grads)]
    if not n % 100:
        print('Iteration {}'.format(n))
        if test_all_inputs(inputs, params):
            break

Iteration 0
[0 0] -> 0
[0 1] -> 0
[1 0] -> 0
[1 1] -> 0
Iteration 100
[0 0] -> 0
[0 1] -> 1
[1 0] -> 1
[1 1] -> 0


In [34]:
random.choice(inputs.shape[0])

1

JAX vmap takes two important args.

- `in_axes` is a tuple or int tells jax over which axes the funtions args should be parallelized. if list, same len as num args, if int, should only have one argument. (None, 0, 0) means ignore first arg, and parallelize 0 axis for last two. 
- `out_axes` is similar. Say 0, means output a batch of losses.

In [36]:
loss

<function __main__.loss(params, x, y)>

In [42]:
loss_grad = jax.jit(vmap(grad(loss), in_axes=(None,0,0), out_axes=0))
params = initial_params()
batch_size = 100


for n in itertools.count():
    # Generate a batch of inputs
    x = inputs[onp.random.choice(inputs.shape[0], size=batch_size)]
    y = onp.bitwise_xor(x[:, 0], x[:, 1])
    # The call to loss_grad remains the same!
    grads = loss_grad(params, x, y)
    # Note that we now need to average gradients over the batch
    params = [param - learning_rate * np.mean(grad, axis=0)
              for param, grad in zip(params, grads)]
    if not n % 100:
        print('Iteration {}'.format(n))
        if test_all_inputs(inputs, params):
            break

Iteration 0
[0 0] -> 1
[0 1] -> 1
[1 0] -> 0
[1 1] -> 0
Iteration 100
[0 0] -> 0
[0 1] -> 1
[1 0] -> 1
[1 1] -> 0


In [57]:
a = np.array([3,-1,5,0]).reshape(2,2)
b = np.array([-2,3,6,1]).reshape(2,2)
print('a:',a)
print('b:',b)
print('jnp.add',jnp.add(a,b))
vadd = vmap(jnp.add, (0,0), 0)
print(vadd(a[None,...],b[None,...]).shape)
print(vadd(a,b).shape)

a: [[ 3 -1]
 [ 5  0]]
b: [[-2  3]
 [ 6  1]]
jnp.add [[ 1  2]
 [11  1]]
(1, 2, 2)
(2, 2)


In [38]:
inputs

DeviceArray([[0, 0],
             [0, 1],
             [1, 0],
             [1, 1]], dtype=int32)

In [58]:
from jax import lax

## How to think in JAX

In [65]:
x = jnp.arange(10)
# x[0] = 10
x = x.at[0].set(10)
x

DeviceArray([10,  1,  2,  3,  4,  5,  6,  7,  8,  9], dtype=int32)

In [73]:
x = jnp.array([1, 2, 1])
y = jnp.ones(10)
jnp.convolve(x, y)

DeviceArray([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)

In [74]:
from jax import lax
result = lax.conv_general_dilated(
    x.reshape(1,1,3).astype(float),
    y.reshape(1,1,10),
    window_strides=(1,),
    padding=[(len(y) - 1, len(y) - 1)]
)

In [78]:
import numpy as np
@jit
def f(x,y):
    print("Running f():")
    print(f"  x = {x}")
    print(f"  y = {y}")
    result = jnp.dot(x + 1, y + 1)
    print(f"  result = {result}")
    return result
x = np.random.randn(3,4)
y = np.random.randn(4)
f(x,y)

Running f():
  x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace(level=0/1)>
  y = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=0/1)>
  result = Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=0/1)>


DeviceArray([7.203108, 6.721834, 5.179496], dtype=float32)