# Naive hybrid device mesh from Jax

In [17]:
import dataclasses
import collections
import numpy as np
from typing import Sequence, Any
from jax.experimental import mesh_utils
import jax



@dataclasses.dataclass
class Device:
  process_index: int
  slice_index: int
  uid: int
  device_kind: str = ''
  platform: str = 'cpu'

devices = [Device(i // 256, i // 256, i) for i in range(512)]

devices = mesh_utils.create_hybrid_device_mesh([64, 4], [2, 1], devices).tolist()
devices  = np.array(jax.tree.map(lambda d: d.uid, devices))

print(devices.shape)

devices


(128, 4)


array([[  0,   1,   2,   3],
       [  4,   5,   6,   7],
       [  8,   9,  10,  11],
       [ 12,  13,  14,  15],
       [ 16,  17,  18,  19],
       [ 20,  21,  22,  23],
       [ 24,  25,  26,  27],
       [ 28,  29,  30,  31],
       [ 32,  33,  34,  35],
       [ 36,  37,  38,  39],
       [ 40,  41,  42,  43],
       [ 44,  45,  46,  47],
       [ 48,  49,  50,  51],
       [ 52,  53,  54,  55],
       [ 56,  57,  58,  59],
       [ 60,  61,  62,  63],
       [ 64,  65,  66,  67],
       [ 68,  69,  70,  71],
       [ 72,  73,  74,  75],
       [ 76,  77,  78,  79],
       [ 80,  81,  82,  83],
       [ 84,  85,  86,  87],
       [ 88,  89,  90,  91],
       [ 92,  93,  94,  95],
       [ 96,  97,  98,  99],
       [100, 101, 102, 103],
       [104, 105, 106, 107],
       [108, 109, 110, 111],
       [112, 113, 114, 115],
       [116, 117, 118, 119],
       [120, 121, 122, 123],
       [124, 125, 126, 127],
       [128, 129, 130, 131],
       [132, 133, 134, 135],
       [136, 1

# Custom hybrid mesh where each 2x2 group forms a ring

In [22]:
import collections
import numpy as np
from typing import Sequence, Any
from jax.experimental import mesh_utils
import jax

def create_custom_64x4_device_mesh(
    mesh_shape: Sequence[int],
    dcn_mesh_shape: Sequence[int],
    devices: Sequence[Any],
    process_is_granule: bool = False,
    should_sort_granules_by_key: bool = True,
) -> np.ndarray:
  """Custom device mesh for 64x4 ici parallelism"""
  assert len(devices) % 256 == 0, f"This custom mesh is not valid for {len(devices)} devices"
  attr = "process_index" if process_is_granule else "slice_index"
  if not hasattr(devices[0], attr):
    raise ValueError(f"Device {devices[0]} does not have attribute {attr}. See" " `process_is_granule` option.")
  granule_dict = collections.defaultdict(list)
  for dev in devices:
    granule_dict[getattr(dev, attr)].append(dev)
  granules = (
      [granule_dict[key] for key in sorted(granule_dict.keys())] if should_sort_granules_by_key else granule_dict.values()
  )
  if np.prod(dcn_mesh_shape) != len(granules):
    raise ValueError(f"Number of slices {len(granules)} must equal the product of " f"dcn_mesh_shape {dcn_mesh_shape}")
  per_granule_meshes = [
      mesh_utils.create_device_mesh(
          [16, 16],
          granule,
          allow_split_physical_axes=False,
      )
      for granule in granules
  ]

  def reshape_mesh_to_rings(a):
    b = []
    for i in range(8):
      b.append([])
      for j in range(8):
        a_i = i * 2
        a_j = j * 2
        # forms a ring of size 4
        b[i].append([a[a_i, a_j], a[a_i, a_j + 1], a[a_i + 1, a_j + 1], a[a_i + 1, a_j]])
    b = np.array(b)
    b = np.reshape(b, (64, 4))
    return b

  per_granule_meshes = [np.reshape(reshape_mesh_to_rings(x), mesh_shape) for x in per_granule_meshes]
  # TODO(jekbradbury): handle non-uniform DCN topologies
  granule_mesh = np.arange(len(granules)).reshape(dcn_mesh_shape)
  blocks = np.vectorize(lambda i: per_granule_meshes[i], otypes=[object])(granule_mesh)
  device_mesh = np.block(blocks.tolist())
  return device_mesh


devices = [Device(i // 256, i // 256, i) for i in range(512)]

devices = create_custom_64x4_device_mesh([64, 4], [2, 1], devices).tolist()
devices  = np.array(jax.tree.map(lambda d: d.uid, devices))

print(devices.shape)

devices


(128, 4)


array([[  0,   1,  17,  16],
       [  2,   3,  19,  18],
       [  4,   5,  21,  20],
       [  6,   7,  23,  22],
       [  8,   9,  25,  24],
       [ 10,  11,  27,  26],
       [ 12,  13,  29,  28],
       [ 14,  15,  31,  30],
       [ 32,  33,  49,  48],
       [ 34,  35,  51,  50],
       [ 36,  37,  53,  52],
       [ 38,  39,  55,  54],
       [ 40,  41,  57,  56],
       [ 42,  43,  59,  58],
       [ 44,  45,  61,  60],
       [ 46,  47,  63,  62],
       [ 64,  65,  81,  80],
       [ 66,  67,  83,  82],
       [ 68,  69,  85,  84],
       [ 70,  71,  87,  86],
       [ 72,  73,  89,  88],
       [ 74,  75,  91,  90],
       [ 76,  77,  93,  92],
       [ 78,  79,  95,  94],
       [ 96,  97, 113, 112],
       [ 98,  99, 115, 114],
       [100, 101, 117, 116],
       [102, 103, 119, 118],
       [104, 105, 121, 120],
       [106, 107, 123, 122],
       [108, 109, 125, 124],
       [110, 111, 127, 126],
       [128, 129, 145, 144],
       [130, 131, 147, 146],
       [132, 1

In [19]:
devices = [Device(i // 256, i // 256, i) for i in range(512)]

flat_devices = create_custom_64x4_device_mesh([64, 4], [2, 1], devices).reshape(-1).tolist()
flat_devices  = np.array(jax.tree.map(lambda d: d.uid, flat_devices))
flat_devices

array([  0,   1,  17,  16,   2,   3,  19,  18,   4,   5,  21,  20,   6,
         7,  23,  22,   8,   9,  25,  24,  10,  11,  27,  26,  12,  13,
        29,  28,  14,  15,  31,  30,  32,  33,  49,  48,  34,  35,  51,
        50,  36,  37,  53,  52,  38,  39,  55,  54,  40,  41,  57,  56,
        42,  43,  59,  58,  44,  45,  61,  60,  46,  47,  63,  62,  64,
        65,  81,  80,  66,  67,  83,  82,  68,  69,  85,  84,  70,  71,
        87,  86,  72,  73,  89,  88,  74,  75,  91,  90,  76,  77,  93,
        92,  78,  79,  95,  94,  96,  97, 113, 112,  98,  99, 115, 114,
       100, 101, 117, 116, 102, 103, 119, 118, 104, 105, 121, 120, 106,
       107, 123, 122, 108, 109, 125, 124, 110, 111, 127, 126, 128, 129,
       145, 144, 130, 131, 147, 146, 132, 133, 149, 148, 134, 135, 151,
       150, 136, 137, 153, 152, 138, 139, 155, 154, 140, 141, 157, 156,
       142, 143, 159, 158, 160, 161, 177, 176, 162, 163, 179, 178, 164,
       165, 181, 180, 166, 167, 183, 182, 168, 169, 185, 184, 17

# Naive single pod mesh

In [23]:
import torch_xla.distributed.spmd as xs

num_devices = 256
fsdp_axis = 64
tensor_axis = 4
mesh_shape = (fsdp_axis, tensor_axis)
spmd_mesh = xs.Mesh(np.array(range(num_devices)), mesh_shape, ('fsdp', 'tensor'))
spmd_mesh.device_ids

array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
        78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
        91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
       104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116,
       117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129,
       130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142,
       143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
       156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168,
       169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 18

In [21]:
len(devices)

512