In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from layers.routing import ExpertsChooseRouter, ExpertsChooseMaskedRouter

In [2]:
# create a test tensor of shape (batch_size, seq_len, hidden_size)
x = torch.randn(1, 20, 128)
capacity = 20
num_experts = 4

In [3]:
# create a router
router = ExpertsChooseRouter(num_experts=num_experts, dim=128)

# apply the router to the input tensor
expert_gate, expert_indices = router(x, capacity)

print(f"expert_gate shape: {expert_gate.shape}, expert_indices shape: {expert_indices.shape}")

expert_gate shape: torch.Size([1, 4, 20]), expert_indices shape: torch.Size([1, 4, 20])


In [4]:
masked_router = ExpertsChooseMaskedRouter(num_experts=num_experts, dim=128)
# copy weights from router to masked router
masked_router.router_weights.weight = router.router_weights.weight
masked_router.router_weights.bias = router.router_weights.bias

dispatch_mask, combine_array = masked_router(x, capacity)

print(f"dispatch_mask shape: {dispatch_mask.shape}, combine_array shape: {combine_array.shape}")

dispatch_mask shape: torch.Size([1, 20, 4, 20]), combine_array shape: torch.Size([1, 20, 4, 20])


In [6]:
from layers.expert_choose_linear import gather_experts

# test expert choose contract using advanced indexing
x_selected_advanced = gather_experts(x, expert_indices)

# now do equivalent with masking
x_selected_masked = torch.einsum("bt...,btec->bec...", x, dispatch_mask)

print(f"x_selected_advanced shape: {x_selected_advanced.shape}, x_selected_masked shape: {x_selected_masked.shape}")

# now check if the two are the same
print(f"x_selected_advanced == x_selected_masked: {torch.allclose(x_selected_advanced, x_selected_masked)}")

x_selected_advanced shape: torch.Size([1, 4, 20, 128]), x_selected_masked shape: torch.Size([1, 4, 20, 128])
x_selected_advanced == x_selected_masked: True


In [12]:
# create a linear layer
linear = nn.Linear(128, 128)

# Expert computation
x_expert_advanced = torch.einsum("beci,eoi->beco", x_selected_advanced, torch.reshape(linear.weight, (num_experts, 128 // num_experts, 128)))
x_expert_masked = torch.einsum("beci,eoi->beco", x_selected_masked, torch.reshape(linear.weight, (num_experts, 128 // num_experts, 128)))

print(f"x_expert_advanced shape: {x_expert_advanced.shape}, x_expert_masked shape: {x_expert_masked.shape}")

print(f"x_expert_advanced == x_expert_masked: {torch.allclose(x_expert_advanced, x_expert_masked, atol=1e-6)}")


x_expert_advanced shape: torch.Size([1, 4, 20, 32]), x_expert_masked shape: torch.Size([1, 4, 20, 32])
x_expert_advanced == x_expert_masked: True


In [15]:
# now let's do the same with scatter
x_expanded = torch.einsum("beci,eoi->beco", x_expert_advanced, torch.reshape(linear.weight, (num_experts, 128, 128 // num_experts)))
print(x_expanded.shape)

torch.Size([1, 4, 20, 128])


In [22]:
from layers.expert_choose_linear import scatter_experts

x_reduced_advanced = scatter_experts(x_expanded * expert_gate.unsqueeze(-1), expert_indices, capacity)

print(x_reduced_advanced.shape)

torch.Size([1, 20, 128])


In [23]:
x_reduced_masked = torch.einsum(
            "bec...,btec->bt...", x_expanded, combine_array
        )

print(x_reduced_masked.shape)

print(f"x_reduced_advanced == x_reduced_masked: {torch.allclose(x_reduced_advanced, x_reduced_masked, atol=1e-6)}")

torch.Size([1, 20, 128])
x_reduced_advanced == x_reduced_masked: True
