In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from layers.routing import ExpertsChooseMaskedRouter

In [35]:
# create a test tensor of shape (batch_size, seq_len, hidden_size)
x = torch.randn(256, 196, 768).cuda()
capacity = 196
num_experts = 4

router = ExpertsChooseMaskedRouter(dim=768, num_experts=4).cuda()
dispatch_mask, combine_array = router(x, capacity)

print(dispatch_mask.shape)
print(combine_array.shape)



torch.Size([256, 196, 4, 196])
torch.Size([256, 196, 4, 196])


In [39]:
class MLP(nn.Module):
    def __init__(self, in_features=768, hidden_dim=768, out_features=768):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_dim).cuda()
        self.fc2 = nn.Linear(hidden_dim, out_features).cuda()

    def forward(self, x):
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        return x
    
def benchmark(fn, args, num_runs=100):
    import time
    
    times = []
    for _ in range(num_runs):
        start = time.time()
        fn(*args)
        torch.cuda.synchronize()
        times.append(time.time() - start)
    return sum(times) / num_runs * 1000
    
mlp = MLP(in_features=768, hidden_dim=768, out_features=768).cuda()

def expert_mlp(x):
    x = router.dispatch(x, dispatch_mask)
    x = mlp(x)
    return router.combine(x, combine_array)

print(benchmark(expert_mlp, (x,)))




40.9356427192688
