# Complexity of large contractions

This notebook assesses computational complexity of `einsum` on large contractions. To generate data (pasted in the cells below), run:
```sh
GROWTH_SIZE=50 pytest -s tests/infer/test_enum.py -k growth
```

In [None]:
from matplotlib import pyplot
%matplotlib inline
%config InlineBackend.figure_format = 'svg'

In [None]:
sizes = None
costs = None
times1 = None
times2 = None

def plot(title):
    pyplot.figure(figsize=(8,5)).patch.set_color('white')
    pyplot.title('{} data structures'.format(title))
    for name, series in sorted(costs.items()):
        pyplot.plot(sizes, series, label=name)
    pyplot.xlabel('problem size')
    pyplot.xlim(0, max(sizes))
    pyplot.legend(loc='best')
    pyplot.tight_layout()

    pyplot.figure(figsize=(8,5)).patch.set_color('white')
    pyplot.title('{} run time'.format(title))
    pyplot.plot(sizes, times1, label='optim + compute')
    pyplot.plot(sizes, times2, label='compute')
    pyplot.xlim(0, max(sizes))
    pyplot.xlabel('problem size')
    pyplot.ylabel('time (sec)')
    pyplot.legend(loc='best')
    pyplot.tight_layout()

In [None]:
sizes = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50]
costs = {'einsum': [2, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48, 51, 54, 57, 60, 63, 66, 69, 72, 75, 78, 81, 84, 87, 90, 93, 96, 99, 102, 105, 108, 111, 114, 117, 120, 123, 126, 129, 132, 135, 138, 141, 144, 147], 'tensor': [2, 6, 10, 14, 18, 22, 26, 30, 34, 38, 42, 46, 50, 54, 58, 62, 66, 70, 74, 78, 82, 86, 90, 94, 98, 102, 106, 110, 114, 118, 122, 126, 130, 134, 138, 142, 146, 150, 154, 158, 162, 166, 170, 174, 178, 182, 186, 190, 194], 'tensordot': [1, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62, 64, 66, 68, 70, 72, 74, 76, 78, 80, 82, 84, 86, 88, 90, 92, 94, 96]}
times1 = [0.009238958358764648, 0.0076100826263427734, 0.010653018951416016, 0.013020992279052734, 0.019320011138916016, 0.023880958557128906, 0.027106046676635742, 0.02959418296813965, 0.03682708740234375, 0.03769993782043457, 0.04551100730895996, 0.05617785453796387, 0.059137821197509766, 0.06655502319335938, 0.12415003776550293, 0.0887289047241211, 0.09188199043273926, 0.10331511497497559, 0.1072688102722168, 0.12005400657653809, 0.1294269561767578, 0.15391898155212402, 0.15573501586914062, 0.173508882522583, 0.20106291770935059, 0.2248392105102539, 0.22595691680908203, 0.26395392417907715, 0.2868459224700928, 0.24104595184326172, 0.3087790012359619, 0.3256571292877197, 0.41802000999450684, 0.35585904121398926, 0.34830403327941895, 0.3733639717102051, 0.36162710189819336, 0.49504995346069336, 0.42130112648010254, 0.49214792251586914, 0.4690511226654053, 0.621647834777832, 0.5029211044311523, 0.5283229351043701, 0.6394450664520264, 0.5988290309906006, 0.7579779624938965, 0.6156389713287354, 0.6547160148620605]
times2 = [0.003954172134399414, 0.006880044937133789, 0.007801055908203125, 0.01289510726928711, 0.012656927108764648, 0.014309167861938477, 0.01649188995361328, 0.02225184440612793, 0.019804954528808594, 0.029505014419555664, 0.028504133224487305, 0.03278398513793945, 0.030045032501220703, 0.037201881408691406, 0.048876047134399414, 0.04040193557739258, 0.04906010627746582, 0.05202484130859375, 0.06052207946777344, 0.0563511848449707, 0.06153106689453125, 0.08444690704345703, 0.08440399169921875, 0.08850812911987305, 0.12862896919250488, 0.18652892112731934, 0.11293220520019531, 0.1212620735168457, 0.120758056640625, 0.11810588836669922, 0.13967180252075195, 0.14763784408569336, 0.16614389419555664, 0.15189695358276367, 0.17238807678222656, 0.1842348575592041, 0.17844295501708984, 0.18040204048156738, 0.21193981170654297, 0.20941805839538574, 0.21959686279296875, 0.2227480411529541, 0.22359490394592285, 0.2915799617767334, 0.2692139148712158, 0.27510595321655273, 0.27895402908325195, 0.2873239517211914, 0.3136889934539795]

In [None]:
plot('HMM')

In [None]:
sizes = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50]
costs = {'einsum': [10, 19, 26, 33, 40, 47, 54, 61, 68, 75, 82, 89, 96, 103, 110, 117, 124, 131, 138, 145, 152, 159, 166, 173, 180, 187, 194, 201, 208, 215, 222, 229, 236, 243, 250, 257, 264, 271, 278, 285, 292, 299, 306, 313, 320, 327, 334, 341, 348], 'tensor': [11, 20, 28, 36, 44, 52, 60, 68, 76, 84, 92, 100, 108, 116, 124, 132, 140, 148, 156, 164, 172, 180, 188, 196, 204, 212, 220, 228, 236, 244, 252, 260, 268, 276, 284, 292, 300, 308, 316, 324, 332, 340, 348, 356, 364, 372, 380, 388, 396], 'tensordot': [4, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48, 51, 54, 57, 60, 63, 66, 69, 72, 75, 78, 81, 84, 87, 90, 93, 96, 99, 102, 105, 108, 111, 114, 117, 120, 123, 126, 129, 132, 135, 138, 141, 144, 147]}
times1 = [0.012120962142944336, 0.02348804473876953, 0.03629016876220703, 0.05564689636230469, 0.06866812705993652, 0.0749199390411377, 0.09441709518432617, 0.12393307685852051, 0.14230084419250488, 0.1649949550628662, 0.2086479663848877, 0.24796605110168457, 0.2746889591217041, 0.2975029945373535, 0.36155104637145996, 0.37074708938598633, 0.4215819835662842, 0.4682760238647461, 0.5100998878479004, 0.5894119739532471, 0.7875461578369141, 0.6795549392700195, 0.823577880859375, 0.8102381229400635, 0.8138418197631836, 1.1368579864501953, 1.3406708240509033, 1.0648369789123535, 1.2017300128936768, 1.2342798709869385, 1.4699361324310303, 1.672677993774414, 1.509364128112793, 1.654857873916626, 1.7111709117889404, 2.128197193145752, 2.253509998321533, 2.113978862762451, 2.2374460697174072, 2.5715529918670654, 2.2894680500030518, 2.388978958129883, 2.5751028060913086, 3.0885629653930664, 2.85188889503479, 2.96618914604187, 3.0755038261413574, 3.733936071395874, 3.396009922027588]
times2 = [0.008181095123291016, 0.01610589027404785, 0.026497840881347656, 0.026099205017089844, 0.035585880279541016, 0.040966033935546875, 0.05190682411193848, 0.06533694267272949, 0.06711101531982422, 0.08668303489685059, 0.09986996650695801, 0.2875368595123291, 0.12397003173828125, 0.13421416282653809, 0.16019892692565918, 0.16496610641479492, 0.1838369369506836, 0.21637511253356934, 0.2744572162628174, 0.26456403732299805, 0.2825338840484619, 0.31387805938720703, 0.34751009941101074, 0.37461090087890625, 0.3890390396118164, 0.4143550395965576, 0.48529911041259766, 0.48372912406921387, 0.5464189052581787, 0.5734119415283203, 0.6411929130554199, 0.7195680141448975, 0.6919949054718018, 0.8279869556427002, 0.8089959621429443, 0.8228888511657715, 1.0133180618286133, 1.0562009811401367, 1.0218558311462402, 1.0519750118255615, 1.0679600238800049, 1.1231629848480225, 1.2386372089385986, 1.2687039375305176, 1.3194000720977783, 1.4305448532104492, 1.4448800086975098, 1.5895040035247803, 1.610414981842041]

In [None]:
plot('DBN')