In [None]:
#export
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from jax.scipy.special import logsumexp


# **Jax Wrapper**

In [None]:
#export
seed = random.PRNGKey(0)

In [None]:

scale = 1
random_in = random.normal(seed,(100,784))
w_key, b_key = random.split(seed)
w = scale * random.normal(w_key, (784,512))
b = scale * random.normal(b_key, (512,))

In [None]:

w.shape,random_in.shape

((784, 512), (100, 784))

In [None]:
jnp.dot(random_in,w).shape

(100, 512)

In [None]:
#export
class Linear():
  def __init__(self,m,n,seed=None,scale=1e-2,bias=True):
    w_key, b_key = random.split(seed)
    self.w = scale * random.normal(w_key, (m,n))
    if bias: self.b = scale * random.normal(b_key, (n,))

  def __call__(self,x): return (self.w,self.b)

  def forward(self,x): return jnp.dot(x,self.w) + self.b

class ReLU(): 
  def __call__(self,x): return jnp.maximum(0, x)
  

class LogSoftMax():
  def __call__(self,x): return x - logsumexp(x)

In [None]:

def sigmoid(x):
    return 0.5 * (jnp.tanh(x / 2) + 1)

# Outputs probability of a label being true.
def predict(W, b, inputs):
    return sigmoid(jnp.dot(inputs, W) + b)

# Build Sectiona toy dataset.
inputs = jnp.array([[0.52, 1.12,  0.77],
                   [0.88, -1.08, 0.15],
                   [0.52, 0.06, -1.30],
                   [0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])


In [None]:


# Training loss is the negative log-likelihood of the training examples.
def loss(W, b):
    preds = predict(W, b, inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))

# Initialize random model coefficients
key, W_key, b_key = random.split(seed, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, (1,))

In [None]:

W_grad = grad(loss)(W, b)
print('W_grad', W_grad)


W_grad [-0.16965583 -0.8774647  -1.4901344 ]


In [None]:

lin = Linear(784,512,random.PRNGKey(0))

In [None]:
lin = Linear(784,512,seed)
random_in = random.normal(seed,(1000,784))
random_in.shape

relu = ReLU()

relu(lin.forward(random_in)).shape

(1000, 512)

In [None]:
random_in.shape

(1000, 784)

In [None]:

%timeit -n 100 lin.forward(random_in).shape

100 loops, best of 5: 774 µs per loop


In [None]:
# import torch 
# from torch.nn import Linear

# torch_input = torch.randn(1000,784)
# torch_lin = Linear(784,512)
# torch.cuda.is_available()
# %timeit -n 100 torch_lin(torch_input).shape

In [None]:
Linear(784,512,seed=random.PRNGKey(0))

<__main__.Linear at 0x7f57607a50d0>

In [None]:
class Sequential():
  def __init__(self,layers):
    self.layers = layers
  
  def __call__(self,x):
    for layer_it in self.layers:
      x = layer_it(x)
    return x

model = Sequential([
                    Linear(784,512,seed),
                    ReLU(),
                    Linear(512,512,seed),
                    ReLU(),
                    Linear(512,512,seed),
                    ReLU(),
                    Linear(512,10,seed),
                    LogSoftMax()
                    ])

In [None]:
# model(random_in).shape

TypeError: ignored

In [None]:
%timeit -n 100 model(random_in)

100 loops, best of 5: 5.52 ms per loop


In [None]:
vmap??

In [None]:
batched_model = vmap(model, in_axes=(0))

In [None]:
batched_model(random_in).shape

TypeError: ignored

In [None]:
%timeit -n 100 batched_model(random_in)

100 loops, best of 5: 8.99 ms per loop


In [None]:
def cross_entropy_loss(preds,targets):
    return -jnp.mean(preds * targets)


# JAX MNIST

In [None]:
# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
  w_key, b_key = random.split(key)
  return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key,scale):
  keys = random.split(key, len(sizes))
  W = []
  b = []
  for m, n, k in zip(sizes[:-1], sizes[1:], keys):
    _w, _b = random_layer_params(m, n, k,scale)
    W.append(_w)
    b.append(_b)
  return W,b

layer_sizes = [784, 32, 10]
param_scale = 0.1
step_size = 0.01
num_epochs = 10
batch_size = 4
n_targets = 10
W,b = init_network_params(layer_sizes, random.PRNGKey(0),1e-2)

In [None]:
W[0].mean(),W[0].std(),W[0].shape

(DeviceArray(-2.9898334e-05, dtype=float32),
 DeviceArray(0.00994793, dtype=float32),
 (32, 784))

In [None]:
for weight in W:
  print('**********')
  print(weight.mean(),weight.std(),weight.shape)

**********
-2.9898336e-05 0.009947931 (32, 784)
**********
-0.000115454604 0.00996856 (10, 32)


In [None]:
from jax.scipy.special import logsumexp

def relu(x):
  return jnp.maximum(0, x)

def predict(W,b, image):
  # per-example predictions
  activations = image
  for w, b in list(zip(W,b))[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = relu(outputs)
  
  final_w, final_b = list(zip(W,b))[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits - logsumexp(logits)

In [None]:
# This works on single examples
random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))
preds = predict(W,b, random_flattened_image)
print(preds.shape)

(10,)


In [None]:
# Doesn't work with a batch
random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))
try:
  preds = predict(W,b, random_flattened_images)
except TypeError:
  print('Invalid shapes!')

Invalid shapes!


In [None]:
batched_predict = vmap(predict, in_axes=(None, 0))


In [None]:
# Let's upgrade it to handle batches using `vmap`

# Make a batched version of the `predict` function
batched_predict = vmap(predict, in_axes=(None,None,0))

# `batched_predict` has the same call signature as `predict`
batched_preds = batched_predict(W,b, random_flattened_images)
print(batched_preds.shape)

(10, 10)


In [None]:

def get_stats(x): return x.mean(),x.std(),x.shape

In [None]:
from scipy.spatial import distance

In [None]:
def one_hot(x, k, dtype=jnp.float32):
  """Create a one-hot encoding of x of size k."""
  return jnp.array(x[:, None] == jnp.arange(k), dtype)
  
def accuracy(W,b, images, targets):
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(batched_predict(W,b, images), axis=1)
  return jnp.mean(predicted_class == target_class)

def loss(W,b, images, targets):
  preds = batched_predict(W,b, images)
  return -jnp.mean(preds * targets)

@jit
def update(W,b, x, y):
  DW,DB = grad(loss,(0, 1))(W,b, x, y)
  updated_W = []
  updated_b = []
  grads_params = []
  for w, _b, _dw, _db in zip(W,b, DW,DB):
    new_w = w - _dw*step_size
    new_b = _b - _db*step_size

    # updated_params.append((new_w,new_b))
    updated_W.append(new_w)
    updated_b.append(new_b)
  return updated_W,updated_b,DW,DB

    



In [None]:
x,y = get_one_batch()
x = jnp.reshape(x, (len(x), num_pixels))
y = one_hot(y, num_labels)
x.shape, y.shape

NameError: ignored

In [None]:
len(W), len(b)

In [None]:
dw,db = grad(loss,(0, 1))(W,b, x, y)

In [None]:
len(dw[0]),len(db[1])

In [None]:
W[0]-dw[0]

NameError: ignored

In [None]:
import tensorflow_datasets as tfds

data_dir = '/tmp/tfds'

# Fetch full datasets for evaluation
# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)
# You can convert them to NumPy arrays (or iterables of NumPy arrays) with tfds.dataset_as_numpy
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)
mnist_data = tfds.as_numpy(mnist_data)
train_data, test_data = mnist_data['train'], mnist_data['test']
num_labels = info.features['label'].num_classes
h, w, c = info.features['image'].shape
num_pixels = h * w * c

# Full train set
train_images, train_labels = train_data['image'], train_data['label']
train_images = jnp.reshape(train_images, (len(train_images), num_pixels))
train_labels = one_hot(train_labels, num_labels)

# Full test set
test_images, test_labels = test_data['image'], test_data['label']
test_images = jnp.reshape(test_images, (len(test_images), num_pixels))
test_labels = one_hot(test_labels, num_labels)

In [None]:

list(train_data.keys())

['image', 'label']

In [None]:
# train_data: dict: ['image', 'label']
train_data[list(train_data.keys())[0]].shape,train_data[list(train_data.keys())[1]].shape

((60000, 28, 28, 1), (60000,))

In [None]:
import time

def get_train_batches():
  # as_supervised=True gives us the (image, label) as a tuple instead of a dict
  ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir)
  # You can build up an arbitrary tf.data input pipeline
  ds = ds.batch(batch_size).prefetch(1)
  # tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays
  return tfds.as_numpy(ds)


def get_one_batch():
  # as_supervised=True gives us the (image, label) as a tuple instead of a dict
  ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir)
  # You can build up an arbitrary tf.data input pipeline
  ds = list(ds.batch(batch_size).prefetch(1))
  # tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays
  return tfds.as_numpy(ds)[0]

In [None]:
stats_arr = []

In [None]:
import numpy 

In [None]:
def print_param_stats(_W,_b,W,b,epoch,idx,DW):
  layer = 0
  for layer_idx,(w1, w,dw) in enumerate(zip(_W, W,DW)):
    stats_arr.append({
        "epoch":epoch,
        "idx":idx,
        "layer":layer_idx,
        "new_w_stats":get_stats(w1),
        "grad_w_stats":get_stats(dw),
        "old_w_stats":get_stats(w),
        "dist":numpy.linalg.norm(w1-w,axis=1).mean()
    })
    layer =+ 1




# std: 1 lr = 0.01

In [None]:
# std: 1 lr = 0.01
for epoch in range(num_epochs):
  start_time = time.time()
  for idx,(x, y) in enumerate(get_train_batches()):
    x = jnp.reshape(x, (len(x), num_pixels))
    y = one_hot(y, num_labels)
    params_,grads = update(params, x, y)
    print_param_stats(params_,params,epoch,idx,grads)
    params = params_

  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_images, train_labels)
  test_acc = accuracy(params, test_images, test_labels)
  print("Epoch {} in {:0`.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

Epoch 0 in 19.74 sec
Training set accuracy 0.09931667149066925
Test set accuracy 0.10320000350475311
Epoch 1 in 19.76 sec
Training set accuracy 0.09931667149066925
Test set accuracy 0.10320000350475311
Epoch 2 in 21.65 sec
Training set accuracy 0.09931667149066925
Test set accuracy 0.10320000350475311
Epoch 3 in 21.18 sec
Training set accuracy 0.09931667149066925
Test set accuracy 0.10320000350475311
Epoch 4 in 20.98 sec
Training set accuracy 0.09931667149066925
Test set accuracy 0.10320000350475311
Epoch 5 in 19.97 sec
Training set accuracy 0.09931667149066925
Test set accuracy 0.10320000350475311
Epoch 6 in 20.37 sec
Training set accuracy 0.09931667149066925
Test set accuracy 0.10320000350475311
Epoch 7 in 22.00 sec
Training set accuracy 0.09931667149066925
Test set accuracy 0.10320000350475311
Epoch 8 in 20.98 sec
Training set accuracy 0.09931667149066925
Test set accuracy 0.10320000350475311
Epoch 9 in 20.77 sec
Training set accuracy 0.09931667149066925
Test set accuracy 0.10320000

In [None]:
import pandas as pd 
df1 = pd.DataFrame(stats_arr)

df1['grad_mean'] = df1['grad_w_stats'].apply(lambda x: x[0])
df1['grad_pct_old'] = df1.apply(lambda x: x['grad_mean']*100/x['old_w_stats'][0],axis=1)

df1['grad_mean'] = df1['grad_mean'].apply(float)
df1['grad_pct_old'] = df1['grad_pct_old'].apply(float)

epoch_layers_grads1 = df1[['epoch','layer','grad_mean','grad_pct_old']].groupby(['epoch','layer']).mean().reset_index()


# std 0.01

In [None]:
W[0].mean()

NameError: ignored

In [None]:
# std: 0.01 lr = 0.01
W,b = init_network_params(layer_sizes, random.PRNGKey(0),1e-2)
for epoch in range(3):
  start_time = time.time()
  for idx,(x, y) in enumerate(get_train_batches()):
    x = jnp.reshape(x, (len(x), num_pixels))
    y = one_hot(y, num_labels)
    _W,_b,DW,DB = update(W,b, x, y)
    print_param_stats(_W,_b,W,b,epoch,idx,DW,DB)
    # params = params_
    W = _W
    b = _b

  epoch_time = time.time() - start_time

  train_acc = accuracy(W,b, train_images, train_labels)
  test_acc = accuracy(W,b, test_images, test_labels)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

In [None]:
import pandas as pd 
df = pd.DataFrame(stats_arr)

In [None]:
df.shape

In [None]:
df['grad_mean'] = df['grad_w_stats'].apply(lambda x: x[0])
df['grad_pct_old'] = df.apply(lambda x: x['grad_mean']*100/x['old_w_stats'][0],axis=1)

In [None]:
for i in range(3):
  print(i)
  epoch_idx = 8
  print(f"Grad Mean: {df.loc[(df.epoch == epoch_idx) & (df.layer == i)].grad_mean.mean()}")
  print(f"Grad Pct Mean: {df.loc[(df.epoch == epoch_idx) & (df.layer == i)].grad_pct_old.mean()}")

In [None]:
df['grad_mean'] = df['grad_mean'].apply(float)
df['grad_pct_old'] = df['grad_pct_old'].apply(float)

In [None]:
epoch_layers_grads = df[['epoch','layer','grad_mean','grad_pct_old']].groupby(['epoch','layer']).mean().reset_index()

In [None]:
epoch_layers_grads.loc[epoch_layers_grads.layer == 0]

In [None]:
epoch_layers_grads.loc[epoch_layers_grads.layer == 1]

In [None]:
#default
#do not execute
epoch_layers_grads

In [None]:
epoch_layers_grads1

#Second Order Updates 


In [None]:
@jit
def update(W,b, x, y):
  DW,DB = grad(loss,(0, 1))(W,b, x, y)
  updated_W = []
  updated_b = []
  grads_params = []
  layer = 0
  for w, _b, _dw, _db in zip(W,b, DW,DB):
    new_w = w - _dw*step_size
    new_b = _b - _db*step_size

    # updated_params.append((new_w,new_b))
    updated_W.append(new_w)
    updated_b.append(new_b)
  return updated_W,updated_b,DW,DB

    

In [None]:

@jit
def update_2(W,b, x, y):
  DW = grad(loss,(0))(W,b, x, y)
  fw = lambda W: loss(W, b, x,y)
  # fb = lambda b: loss(W, b, inputs)
  grads_w_2 = hessian(fw)(W)
  # grads_b_2 = hessian(fb)(b)
  updated_W = []
  # updated_b = []

  grad_W = []
  # grad_B = []
  grads_params = []
  layer = 0
  # for (w, _b, dw, db, ddw,ddb) in zip(W,b, DW,DB,grads_w_2,grads_b_2):
  for (w, dw, ddw) in zip(W, DW,grads_w_2):

    grad_w = dw/ddw
    # grad_b = db/ddb

    new_w = w - grad_w
    # new_b = _b - grad_b

    updated_W.append(new_w)
    # updated_b.append(new_b)
    grad_W.append(grad_w)
    # grad_B.append(grad_b)
  return updated_W,grad_W

    


In [None]:
x,y = get_one_batch()
x = jnp.reshape(x, (len(x), num_pixels))
y = one_hot(y, num_labels)

In [None]:
grads = grad(loss)(params, x, y)

In [None]:
type(grads[0])

tuple

In [None]:
def f(x): return x**2
theta = 5.

In [None]:
grad_1 = grad(f)(theta)
grad_1

DeviceArray(10., dtype=float32)

In [None]:
grad_2 = grad(grad(f))(theta)
grad_2

DeviceArray(2., dtype=float32)

In [None]:
theta - (grad_1/grad_2)

DeviceArray(0., dtype=float32)

In [None]:
from jax import jacfwd, jacrev

In [None]:
def hessian(f):
    return jacfwd(jacrev(f))


In [None]:
hessian(f)(theta)

DeviceArray(2., dtype=float32)

In [None]:
params = init_network_params(layer_sizes, random.PRNGKey(0),1e-2)
for epoch in range(num_epochs):
  start_time = time.time()
  for idx,(x, y) in enumerate(get_train_batches()):
    x = jnp.reshape(x, (len(x), num_pixels))
    y = one_hot(y, num_labels)
    updated_W,grad_W = update_2(W,b, x, y)
    print_param_stats(updated_W,updated_b,W,b,epoch,idx,grad_W)
    params = params_

  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_images, train_labels)
  test_acc = accuracy(params, test_images, test_labels)
  print("Epoch {} in {:0`.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

TypeError: ignored

In [None]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

key = random.PRNGKey(0)

In [None]:
grad_tanh = grad(jnp.tanh)
print(grad_tanh(2.0))

0.070650935


In [None]:
print(grad(grad(jnp.tanh))(2.0))
print(grad(grad(grad(jnp.tanh)))(2.0))

-0.13621888
0.2526544


In [None]:
def f(x): return x**2

In [None]:
grad(f)(2.)

DeviceArray(4., dtype=float32)

In [None]:
def sigmoid(x):
    return 0.5 * (jnp.tanh(x / 2) + 1)

# Outputs probability of a label being true.
def predict(W, b, inputs):
    return sigmoid(jnp.dot(inputs, W) + b)

# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12,  0.77],
                   [0.88, -1.08, 0.15],
                   [0.52, 0.06, -1.30],
                   [0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])

# Training loss is the negative log-likelihood of the training examples.
def loss(W, b):
    preds = predict(W, b, inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))

# Initialize random model coefficients
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())

In [None]:
grad(loss, argnums=0)(W, b)

DeviceArray([-0.16965583, -0.8774647 , -1.4901344 ], dtype=float32)

In [None]:
grad(loss, argnums=0)(W, b)