In [1]:
import numpy as np

# Params

In [2]:
N = 7 # number of nodes
c = 2 # number of replicas per node
E = 4 # number of experts
f = 2 # fault tolerance, min number of replicas per expert

In [3]:
assert N * c >= E * f, "Not enough replicas to tolerate f faults"

# Imaginary Stats

In [4]:
# number of tokens routed to experts
t = (np.arange(E) + 1) * 100
t

array([100, 200, 300, 400])

# Expert Allocation

$$ r_e = \max \left\{ \left\lfloor \frac{t_e}{\sum_{e'=e}^{E} t_{e'}} \cdot \left(N \cdot c - \sum_{e'=1}^{e-1} r_{e'} \right) \right\rfloor, f \right\} $$

In [5]:
# number of allocated replicas per expert
r = np.zeros(E)

for e in range(E):
    weighted_t = t[e] / np.sum(t[e:])
    total_r = np.sum(r[:e])
    r_count = np.floor(weighted_t * (N * c - total_r))
    r[e] = max(r_count, f)

r

array([2., 2., 4., 6.])

# Expert Placement
## Case 2: $E > c$

In [6]:
Ec = int(np.ceil(E / c))
Ec

2

In [7]:
# expert partitions
ep = [ [] for _ in range(Ec) ]
for e in range(E):
    ep[e // c].append(e)
ep

[[0, 1], [2, 3]]

In [8]:
# node partitions

s = np.sum([r[c * i + 1] for i in range(Ec)])
min(N, s)

7

In [9]:
# node groups

slots = np.ones((N, c), dtype=np.int32) * (-1)
last_n = 0
r_rem = r.copy()

for k in range(Ec):
    n = r[c * k]
    if k == Ec - 1:
        n = min(N - np.sum([r[c * i + 1] for i in range(Ec - 1)]), r[c * (Ec - 1) + 1])
    n = int(n)

    print(f"Node group {k}: {n} nodes. Belongs to experts {ep[k]}")

    for j in range(last_n, last_n + n):
        for i in range(min(c, len(ep[k]))):
            e = ep[k][i]
            r_rem[e] -= 1
            slots[j][i] = e

    last_n += n

Node group 0: 2 nodes. Belongs to experts [0, 1]
Node group 1: 5 nodes. Belongs to experts [2, 3]


In [10]:
# fill remaining slots

e_i = 0

for i in range(N):
    for j in range(c):
        if slots[i][j] == -1:
            while r_rem[e_i] == 0:
                e_i += 1
            slots[i][j] = e_i
            r_rem[e_i] -= 1

In [11]:
slots

array([[0, 1],
       [0, 1],
       [2, 3],
       [2, 3],
       [2, 3],
       [2, 3],
       [2, 3]], dtype=int32)