In [17]:
import torch, time
device = torch.device("cuda")
from holo.profilers import SimpleProfiler, Profiler, prettyTime
from holo.prettyFormats import prettyDataSizeOctes, get_prettyTime_Formater

def resetMemStats():
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.reset_accumulated_memory_stats()
    time.sleep(0.1)

def memDiag(name:str):
    print(f"memory status of {name}")
    print(f"torch.cuda.max_memory_allocated: {prettyDataSizeOctes(torch.cuda.max_memory_allocated())}")
    print(f"torch.cuda.max_memory_reserved: {prettyDataSizeOctes(torch.cuda.max_memory_reserved())}")
    print("---"*5)
    resetMemStats()

def showTimeStat(sp:SimpleProfiler, nbIters:int, baseBatchSize:int, convShape:tuple[int, int], nbExperts:int, nbFeatures:int):
    batchSize = baseBatchSize * convShape[0] * convShape[1]
    tPerIter = (sp.time() / nbIters)
    tPerElts = (tPerIter / batchSize)
    tPerEltsExperts = (tPerElts / nbExperts)
    ptNs = get_prettyTime_Formater("ns")
    print(f"{sp.name} -> {prettyTime(tPerIter)} | {ptNs(tPerElts)}/elts | {ptNs(tPerEltsExperts)}/(elts*experts)")
    pt2Ns = lambda t: f"{t*1e9:.0f}ns"
    print(f"| {baseBatchSize}*{convShape[0]}*{convShape[1]} | {nbExperts} | {nbFeatures} "
          f"| {pt2Ns(tPerElts)} | {pt2Ns(tPerEltsExperts)} | Mo |")

def cell(baseBatchSize:int, convShape:tuple[int, int], nbExperts:int, nbFeatures:int, nbOut:int):
    print(f"*** unsing: {baseBatchSize=}, {convShape=}, {nbExperts=}, {nbFeatures=}, {nbOut=} ****")
    resetMemStats()
    batchSize = baseBatchSize * convShape[0] * convShape[1]
    nbHidden = max(nbFeatures, nbOut) * 4
    k = 2
    gatingLogits = torch.randn((batchSize, nbExperts), device=device) * 2 # (batchSize, nbExperts)
    gatingLogits_k, topK_indices = torch.topk(gatingLogits, k=k, dim=1) # (batchSize, k), (batchSize, k)
    gatingProb = torch.softmax(gatingLogits_k, dim=1) # (batchSize, k)
    allExperts_weigths = torch.zeros((batchSize, nbExperts), device=device) # (batchSize, nbExperts)
    allExperts_weigths.scatter_(dim=1, index=topK_indices, src=gatingProb)
    datasIn = torch.randn((batchSize, nbFeatures), device=device)
    experts = [
        torch.nn.Sequential(
            torch.nn.Linear(nbFeatures, nbHidden),
            torch.nn.ReLU(), torch.nn.Linear(nbHidden, nbOut)).to(device)
        for _ in range(nbExperts)]

    _prof = Profiler(["A", "B", "C1", "C2", "D", "E", "F"])
    memDiag("initial datas")

    N = 20
    with SimpleProfiler("efficient") as sp1:
        for i in range(N):
            datasOut = torch.zeros((batchSize, nbOut), device=device)
            for iExpert in range(nbExperts):
                with _prof.mesure("A"):
                    where = (topK_indices == iExpert).any(dim=1)
                with _prof.mesure("B"):
                    expertUse, *_ = torch.where(where)
                if expertUse.shape[0] == 0:
                    continue
                with _prof.mesure("C1"):
                    inin = datasIn[expertUse]
                with _prof.mesure("C2"):
                    expertOuts = experts[iExpert](inin)
                with _prof.mesure("D"):
                    expertGates = allExperts_weigths[expertUse, iExpert, None]
                with _prof.mesure("E"):
                    res = expertGates * expertOuts
                with _prof.mesure("F"):
                    datasOut.index_add_(dim=0, index=expertUse, source=res)
                del where, expertUse, _, inin, expertGates, res
            del datasOut
    for c in _prof.categories:
        print(f"{c}:{_prof.totalMesure(c) / sp1.time():.2%}", end=", ")
    print(_prof.pretty_totalTimes())
    showTimeStat(sp1, nbIters=N, baseBatchSize=baseBatchSize, convShape=convShape, nbExperts=nbExperts, nbFeatures=nbFeatures)
    memDiag("after effiecient topK")

    with SimpleProfiler("computeAll") as sp2:
        for i in range(N):
            expertsOut = torch.stack([expert(datasIn) for expert in experts], dim=2)
            mergedOuts = torch.sum(allExperts_weigths[:, None, :] * expertsOut, dim=2)
            del expertsOut, mergedOuts
    showTimeStat(sp2, nbIters=N, baseBatchSize=baseBatchSize, convShape=convShape, nbExperts=nbExperts, nbFeatures=nbFeatures)
    memDiag("after compute all")
cell(baseBatchSize=256, convShape=(16, 16), nbExperts=6, nbFeatures=256, nbOut=256); del cell

*** unsing: baseBatchSize=256, convShape=(16, 16), nbExperts=6, nbFeatures=256, nbOut=256 ****
memory status of initial datas
torch.cuda.max_memory_allocated: 93.868 Mo
torch.cuda.max_memory_reserved: 104.858 Mo
---------------
A:1.35%, B:86.23%, C1:1.04%, C2:6.94%, D:1.29%, E:0.66%, F:0.70%, {'A': '4.192 ms', 'B': '268.465 ms', 'C1': '3.246 ms', 'C2': '21.612 ms', 'D': '4.027 ms', 'E': '2.055 ms', 'F': '2.184 ms'}
efficient -> 15.566 ms | 237.522 ns/elts | 39.587 ns/(elts*experts)
| 256*16*16 | 6 | 256 | 238ns | 40ns | Mo |
memory status of after effiecient topK
torch.cuda.max_memory_allocated: 1.064 Go
torch.cuda.max_memory_reserved: 1.099 Go
---------------
computeAll -> 1.503 ms | 22.935 ns/elts | 3.823 ns/(elts*experts)
| 256*16*16 | 6 | 256 | 23ns | 4ns | Mo |
memory status of after compute all
torch.cuda.max_memory_allocated: 2.712 Go
torch.cuda.max_memory_reserved: 3.506 Go
---------------


```raw
*** unsing: baseBatchSize=256, convShape=(7, 7), nbExperts=6, nbFeatures=128, nbOut=128 ****
memory status of initial datas
torch.cuda.max_memory_allocated: 19.646 Mo
torch.cuda.max_memory_reserved: 25.166 Mo
---------------
A:4.79%, B:51.24%, C1:4.05%, C2:22.61%, D:5.06%, E:2.68%, F:2.75%, {'A': '3.997 ms', 'B': '42.756 ms', 'C1': '3.382 ms', 'C2': '18.869 ms', 'D': '4.220 ms', 'E': '2.238 ms', 'F': '2.298 ms'}
efficient -> 4.172 ms | 332.598 ns/elts | 55.433 ns/(elts*experts)
| 256*7*7 | 6 | 128 | 333ns | 55ns | Mo |
memory status of after effiecient topK
torch.cuda.max_memory_allocated: 112.067 Mo
torch.cuda.max_memory_reserved: 132.121 Mo
---------------
computeAll -> 1.099 ms | 87.591 ns/elts | 14.598 ns/(elts*experts)
| 256*7*7 | 6 | 128 | 88ns | 15ns | Mo |
memory status of after compute all
torch.cuda.max_memory_allocated: 270.259 Mo
torch.cuda.max_memory_reserved: 367.002 Mo
---------------

*** unsing: baseBatchSize=256, convShape=(7, 7), nbExperts=24, nbFeatures=128, nbOut=128 ****
memory status of initial datas
torch.cuda.max_memory_allocated: 30.936 Mo
torch.cuda.max_memory_reserved: 35.652 Mo
---------------
A:5.98%, B:40.11%, C1:4.87%, C2:28.10%, D:6.33%, E:3.30%, F:3.42%, {'A': '16.464 ms', 'B': '110.421 ms', 'C1': '13.397 ms', 'C2': '77.363 ms', 'D': '17.435 ms', 'E': '9.099 ms', 'F': '9.420 ms'}
efficient -> 13.765 ms | 1097.356 ns/elts | 45.723 ns/(elts*experts)
| 256*7*7 | 24 | 128 | 1097ns | 46ns | Mo |
memory status of after effiecient topK
torch.cuda.max_memory_allocated: 117.857 Mo
torch.cuda.max_memory_reserved: 132.121 Mo
---------------
computeAll -> 3.574 ms | 284.878 ns/elts | 11.87 ns/(elts*experts)
| 256*7*7 | 24 | 128 | 285ns | 12ns | Mo |
memory status of after compute all
torch.cuda.max_memory_allocated: 967.43 Mo
torch.cuda.max_memory_reserved: 1.216 Go
---------------

*** unsing: baseBatchSize=256, convShape=(7, 7), nbExperts=64, nbFeatures=128, nbOut=128 ****
memory status of initial datas
torch.cuda.max_memory_allocated: 56.024 Mo
torch.cuda.max_memory_reserved: 77.595 Mo
---------------
A:6.19%, B:37.76%, C1:4.91%, C2:29.79%, D:6.56%, E:3.36%, F:3.56%, {'A': '41.445 ms', 'B': '252.954 ms', 'C1': '32.914 ms', 'C2': '199.601 ms', 'D': '43.955 ms', 'E': '22.535 ms', 'F': '23.830 ms'}
efficient -> 33.498 ms | 2670.427 ns/elts | 41.725 ns/(elts*experts)
| 256*7*7 | 64 | 128 | 2670ns | 42ns | Mo |
memory status of after effiecient topK
torch.cuda.max_memory_allocated: 141.11 Mo
torch.cuda.max_memory_reserved: 159.384 Mo
---------------
computeAll -> 9.471 ms | 755.06 ns/elts | 11.798 ns/(elts*experts)
| 256*7*7 | 64 | 128 | 755ns | 12ns | Mo |
memory status of after compute all
torch.cuda.max_memory_allocated: 2.53 Go
torch.cuda.max_memory_reserved: 3.112 Go
---------------


######################################################################################


*** unsing: baseBatchSize=256, convShape=(16, 16), nbExperts=6, nbFeatures=128, nbOut=128 ****
memory status of initial datas
torch.cuda.max_memory_allocated: 51.018 Mo
torch.cuda.max_memory_reserved: 60.817 Mo
---------------
A:2.48%, B:73.45%, C1:2.14%, C2:12.96%, D:2.64%, E:1.36%, F:1.42%, {'A': '3.771 ms', 'B': '111.634 ms', 'C1': '3.260 ms', 'C2': '19.699 ms', 'D': '4.011 ms', 'E': '2.060 ms', 'F': '2.164 ms'}
efficient -> 7.599 ms | 115.957 ns/elts | 19.326 ns/(elts*experts)
| 256*16*16 | 6 | 128 | 116ns | 19ns | Mo |
memory status of after effiecient topK
torch.cuda.max_memory_allocated: 535.822 Mo
torch.cuda.max_memory_reserved: 570.425 Mo
---------------
computeAll -> 1.178 ms | 17.976 ns/elts | 2.996 ns/(elts*experts)
| 256*16*16 | 6 | 128 | 18ns | 3ns | Mo |
memory status of after compute all
torch.cuda.max_memory_allocated: 1.361 Go
torch.cuda.max_memory_reserved: 1.766 Go
---------------

*** unsing: baseBatchSize=256, convShape=(16, 16), nbExperts=24, nbFeatures=128, nbOut=128 ****
memory status of initial datas
torch.cuda.max_memory_allocated: 69.938 Mo
torch.cuda.max_memory_reserved: 92.275 Mo
---------------
A:4.86%, B:50.56%, C1:3.89%, C2:23.94%, D:5.15%, E:2.61%, F:2.75%, {'A': '15.061 ms', 'B': '156.619 ms', 'C1': '12.038 ms', 'C2': '74.152 ms', 'D': '15.950 ms', 'E': '8.076 ms', 'F': '8.535 ms'}
efficient -> 15.490 ms | 236.357 ns/elts | 9.848 ns/(elts*experts)
| 256*16*16 | 24 | 128 | 236ns | 10ns | Mo |
memory status of after effiecient topK
torch.cuda.max_memory_allocated: 520.138 Mo
torch.cuda.max_memory_reserved: 593.494 Mo
---------------
computeAll -> 4.552 ms | 69.465 ns/elts | 2.894 ns/(elts*experts)
| 256*16*16 | 24 | 128 | 69ns | 3ns | Mo |
memory status of after compute all
torch.cuda.max_memory_allocated: 4.953 Go
torch.cuda.max_memory_reserved: 5.876 Go
---------------

*** unsing: baseBatchSize=256, convShape=(16, 16), nbExperts=64, nbFeatures=128, nbOut=128 ****
memory status of initial datas
torch.cuda.max_memory_allocated: 111.444 Mo
torch.cuda.max_memory_reserved: 125.829 Mo
---------------
A:5.64%, B:43.56%, C1:4.69%, C2:26.67%, D:6.11%, E:3.05%, F:3.21%, {'A': '39.014 ms', 'B': '301.072 ms', 'C1': '32.409 ms', 'C2': '184.319 ms', 'D': '42.208 ms', 'E': '21.057 ms', 'F': '22.206 ms'}
efficient -> 34.557 ms | 527.294 ns/elts | 8.239 ns/(elts*experts)
| 256*16*16 | 64 | 128 | 527ns | 8ns | Mo |
memory status of after effiecient topK
torch.cuda.max_memory_allocated: 568.229 Mo
torch.cuda.max_memory_reserved: 591.397 Mo
---------------
computeAll -> 10.150 ms | 154.873 ns/elts | 2.42 ns/(elts*experts)
| 256*16*16 | 64 | 128 | 155ns | 2ns | Mo |
memory status of after compute all
torch.cuda.max_memory_allocated: 13.036 Go
torch.cuda.max_memory_reserved: 15.316 Go
---------------


######################################################################################



*** unsing: baseBatchSize=256, convShape=(16, 16), nbExperts=6, nbFeatures=256, nbOut=256 ****
memory status of initial datas
torch.cuda.max_memory_allocated: 94.025 Mo
torch.cuda.max_memory_reserved: 104.858 Mo
---------------
A:1.27%, B:86.80%, C1:0.99%, C2:6.76%, D:1.22%, E:0.62%, F:0.66%, {'A': '4.034 ms', 'B': '276.673 ms', 'C1': '3.142 ms', 'C2': '21.551 ms', 'D': '3.903 ms', 'E': '1.979 ms', 'F': '2.101 ms'}
efficient -> 15.938 ms | 243.188 ns/elts | 40.531 ns/(elts*experts)
| 256*16*16 | 6 | 256 | 243ns | 41ns | Mo |
memory status of after effiecient topK
torch.cuda.max_memory_allocated: 1.062 Go
torch.cuda.max_memory_reserved: 1.288 Go
---------------
computeAll -> 1.406 ms | 21.449 ns/elts | 3.575 ns/(elts*experts)
| 256*16*16 | 6 | 256 | 21ns | 4ns | Mo |
memory status of after compute all
torch.cuda.max_memory_allocated: 2.712 Go
torch.cuda.max_memory_reserved: 3.509 Go
---------------

*** unsing: baseBatchSize=256, convShape=(16, 16), nbExperts=24, nbFeatures=256, nbOut=256 ****
memory status of initial datas
torch.cuda.max_memory_allocated: 141.303 Mo
torch.cuda.max_memory_reserved: 163.578 Mo
---------------
A:3.28%, B:64.95%, C1:2.67%, C2:17.39%, D:3.53%, E:1.82%, F:1.88%, {'A': '15.308 ms', 'B': '302.876 ms', 'C1': '12.467 ms', 'C2': '81.102 ms', 'D': '16.458 ms', 'E': '8.509 ms', 'F': '8.763 ms'}
efficient -> 23.315 ms | 355.764 ns/elts | 14.824 ns/(elts*experts)
| 256*16*16 | 24 | 256 | 356ns | 15ns | Mo |
memory status of after effiecient topK
torch.cuda.max_memory_allocated: 1.061 Go
torch.cuda.max_memory_reserved: 1.086 Go
---------------
computeAll -> 4.840 ms | 73.855 ns/elts | 3.077 ns/(elts*experts)
| 256*16*16 | 24 | 256 | 74ns | 3ns | Mo |
memory status of after compute all
torch.cuda.max_memory_allocated: 9.906 Go
torch.cuda.max_memory_reserved: 11.752 Go
---------------

*** unsing: baseBatchSize=256, convShape=(16, 16), nbExperts=64, nbFeatures=256, nbOut=256 ****
memory status of initial datas
torch.cuda.max_memory_allocated: 245.826 Mo
torch.cuda.max_memory_reserved: 260.047 Mo
---------------
A:4.85%, B:51.33%, C1:3.93%, C2:23.18%, D:5.18%, E:2.61%, F:2.77%, {'A': '40.052 ms', 'B': '423.501 ms', 'C1': '32.458 ms', 'C2': '191.239 ms', 'D': '42.755 ms', 'E': '21.501 ms', 'F': '22.875 ms'}
efficient -> 41.253 ms | 629.472 ns/elts | 9.835 ns/(elts*experts)
| 256*16*16 | 64 | 256 | 629ns | 10ns | Mo |
memory status of after effiecient topK
torch.cuda.max_memory_allocated: 1.134 Go
torch.cuda.max_memory_reserved: 1.21 Go
---------------
computeAll -> 0.150 Î¼s | 0.002 ns/elts | 0.0 ns/(elts*experts)
| 256*16*16 | 64 | 256 | 0ns | 0ns | Mo |
memory status of after compute all
torch.cuda.max_memory_allocated: 258.255 Mo
torch.cuda.max_memory_reserved: 301.99 Mo
---------------

```


### using efficient computation
| batch size | nb experts | nbFeatures | time/elt | time/elt/expert | peak memory used | time ratio | mem ratio |
|---|---|---|---|---|---|---|---|
| 256\*7*7 | 6 | 128 | 333ns | 55ns | 112Mo | 3.78x slower | 2.41x less |
| 256\*7*7 | 24 | 128 | 1097ns | 46ns | 117Mo | 3.85x slower | 8.26x less |
| 256\*7*7 | 64 | 128 | 2670ns | 42ns | 141Mo | 3.54x slower | 17.94x less |
|---|---|---|---|---|---|---|---|
| 256\*16*16 | 6 | 128 | 122ns | 20ns | 534Mo | 6.78x slower | 2.55x less |
| 256\*16*16 | 24 | 128 | 251ns | 10ns | 520Mo | 3.64x slower | 9.53x less |
| 256\*16*16 | 64 | 128 | 527ns | 8ns | 568Mo | 3.40x slower | 22.95x less |
|---|---|---|---|---|---|---|---|
| 256\*16*16 | 6 | 256 | 243ns | 41ns | 1062Mo | 11.57x slower | 2.55x less |
| 256\*16*16 | 24 | 256 | 356ns | 15ns | 1061Mo | 4.81x slower | 9.34x less |
| 256\*16*16 | 64 | 256 | 602ns | 9.4ns | 1136Mo | / | / |

### using computation with all experts
| batch size | nb experts | nbFeatures | time/elt | time/elt/expert | peak memory used |
|---|---|---|---|---|---|
| 256\*7*7 | 6 | 128 | 88ns | 15ns | 270Mo |
| 256\*7*7 | 24 | 128 | 285ns | 12ns | 967Mo |
| 256\*7*7 | 64 | 128 | 755ns | 12ns | 2530Mo |
|---|---|---|---|---|---|---|---|
| 256\*16*16 | 6 | 128 | 18ns | 3.0ns | 1361Mo |
| 256\*16*16 | 24 | 128 | 69ns | 2.9ns | 4953Mo |
| 256\*16*16 | 64 | 128 | 155ns | 2.4ns | 13036Mo |
|---|---|---|---|---|---|---|---|
| 256\*16*16 | 6 | 256 | 21ns | 3.5ns | 2712Mo |
| 256\*16*16 | 24 | 256 | 74ns | 3.1ns | 9906Mo |
| 256\*16*16 | 64 | 128 | / | / | >16Go |




In [20]:
print(f"{270 / 112:.2f}")
print(f"{967 / 117:.2f}")
print(f"{2530 / 141:.2f}")
print(f"{1361 / 534:.2f}")
print(f"{4953 / 520:.2f}")
print(f"{13036 / 568:.2f}")
print(f"{2712 / 1062:.2f}")
print(f"{9906 / 1061:.2f}")
print(f"{16 / 1136:.2f}")

2.41
8.26
17.94
2.55
9.53
22.95
2.55
9.34
0.01
