In [1]:
from utils import simulate_CPU_devices

simulate_CPU_devices()

In [30]:
jax.devices()

[CpuDevice(id=0),
 CpuDevice(id=1),
 CpuDevice(id=2),
 CpuDevice(id=3),
 CpuDevice(id=4),
 CpuDevice(id=5),
 CpuDevice(id=6),
 CpuDevice(id=7)]

In [2]:
import functools
from pprint import pprint
from typing import Any, Callable, Dict, Sequence, Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
from absl import logging
from jax import lax
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
from ml_collections import ConfigDict

PyTree = Any
Metrics = Dict[str, Tuple[jax.Array, ...]]

In [3]:
from single_gpu import Batch, TrainState, accumulate_gradients, print_metrics

In [4]:
def fold_rng_over_axis(rng: jax.random.PRNGKey, axis_name: str) -> jax.random.PRNGKey:
    """Folds the random number generator over the given axis.

    This is useful for generating a different random number for each device
    across a certain axis (e.g. the model axis).

    Args:
        rng: The random number generator.
        axis_name: The axis name to fold the random number generator over.

    Returns:
        A new random number generator, different for each device index along the axis.
    """
    axis_index = jax.lax.axis_index(axis_name)
    return jax.random.fold_in(rng, axis_index)

In [5]:
class DPClassifier(nn.Module):
    config: ConfigDict

    @nn.compact
    def __call__(self, x: jax.Array, train: bool) -> jax.Array:
        x = nn.Dense(
            features=self.config.hidden_size,
            dtype=self.config.dtype,
            name="input_dense",
        )(x)
        x = nn.silu(x)
        x = nn.Dropout(rate=self.config.dropout_rate, deterministic=not train)(x)
        x = nn.Dense(
            features=self.config.num_classes,
            dtype=self.config.dtype,
            name="output_dense",
        )(x)
        x = x.astype(jnp.float32)
        return x

In [6]:
data_config = ConfigDict(
    dict(
        batch_size=128,
        num_classes=10,
        input_size=784,
    )
)
model_config = ConfigDict(
    dict(
        hidden_size=512,
        dropout_rate=0.1,
        dtype=jnp.bfloat16,
        num_classes=data_config.num_classes,
        data_axis_name="data",
    )
)
optimizer_config = ConfigDict(
    dict(
        learning_rate=1e-3,
        num_minibatches=4,
    )
)
config = ConfigDict(
    dict(
        model=model_config,
        optimizer=optimizer_config,
        data=data_config,
        data_axis_name=model_config.data_axis_name,
        seed=42,
    )
)

In [7]:
model_dp = DPClassifier(config=config.model)
optimizer = optax.adamw(
    learning_rate=config.optimizer.learning_rate,
)

In [8]:
rng = jax.random.PRNGKey(config.seed)
model_init_rng, data_inputs_rng, data_labels_rng = jax.random.split(rng, 3)
batch = Batch(
    inputs=jax.random.normal(data_inputs_rng, (config.data.batch_size, config.data.input_size)),
    labels=jax.random.randint(
        data_labels_rng, (config.data.batch_size,), 0, config.data.num_classes
    ),
)

In [10]:
print(rng, model_init_rng, data_inputs_rng, data_labels_rng)

[ 0 42] [3134548294 3733159049] [3746501087  894150801] [ 801545058 2363201431]


In [11]:
print(batch.inputs)

[[-0.9723122  -1.2452544   0.17470965 ...  0.08267392 -0.89882183
   0.0232077 ]
 [ 0.8691865   0.64485383 -1.8292103  ... -1.555047    0.82450116
   0.2290802 ]
 [ 0.7618328  -0.07111705 -0.4348067  ...  0.7814184  -0.46264035
   0.3288694 ]
 ...
 [ 1.4671646   0.48726186 -0.51675266 ... -0.03286455 -0.5917859
  -0.7451234 ]
 [ 1.2910836  -0.43864703 -0.50237346 ...  1.1371423   0.07289556
  -0.24300338]
 [ 1.2610282   0.6610234  -0.530804   ...  0.32887977 -0.62506366
  -0.02475406]]


In [12]:
def init_dp(rng: jax.random.PRNGKey, x: jax.Array, model: nn.Module) -> TrainState:
    init_rng, rng = jax.random.split(rng)
    variables = model.init({"params": init_rng}, x, train=False)
    params = variables.pop("params")
    state = TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=optimizer,
        rng=rng,
    )
    return state

In [13]:
device_array = np.array(jax.devices())
mesh = Mesh(device_array, (config.data_axis_name,))

In [14]:
init_dp_fn = jax.jit(
    shard_map(
        functools.partial(init_dp, model=model_dp),
        mesh,
        in_specs=(P(), P(config.data_axis_name)),
        out_specs=P(),
        check_rep=False,
    ),
)

In [16]:
state_dp = init_dp_fn(model_init_rng, batch.inputs)
print("DP Parameters")
pprint(jax.tree.map(lambda x: (x.shape, x.sharding), state_dp.params))

DP Parameters
{'input_dense': {'bias': ((512,),
                          NamedSharding(mesh=Mesh('data': 8), spec=PartitionSpec(), memory_kind=unpinned_host)),
                 'kernel': ((784, 512),
                            NamedSharding(mesh=Mesh('data': 8), spec=PartitionSpec(), memory_kind=unpinned_host))},
 'output_dense': {'bias': ((10,),
                           NamedSharding(mesh=Mesh('data': 8), spec=PartitionSpec(), memory_kind=unpinned_host)),
                  'kernel': ((512, 10),
                             NamedSharding(mesh=Mesh('data': 8), spec=PartitionSpec(), memory_kind=unpinned_host))}}


In [17]:
def loss_fn(
    params: PyTree, apply_fn: Any, batch: Batch, rng: jax.Array
) -> Tuple[jax.Array, Dict[str, Any]]:
    # Since dropout masks vary across the batch dimension, we want each device to generate a
    # different mask. We can achieve this by folding the rng over the data axis, so that each
    # device gets a different rng and thus mask.
    dropout_rng = fold_rng_over_axis(rng, config.data_axis_name)
    # Remaining computation is the same as before for single device.
    logits = apply_fn({"params": params}, batch.inputs, train=True, rngs={"dropout": dropout_rng})
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch.labels)
    correct_pred = jnp.equal(jnp.argmax(logits, axis=-1), batch.labels)
    batch_size = batch.inputs.shape[0]
    step_metrics = {"loss": (loss.sum(), batch_size), "accuracy": (correct_pred.sum(), batch_size)}
    loss = loss.mean()
    return loss, step_metrics

In [78]:
def train_step_dp(
    state: TrainState,
    metrics: Metrics | None,
    batch: Batch,
) -> Tuple[TrainState, Metrics]:
    rng, step_rng = jax.random.split(state.rng)
    grads, step_metrics = accumulate_gradients(
        state,
        batch,
        step_rng,
        config.optimizer.num_minibatches,
        loss_fn=loss_fn,
    )
    print(grads)
    # jax.debug.visualize_array_sharding(grads['output_dense']['kernel'])
    # Update parameters. We need to sync the gradients across devices before updating.
    with jax.named_scope("sync_gradients"):
        grads = jax.tree.map(lambda g: jax.lax.pmean(g, axis_name=config.data_axis_name), grads)
    new_state = state.apply_gradients(grads=grads, rng=rng)
    # Sum metrics across replicas. Alternatively, we could keep the metrics separate
    # and only synchronize them before logging. For simplicity, we sum them here.
    with jax.named_scope("sync_metrics"):
        step_metrics = jax.tree_map(
            lambda x: jax.lax.psum(x, axis_name=config.data_axis_name), step_metrics
        )
    if metrics is None:
        metrics = step_metrics
    else:
        metrics = jax.tree_map(jnp.add, metrics, step_metrics)
    return new_state, metrics

In [79]:
train_step_dp_fn = jax.jit(
    shard_map(
        train_step_dp,
        mesh,
        in_specs=(P(), P(), P(config.data_axis_name)),
        out_specs=(P(), P()),
        check_rep=False,
    ),
    donate_argnames=("state", "metrics"),
)

In [58]:
_, metric_shapes = jax.eval_shape(
    train_step_dp_fn,
    state_dp,
    None,
    batch,
)
metrics_dp = jax.tree.map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)

In [80]:
for _ in range(5):
    state_dp, metrics_dp = train_step_dp_fn(state_dp, metrics_dp, batch)
final_metrics_dp = jax.tree.map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)
state_dp, final_metrics_dp = train_step_dp_fn(state_dp, final_metrics_dp, batch)
print_metrics(final_metrics_dp)

{'input_dense': {'bias': Traced<ShapedArray(float32[512])>with<DynamicJaxprTrace>, 'kernel': Traced<ShapedArray(float32[784,512])>with<DynamicJaxprTrace>}, 'output_dense': {'bias': Traced<ShapedArray(float32[10])>with<DynamicJaxprTrace>, 'kernel': Traced<ShapedArray(float32[512,10])>with<DynamicJaxprTrace>}}


XlaRuntimeError: INTERNAL: Error calling inspect_sharding: Traceback (most recent call last):
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\runpy.py", line 196, in _run_module_as_main
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\runpy.py", line 86, in _run_code
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\ipykernel_launcher.py", line 18, in <module>
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\traitlets\config\application.py", line 1075, in launch_instance
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\ipykernel\kernelapp.py", line 739, in start
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\tornado\platform\asyncio.py", line 205, in start
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\asyncio\base_events.py", line 603, in run_forever
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\asyncio\base_events.py", line 1909, in _run_once
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\asyncio\events.py", line 80, in _run
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\ipykernel\kernelbase.py", line 545, in dispatch_queue
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\ipykernel\kernelbase.py", line 534, in process_one
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\ipykernel\kernelbase.py", line 437, in dispatch_shell
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\ipykernel\ipkernel.py", line 362, in execute_request
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\ipykernel\kernelbase.py", line 778, in execute_request
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\ipykernel\ipkernel.py", line 449, in do_execute
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\ipykernel\zmqshell.py", line 549, in run_cell
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\IPython\core\interactiveshell.py", line 3075, in run_cell
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\IPython\core\interactiveshell.py", line 3130, in _run_cell
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\IPython\core\async_helpers.py", line 128, in _pseudo_sync_runner
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\IPython\core\interactiveshell.py", line 3334, in run_cell_async
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\IPython\core\interactiveshell.py", line 3517, in run_ast_nodes
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\IPython\core\interactiveshell.py", line 3577, in run_code
  File "C:\Users\leyu0002\AppData\Local\Temp\ipykernel_27612\784880305.py", line 2, in <module>
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\jax\_src\traceback_util.py", line 180, in reraise_with_filtered_traceback
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\jax\_src\pjit.py", line 337, in cache_miss
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\jax\_src\pjit.py", line 195, in _python_pjit_helper
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\jax\_src\pjit.py", line 1672, in _pjit_call_impl_python
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\jax\_src\interpreters\pxla.py", line 2415, in compile
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\jax\_src\interpreters\pxla.py", line 2923, in from_hlo
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\jax\_src\interpreters\pxla.py", line 2729, in _cached_compilation
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\jax\_src\compiler.py", line 452, in compile_or_get_cached
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\jax\_src\compiler.py", line 653, in _compile_and_write_cache
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\jax\_src\profiler.py", line 333, in wrapper
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\jax\_src\compiler.py", line 303, in backend_compile
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\jax\_src\debugging.py", line 403, in _hlo_sharding_callback
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\jax\_src\debugging.py", line 644, in _visualize
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\jax\_src\debugging.py", line 514, in visualize_sharding
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\jax\_src\sharding.py", line 167, in devices_indices_map
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\jax\_src\util.py", line 302, in wrapper
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\jax\_src\util.py", line 296, in cached
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\jax\_src\sharding.py", line 48, in common_devices_indices_map
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\jax\_src\sharding.py", line 184, in shard_shape
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\jax\_src\util.py", line 302, in wrapper
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\jax\_src\util.py", line 296, in cached
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\jax\_src\sharding.py", line 57, in _common_shard_shape
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\jax\_src\sharding_impls.py", line 830, in _to_xla_hlo_sharding
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\jax\_src\util.py", line 302, in wrapper
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\jax\_src\util.py", line 296, in cached
  File "C:\Users\leyu0002\Anaconda3\envs\JAX_FLAX\lib\site-packages\jax\_src\sharding_impls.py", line 702, in _positional_sharding_to_xla_hlo_sharding
ValueError: not enough values to unpack (expected 1, got 0)

In [28]:
print("DP Parameters")
pprint(jax.tree_map(lambda x: (x.shape, x.sharding), state_dp.params))
print("Metrics")
pprint(jax.tree_map(lambda x: (x.shape, x.sharding), final_metrics_dp))

DP Parameters
{'input_dense': {'bias': ((512,),
                          NamedSharding(mesh=Mesh('data': 8), spec=PartitionSpec(), memory_kind=unpinned_host)),
                 'kernel': ((784, 512),
                            NamedSharding(mesh=Mesh('data': 8), spec=PartitionSpec(), memory_kind=unpinned_host))},
 'output_dense': {'bias': ((10,),
                           NamedSharding(mesh=Mesh('data': 8), spec=PartitionSpec(), memory_kind=unpinned_host)),
                  'kernel': ((512, 10),
                             NamedSharding(mesh=Mesh('data': 8), spec=PartitionSpec(), memory_kind=unpinned_host))}}
Metrics
{'accuracy': (((),
               NamedSharding(mesh=Mesh('data': 8), spec=PartitionSpec(), memory_kind=unpinned_host)),
              ((),
               NamedSharding(mesh=Mesh('data': 8), spec=PartitionSpec(), memory_kind=unpinned_host))),
 'loss': (((),
           NamedSharding(mesh=Mesh('data': 8), spec=PartitionSpec(), memory_kind=unpinned_host)),
          ((),

In [55]:
jax.debug.visualize_array_sharding(batch.inputs)

In [45]:
jax.debug.visualize_array_sharding(state_dp.params['input_dense']['bias'])

In [47]:
jax.debug.visualize_array_sharding(state_dp.params['input_dense']['kernel'])