In [None]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import numpy as np

In [None]:
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)

In [None]:
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU

In [None]:
from jax import device_put

x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)
%timeit jnp.dot(x, x.T).block_until_ready()

In [None]:
%timeit jnp.dot(x, x.T).block_until_ready()

In [None]:
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()

In [None]:
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()

In [None]:
def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

In [None]:
def first_finite_differences(f, x):
  eps = 1e-3
  return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                   for v in jnp.eye(len(x))])


print(first_finite_differences(sum_logistic, x_small))

In [None]:
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))

In [None]:
mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))

def apply_matrix(v):
  return jnp.dot(mat, v)

In [None]:
def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

In [None]:
@jit
def batched_apply_matrix(v_batched):
  return jnp.dot(v_batched, mat.T)

print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()

In [None]:
@jit
def vmap_batched_apply_matrix(v_batched):
  return vmap(apply_matrix)(v_batched)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

In [None]:
vmap(apply_matrix,(0,),0)(random.normal(key, ( 10, 100))).shape

In [None]:
help(vmap)

In [None]:
xs = jnp.arange(3. * 4.).reshape(3, 4)
xs

In [None]:
from jax import lax
print(vmap(lambda x: lax.psum(x, 'i'), axis_name='i')(xs))

In [None]:
help(lax.psum)

In [None]:
jax.pmap(lambda x: x ** 2)(jnp.arange(1)) 

In [None]:
import jax.numpy as jnp

x_jnp = jnp.linspace(0, 10, 1000)
y_jnp = 2 * jnp.sin(x_jnp) * jnp.cos(x_jnp)
plt.plot(x_jnp, y_jnp);

In [None]:
type(x_jnp)

In [None]:
y = x_jnp.at[0].set(10)
print(x_jnp)
print(y)

In [None]:
x = jnp.array([1, 2, 1])
y = jnp.ones(10)
jnp.convolve(x, y)

In [None]:
import jax.numpy as jnp

def norm(X):
  X = X - X.mean(0)
  return X / X.std(0)

In [None]:
from jax import jit
norm_compiled = jit(norm)

In [None]:
np.random.seed(1701)
X = jnp.array(np.random.rand(10000, 10))
np.allclose(norm(X), norm_compiled(X), atol=1E-6)

In [None]:
%timeit norm(X).block_until_ready()
%timeit norm_compiled(X).block_until_ready()

In [None]:
from functools import partial

# @partial(jit, static_argnums=(1,))
def f(x, neg):
  return -x if neg else x

%timeit f(jnp.array([1,2]), True)

In [None]:
from functools import partial

@partial(jit, static_argnums=(1,))
def f(x, neg):
  return -x if neg else x

%timeit -n1 -r1 f(1, False)

In [None]:
%timeit -n1 -r1 f(1, False)

In [None]:
%timeit -n1 -r1 f(1, True)

In [None]:
%timeit -n1 -r1 f(1, True)

In [None]:
from jax import jit
import jax.numpy as jnp
import numpy as np

@jit
def f(x):
  #return x.reshape((np.prod(x.shape),))
    return x[x<0]

f(x)

In [None]:
x = np.arange(10)
@jit
def func(i):
    return jnp.asarray(x)[i]

%timeit -r1 -n1 func(jnp.arange(4))  

In [None]:
%timeit -r1 -n1 func(jnp.arange(4))  

In [None]:
x = jnp.arange(10)
@jit
def func(x,):
    return jnp.split(x, 2, 0)

func(x,)
%timeit func(x, )

In [None]:
grad(func)(x)

In [None]:
x = jnp.arange(10)
# @partial(jit, static_argnums=1)
def func(x, axis):
    return jnp.split(x, 2, axis)

func(x, 0)
%timeit func(x, 0)

In [None]:
from jax import jit
import jax.numpy as jnp
import numpy as np

@jit
def f(x):
  return x.reshape((np.prod(x.shape),))

x = jnp.ones((2, 3))
f(x)

In [None]:
x = jnp.ones((2, 4,6))
%timeit -n100 -r1 f(x)

In [None]:
x2=np.array([[1,2],[3,4]])
%timeit -n100 -r1  f(x2)

In [None]:
import jax.numpy as jnp
import jax
from jax import jit

def slow_f(x):
  # Element-wise ops see a large benefit from fusion
  return x * x + x * 2.0

x = jnp.ones((5000, 5000))
fast_f = jax.jit(slow_f)         # 静态编译slow_f;
%timeit -n10 -r3 fast_f(x)
%timeit -n10 -r3 slow_f(x)

In [None]:
jaxpr = jax.make_jaxpr(fast_f)
jaxpr(3)

In [None]:
def unjitted_loop_body(prev_i):
  return prev_i + 1

def g_inner_jitted_lambda(x):
  i = 0
  while i < 20:
    # Don't do this!, lambda will return
    # a function with a different hash
    i = jax.jit(lambda x: unjitted_loop_body(x))(i)
  return x + i

def g_inner_jitted_normal(x):
  i = 0
  while i < 20:
    # this is OK, since JAX can find the
    # cached, compiled function
    i = jax.jit(unjitted_loop_body)(i)
  return x + i

print("jit called in a loop with lambdas:")
%timeit -n 10 g_inner_jitted_lambda(10)

print("jit called in a loop with caching:")
%timeit -n 10 g_inner_jitted_normal(10)

In [None]:
def f(x):
    if x < 3:
        return 3. * x ** 2
    else:
        return -4 * x
    
static_f = jax.jit(f, static_argnums=(0,))
# jax.make_jaxpr(static_f, static_argnums=(0,))(2)

In [None]:
%timeit static_f(2)

In [None]:
%timeit static_f(4)

In [None]:
static_f(2),static_f(4)

In [None]:
from jax import jit
import jax.numpy as jnp
import numpy as np

# x = np.random.randn(3, 4)


@jit
def f(x):
  x = np.prod(x.shape)
  print(x)
  return x
#   return x.reshape()

x = jnp.ones((2, 3))
%timeit -r1 -n1 f(x)
%timeit -r1 -n1 f(x)
# jax.make_jaxpr(f)(x)

In [None]:
x = jnp.ones((2, 4))
%timeit -r1 -n1 f(x)
%timeit -r1 -n1 f(x)

In [None]:
x = jnp.ones((5, 14))
%timeit -n1 -r1 f(x)
%timeit -n1 -r1 f(x)

In [None]:
import numpy as np
from jax import grad, jit
from jax import lax
from jax import random
import jax
import jax.numpy as jnp

In [None]:
@jit
def f(x):
  #return x.reshape(jnp.array(x.shape).prod())
    return jnp.array(x.shape).prod()

x = jnp.ones((2, 3))
f(x)
%timeit f(x)

In [None]:
@jit
def impure_print_side_effect(x):
  print("Executing function")  # This is a side-effect
  return x
print ("First call: ", (impure_print_side_effect)(4.))

In [None]:
print ("Second call: ", (impure_print_side_effect)(5.))

In [None]:
print ("Third call, different type: ", (impure_print_side_effect)(jnp.array([5.])))

In [None]:
g = 0.
def impure_saves_global(x):
  global g
  g = x
  print('x',x)
  return x

# JAX runs once the transformed function with special Traced values for arguments
print ("First call: ", jit(impure_saves_global)(4.))
print ("Saved global: ", g)  # Saved global has an internal JAX value

In [None]:
print ("First call: ", jit(impure_saves_global)(5.))

In [None]:
jnp.array([1])

In [None]:
st=dict(even=0, odd=0)
@jit
def pure_uses_internal_state(x):
  state = st#dict(even=0, odd=0)
  for i in range(10):
    state['even' if i % 2 == 0 else 'odd'] += x
  return state['even'] + state['odd']

print(jit(pure_uses_internal_state)(7.))

In [None]:
import jax.numpy as jnp
import jax.lax as lax
from jax import make_jaxpr

# lax.fori_loop
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45
iterator = iter(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0

In [None]:
help(lax.scan)

In [None]:
# lax.scan
def func11(arr, extra):
    ones = jnp.ones(arr.shape)
    def body(carry, aelems):
        ae1, ae2 = aelems
        return (carry + ae1 * ae2 + extra, carry)
    return lax.scan(body, 0., (arr, ones))
make_jaxpr(func11)(jnp.arange(16), 5.)
# make_jaxpr(func11)(iter(range(16)), 5.) # throws error

In [None]:
func11(jnp.arange(16), 5.)

In [None]:
arr=jnp.arange(16)
extra=5
ones = jnp.ones(arr.shape)
def body(state, aelems):
    print('state:',state)
    print('aelems:',aelems)
    ae1, ae2 = aelems
    return (state + ae1 * ae2 + extra, state)
lax.scan(body, 0., (arr, ones))


In [None]:
help(lax.cond)

In [None]:
# lax.cond
print('arr:',arr)
# lax.cond(arr%2==0, lambda x: x[0]+1, lambda x: x[0]-1, (arr, ones))
arr%2==0

In [None]:
b=arr%2==0
print(b)
jnp.where(b,size=40,fill_value=-1)

In [None]:
numpy_array = np.zeros((3,3), dtype=np.float32)
print("original array:")
print(numpy_array)

# In place, mutating update
numpy_array[1, :] = 1.0
print("updated array:")
print(numpy_array)

In [None]:
arr=jnp.array((3,4))
arr

In [None]:
x = random.normal(key, (3,4))
x

In [None]:
x.at[jnp.where(x>0.2)].set(1)

In [None]:
import jax
def my_log(x):
  return jnp.where(x > 0., jnp.log(x+1e-2), jnp.log(x+1e-3))

jax.grad(my_log)(0.) # ==> NaN

In [None]:
jax.make_jaxpr(jax.grad(my_log))(0.)

In [None]:
from jax.experimental import io_callback
from functools import partial

global_rng = np.random.default_rng(0)

def host_side_random_like(x):
  """Generate a random array like x using the global_rng state"""
  # We have two side-effects here:
  # - printing the shape and dtype
  # - calling global_rng, thus updating its state
#   print(f'generating {x.dtype}{x}')
  return global_rng.uniform(size=x.shape).astype(x.dtype)

@jax.jit
def numpy_random_like(x):
  return io_callback(host_side_random_like, x, x)

x = jnp.zeros(5)
numpy_random_like(x)

In [None]:
help(io_callback)

In [None]:
jax_array = jnp.zeros((3,3), dtype=jnp.float32)
updated_array = jax_array.at[1, :].set(1.0)
print("updated array:\n", updated_array)

In [None]:
print("original array:")
jax_array = jnp.ones((5, 6))
print(jax_array)

new_jax_array = jax_array.at[::2, 3:].add(7.)
print("new array post-addition:")
print(new_jax_array)

In [None]:
jnp.arange(10.0).at[11].set(1)

In [None]:
jnp.arange(10.0).at[11].get(mode='fill', fill_value=jnp.nan)

In [None]:
@jit
def permissive_sum(x):
  return jnp.sum(x)

x = jnp.array(list(range(10)))
permissive_sum(x)

In [None]:
%timeit permissive_sum(x)

In [None]:
help(random.split)

In [None]:
def f(x):
    return  jnp.log(x)

x=jax.random.normal(key,(1000,))
print(f(x).shape)
%timeit f(x)

In [None]:
@jit
def f(x):
    return  jnp.log(x)

x=jax.random.normal(key,(1000,))
print(f(x).shape)
%timeit f(x)

In [None]:
@jit
def f(x):
  for i in range(3):
    x = 2 * x
  return x

print(f(3))

In [None]:
jnp.prod(jnp.array(x.shape))

In [None]:
@jit
def g(x):
  s=np.prod(np.array(x.shape))
  print(s)
  return jnp.ones(s)

%timeit -n1 -r1 print(g(jnp.array([[1., 2., 3.]])))
%timeit -n1 -r1 g(jnp.array([[1., 2., 3.]]))

In [None]:
%timeit -n1 -r1 print(g(jnp.array([1., 2., 3.,4])))
%timeit -n1 -r1 g(jnp.array([1., 2., 3.,4]))

In [None]:
@partial(jit,static_argnums=(1,))
def f(x, n):
  y = 0.
  for i in range(n):
    y = y + x[i]
  return y

#f = jit(f, static_argnums=(1,))
a=jnp.array([2., 3., 4.])
f(a, 2)
%timeit f(a, 2)

In [None]:
@jit
def f(x, y):
   a = x * y
   b = (x + y) / (x - y)
   c = a + 2
   return a + b * c


x = jnp.array([2., 0.])

y = jnp.array([3., 0.])

f(x, y)


In [None]:
import jax
import jax.numpy as jnp
import jax.profiler

def func1(x):
  return jnp.tile(x, 10) * 0.5

def func2(x):
  y = func1(x)
  return y, jnp.tile(x, 10) + 1

x = jax.random.normal(jax.random.PRNGKey(42), (1000, 1000))
y, z = func2(x)

z.block_until_ready()

jax.profiler.save_device_memory_profile("memory.prof")

In [None]:
!ls -alh

In [None]:
!pprof 
!go

In [None]:
import jax
import jax.numpy as jnp

@jax.jit
def f(x):
  jax.debug.print("🤯 {x} 🤯", x=x)
  y = jnp.sin(x)
  jax.debug.breakpoint()
  jax.debug.print("🤯 {y} 🤯", y=y)
  return y

f(2.)


In [None]:
import jax
from jax import lax
def breakpoint_if_nonfinite(x):
  is_finite = jnp.isfinite(x).all()
  def true_fn(x):
    pass
  def false_fn(x):
    jax.debug.breakpoint()
  lax.cond(is_finite, true_fn, false_fn, x)

@jax.jit
def f(x, y):
  z = x / y
  breakpoint_if_nonfinite(z)
  return z
f(2., 0.) # ==> Pauses during execution!

In [None]:
xs = jnp.arange(3.)

def f(x):
  jax.debug.print("x: {}", x)
  y = jnp.sin(x)
  jax.debug.print("y: {}", y)
  return y
print('v1',jax.vmap(f)(xs))
# Prints: x: 0.0
#         x: 1.0
#         x: 2.0
#         y: 0.0
#         y: 0.841471
#         y: 0.9092974
print('v2',jax.lax.map(f, xs))
# Prints: x: 0.0
#         y: 0.0
#         x: 1.0
#         y: 0.841471
#         x: 2.0
#         y: 0.9092974

In [None]:
from jax.experimental import checkify
import jax
import jax.numpy as jnp

def f(x, i):
  checkify.check(i >= 0, f"index needs to be non-negative, got {i}")
  y = x[i]
  z = jnp.sin(y)
  return z

jittable_f = checkify.checkify(f)

err, z = jax.jit(jittable_f)(jnp.ones((5,)), -2)
print(err.get())
# >> index needs to be non-negative, got -2! (check failed at <...>:6 (f))

In [None]:
errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)

err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)

err, z = checked_f(jnp.ones((5,)), -1)
err.throw()
# ValueError: index needs to be non-negative! (check failed at <…>:6 (f))

err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)

err, z = checked_f(jnp.array([5, 1]), 0)
err.throw()  # if no error occurred, throw does nothing!

In [None]:
import jax
import jax.numpy as jnp
import numpy as np

def f_host(x):
  return np.sin(x).astype(x.dtype)

def f(x):
  return jax.pure_callback(f_host, x, x)

x = jnp.arange(5.0)
f(x)

In [None]:
jax.jit(f)(x)

In [None]:
jax.vmap(f)(x)

In [None]:
def body_fun(_, x):
  return _+1, f(x)
jax.lax.scan(body_fun, 0, jnp.arange(5.0))

In [None]:
%xmode minimal
jax.grad(f)(x)

In [None]:
jnp.asarray(2)

In [None]:
jnp.array(2)

In [None]:
from jax import jit
from functools import partial
@partial(jit, static_argnums=1)
def func(x, axis):
    return x.min(axis)

In [None]:
func(jnp.arange(4), 0)  

In [None]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [None]:
# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
  w_key, b_key = random.split(key)
  return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 10
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.PRNGKey(0))

In [None]:
from jax.scipy.special import logsumexp

def relu(x):
  return jnp.maximum(0, x)

def predict(params, image):
  # per-example predictions
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = relu(outputs)
  
  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits - logsumexp(logits)  #log(p)

In [None]:
logsumexp(np.array([-2000,-1.,0,1,20]))

In [None]:
def lx(logits):
    return logits - logsumexp(logits)
lx(jnp.array([-2000,-1.,0,1,20]))

In [None]:
# This works on single examples
random_flattened_image = random.normal(random.PRNGKey(1), (28 * 28,))
preds = predict(params, random_flattened_image)
print(preds.shape)

In [None]:
# Doesn't work with a batch
random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))
try:
  preds = predict(params, random_flattened_images)
except TypeError:
  print('Invalid shapes!')

In [None]:
# Let's upgrade it to handle batches using `vmap`

# Make a batched version of the `predict` function
batched_predict = vmap(predict, in_axes=(None, 0))

# `batched_predict` has the same call signature as `predict`
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)

In [None]:
x=jnp.array([1,2,3])
k=9
jnp.array(x[:, None] == jnp.arange(k), jnp.float32)

In [None]:
def one_hot(x, k, dtype=jnp.float32):
  """Create a one-hot encoding of x of size k."""
  return jnp.array(x[:, None] == jnp.arange(k), dtype)
  
def accuracy(params, images, targets):
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
  return jnp.mean(predicted_class == target_class)

def loss(params, images, targets):
  preds = batched_predict(params, images)
  return -jnp.mean(preds * targets)

@jit
def update(params, x, y):
  grads = grad(loss)(params, x, y)
  return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]

In [None]:
import tensorflow as tf
# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type='GPU')

import tensorflow_datasets as tfds

data_dir = '/tmp/tfds'

# Fetch full datasets for evaluation
# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)
# You can convert them to NumPy arrays (or iterables of NumPy arrays) with tfds.dataset_as_numpy
mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True)
mnist_data = tfds.as_numpy(mnist_data)
train_data, test_data = mnist_data['train'], mnist_data['test']
num_labels = info.features['label'].num_classes
h, w, c = info.features['image'].shape
num_pixels = h * w * c

# Full train set
train_images, train_labels = train_data['image'], train_data['label']
train_images = jnp.reshape(train_images, (len(train_images), num_pixels))
train_labels = one_hot(train_labels, num_labels)

# Full test set
test_images, test_labels = test_data['image'], test_data['label']
test_images = jnp.reshape(test_images, (len(test_images), num_pixels))
test_labels = one_hot(test_labels, num_labels)

In [None]:
print('Train:', train_images.shape, train_labels.shape)
print('Test:', test_images.shape, test_labels.shape)

In [None]:
import time

def get_train_batches():
  # as_supervised=True gives us the (image, label) as a tuple instead of a dict
  ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir)
  # You can build up an arbitrary tf.data input pipeline
  ds = ds.batch(batch_size).prefetch(1)
  # tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays
  return tfds.as_numpy(ds)

for epoch in range(num_epochs):
  start_time = time.time()
  for x, y in get_train_batches():
    x = jnp.reshape(x, (len(x), num_pixels))
    y = one_hot(y, num_labels)
    params = update(params, x, y)
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_images, train_labels)
  test_acc = accuracy(params, test_images, test_labels)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))

In [None]:
[[e.shape for e in p] for p in params]

In [None]:
from flax import linen as nn

In [None]:

class CNN(nn.Module):
  """A simple CNN model."""

  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    return x

In [None]:
!pip install keras-core

# JAX

In [None]:
import os

os.environ["KERAS_BACKEND"] = "jax"

import jax
import numpy as np
import tensorflow as tf
import keras_core as keras

from jax.experimental import mesh_utils
from jax.sharding import Mesh
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P


def get_model():
    # Make a simple 2-layer densely-connected neural network.
    inputs = keras.Input(shape=(28,28,1))
    x = keras.layers.Reshape((28*28,))(inputs)
    x = keras.layers.Dense(256, activation="relu")(x)
    x = keras.layers.Dense(256, activation="relu")(x)
    outputs = keras.layers.Dense(10)(x)
    model = keras.Model(inputs, outputs)
    return model


def get_datasets():
    # Load the data and split it between train and test sets
    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

    # Scale images to the [0, 1] range
    x_train = x_train.astype("float32")
    x_test = x_test.astype("float32")
    # Make sure images have shape (28, 28, 1)
    x_train = np.expand_dims(x_train, -1)
    x_test = np.expand_dims(x_test, -1)
    print("x_train shape:", x_train.shape)
    print(x_train.shape[0], "train samples")
    print(x_test.shape[0], "test samples")

    # Create TF Datasets
    train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    eval_data = tf.data.Dataset.from_tensor_slices((x_test, y_test))
    return train_data, eval_data

num_epochs = 2
batch_size = 64 

train_data, eval_data = get_datasets()
train_data = train_data.batch(batch_size, drop_remainder=True)

model = get_model()
optimizer = keras.optimizers.Adam(1e-3)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# Initialize all state with .build()
(one_batch, one_batch_labels) = next(iter(train_data))
model.build(one_batch)
optimizer.build(model.trainable_variables)


# This is the loss function that will be differentiated.
# Keras provides a pure functional forward pass: model.stateless_call
def compute_loss(trainable_variables, non_trainable_variables, x, y):
    y_pred, updated_non_trainable_variables = model.stateless_call(
        trainable_variables, non_trainable_variables, x
    )
    loss_value = loss(y, y_pred)
    return loss_value, updated_non_trainable_variables


# Function to compute gradients
compute_gradients = jax.value_and_grad(compute_loss, has_aux=True)


# Training step, Keras provides a pure functional optimizer.stateless_apply
@jax.jit
def train_step(train_state, x, y):
    trainable_variables, non_trainable_variables, optimizer_variables = train_state
    (loss_value, non_trainable_variables), grads = compute_gradients(
        trainable_variables, non_trainable_variables, x, y
    )

    trainable_variables, optimizer_variables = optimizer.stateless_apply(
        optimizer_variables, grads, trainable_variables
    )

    return loss_value, (
        trainable_variables,
        non_trainable_variables,
        optimizer_variables,
    )


# Replicate the model and optimizer variable on all devices
def get_replicated_train_state(devices):
    # All variables will be replicated on all devices
    var_mesh = Mesh(devices, axis_names=("_"))
    # In NamedSharding, axes not mentioned are replicated (all axes here)
    var_replication = NamedSharding(var_mesh, P())

    # Apply the distribution settings to the model variables
    trainable_variables     = jax.device_put(model.trainable_variables,     var_replication)
    non_trainable_variables = jax.device_put(model.non_trainable_variables, var_replication)
    optimizer_variables     = jax.device_put(optimizer.variables,           var_replication)

    # Combine all state in a tuple
    return (trainable_variables, non_trainable_variables, optimizer_variables)


num_devices = len(jax.local_devices())
print(f"Running on {num_devices} devices: {jax.local_devices()}")
devices = mesh_utils.create_device_mesh((num_devices,))

# Data will be split along the batch axis
data_mesh = Mesh(devices, axis_names=("batch",))  # naming axes of the mesh
data_sharding = NamedSharding(
    data_mesh,
    P(
        "batch",
    ),
)  # naming axes of the sharded partition

# Display data sharding
x, y = next(iter(train_data))
sharded_x = jax.device_put(x.numpy(), data_sharding)
print("Data sharding")
jax.debug.visualize_array_sharding(jax.numpy.reshape(sharded_x, [-1, 28 * 28]))

train_state = get_replicated_train_state(devices)

# Custom training loop
for epoch in range(4):
    import time
    start=time.time()
    data_iter = iter(train_data)
    for data in data_iter:
        x, y = data
        sharded_x = jax.device_put(x.numpy(), data_sharding)
        loss_value, train_state = train_step(train_state, sharded_x, y.numpy())
    print("Epoch", epoch, "loss:", loss_value,'cost:',time.time()-start)

# Post-processing model state update to write them back into the model
trainable_variables, non_trainable_variables, optimizer_variables = train_state
for variable, value in zip(model.trainable_variables, trainable_variables):
    variable.assign(value)
for variable, value in zip(model.non_trainable_variables, non_trainable_variables):
    variable.assign(value)

# tensorflow

In [None]:
import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import tensorflow as tf
import keras_core as keras
# tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect(tpu='local')
# strategy = tf.distribute.TPUStrategy(tpu)
strategy = tf.distribute.MirroredStrategy()

In [None]:

def get_compiled_model():
    # Make a simple 2-layer densely-connected neural network.
    inputs = keras.Input(shape=(784,))
    x = keras.layers.Dense(256, activation="relu")(inputs)
    x = keras.layers.Dense(256, activation="relu")(x)
    outputs = keras.layers.Dense(10)(x)
    model = keras.Model(inputs, outputs)
    model.compile(
        optimizer=keras.optimizers.Adam(),
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    )
    return model


def get_dataset():
    batch_size = 64 * 8 #strategy.num_replicas_in_sync
    num_val_samples = 10000

    # Return the MNIST dataset in the form of a [`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset).
    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

    # Preprocess the data (these are Numpy arrays)
    x_train = x_train.reshape(-1, 784).astype("float32") / 255
    x_test = x_test.reshape(-1, 784).astype("float32") / 255
    y_train = y_train.astype("float32")
    y_test = y_test.astype("float32")

    # Reserve num_val_samples samples for validation
    x_val = x_train[-num_val_samples:]
    y_val = y_train[-num_val_samples:]
    x_train = x_train[:-num_val_samples]
    y_train = y_train[:-num_val_samples]
    return (
        tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size),
        tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(batch_size),
        tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size),
    )



# Open a strategy scope.
with strategy.scope():
    # Everything that creates variables should be under the strategy scope.
    # In general this is only model construction & `compile()`.
    model = get_compiled_model()

    # Train the model on all available devices.
    train_dataset, val_dataset, test_dataset = get_dataset()
    model.fit(train_dataset, epochs=4)

    # Test the model on all available devices.


# torch

In [4]:
import os

os.environ["KERAS_BACKEND"] = "torch"

import torch
import numpy as np
import keras_core as keras



def get_model():
    # Make a simple 2-layer densely-connected neural network.
    inputs = keras.Input(shape=(28,28,1))
    x = keras.layers.Reshape((28*28,))(inputs)
    x = keras.layers.Dense(256, activation="relu")(x)
    x = keras.layers.Dense(256, activation="relu")(x)
    outputs = keras.layers.Dense(10)(x)
    model = keras.Model(inputs, outputs)
    return model


def get_dataset():
    # Load the data and split it between train and test sets
    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

    # Scale images to the [0, 1] range
    x_train = x_train.astype("float32")
    x_test = x_test.astype("float32")
    # Make sure images have shape (28, 28, 1)
    x_train = np.expand_dims(x_train, -1)
    x_test = np.expand_dims(x_test, -1)
    print("x_train shape:", x_train.shape)

    # Create a TensorDataset
    dataset = torch.utils.data.TensorDataset(
        torch.from_numpy(x_train), torch.from_numpy(y_train)
    )
    return dataset

def train_model(model, dataloader, num_epochs, optimizer, loss_fn):
    for epoch in range(num_epochs):
        import time
        start=time.time()
        running_loss = 0.0
        running_loss_count = 0
        for batch_idx, (inputs, targets) in enumerate(dataloader):
            inputs = inputs.cuda(non_blocking=True)
            targets = targets.cuda(non_blocking=True)

            # Forward pass
            outputs = model(inputs)
            loss = loss_fn(outputs, targets)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            running_loss_count += 1

        # Print loss statistics
        print(
            f"Epoch {epoch + 1}/{num_epochs}, "
            f"Loss: {running_loss / running_loss_count}, "
            f"Cost: {time.time()-start}"
        )
num_gpu = torch.cuda.device_count()
num_epochs = 2
batch_size = 64 *8
print(f"Running on {num_gpu} GPUs")


def setup_device(current_gpu_index, num_gpus):
    # Device setup
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "56492"
    device = torch.device("cuda:{}".format(current_gpu_index))
    torch.distributed.init_process_group(
        backend="nccl",
        init_method="env://",
        world_size=num_gpus,
        rank=current_gpu_index,
    )
    torch.cuda.set_device(device)


def cleanup():
    torch.distributed.destroy_process_group()


def prepare_dataloader(dataset, current_gpu_index, num_gpus, batch_size):
    sampler = torch.utils.data.distributed.DistributedSampler(
        dataset,
        num_replicas=num_gpus,
        rank=current_gpu_index,
        shuffle=False,
    )
    dataloader = torch.utils.data.DataLoader(
        dataset,
        sampler=sampler,
        batch_size=batch_size,
        shuffle=False,
    )
    return dataloader


def per_device_launch_fn(current_gpu_index, num_gpu):
    # Setup the process groups
    setup_device(current_gpu_index, num_gpu)

    dataset = get_dataset()
    model = get_model()

    # prepare the dataloader
    dataloader = prepare_dataloader(dataset, current_gpu_index, num_gpu, batch_size)

    # Instantiate the torch optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # Instantiate the torch loss function
    loss_fn = torch.nn.CrossEntropyLoss()

    # Put model on device
    model = model.to(current_gpu_index)
    ddp_model = torch.nn.parallel.DistributedDataParallel(
        model, device_ids=[current_gpu_index], output_device=current_gpu_index
    )

    train_model(ddp_model, dataloader, 4, optimizer, loss_fn)

    cleanup()
    
per_device_launch_fn(0,1)

Running on 1 GPUs
x_train shape: (60000, 28, 28, 1)
Epoch 1/4, Loss: 5.487846383603952, Cost: 0.7879757881164551
Epoch 2/4, Loss: 0.6967917170565007, Cost: 0.7826569080352783
Epoch 3/4, Loss: 0.3728914860699136, Cost: 0.7845509052276611
Epoch 4/4, Loss: 0.21847973126223533, Cost: 1.052258014678955
