In [None]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [None]:
class Linear():
  def __init__(self,m,n,key,scale=1e-2,bias=True):
    w_key, b_key = random.split(key)
    self.w = scale * random.normal(w_key, (m,n))
    if bias: self.b = scale * random.normal(b_key, (n,))

  def __call__(self,x): return jnp.dot(x,self.w) + self.b


class ReLU(): 
  def __call__(self,x): return jnp.maximum(0, x)
  

class LogSoftMax():
  def __call__(self,x): return x - logsumexp(x)

In [None]:

seed = 
seed

DeviceArray([0, 1], dtype=uint32)

In [None]:

lin = Linear(784,512,random.split(random.PRNGKey(1)))

ValueError: ignored

In [None]:
lin = Linear(784,512,random.split())
random_in = random.normal(seed,(1000,784))
random_in.shape

relu = ReLU()

relu(lin(random_in)).shape

ValueError: ignored

In [None]:

%timeit -n 100 lin(random_in).shape

100 loops, best of 5: 911 µs per loop


In [None]:
import torch 
from torch.nn import Linear
x = torch.randn(1000,784)


In [None]:
lin = Linear(784,512)

In [None]:
torch.cuda.is_available()

True

In [None]:

%timeit -n 100 lin(x).shape

100 loops, best of 5: 10.3 ms per loop


In [None]:
class Sequential():
  def __init__(self,layers):
    self.layers = layers
  
  def __call__(self,x):
    for layer_it in self.layers:
      x = layer_it(x)
    return x

model = Sequential([
                    Linear(784,512,seed),
                    ReLU(),
                    Linear(512,512,seed),
                    ReLU(),
                    Linear(512,512,seed),
                    ReLU(),
                    Linear(512,10,seed)])

ValueError: ignored

In [None]:
%timeit -n 100 model(random_in)

100 loops, best of 5: 5.39 ms per loop


In [None]:
batched_model = vmap(model, in_axes=(0))

In [None]:
batched_model(random_in).shape

(1000, 10)

In [None]:
%timeit -n 100 batched_model(random_in)

100 loops, best of 5: 9.46 ms per loop


In [None]:
layer_sizes = [784, 512, 512, 10]
param_scale = 0.1
step_size = 0.01
num_epochs = 8
batch_size = 128
n_targets = 10
seed = random.PRNGKey(1)
params = init_network_params(layer_sizes, seed)

NameError: ignored

In [None]:
params[0][0].mean(), params[0][0].std()

(DeviceArray(-6.467347e-05, dtype=float32),
 DeviceArray(1.0007097, dtype=float32))

In [None]:
for weight,bias in params:
  print('**********')
  print(weight.mean(),weight.std(),weight.shape)

**********
-6.4673506e-05 1.0007097 (512, 784)
**********
-0.00078998506 1.0012791 (512, 512)
**********
0.021645868 0.9996603 (10, 512)


In [None]:
from jax.scipy.special import logsumexp

def relu(x):
  return jnp.maximum(0, x)

def predict(params, image):
  # per-example predictions
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = relu(outputs)
  
  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits - logsumexp(logits)

In [None]:
# This works on single examples
random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))
preds = predict(params, random_flattened_image)
print(preds.shape)

NameError: ignored

In [None]:
# Doesn't work with a batch
random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))
try:
  preds = predict(params, random_flattened_images)
except TypeError:
  print('Invalid shapes!')

Invalid shapes!


In [None]:
batched_predict = vmap(predict, in_axes=(None, 0))


In [None]:
# Let's upgrade it to handle batches using `vmap`

# Make a batched version of the `predict` function
batched_predict = vmap(predict, in_axes=(None, 0))

# `batched_predict` has the same call signature as `predict`
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)

(10, 10)
