In [1]:
import random
import itertools

import jax
import jax.numpy as np
# Current convention is to import original numpy as "onp"
import numpy as onp

In [2]:
def predict(params, inputs):
  for W, b in params:
    outputs = np.dot(inputs, W) + b
    inputs = np.tanh(outputs)
  return outputs

def logprob_fun(params, inputs, targets):
  preds = predict(params, inputs)
  return np.sum((preds - targets)**2)

grad_fun = jax.jit(jax.grad(logprob_fun))  # compiled gradient evaluation function
perex_grads = jax.jit(jax.vmap(grad_fun, in_axes=(None, 0, 0)))  # fast per-example grads

In [3]:
# Sigmoid nonlinearity
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

# Computes our network's output
def net(params, x):
    w1, b1, w2, b2 = params
    hidden = np.tanh(np.dot(w1, x) + b1)
    return sigmoid(np.dot(w2, hidden) + b2)

# Cross-entropy loss
def loss(params, x, y):
    out = net(params, x)
    cross_entropy = -y * np.log(out) - (1 - y)*np.log(1 - out)
    return cross_entropy

# Utility function for testing whether the net produces the correct
# output for all possible inputs
def test_all_inputs(inputs, params):
    predictions = [int(net(params, inp) > 0.5) for inp in inputs]
    for inp, out in zip(inputs, predictions):
        print(inp, '->', out)
    return (predictions == [onp.bitwise_xor(*inp) for inp in inputs])

In [4]:
def initial_params():
    return [
        onp.random.randn(3, 2),  # w1
        onp.random.randn(3),  # b1
        onp.random.randn(3),  # w2
        onp.random.randn(),  #b2
    ]

In [7]:
loss_grad = jax.jit(jax.grad(loss))

# Stochastic gradient descent learning rate
learning_rate = 1.
# All possible inputs
inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])

# Initialize parameters randomly
params = initial_params()

for n in itertools.count():
    # Grab a single random input
    x = inputs[onp.random.choice(inputs.shape[0])]
    # Compute the target output
    y = onp.bitwise_xor(*x)
    # Get the gradient of the loss for this input/output pair
    grads = loss_grad(params, x, y)
    # Update parameters via gradient descent
    params = [param - learning_rate * grad
              for param, grad in zip(params, grads)]
    # Every 100 iterations, check whether we've solved XOR
    if not n % 100:
        print('Iteration {}'.format(n))
        if test_all_inputs(inputs, params):
            break

Iteration 0
[0 0] -> 1
[0 1] -> 1
[1 0] -> 1
[1 1] -> 1
Iteration 100
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
Iteration 200
[0 0] -> 0
[0 1] -> 1
[1 0] -> 1
[1 1] -> 0


In [8]:
loss_grad = jax.jit(jax.vmap(jax.grad(loss), in_axes=(None, 0, 0), out_axes=0))

params = initial_params()

batch_size = 100

for n in itertools.count():
    # Generate a batch of inputs
    x = inputs[onp.random.choice(inputs.shape[0], size=batch_size)]
    y = onp.bitwise_xor(x[:, 0], x[:, 1])
    # The call to loss_grad remains the same!
    grads = loss_grad(params, x, y)
    # Note that we now need to average gradients over the batch
    params = [param - learning_rate * np.mean(grad, axis=0)
              for param, grad in zip(params, grads)]
    if not n % 100:
        print('Iteration {}'.format(n))
        if test_all_inputs(inputs, params):
            break

Iteration 0
[0 0] -> 1
[0 1] -> 1
[1 0] -> 1
[1 1] -> 1
Iteration 100
[0 0] -> 0
[0 1] -> 1
[1 0] -> 1
[1 1] -> 0


In [9]:
!pip install objax
import random

import numpy as np
import tensorflow as tf

import objax
from objax.zoo.wide_resnet import WideResNet

Collecting objax
  Downloading objax-1.3.1.tar.gz (45 kB)
[K     |████████████████████████████████| 45 kB 957 kB/s eta 0:00:011
Collecting parameterized
  Downloading parameterized-0.8.1-py2.py3-none-any.whl (26 kB)
Building wheels for collected packages: objax
  Building wheel for objax (setup.py) ... [?25ldone
[?25h  Created wheel for objax: filename=objax-1.3.1-py3-none-any.whl size=72554 sha256=616d559eefc2a762748425af9dd99f9d8630930551652be7945472b8900c362c
  Stored in directory: /home/silviu/.cache/pip/wheels/8f/dc/bd/2b83372b783db0915ab63e24ac8947169ad70c514ab24ead41
Successfully built objax
Installing collected packages: parameterized, objax
Successfully installed objax-1.3.1 parameterized-0.8.1


# Data
(X_train, Y_train), (X_test, Y_test) = tf.keras.datasets.cifar10.load_data()
X_train = X_train.transpose(0, 3, 1, 2) / 255.0
X_test = X_test.transpose(0, 3, 1, 2) / 255.0

# Model
model = WideResNet(nin=3, nclass=10, depth=28, width=2)
opt = objax.optimizer.Adam(model.vars())

# Losses
@objax.Function.with_vars(model.vars())
def loss(x, label):
    logit = model(x, training=True)
    return objax.functional.loss.cross_entropy_logits_sparse(logit, label).mean()

gv = objax.GradValues(loss, model.vars())

@objax.Function.with_vars(model.vars() + opt.vars())
def train_op(x, y, lr):
    g, v = gv(x, y)
    opt(lr=lr, grads=g)
    return v


train_op = objax.Jit(train_op)
predict = objax.Jit(objax.nn.Sequential([
    objax.ForceArgs(model, training=False), objax.functional.softmax
]))


def augment(x):
    if random.random() < .5:
        x = x[:, :, :, ::-1]  # Flip the batch images about the horizontal axis
    # Pixel-shift all images in the batch by up to 4 pixels in any direction.
    x_pad = np.pad(x, [[0, 0], [0, 0], [4, 4], [4, 4]], 'reflect')
    rx, ry = np.random.randint(0, 8), np.random.randint(0, 8)
    x = x_pad[:, :, rx:rx + 32, ry:ry + 32]
    return x


# Training
# print(model.vars())
for epoch in range(30):
    # Train
    loss = []
    sel = np.arange(len(X_train))
    np.random.shuffle(sel)
    for it in range(0, X_train.shape[0], 64):
        loss.append(train_op(augment(X_train[sel[it:it + 64]]), Y_train[sel[it:it + 64]].flatten(),
                             4e-3 if epoch < 20 else 4e-4))

    # Eval
    test_predictions = [predict(x_batch).argmax(1) for x_batch in X_test.reshape((50, -1) + X_test.shape[1:])]
    accuracy = np.array(test_predictions).flatten() == Y_test.flatten()
    print(f'Epoch {epoch + 1:4d}  Loss {np.mean(loss):.2f}  Accuracy {100 * np.mean(accuracy):.2f}')

In [90]:
import jax
import jax.numpy as np
# Current convention is to import original numpy as "onp"
import numpy as onp

from jax import random

In [139]:
f = lambda x: np.sum(3 * x ** 2)
x = np.ones((2, 3))
y, vjp_fun = jax.vjp(f, x)
# compute J^T v
vjp = vjp_fun(np.array(1.))

In [140]:
vjp

(DeviceArray([[6., 6., 6.],
              [6., 6., 6.]], dtype=float32),)

In [141]:
def my_grad(f, x):
  y, vjp_fn = jax.vjp(f, x)
  return vjp_fn(np.ones(y.shape))[0]

print("my_grad:\n {}".format(my_grad(f, np.ones((2, 3)))))
print("jax grad:\n {}".format(jax.grad(f)(np.ones((2, 3)))))

my_grad:
 [[6. 6. 6.]
 [6. 6. 6.]]
jax grad:
 [[6. 6. 6.]
 [6. 6. 6.]]


In [144]:
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)

[-0.3721109   0.26423115 -0.18252768 -0.7368197  -0.44030377 -0.1521442
 -0.67135346 -0.5908641   0.73168886  0.5673026 ]


In [150]:
size = 3000

In [153]:
jnp = np

In [154]:
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()

2.17 ms ± 129 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [157]:
selu_jit = jax.jit(selu)
%timeit selu_jit(x).block_until_ready()

102 µs ± 29.2 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
