In [1]:
import random
import itertools

import jax
import jax.numpy as np
# Current convention is to import original numpy as "onp"
import numpy as onp

In [2]:
def predict(params, inputs):
  for W, b in params:
    outputs = np.dot(inputs, W) + b
    inputs = np.tanh(outputs)
  return outputs

def logprob_fun(params, inputs, targets):
  preds = predict(params, inputs)
  return np.sum((preds - targets)**2)

grad_fun = jax.jit(jax.grad(logprob_fun))  # compiled gradient evaluation function
perex_grads = jax.jit(jax.vmap(grad_fun, in_axes=(None, 0, 0)))  # fast per-example grads

In [3]:
# Sigmoid nonlinearity
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

# Computes our network's output
def net(params, x):
    w1, b1, w2, b2 = params
    hidden = np.tanh(np.dot(w1, x) + b1)
    return sigmoid(np.dot(w2, hidden) + b2)

# Cross-entropy loss
def loss(params, x, y):
    out = net(params, x)
    cross_entropy = -y * np.log(out) - (1 - y)*np.log(1 - out)
    return cross_entropy

# Utility function for testing whether the net produces the correct
# output for all possible inputs
def test_all_inputs(inputs, params):
    predictions = [int(net(params, inp) > 0.5) for inp in inputs]
    for inp, out in zip(inputs, predictions):
        print(inp, '->', out)
    return (predictions == [onp.bitwise_xor(*inp) for inp in inputs])

In [4]:
def initial_params():
    return [
        onp.random.randn(3, 2),  # w1
        onp.random.randn(3),  # b1
        onp.random.randn(3),  # w2
        onp.random.randn(),  #b2
    ]

In [5]:
loss_grad = jax.jit(jax.grad(loss))

# Stochastic gradient descent learning rate
learning_rate = 1.
# All possible inputs
inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])

# Initialize parameters randomly
params = initial_params()

for n in itertools.count():
    # Grab a single random input
    x = inputs[onp.random.choice(inputs.shape[0])]
    # Compute the target output
    y = onp.bitwise_xor(*x)
    # Get the gradient of the loss for this input/output pair
    grads = loss_grad(params, x, y)
    # Update parameters via gradient descent
    params = [param - learning_rate * grad
              for param, grad in zip(params, grads)]
    # Every 100 iterations, check whether we've solved XOR
    if not n % 100:
        print('Iteration {}'.format(n))
        if test_all_inputs(inputs, params):
            break

Iteration 0
[0 0] -> 0
[0 1] -> 1
[1 0] -> 0
[1 1] -> 1
Iteration 100
[0 0] -> 0
[0 1] -> 1
[1 0] -> 1
[1 1] -> 0


In [6]:
loss_grad = jax.jit(jax.vmap(jax.grad(loss), in_axes=(None, 0, 0), out_axes=0))

params = initial_params()

batch_size = 100

for n in itertools.count():
    # Generate a batch of inputs
    x = inputs[onp.random.choice(inputs.shape[0], size=batch_size)]
    y = onp.bitwise_xor(x[:, 0], x[:, 1])
    # The call to loss_grad remains the same!
    grads = loss_grad(params, x, y)
    # Note that we now need to average gradients over the batch
    params = [param - learning_rate * np.mean(grad, axis=0)
              for param, grad in zip(params, grads)]
    if not n % 100:
        print('Iteration {}'.format(n))
        if test_all_inputs(inputs, params):
            break

Iteration 0
[0 0] -> 0
[0 1] -> 0
[1 0] -> 0
[1 1] -> 1
Iteration 100
[0 0] -> 1
[0 1] -> 1
[1 0] -> 0
[1 1] -> 0
Iteration 200
[0 0] -> 0
[0 1] -> 1
[1 0] -> 0
[1 1] -> 0
Iteration 300
[0 0] -> 0
[0 1] -> 1
[1 0] -> 0
[1 1] -> 0
Iteration 400
[0 0] -> 0
[0 1] -> 1
[1 0] -> 1
[1 1] -> 0


In [8]:
!pip install tensorflow

Collecting tensorflow
  Downloading tensorflow-2.9.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (511.7 MB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m511.7/511.7 MB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0mm eta [36m0:00:01[0m[36m0:00:02[0m
Collecting tensorflow-io-gcs-filesystem>=0.23.1
  Downloading tensorflow_io_gcs_filesystem-0.26.0-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (2.4 MB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.4/2.4 MB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m MB/s[0m eta [36m0:00:01[0mm:01[0m
Collecting tensorflow-estimator<2.10.0,>=2.9.0rc0
  Downloading tensorflow_estimator-2.9.0-py2.py3-none-any.whl (438 kB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m438.7/438.7 KB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m31m11.9 MB/s[0m eta [36m0:00:01[0m
Collecting keras-preprocessing>=1.1.1
  Downloading Keras_Preprocessing-1.1.2-py2.p

    Uninstalling tensorboard-2.7.0:
      Successfully uninstalled tensorboard-2.7.0
Successfully installed astunparse-1.6.3 flatbuffers-1.12 gast-0.4.0 google-pasta-0.2.0 keras-2.9.0 keras-preprocessing-1.1.2 libclang-14.0.1 tensorboard-2.9.0 tensorflow-2.9.1 tensorflow-estimator-2.9.0 tensorflow-io-gcs-filesystem-0.26.0
You should consider upgrading via the '/home/silviu/anaconda3/bin/python -m pip install --upgrade pip' command.[0m[33m
[0m

In [10]:
import tensorflow as tf; print(tf.config.list_physical_devices('GPU'))


[]


2022-05-27 11:22:41.307266: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudnn.so.8'; dlerror: libcudnn.so.8: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/usr/local/cuda/lib64:/home/silviu/.mujoco/mjpro150/bin:/home/silviu/.mujoco/mujoco200/bin:/usr/lib/nvidia-375:/usr/local/cuda/lib64:/home/silviu/.mujoco/mjpro150/bin:/home/silviu/.mujoco/mujoco200/bin:/usr/lib/nvidia-375
2022-05-27 11:22:41.307282: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1850] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


In [11]:
!pip install objax
import random

import numpy as np

import objax
from objax.zoo.wide_resnet import WideResNet

You should consider upgrading via the '/home/silviu/anaconda3/bin/python -m pip install --upgrade pip' command.[0m[33m
[0m

# Data
(X_train, Y_train), (X_test, Y_test) = tf.keras.datasets.cifar10.load_data()
X_train = X_train.transpose(0, 3, 1, 2) / 255.0
X_test = X_test.transpose(0, 3, 1, 2) / 255.0

# Model
model = WideResNet(nin=3, nclass=10, depth=28, width=2)
opt = objax.optimizer.Adam(model.vars())

# Losses
@objax.Function.with_vars(model.vars())
def loss(x, label):
    logit = model(x, training=True)
    return objax.functional.loss.cross_entropy_logits_sparse(logit, label).mean()

gv = objax.GradValues(loss, model.vars())

@objax.Function.with_vars(model.vars() + opt.vars())
def train_op(x, y, lr):
    g, v = gv(x, y)
    opt(lr=lr, grads=g)
    return v


train_op = objax.Jit(train_op)
predict = objax.Jit(objax.nn.Sequential([
    objax.ForceArgs(model, training=False), objax.functional.softmax
]))


def augment(x):
    if random.random() < .5:
        x = x[:, :, :, ::-1]  # Flip the batch images about the horizontal axis
    # Pixel-shift all images in the batch by up to 4 pixels in any direction.
    x_pad = np.pad(x, [[0, 0], [0, 0], [4, 4], [4, 4]], 'reflect')
    rx, ry = np.random.randint(0, 8), np.random.randint(0, 8)
    x = x_pad[:, :, rx:rx + 32, ry:ry + 32]
    return x


# Training
# print(model.vars())
for epoch in range(30):
    # Train
    loss = []
    sel = np.arange(len(X_train))
    np.random.shuffle(sel)
    for it in range(0, X_train.shape[0], 64):
        loss.append(train_op(augment(X_train[sel[it:it + 64]]), Y_train[sel[it:it + 64]].flatten(),
                             4e-3 if epoch < 20 else 4e-4))

    # Eval
    test_predictions = [predict(x_batch).argmax(1) for x_batch in X_test.reshape((50, -1) + X_test.shape[1:])]
    accuracy = np.array(test_predictions).flatten() == Y_test.flatten()
    print(f'Epoch {epoch + 1:4d}  Loss {np.mean(loss):.2f}  Accuracy {100 * np.mean(accuracy):.2f}')

In [12]:
import jax
import jax.numpy as np
# Current convention is to import original numpy as "onp"
import numpy as onp

from jax import random

In [13]:
f = lambda x: np.sum(3 * x ** 2)
x = np.ones((2, 3))
y, vjp_fun = jax.vjp(f, x)
# compute J^T v
vjp = vjp_fun(np.array(1.))

In [14]:
vjp

(DeviceArray([[6., 6., 6.],
              [6., 6., 6.]], dtype=float32),)

In [15]:
def my_grad(f, x):
  y, vjp_fn = jax.vjp(f, x)
  return vjp_fn(np.ones(y.shape))[0]

print("my_grad:\n {}".format(my_grad(f, np.ones((2, 3)))))
print("jax grad:\n {}".format(jax.grad(f)(np.ones((2, 3)))))

my_grad:
 [[6. 6. 6.]
 [6. 6. 6.]]
jax grad:
 [[6. 6. 6.]
 [6. 6. 6.]]


In [16]:
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)

[-0.3721109   0.26423115 -0.18252768 -0.7368197  -0.44030377 -0.1521442
 -0.67135346 -0.5908641   0.73168886  0.5673026 ]


In [17]:
size = 3000

In [18]:
jnp = np

In [19]:
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()

388 µs ± 73 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [20]:
selu_jit = jax.jit(selu)
%timeit selu_jit(x).block_until_ready()

34.9 µs ± 276 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
