In [1]:
import torch
import torch.nn as nn
from torch import Tensor

In [2]:
torch.manual_seed(11)

batch = 2
sqlen = 6
hidden_dim = 16
topk = 2
experts = 8

x = torch.randn(batch, sqlen, hidden_dim)
x_flat = x.view(-1, hidden_dim) # [12, 16]

def renormalization(input: Tensor) -> Tensor:
        total = input.sum(
            dim=-1, keepdim=True
        )  # total sum of experts' raw scores, not token count, hence -1 dim
        renormalized = input / total  # [batch*seqlen, topk]
        return renormalized

weights = renormalization(torch.rand(batch*sqlen, topk))
indices = torch.randint(0, 7, [batch*sqlen, topk])

weights, indices

(tensor([[0.5858, 0.4142],
         [0.4547, 0.5453],
         [0.7196, 0.2804],
         [0.7319, 0.2681],
         [0.5072, 0.4928],
         [0.5598, 0.4402],
         [0.2730, 0.7270],
         [0.4512, 0.5488],
         [0.1884, 0.8116],
         [0.3390, 0.6610],
         [0.0505, 0.9495],
         [0.8895, 0.1105]]),
 tensor([[3, 0],
         [1, 6],
         [4, 0],
         [2, 1],
         [1, 0],
         [3, 3],
         [4, 5],
         [3, 6],
         [6, 5],
         [4, 0],
         [0, 6],
         [4, 3]]))

In [3]:
# total weight being received
expert_load_0 = torch.where(indices==0)
token_idx, topk_idx = expert_load_0
token_idx, topk_idx

(tensor([ 0,  2,  4,  9, 10]), tensor([1, 1, 1, 1, 0]))

In [4]:
expert_0_weights_summed = sum([(weights[i][j]) for i, j in zip(token_idx, topk_idx)])
expert_0_weights_summed

tensor(1.8990)

In [5]:
def compute_load_balance_loss(weights, indices):
    total_tokens, topk_experts = weights.shape
    results = []
    for i in range(experts):
        token_idx, topk_idx = torch.where(indices==i)
        expert_i_weights_summed = weights[token_idx, topk_idx].sum()
        results.append(expert_i_weights_summed)
    print(f"Results:\n{results}")
    expert_loads = torch.stack(results)
    print(f"results with torch.stack:\n{expert_loads}")
    expert_load_fractions = expert_loads / total_tokens
    load_variance = torch.var(expert_load_fractions)
    return load_variance

In [6]:
a = compute_load_balance_loss(weights, indices)
a

Results:
[tensor(1.8990), tensor(1.2300), tensor(0.7319), tensor(2.1475), tensor(2.2211), tensor(1.5386), tensor(2.2319), tensor(0.)]
results with torch.stack:
tensor([1.8990, 1.2300, 0.7319, 2.1475, 2.2211, 1.5386, 2.2319, 0.0000])


tensor(0.0045)

In [7]:
b = torch.rand(2, 6, 16)
b.shape

torch.Size([2, 6, 16])

In [8]:
b.reshape(-1, 16).shape

torch.Size([12, 16])

In [9]:
W_router = nn.Linear(4096, 8, bias=False)
W_router.weight.shape, W_router

(torch.Size([8, 4096]), Linear(in_features=4096, out_features=8, bias=False))

In [10]:
torch.manual_seed(10)

x_in = torch.randn(2, 6, 4096)
x_flattened = x_in.reshape(-1, 4096)
x_flattened.shape

torch.Size([12, 4096])

In [11]:
router_matrix = W_router(x_flattened)
router_matrix.shape, router_matrix

(torch.Size([12, 8]),
 tensor([[ 1.4361,  0.5818, -0.3305, -1.5377,  0.5158, -0.7169, -0.0205, -0.4584],
         [ 0.0093,  0.6592, -0.3625,  1.0004,  0.6719, -0.3256,  0.1337, -0.0865],
         [-0.8257, -0.4223,  1.2087,  0.3993, -1.1458, -0.2525,  0.3592, -1.1005],
         [ 0.5657,  0.8405, -0.6680, -0.0194, -0.4599,  0.6501, -0.3727, -0.1869],
         [-0.0244,  0.6992,  0.5910, -1.7488,  0.0191,  0.6949,  0.4897,  0.5927],
         [ 0.4651,  0.7655,  0.1293,  0.8185, -0.0651, -0.2502,  0.7604, -0.9779],
         [ 0.3417,  0.2140,  0.4174, -0.3592,  0.5346, -0.2904, -0.8432, -0.0372],
         [ 0.5197, -0.8920,  0.0574,  0.4456,  0.2165, -0.2069,  0.1237, -0.3281],
         [ 0.2653, -0.2304,  0.1694,  0.0213, -0.4691,  0.0512,  0.1946,  0.4519],
         [-0.2691, -0.6196, -0.2924, -0.2852, -0.2254, -0.4779, -0.2580, -1.1385],
         [-0.6962, -0.7461,  0.1406, -0.0272, -0.2004,  0.5228, -0.1185,  1.0816],
         [ 0.0731,  1.0642,  0.7293, -0.1455,  0.3779,  1.0956,  

In [12]:
router_matrix[3].shape, router_matrix[3]

(torch.Size([8]),
 tensor([ 0.5657,  0.8405, -0.6680, -0.0194, -0.4599,  0.6501, -0.3727, -0.1869],
        grad_fn=<SelectBackward0>))

In [13]:
W_router.weight[3, :].shape

torch.Size([4096])

In [14]:
W_router.weight.T.shape, W_router.weight.T

(torch.Size([4096, 8]),
 tensor([[-0.0043,  0.0086,  0.0075,  ..., -0.0029, -0.0094,  0.0082],
         [-0.0087, -0.0008,  0.0036,  ..., -0.0107,  0.0113,  0.0142],
         [ 0.0052,  0.0141, -0.0129,  ...,  0.0122,  0.0016,  0.0136],
         ...,
         [-0.0141, -0.0149,  0.0123,  ..., -0.0124, -0.0023, -0.0058],
         [-0.0099,  0.0108, -0.0073,  ...,  0.0148,  0.0040,  0.0107],
         [ 0.0009,  0.0099,  0.0082,  ...,  0.0104,  0.0116, -0.0101]],
        grad_fn=<PermuteBackward0>))

In [15]:
torch.matmul(x_flattened, W_router.weight.T)

tensor([[ 1.4361,  0.5818, -0.3305, -1.5377,  0.5158, -0.7169, -0.0205, -0.4584],
        [ 0.0093,  0.6592, -0.3625,  1.0004,  0.6719, -0.3256,  0.1337, -0.0865],
        [-0.8257, -0.4223,  1.2087,  0.3993, -1.1458, -0.2525,  0.3592, -1.1005],
        [ 0.5657,  0.8405, -0.6680, -0.0194, -0.4599,  0.6501, -0.3727, -0.1869],
        [-0.0244,  0.6992,  0.5910, -1.7488,  0.0191,  0.6949,  0.4897,  0.5927],
        [ 0.4651,  0.7655,  0.1293,  0.8185, -0.0651, -0.2502,  0.7604, -0.9779],
        [ 0.3417,  0.2140,  0.4174, -0.3592,  0.5346, -0.2904, -0.8432, -0.0372],
        [ 0.5197, -0.8920,  0.0574,  0.4456,  0.2165, -0.2069,  0.1237, -0.3281],
        [ 0.2653, -0.2304,  0.1694,  0.0213, -0.4691,  0.0512,  0.1946,  0.4519],
        [-0.2691, -0.6196, -0.2924, -0.2852, -0.2254, -0.4779, -0.2580, -1.1385],
        [-0.6962, -0.7461,  0.1406, -0.0272, -0.2004,  0.5228, -0.1185,  1.0816],
        [ 0.0731,  1.0642,  0.7293, -0.1455,  0.3779,  1.0956,  0.1346,  0.3392]],
       grad_fn=

In [None]:
xxx = torch.randn(12, 2)

In [21]:
xxx - xxx.detach()

tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.]])

In [26]:
class Router(nn.Module):
    def __init__(self, hidden_dim: int, n_ffn_experts: int, topk: int, noise_std: float = 0.1):
        super().__init__()
        self.W_router = nn.Linear(hidden_dim, n_ffn_experts, bias=False)
        self.register_buffer("training_step", torch.tensor(0, dtype=torch.long))
        self.register_buffer("anneal_steps", torch.tensor(0, dtype=torch.long))

    def set_noise_annealing(self, total_steps: int, anneal_fraction: float = 0.25):
        self.anneal_steps.fill_(int(total_steps * anneal_fraction))



router = Router(hidden_dim=4096, n_ffn_experts=8, topk=2)
router.set_noise_annealing(total_steps=10000, anneal_fraction=0.25)
print(f"Training step: {router.training_step}")
print(f"Anneal steps: {router.anneal_steps}")

Training step: 0
Anneal steps: 2500
