In [1]:
import torch
device = torch.device("cuda:0")
from holo.profilers import SimpleProfiler, Profiler, prettyTime
batchSize, nbExperts, nbFeatures, nbOut = (256*(16*16), 6, 64, 64)
k = 2
gatingLogits = torch.randn((batchSize, nbExperts), device=device) * 2 # (batchSize, nbExperts)
gatingLogits_k, topK_indices = torch.topk(gatingLogits, k=k, dim=1) # (batchSize, k), (batchSize, k)
gatingProb = torch.softmax(gatingLogits_k, dim=1) # (batchSize, k)
allExperts_weigths = torch.zeros((batchSize, nbExperts), device=device) # (batchSize, nbExperts)
allExperts_weigths.scatter_(dim=1, index=topK_indices, src=gatingProb)
datasIn = torch.randn((batchSize, nbFeatures), device=device)
experts = [
    torch.nn.Sequential(
        torch.nn.Linear(nbFeatures, nbFeatures*4),
        torch.nn.ReLU(), torch.nn.Linear(nbFeatures*4, nbOut)).to(device)
    for _ in range(nbExperts)]

_prof = Profiler(["A", "B", "C1", "C2", "D", "E", "F"])

N = 10
with SimpleProfiler("efficient") as sp1:
    for i in range(N):
        datasOut = torch.zeros((batchSize, nbOut), device=device)
        for iExpert in range(nbExperts):
            with _prof.mesure("A"):
                where = (topK_indices == iExpert).any(dim=1)
            with _prof.mesure("B"):
                expertUse, *_ = torch.where(where)
            if expertUse.shape[0] == 0:
                continue
            with _prof.mesure("C1"):
                inin = datasIn[expertUse]
            with _prof.mesure("C2"):
                expertOuts = experts[iExpert](inin)
            with _prof.mesure("D"):
                expertGates = allExperts_weigths[expertUse, iExpert, None]
                #expertGates = allExperts_weigths[expertUse, iExpert].unsqueeze(dim=1)
            with _prof.mesure("E"):
                res = expertGates * expertOuts
                #res = torch.mul(expertGates, expertOuts)
            with _prof.mesure("F"):
                datasOut.index_add_(dim=0, index=expertUse, source=res)
                #datasOut[expertUse] += res
for c in _prof.categories:
    print(f"{c}:{_prof.totalMesure(c) / sp1.time():.2%}", end=", ")
print()
print(_prof.pretty_totalTimes())
print(sp1.perttyStr())

with SimpleProfiler("computeAll") as sp2:
    for i in range(N):
        expertsOut = torch.stack([expert(datasIn) for expert in experts], dim=2)
        mergedOuts = torch.sum(allExperts_weigths[:, None, :] * expertsOut, dim=2)

print(sp2.perttyStr())

A:9.45%, B:26.68%, C1:6.11%, C2:32.78%, D:1.49%, E:0.84%, F:20.69%, 
{'A': '12.670 ms', 'B': '35.788 ms', 'C1': '8.201 ms', 'C2': '43.975 ms', 'D': '2.005 ms', 'E': '1.131 ms', 'F': '27.750 ms'}
SimpleProfiler('efficient', 134.134 ms)
SimpleProfiler('computeAll', 36.492 ms)
