In [3]:
import dataclasses
import collections
import numpy as np
from typing import Sequence, Any
from jax.experimental import mesh_utils
import jax

def create_custom_64x2_device_mesh(
    mesh_shape: Sequence[int],
    dcn_mesh_shape: Sequence[int],
    devices: Sequence[Any],
    process_is_granule: bool = False,
    should_sort_granules_by_key: bool = True,
) -> np.ndarray:
  """Custom device mesh for 64x2 ici parallelism"""
  assert len(devices) % 256 == 0, f"This custom mesh is not valid for {len(devices)} devices"
  attr = "process_index" if process_is_granule else "slice_index"
  if not hasattr(devices[0], attr):
    raise ValueError(f"Device {devices[0]} does not have attribute {attr}. See" " `process_is_granule` option.")
  granule_dict = collections.defaultdict(list)
  for dev in devices:
    granule_dict[getattr(dev, attr)].append(dev)
  granules = (
      [granule_dict[key] for key in sorted(granule_dict.keys())] if should_sort_granules_by_key else granule_dict.values()
  )
  if np.prod(dcn_mesh_shape) != len(granules):
    raise ValueError(f"Number of slices {len(granules)} must equal the product of " f"dcn_mesh_shape {dcn_mesh_shape}")
  per_granule_meshes = [
      mesh_utils.create_device_mesh(
          [8, 16],
          granule,
          allow_split_physical_axes=False,
      )
      for granule in granules
  ]

  def reshape_mesh_to_rings(a):
    b = []
    for i in range(4):
      b.append([])
      for j in range(8):
        a_i = i * 2
        a_j = j * 2
        # forms a ring of size 4
        b[i].append([a[a_i, a_j], a[a_i, a_j + 1], a[a_i + 1, a_j + 1], a[a_i + 1, a_j]])
    b = np.array(b)
    b = np.reshape(b, (64, 2))
    return b

  per_granule_meshes = [np.reshape(reshape_mesh_to_rings(x), mesh_shape) for x in per_granule_meshes]
  # TODO(jekbradbury): handle non-uniform DCN topologies
  granule_mesh = np.arange(len(granules)).reshape(dcn_mesh_shape)
  blocks = np.vectorize(lambda i: per_granule_meshes[i], otypes=[object])(granule_mesh)
  device_mesh = np.block(blocks.tolist())
  return device_mesh


@dataclasses.dataclass
class Device:
  process_index: int
  slice_index: int
  uid: int
  device_kind: str = ''
  platform: str = 'cpu'


def get_hybrid_mesh(ici_mesh_shape: Sequence[int], dcn_mesh_shape: Sequence[int], num_devices: int, num_slices: int) -> np.ndarray:
  num_devices_per_granule = num_devices // num_slices
  devices = [Device(i // num_devices_per_granule, i // num_devices_per_granule, i) for i in range(num_devices)]
  devices = create_custom_64x2_device_mesh(ici_mesh_shape, dcn_mesh_shape, devices).reshape(-1).tolist()
  devices = np.array(jax.tree_map(lambda d: d.uid, devices))
  return devices


In [4]:
get_hybrid_mesh(ici_mesh_shape=(1, 64, 2), dcn_mesh_shape=(2, 1, 1), num_devices=256, num_slices=2)

  devices = np.array(jax.tree_map(lambda d: d.uid, devices))


array([  0,   1,  17,  16,   2,   3,  19,  18,   4,   5,  21,  20,   6,
         7,  23,  22,   8,   9,  25,  24,  10,  11,  27,  26,  12,  13,
        29,  28,  14,  15,  31,  30,  32,  33,  49,  48,  34,  35,  51,
        50,  36,  37,  53,  52,  38,  39,  55,  54,  40,  41,  57,  56,
        42,  43,  59,  58,  44,  45,  61,  60,  46,  47,  63,  62,  64,
        65,  81,  80,  66,  67,  83,  82,  68,  69,  85,  84,  70,  71,
        87,  86,  72,  73,  89,  88,  74,  75,  91,  90,  76,  77,  93,
        92,  78,  79,  95,  94,  96,  97, 113, 112,  98,  99, 115, 114,
       100, 101, 117, 116, 102, 103, 119, 118, 104, 105, 121, 120, 106,
       107, 123, 122, 108, 109, 125, 124, 110, 111, 127, 126, 128, 129,
       145, 144, 130, 131, 147, 146, 132, 133, 149, 148, 134, 135, 151,
       150, 136, 137, 153, 152, 138, 139, 155, 154, 140, 141, 157, 156,
       142, 143, 159, 158, 160, 161, 177, 176, 162, 163, 179, 178, 164,
       165, 181, 180, 166, 167, 183, 182, 168, 169, 185, 184, 17