# Multihost TPU Jupyter Notebook Demo
This notebook shows how to run jupyternote book on mutlihost tpu and multislice TPUs.

## Connect to `ipyparallel` cluster as a client

Please set the `code_dir` to be the path you used to run `ipp_tool.py`, we are expecting there is `ipcontroller-client.json` under `code_dir/ipython/security/` folder.

In [1]:
import ipyparallel as ipp
import os
code_dir = '/home/yejingxin/src/'
rc = ipp.Client(connection_info=os.path.join(code_dir, 'ipython/security/ipcontroller-client.json'))

## Start each cell with the cell magic `%%px --block --group-outputs=engine`
This cell magic helps send its code block to each TPU host. 

In [2]:
%%px --block --group-outputs=engine
import jax
print(jax.device_count())

[stderr:0] E0330 16:44:16.203816036   26621 credentials_generic.cc:35]            Could not get HOME environment variable.


[stderr:1] E0330 16:44:16.225357127   26637 credentials_generic.cc:35]            Could not get HOME environment variable.


%px:   0%|                                            | 0/2 [00:01<?, ?tasks/s]

[stdout:0] 8


%px:  50%|██████████████████                  | 1/2 [00:01<00:00,  9.89tasks/s]

[stdout:1] 8


%px: 100%|████████████████████████████████████| 2/2 [00:01<00:00,  1.09tasks/s]


In [3]:
%%px --block --group-outputs=engine
from functools import partial

import numpy as np

import jax
jax.config.update("jax_array", True)
# This allows replicated jax.Arrays to be used for computation on the host.
jax.config.update("jax_spmd_mode", "allow_all")
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map

devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, axis_names=('i', 'j'))

a = jnp.arange( 8 * 16.).reshape(8, 16)
b = jnp.arange(16 * 32.).reshape(16, 32)

@partial(shard_map, mesh=mesh, in_specs=(P('i', 'j'), P('j', None)),
         out_specs=P('i', None))
def matmul_basic(a_block, b_block):
  z_partialsum = jnp.dot(a_block, b_block)
  z_block = jax.lax.psum(z_partialsum, 'j')
  return z_block

c = matmul_basic(a, b)  # c: f32[8, 32]


In [4]:
%%px --block --group-outputs=engine
jax.debug.visualize_array_sharding(c)

[output:0]

[output:1]

In [6]:
%%px --block --group-outputs=engine
from jax.experimental import multihost_utils
print(multihost_utils.process_allgather(c).shape)

[stdout:1] (8, 32)


[stdout:0] (8, 32)
