In [None]:
import torch, time
device = torch.device("cuda")
from holo.calc import product
from holo.profilers import SimpleProfiler, Profiler, prettyTime
from holo.prettyFormats import prettyDataSizeOctes, get_prettyTime_Formater

def resetMemStats():
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.reset_accumulated_memory_stats()
    time.sleep(0.1)

def memDiag(name:str):
    print(f"memory status of {name}")
    print(f"torch.cuda.max_memory_allocated: {prettyDataSizeOctes(torch.cuda.max_memory_allocated())}")
    print(f"torch.cuda.max_memory_reserved: {prettyDataSizeOctes(torch.cuda.max_memory_reserved())}")
    print("---"*5)
    resetMemStats()

def showTimeStat(sp:SimpleProfiler, nbIters:int, baseBatchSize:int, convShape:tuple[int, int], nbExperts:int, nbFeatures:int):
    batchSize = baseBatchSize * convShape[0] * convShape[1]
    tPerIter = (sp.time() / nbIters)
    tPerElts = (tPerIter / batchSize)
    tPerEltsExperts = (tPerElts / nbExperts)
    ptNs = get_prettyTime_Formater("ns")
    print(f"{sp.name} -> {prettyTime(tPerIter)} | {ptNs(tPerElts)}/elts | {ptNs(tPerEltsExperts)}/(elts*experts)")
    pt2Ns = lambda t: f"{t*1e9:.0f}ns"
    print(f"| {baseBatchSize}*{convShape[0]}*{convShape[1]} | {nbExperts} | {nbFeatures} "
          f"| {pt2Ns(tPerElts)} | {pt2Ns(tPerEltsExperts)} | Mo |")

def cell(baseBatchSize:int, convShape:tuple[int, int], nbExperts:int, nbFeatures:int, nbOut:int):
    print(f"*** unsing: {baseBatchSize=}, {convShape=}, {nbExperts=}, {nbFeatures=}, {nbOut=} ****")
    resetMemStats()
    batchSize = baseBatchSize * convShape[0] * convShape[1]
    nbHidden = max(nbFeatures, nbOut) * 4
    k = 2
    gatingLogits = torch.randn((batchSize, nbExperts), device=device) * 2 # (batchSize, nbExperts)
    gatingLogits_k, topK_indices = torch.topk(gatingLogits, k=k, dim=1) # (batchSize, k), (batchSize, k)
    gatingProb = torch.softmax(gatingLogits_k, dim=1) # (batchSize, k)
    allExperts_weigths = torch.zeros((batchSize, nbExperts), device=device) # (batchSize, nbExperts)
    allExperts_weigths.scatter_(dim=1, index=topK_indices, src=gatingProb)
    datasIn = torch.randn((batchSize, nbFeatures), device=device)
    experts = [
        torch.nn.Sequential(
            torch.nn.Linear(nbFeatures, nbHidden),
            torch.nn.ReLU(), torch.nn.Linear(nbHidden, nbOut)).to(device)
        for _ in range(nbExperts)]
    print(f"nb params: {sum(sum(product(p.shape) for p in e.parameters()) for e in experts):_d}")

    _prof = Profiler(["A", "B", "C1", "C2", "D", "E", "F"])
    memDiag("initial datas")

    N = 1
    with SimpleProfiler("efficient") as sp1:
        for i in range(N):
            datasOut = torch.zeros((batchSize, nbOut), device=device)
            for iExpert in range(nbExperts):
                with _prof.mesure("A"):
                    where = (topK_indices == iExpert).any(dim=1)
                with _prof.mesure("B"):
                    expertUse, *_ = torch.where(where)
                if expertUse.shape[0] == 0:
                    continue
                with _prof.mesure("C1"):
                    inin = datasIn[expertUse]
                with _prof.mesure("C2"):
                    expertOuts = experts[iExpert](inin)
                with _prof.mesure("D"):
                    expertGates = allExperts_weigths[expertUse, iExpert, None]
                with _prof.mesure("E"):
                    res = expertGates * expertOuts
                with _prof.mesure("F"):
                    datasOut.index_add_(dim=0, index=expertUse, source=res)
                del where, expertUse, _, inin, expertGates, res
            del datasOut
    for c in _prof.categories:
        print(f"{c}:{_prof.totalMesure(c) / sp1.time():.2%}", end=", ")
    print(_prof.pretty_totalTimes())
    showTimeStat(sp1, nbIters=N, baseBatchSize=baseBatchSize, convShape=convShape, nbExperts=nbExperts, nbFeatures=nbFeatures)
    memDiag("after effiecient topK")

    with SimpleProfiler("computeAll") as sp2:
        for i in range(N):
            expertsOut = torch.stack([expert(datasIn) for expert in experts], dim=2)
            mergedOuts = torch.sum(allExperts_weigths[:, None, :] * expertsOut, dim=2)
            del expertsOut, mergedOuts
    showTimeStat(sp2, nbIters=N, baseBatchSize=baseBatchSize, convShape=convShape, nbExperts=nbExperts, nbFeatures=nbFeatures)
    memDiag("after compute all")
cell(baseBatchSize=256, convShape=(16, 16), nbExperts=6, nbFeatures=256, nbOut=256); del cell

```raw
*** unsing: baseBatchSize=256, convShape=(7, 7), nbExperts=6, nbFeatures=64, nbOut=64 ****
memory status of initial datas
torch.cuda.max_memory_allocated: 13.53 Mo
torch.cuda.max_memory_reserved: 23.069 Mo
---------------
A:4.90%, B:51.64%, C1:3.82%, C2:22.51%, D:5.25%, E:2.60%, F:2.80%, {'A': '3.638 ms', 'B': '38.351 ms', 'C1': '2.834 ms', 'C2': '16.714 ms', 'D': '3.896 ms', 'E': '1.930 ms', 'F': '2.082 ms'}
efficient -> 3.713 ms | 296.013 ns/elts | 49.335 ns/(elts*experts)
| 256*7*7 | 6 | 64 | 296ns | 49ns | Mo |
memory status of after effiecient topK
torch.cuda.max_memory_allocated: 64.473 Mo
torch.cuda.max_memory_reserved: 90.178 Mo
---------------
computeAll -> 0.855 ms | 68.193 ns/elts | 11.366 ns/(elts*experts)
| 256*7*7 | 6 | 64 | 68ns | 11ns | Mo |
memory status of after compute all
torch.cuda.max_memory_allocated: 138.886 Mo
torch.cuda.max_memory_reserved: 205.521 Mo
---------------

*** unsing: baseBatchSize=256, convShape=(7, 7), nbExperts=24, nbFeatures=64, nbOut=64 ****
memory status of initial datas
torch.cuda.max_memory_allocated: 18.475 Mo
torch.cuda.max_memory_reserved: 25.166 Mo
---------------
A:6.34%, B:37.91%, C1:5.06%, C2:29.05%, D:6.71%, E:3.38%, F:3.63%, {'A': '14.268 ms', 'B': '85.378 ms', 'C1': '11.405 ms', 'C2': '65.434 ms', 'D': '15.108 ms', 'E': '7.620 ms', 'F': '8.172 ms'}
efficient -> 11.261 ms | 897.71 ns/elts | 37.405 ns/(elts*experts)
| 256*7*7 | 24 | 64 | 898ns | 37ns | Mo |
memory status of after effiecient topK
torch.cuda.max_memory_allocated: 62.982 Mo
torch.cuda.max_memory_reserved: 88.08 Mo
---------------
computeAll -> 3.107 ms | 247.685 ns/elts | 10.32 ns/(elts*experts)
| 256*7*7 | 24 | 64 | 248ns | 10ns | Mo |
memory status of after compute all
torch.cuda.max_memory_allocated: 486.748 Mo
torch.cuda.max_memory_reserved: 645.923 Mo
---------------

*** unsing: baseBatchSize=256, convShape=(7, 7), nbExperts=64, nbFeatures=64, nbOut=64 ****
memory status of initial datas
torch.cuda.max_memory_allocated: 27.904 Mo
torch.cuda.max_memory_reserved: 31.457 Mo
---------------
A:6.02%, B:40.98%, C1:4.77%, C2:27.67%, D:6.31%, E:3.23%, F:3.39%, {'A': '41.205 ms', 'B': '280.654 ms', 'C1': '32.650 ms', 'C2': '189.491 ms', 'D': '43.217 ms', 'E': '22.141 ms', 'F': '23.240 ms'}
efficient -> 34.244 ms | 2729.95 ns/elts | 42.655 ns/(elts*experts)
| 256*7*7 | 64 | 64 | 2730ns | 43ns | Mo |
memory status of after effiecient topK
torch.cuda.max_memory_allocated: 71.166 Mo
torch.cuda.max_memory_reserved: 94.372 Mo
---------------
computeAll -> 10.371 ms | 826.763 ns/elts | 12.918 ns/(elts*experts)
| 256*7*7 | 64 | 64 | 827ns | 13ns | Mo |
memory status of after compute all
torch.cuda.max_memory_allocated: 1.265 Go
torch.cuda.max_memory_reserved: 1.636 Go
---------------

*** unsing: baseBatchSize=256, convShape=(7, 7), nbExperts=128, nbFeatures=64, nbOut=64 ****
memory status of initial datas
torch.cuda.max_memory_allocated: 42.813 Mo
torch.cuda.max_memory_reserved: 60.817 Mo
---------------
A:6.07%, B:41.92%, C1:4.95%, C2:26.27%, D:6.40%, E:3.27%, F:3.43%, {'A': '85.980 ms', 'B': '593.813 ms', 'C1': '70.071 ms', 'C2': '372.172 ms', 'D': '90.628 ms', 'E': '46.317 ms', 'F': '48.653 ms'}
efficient -> 70.829 ms | 5646.42 ns/elts | 44.113 ns/(elts*experts)
| 256*7*7 | 128 | 64 | 5646ns | 44ns | Mo |
memory status of after effiecient topK
torch.cuda.max_memory_allocated: 86.07 Mo
torch.cuda.max_memory_reserved: 102.76 Mo
---------------
computeAll -> 30.070 ms | 2397.125 ns/elts | 18.728 ns/(elts*experts)
| 256*7*7 | 128 | 64 | 2397ns | 19ns | Mo |
memory status of after compute all
torch.cuda.max_memory_allocated: 2.513 Go
torch.cuda.max_memory_reserved: 3.236 Go
---------------

*** unsing: baseBatchSize=256, convShape=(7, 7), nbExperts=256, nbFeatures=64, nbOut=64 ****
memory status of initial datas
torch.cuda.max_memory_allocated: 71.77 Mo
torch.cuda.max_memory_reserved: 85.983 Mo
---------------
A:6.56%, B:38.90%, C1:5.21%, C2:27.07%, D:6.89%, E:3.53%, F:3.74%, {'A': '155.690 ms', 'B': '923.658 ms', 'C1': '123.689 ms', 'C2': '642.800 ms', 'D': '163.606 ms', 'E': '83.845 ms', 'F': '88.772 ms'}
efficient -> 118.720 ms | 9464.308 ns/elts | 36.97 ns/(elts*experts)
| 256*7*7 | 256 | 64 | 9464ns | 37ns | Mo |
memory status of after effiecient topK
torch.cuda.max_memory_allocated: 115.053 Mo
torch.cuda.max_memory_reserved: 127.926 Mo
---------------
computeAll -> 91.033 ms | 7257.062 ns/elts | 28.348 ns/(elts*experts)
| 256*7*7 | 256 | 64 | 7257ns | 28ns | Mo |
memory status of after compute all
torch.cuda.max_memory_allocated: 5.008 Go
torch.cuda.max_memory_reserved: 6.445 Go
---------------

*** unsing: baseBatchSize=256, convShape=(7, 7), nbExperts=512, nbFeatures=64, nbOut=64 ****
memory status of initial datas
torch.cuda.max_memory_allocated: 131.408 Mo
torch.cuda.max_memory_reserved: 144.703 Mo
---------------
A:6.05%, B:43.11%, C1:4.78%, C2:25.23%, D:6.27%, E:3.27%, F:3.44%, {'A': '312.127 ms', 'B': '2.224 sec', 'C1': '246.804 ms', 'C2': '1.302 sec', 'D': '323.344 ms', 'E': '168.561 ms', 'F': '177.604 ms'}
efficient -> 257.920 ms | 20561.24 ns/elts | 40.159 ns/(elts*experts)
| 256*7*7 | 512 | 64 | 20561ns | 40ns | Mo |
memory status of after effiecient topK
torch.cuda.max_memory_allocated: 174.886 Mo
torch.cuda.max_memory_reserved: 186.647 Mo
---------------
computeAll -> 217.858 ms | 17367.543 ns/elts | 33.921 ns/(elts*experts)
| 256*7*7 | 512 | 64 | 17368ns | 34ns | Mo |
memory status of after compute all
torch.cuda.max_memory_allocated: 10.0 Go
torch.cuda.max_memory_reserved: 12.847 Go
---------------




######################################################################################



*** unsing: baseBatchSize=256, convShape=(7, 7), nbExperts=6, nbFeatures=128, nbOut=128 ****
memory status of initial datas
torch.cuda.max_memory_allocated: 19.646 Mo
torch.cuda.max_memory_reserved: 25.166 Mo
---------------
A:4.79%, B:51.24%, C1:4.05%, C2:22.61%, D:5.06%, E:2.68%, F:2.75%, {'A': '3.997 ms', 'B': '42.756 ms', 'C1': '3.382 ms', 'C2': '18.869 ms', 'D': '4.220 ms', 'E': '2.238 ms', 'F': '2.298 ms'}
efficient -> 4.172 ms | 332.598 ns/elts | 55.433 ns/(elts*experts)
| 256*7*7 | 6 | 128 | 333ns | 55ns | Mo |
memory status of after effiecient topK
torch.cuda.max_memory_allocated: 112.067 Mo
torch.cuda.max_memory_reserved: 132.121 Mo
---------------
computeAll -> 1.099 ms | 87.591 ns/elts | 14.598 ns/(elts*experts)
| 256*7*7 | 6 | 128 | 88ns | 15ns | Mo |
memory status of after compute all
torch.cuda.max_memory_allocated: 270.259 Mo
torch.cuda.max_memory_reserved: 367.002 Mo
---------------

*** unsing: baseBatchSize=256, convShape=(7, 7), nbExperts=24, nbFeatures=128, nbOut=128 ****
memory status of initial datas
torch.cuda.max_memory_allocated: 30.936 Mo
torch.cuda.max_memory_reserved: 35.652 Mo
---------------
A:5.98%, B:40.11%, C1:4.87%, C2:28.10%, D:6.33%, E:3.30%, F:3.42%, {'A': '16.464 ms', 'B': '110.421 ms', 'C1': '13.397 ms', 'C2': '77.363 ms', 'D': '17.435 ms', 'E': '9.099 ms', 'F': '9.420 ms'}
efficient -> 13.765 ms | 1097.356 ns/elts | 45.723 ns/(elts*experts)
| 256*7*7 | 24 | 128 | 1097ns | 46ns | Mo |
memory status of after effiecient topK
torch.cuda.max_memory_allocated: 117.857 Mo
torch.cuda.max_memory_reserved: 132.121 Mo
---------------
computeAll -> 3.574 ms | 284.878 ns/elts | 11.87 ns/(elts*experts)
| 256*7*7 | 24 | 128 | 285ns | 12ns | Mo |
memory status of after compute all
torch.cuda.max_memory_allocated: 967.43 Mo
torch.cuda.max_memory_reserved: 1.216 Go
---------------

*** unsing: baseBatchSize=256, convShape=(7, 7), nbExperts=64, nbFeatures=128, nbOut=128 ****
memory status of initial datas
torch.cuda.max_memory_allocated: 56.024 Mo
torch.cuda.max_memory_reserved: 77.595 Mo
---------------
A:6.19%, B:37.76%, C1:4.91%, C2:29.79%, D:6.56%, E:3.36%, F:3.56%, {'A': '41.445 ms', 'B': '252.954 ms', 'C1': '32.914 ms', 'C2': '199.601 ms', 'D': '43.955 ms', 'E': '22.535 ms', 'F': '23.830 ms'}
efficient -> 33.498 ms | 2670.427 ns/elts | 41.725 ns/(elts*experts)
| 256*7*7 | 64 | 128 | 2670ns | 42ns | Mo |
memory status of after effiecient topK
torch.cuda.max_memory_allocated: 141.11 Mo
torch.cuda.max_memory_reserved: 159.384 Mo
---------------
computeAll -> 9.471 ms | 755.06 ns/elts | 11.798 ns/(elts*experts)
| 256*7*7 | 64 | 128 | 755ns | 12ns | Mo |
memory status of after compute all
torch.cuda.max_memory_allocated: 2.53 Go
torch.cuda.max_memory_reserved: 3.112 Go
---------------

*** unsing: baseBatchSize=256, convShape=(7, 7), nbExperts=128, nbFeatures=128, nbOut=128 ****
memory status of initial datas
torch.cuda.max_memory_allocated: 95.625 Mo
torch.cuda.max_memory_reserved: 111.149 Mo
---------------
A:5.86%, B:41.93%, C1:4.60%, C2:27.83%, D:6.12%, E:3.14%, F:3.28%, {'A': '19.868 ms', 'B': '142.207 ms', 'C1': '15.605 ms', 'C2': '94.382 ms', 'D': '20.761 ms', 'E': '10.644 ms', 'F': '11.132 ms'}
efficient -> 67.827 ms | 5407.124 ns/elts | 42.243 ns/(elts*experts)
| 256*7*7 | 128 | 128 | 5407ns | 42ns | Mo |
memory status of after effiecient topK
torch.cuda.max_memory_allocated: 180.646 Mo
torch.cuda.max_memory_reserved: 192.938 Mo
---------------
computeAll -> 24.483 ms | 1951.733 ns/elts | 15.248 ns/(elts*experts)
| 256*7*7 | 128 | 128 | 1952ns | 15ns | Mo |
memory status of after compute all
torch.cuda.max_memory_allocated: 5.035 Go
torch.cuda.max_memory_reserved: 6.149 Go
---------------

*** unsing: baseBatchSize=256, convShape=(7, 7), nbExperts=256, nbFeatures=128, nbOut=128 ****
memory status of initial datas
torch.cuda.max_memory_allocated: 175.907 Mo
torch.cuda.max_memory_reserved: 186.647 Mo
---------------
A:5.96%, B:40.75%, C1:4.71%, C2:27.98%, D:6.20%, E:3.22%, F:3.41%, {'A': '38.863 ms', 'B': '265.676 ms', 'C1': '30.681 ms', 'C2': '182.425 ms', 'D': '40.413 ms', 'E': '20.999 ms', 'F': '22.238 ms'}
efficient -> 130.381 ms | 10393.916 ns/elts | 40.601 ns/(elts*experts)
| 256*7*7 | 256 | 128 | 10394ns | 41ns | Mo |
memory status of after effiecient topK
torch.cuda.max_memory_allocated: 260.899 Mo
torch.cuda.max_memory_reserved: 289.407 Mo
---------------
computeAll -> 56.494 ms | 4503.704 ns/elts | 17.593 ns/(elts*experts)
| 256*7*7 | 256 | 128 | 4504ns | 18ns | Mo |
memory status of after compute all
torch.cuda.max_memory_allocated: 10.048 Go
torch.cuda.max_memory_reserved: 12.231 Go
---------------


######################################################################################


*** unsing: baseBatchSize=256, convShape=(16, 16), nbExperts=6, nbFeatures=128, nbOut=128 ****
memory status of initial datas
torch.cuda.max_memory_allocated: 51.018 Mo
torch.cuda.max_memory_reserved: 60.817 Mo
---------------
A:2.48%, B:73.45%, C1:2.14%, C2:12.96%, D:2.64%, E:1.36%, F:1.42%, {'A': '3.771 ms', 'B': '111.634 ms', 'C1': '3.260 ms', 'C2': '19.699 ms', 'D': '4.011 ms', 'E': '2.060 ms', 'F': '2.164 ms'}
efficient -> 7.599 ms | 115.957 ns/elts | 19.326 ns/(elts*experts)
| 256*16*16 | 6 | 128 | 116ns | 19ns | Mo |
memory status of after effiecient topK
torch.cuda.max_memory_allocated: 535.822 Mo
torch.cuda.max_memory_reserved: 570.425 Mo
---------------
computeAll -> 1.178 ms | 17.976 ns/elts | 2.996 ns/(elts*experts)
| 256*16*16 | 6 | 128 | 18ns | 3ns | Mo |
memory status of after compute all
torch.cuda.max_memory_allocated: 1.361 Go
torch.cuda.max_memory_reserved: 1.766 Go
---------------

*** unsing: baseBatchSize=256, convShape=(16, 16), nbExperts=24, nbFeatures=128, nbOut=128 ****
memory status of initial datas
torch.cuda.max_memory_allocated: 69.938 Mo
torch.cuda.max_memory_reserved: 92.275 Mo
---------------
A:4.86%, B:50.56%, C1:3.89%, C2:23.94%, D:5.15%, E:2.61%, F:2.75%, {'A': '15.061 ms', 'B': '156.619 ms', 'C1': '12.038 ms', 'C2': '74.152 ms', 'D': '15.950 ms', 'E': '8.076 ms', 'F': '8.535 ms'}
efficient -> 15.490 ms | 236.357 ns/elts | 9.848 ns/(elts*experts)
| 256*16*16 | 24 | 128 | 236ns | 10ns | Mo |
memory status of after effiecient topK
torch.cuda.max_memory_allocated: 520.138 Mo
torch.cuda.max_memory_reserved: 593.494 Mo
---------------
computeAll -> 4.552 ms | 69.465 ns/elts | 2.894 ns/(elts*experts)
| 256*16*16 | 24 | 128 | 69ns | 3ns | Mo |
memory status of after compute all
torch.cuda.max_memory_allocated: 4.953 Go
torch.cuda.max_memory_reserved: 5.876 Go
---------------

*** unsing: baseBatchSize=256, convShape=(16, 16), nbExperts=64, nbFeatures=128, nbOut=128 ****
memory status of initial datas
torch.cuda.max_memory_allocated: 111.444 Mo
torch.cuda.max_memory_reserved: 125.829 Mo
---------------
A:5.64%, B:43.56%, C1:4.69%, C2:26.67%, D:6.11%, E:3.05%, F:3.21%, {'A': '39.014 ms', 'B': '301.072 ms', 'C1': '32.409 ms', 'C2': '184.319 ms', 'D': '42.208 ms', 'E': '21.057 ms', 'F': '22.206 ms'}
efficient -> 34.557 ms | 527.294 ns/elts | 8.239 ns/(elts*experts)
| 256*16*16 | 64 | 128 | 527ns | 8ns | Mo |
memory status of after effiecient topK
torch.cuda.max_memory_allocated: 568.229 Mo
torch.cuda.max_memory_reserved: 591.397 Mo
---------------
computeAll -> 10.150 ms | 154.873 ns/elts | 2.42 ns/(elts*experts)
| 256*16*16 | 64 | 128 | 155ns | 2ns | Mo |
memory status of after compute all
torch.cuda.max_memory_allocated: 13.036 Go
torch.cuda.max_memory_reserved: 15.316 Go
---------------


######################################################################################



*** unsing: baseBatchSize=256, convShape=(16, 16), nbExperts=6, nbFeatures=256, nbOut=256 ****
memory status of initial datas
torch.cuda.max_memory_allocated: 94.025 Mo
torch.cuda.max_memory_reserved: 104.858 Mo
---------------
A:1.27%, B:86.80%, C1:0.99%, C2:6.76%, D:1.22%, E:0.62%, F:0.66%, {'A': '4.034 ms', 'B': '276.673 ms', 'C1': '3.142 ms', 'C2': '21.551 ms', 'D': '3.903 ms', 'E': '1.979 ms', 'F': '2.101 ms'}
efficient -> 15.938 ms | 243.188 ns/elts | 40.531 ns/(elts*experts)
| 256*16*16 | 6 | 256 | 243ns | 41ns | Mo |
memory status of after effiecient topK
torch.cuda.max_memory_allocated: 1.062 Go
torch.cuda.max_memory_reserved: 1.288 Go
---------------
computeAll -> 1.406 ms | 21.449 ns/elts | 3.575 ns/(elts*experts)
| 256*16*16 | 6 | 256 | 21ns | 4ns | Mo |
memory status of after compute all
torch.cuda.max_memory_allocated: 2.712 Go
torch.cuda.max_memory_reserved: 3.509 Go
---------------

*** unsing: baseBatchSize=256, convShape=(16, 16), nbExperts=24, nbFeatures=256, nbOut=256 ****
memory status of initial datas
torch.cuda.max_memory_allocated: 141.303 Mo
torch.cuda.max_memory_reserved: 163.578 Mo
---------------
A:3.28%, B:64.95%, C1:2.67%, C2:17.39%, D:3.53%, E:1.82%, F:1.88%, {'A': '15.308 ms', 'B': '302.876 ms', 'C1': '12.467 ms', 'C2': '81.102 ms', 'D': '16.458 ms', 'E': '8.509 ms', 'F': '8.763 ms'}
efficient -> 23.315 ms | 355.764 ns/elts | 14.824 ns/(elts*experts)
| 256*16*16 | 24 | 256 | 356ns | 15ns | Mo |
memory status of after effiecient topK
torch.cuda.max_memory_allocated: 1.061 Go
torch.cuda.max_memory_reserved: 1.086 Go
---------------
computeAll -> 4.840 ms | 73.855 ns/elts | 3.077 ns/(elts*experts)
| 256*16*16 | 24 | 256 | 74ns | 3ns | Mo |
memory status of after compute all
torch.cuda.max_memory_allocated: 9.906 Go
torch.cuda.max_memory_reserved: 11.752 Go
---------------

*** unsing: baseBatchSize=256, convShape=(16, 16), nbExperts=64, nbFeatures=256, nbOut=256 ****
memory status of initial datas
torch.cuda.max_memory_allocated: 245.826 Mo
torch.cuda.max_memory_reserved: 260.047 Mo
---------------
A:4.85%, B:51.33%, C1:3.93%, C2:23.18%, D:5.18%, E:2.61%, F:2.77%, {'A': '40.052 ms', 'B': '423.501 ms', 'C1': '32.458 ms', 'C2': '191.239 ms', 'D': '42.755 ms', 'E': '21.501 ms', 'F': '22.875 ms'}
efficient -> 41.253 ms | 629.472 ns/elts | 9.835 ns/(elts*experts)
| 256*16*16 | 64 | 256 | 629ns | 10ns | Mo |
memory status of after effiecient topK
torch.cuda.max_memory_allocated: 1.134 Go
torch.cuda.max_memory_reserved: 1.21 Go
---------------
computeAll -> 0.150 μs | 0.002 ns/elts | 0.0 ns/(elts*experts)
| 256*16*16 | 64 | 256 | 0ns | 0ns | Mo |
memory status of after compute all
torch.cuda.max_memory_allocated: 258.255 Mo
torch.cuda.max_memory_reserved: 301.99 Mo
---------------

*** unsing: baseBatchSize=256, convShape=(16, 16), nbExperts=4096, nbFeatures=256, nbOut=256 ****
nb params: 2_152_726_528
memory status of initial datas
torch.cuda.max_memory_allocated: 10.836 Go
torch.cuda.max_memory_reserved: 10.849 Go
---------------
A:6.00%, B:50.65%, C1:3.67%, C2:23.19%, D:4.79%, E:2.56%, F:2.67%, {'A': '177.811 ms', 'B': '1.502 sec', 'C1': '108.765 ms', 'C2': '687.508 ms', 'D': '141.880 ms', 'E': '76.007 ms', 'F': '79.094 ms'}
efficient -> 2.965 sec | 45235.869 ns/elts | 11.044 ns/(elts*experts)
| 256*16*16 | 4096 | 256 | 45236ns | 11ns | Mo |
memory status of after effiecient topK
torch.cuda.max_memory_allocated: 11.714 Go
torch.cuda.max_memory_reserved: 11.746 Go
---------------
computeAll -> 2.600 μs | 0.04 ns/elts | 0.0 ns/(elts*experts)
| 256*16*16 | 4096 | 256 | 0ns | 0ns | Mo |
memory status of after compute all
torch.cuda.max_memory_allocated: 10.836 Go
torch.cuda.max_memory_reserved: 10.853 Go
---------------

```


### using efficient computation
| batch size | nb experts | nbFeatures | nb parameters | time | time/elt | time/elt/expert | peak memory used | time ratio | mem ratio |
|---|---|---|---|---|---|---|---|---|---|
| 256\*7*7 | 6 | 64 | 198_528 | 3.713 ms | 296ns | 49ns | 64Mo | 4.35x slower | 2.16x less |
| 256\*7*7 | 24 | 64 | 794_112 | 11.261 ms | 898ns | 37ns | 62Mo | 3.62x slower | 7.84x less |
| 256\*7*7 | 64 | 64 | 2_117_632 | 34.244 ms | 2'730ns | 43ns | 71Mo | 3.30x slower | 17.82x less |
| 256\*7*7 | 128 | 64 | 4_235_264 | 70.829 ms | 5'646ns | 44ns | 86Mo | 2.36x slower | 29.22x less |
| 256\*7*7 | 256 | 64 | 8_470_528 | 118.720 ms | 9'464ns | 37ns | 115Mo | 1.30x slower | 43.55x less |
| 256\*7*7 | 512 | 64 | 16_941_056 | 257.920 ms | 20'561ns | 40ns | 174Mo | 1.18x slower | 57.55x less |
|---|---|---|---|---|---|---|---|---|
| 256\*7*7 | 6 | 128 | 790_272 | 4.172 ms | 333ns | 55ns | 112Mo | 3.78x slower | 2.41x less |
| 256\*7*7 | 24 | 128 | 3_161_088 | 13.765 ms | 1'097ns | 46ns | 117Mo | 3.85x slower | 8.26x less |
| 256\*7*7 | 64 | 128 | 8_429_568 | 33.498 ms | 2'670ns | 42ns | 141Mo | 3.54x slower | 17.94x less |
| 256\*7*7 | 128 | 128 | 16_859_136 | 67.827 ms | 5'407ns | 42ns | 180Mo | 2.77x slower | 27.97x less |
| 256\*7*7 | 256 | 128 | 33_718_272 | 130.381 ms | 1'0394ns | 41ns | 260Mo | 2.31x slower | 38.65x less | 
|---|---|---|---|---|---|---|---|---|
| 256\*16*16 | 6 | 128 | 790_272 | 7.599 ms | 122ns | 20ns | 534Mo | 6.78x slower | 2.55x less |
| 256\*16*16 | 24 | 128 | 3_161_088 | 15.490 ms | 251ns | 10ns | 520Mo | 3.64x slower | 9.53x less |
| 256\*16*16 | 64 | 128 | 8_429_568 | 34.557 ms | 527ns | 8ns | 568Mo | 3.40x slower | 22.95x less |
|---|---|---|---|---|---|---|---|---|
| 256\*16*16 | 6 | 256 | 3_153_408 | 15.938 ms | 243ns | 41ns | 1'062Mo | 11.57x slower | 2.55x less |
| 256\*16*16 | 24 | 256 | 12_613_632 | 23.315 ms | 356ns | 15ns | 1'061Mo | 4.81x slower | 9.34x less |
| 256\*16*16 | 64 | 256 | 33_636_352 | 41.253 ms | 602ns | 9.4ns | 1'136Mo | / | / |
| 256\*16*16 | 4096 | 256 | 2_152_726_528 | 2'965.0 ms | 45'236ns | 11ns | 11'714Mo | / | / |

### using computation with all experts
| batch size | nb experts | time | nbFeatures | time/elt | time/elt/expert | peak memory used |
|---|---|---|---|---|---|---|
| 256\*7*7 | 6 | 64 | 0.855 ms | 68ns | 11ns | 138Mo |
| 256\*7*7 | 24 | 64 | 3.107 ms | 248ns | 10ns | 486Mo |
| 256\*7*7 | 64 | 64 | 10.371 ms | 827ns | 13ns | 1'265Mo |
| 256\*7*7 | 128 | 64 | 30.070 ms | 2397ns | 19ns | 2'513Mo |
| 256\*7*7 | 256 | 64 | 91.033 ms | 7257ns | 28ns | 5'008Mo |
| 256\*7*7 | 512 | 64 | 217.858 ms | 17368ns | 34ns | 10'013Mo |
|---|---|---|---|---|---|---|
| 256\*7*7 | 6 | 128 | 1.099 ms | 88ns | 15ns | 270Mo |
| 256\*7*7 | 24 | 128 | 3.574 ms | 285ns | 12ns | 967Mo |
| 256\*7*7 | 64 | 128 | 9.471 ms | 755ns | 12ns | 2'530Mo |
| 256\*7*7 | 128 | 128 | 24.483 ms | 1952ns | 15ns | 5'035Mo |
| 256\*7*7 | 256 | 128 | 56.494 ms | 4504ns | 18ns | 10'048Mo |
|---|---|---|---|---|---|---|
| 256\*16*16 | 6 | 128 | 1.178 ms | 18ns | 3.0ns | 1'361Mo |
| 256\*16*16 | 24 | 128 | 4.552 ms | 69ns | 2.9ns | 4'953Mo |
| 256\*16*16 | 64 | 128 | 10.150 ms | 155ns | 2.4ns | 13'036Mo |
|---|---|---|---|---|---|---|
| 256\*16*16 | 6 | 256 | 1.406 ms | 21ns | 3.5ns | 2'712Mo |
| 256\*16*16 | 24 | 256 | 4.840 ms | 74ns | 3.1ns | 9'906Mo |
| 256\*16*16 | 64 | 128 | / | / | / | ~ 26Go |
| 256\*16*16 | 4096 | 256 | / | / | / | ~ 1.7To |


