# Complexity of large contractions

This notebook assesses computational complexity of `einsum` on large contractions. To generate data (pasted in the cells below), run:
```sh
GROWTH_SIZE=50 pytest -s tests/infer/test_enum.py -k growth
```

In [None]:
from matplotlib import pyplot
%matplotlib inline
%config InlineBackend.figure_format = 'svg'

In [None]:
sizes = None
costs = None
times1 = None
times2 = None

def plot(title):
    pyplot.figure(figsize=(8,5)).patch.set_color('white')
    pyplot.title('{} data structures'.format(title))
    for name, series in sorted(costs.items()):
        pyplot.plot(sizes, series, label=name)
    pyplot.xlabel('problem size')
    pyplot.xlim(0, max(sizes))
    pyplot.legend(loc='best')
    pyplot.tight_layout()

    pyplot.figure(figsize=(8,5)).patch.set_color('white')
    pyplot.title('{} run time'.format(title))
    pyplot.plot(sizes, times1, label='optim + compute')
    pyplot.plot(sizes, times2, label='compute')
    pyplot.xlim(0, max(sizes))
    pyplot.xlabel('problem size')
    pyplot.ylabel('time (sec)')
    pyplot.legend(loc='best')
    pyplot.tight_layout()

In [None]:
sizes = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50]
costs = {'einsum': [5, 11, 16, 21, 26, 31, 36, 41, 46, 51, 56, 61, 66, 71, 76, 81, 86, 91, 96, 101, 106, 111, 116, 121, 126, 131, 136, 141, 146, 151, 156, 161, 166, 171, 176, 181, 186, 191, 196, 201, 206, 211, 216, 221, 226, 231, 236, 241, 246], 'tensor': [5, 11, 17, 23, 29, 35, 41, 47, 53, 59, 65, 71, 77, 83, 89, 95, 101, 107, 113, 119, 125, 131, 137, 143, 149, 155, 161, 167, 173, 179, 185, 191, 197, 203, 209, 215, 221, 227, 233, 239, 245, 251, 257, 263, 269, 275, 281, 287, 293], 'tensordot': [1, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62, 64, 66, 68, 70, 72, 74, 76, 78, 80, 82, 84, 86, 88, 90, 92, 94, 96]}
times1 = [0.008840084075927734, 0.012933969497680664, 0.013402938842773438, 0.014806032180786133, 0.01769113540649414, 0.020265817642211914, 0.02418994903564453, 0.028639793395996094, 0.037155866622924805, 0.04076790809631348, 0.04963994026184082, 0.05830216407775879, 0.07251191139221191, 0.0738379955291748, 0.13530206680297852, 0.08753705024719238, 0.09625887870788574, 0.10686802864074707, 0.10413289070129395, 0.11687684059143066, 0.13324308395385742, 0.14621710777282715, 0.16944313049316406, 0.18355202674865723, 0.20417189598083496, 0.26263999938964844, 0.21850991249084473, 0.22833514213562012, 0.24388790130615234, 0.25165414810180664, 0.2781820297241211, 0.3772099018096924, 0.31725001335144043, 0.3150670528411865, 0.34282994270324707, 0.35338616371154785, 0.46845293045043945, 0.4063379764556885, 0.4095590114593506, 0.4369330406188965, 0.4560120105743408, 0.5853540897369385, 0.49960780143737793, 0.5211260318756104, 0.5366721153259277, 0.7134480476379395, 0.6329171657562256, 0.616657018661499, 0.6515438556671143]
times2 = [0.0037908554077148438, 0.008044958114624023, 0.009128093719482422, 0.010202169418334961, 0.013428926467895508, 0.013532161712646484, 0.015846967697143555, 0.020440101623535156, 0.021090030670166016, 0.026157140731811523, 0.026005983352661133, 0.030009984970092773, 0.0348970890045166, 0.03642916679382324, 0.04022407531738281, 0.04281187057495117, 0.04643416404724121, 0.05612993240356445, 0.05883002281188965, 0.06754899024963379, 0.07302999496459961, 0.08381986618041992, 0.08524203300476074, 0.09988212585449219, 0.09123611450195312, 0.09980082511901855, 0.10865092277526855, 0.11475586891174316, 0.11963605880737305, 0.11794900894165039, 0.13663697242736816, 0.13257503509521484, 0.14009690284729004, 0.15079498291015625, 0.1653740406036377, 0.17958593368530273, 0.1826319694519043, 0.18945717811584473, 0.19835782051086426, 0.21675801277160645, 0.2163541316986084, 0.23033690452575684, 0.2395031452178955, 0.25373101234436035, 0.2556610107421875, 0.2820451259613037, 0.28106188774108887, 0.29551100730895996, 0.300861120223999]

In [None]:
plot('HMM')

In [None]:
sizes = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50]
costs = {'einsum': [20, 31, 42, 53, 64, 75, 86, 97, 108, 119, 130, 141, 152, 163, 174, 185, 196, 207, 218, 229, 240, 251, 262, 273, 284, 295, 306, 317, 328, 339, 350, 361, 372, 383, 394, 405, 416, 427, 438, 449, 460, 471, 482, 493, 504, 515, 526, 537, 548], 'tensor': [19, 30, 42, 54, 66, 78, 90, 102, 114, 126, 138, 150, 162, 174, 186, 198, 210, 222, 234, 246, 258, 270, 282, 294, 306, 318, 330, 342, 354, 366, 378, 390, 402, 414, 426, 438, 450, 462, 474, 486, 498, 510, 522, 534, 546, 558, 570, 582, 594], 'tensordot': [2, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 43, 46, 49, 52, 55, 58, 61, 64, 67, 70, 73, 76, 79, 82, 85, 88, 91, 94, 97, 100, 103, 106, 109, 112, 115, 118, 121, 124, 127, 130, 133, 136, 139, 142, 145]}
times1 = [0.01014399528503418, 0.02048206329345703, 0.03332209587097168, 0.04213905334472656, 0.05838298797607422, 0.07877993583679199, 0.08360099792480469, 0.11620116233825684, 0.14879703521728516, 0.17735600471496582, 0.3362879753112793, 0.22052502632141113, 0.2483820915222168, 0.2895829677581787, 0.32854294776916504, 0.3544960021972656, 0.3895840644836426, 0.4500439167022705, 0.487637996673584, 0.5347411632537842, 0.7929251194000244, 0.6686029434204102, 0.7271420955657959, 0.7760331630706787, 0.8295149803161621, 0.9403810501098633, 1.0393428802490234, 1.1553730964660645, 1.2608709335327148, 1.4212009906768799, 1.6353850364685059, 1.5396859645843506, 1.5881941318511963, 1.611825942993164, 1.8403120040893555, 2.2241549491882324, 2.040048122406006, 2.1147170066833496, 2.2779018878936768, 2.4893229007720947, 2.394083023071289, 2.5438458919525146, 2.780435085296631, 3.338901996612549, 2.877775192260742, 3.0492379665374756, 3.5574729442596436, 3.8712759017944336, 3.52891206741333]
times2 = [0.008752107620239258, 0.015112876892089844, 0.018546104431152344, 0.022469043731689453, 0.035398006439208984, 0.04074811935424805, 0.04929399490356445, 0.061003923416137695, 0.08153605461120605, 0.08604192733764648, 0.0934908390045166, 0.11775994300842285, 0.12784194946289062, 0.14388203620910645, 0.15465903282165527, 0.17068195343017578, 0.1893310546875, 0.21833109855651855, 0.23193001747131348, 0.25475001335144043, 0.2785217761993408, 0.3174769878387451, 0.3349418640136719, 0.3678150177001953, 0.3949568271636963, 0.6391921043395996, 0.47622203826904297, 0.5376009941101074, 0.5879840850830078, 0.7873790264129639, 0.6488220691680908, 0.7457289695739746, 0.7377579212188721, 0.772442102432251, 0.8408718109130859, 0.8956530094146729, 0.9778108596801758, 1.0253138542175293, 1.443005084991455, 1.1630909442901611, 1.1220598220825195, 1.2206590175628662, 1.2912120819091797, 1.2837011814117432, 1.4531478881835938, 1.4732799530029297, 1.6431920528411865, 1.8390159606933594, 1.7101619243621826]

In [None]:
plot('DBN')