# Complexity of large contractions

This notebook assesses computational complexity of `einsum` on large contractions. To generate data (pasted in the cells below), run:
```sh
GROWTH_SIZE=50 pytest -s tests/infer/test_enum.py -k growth
```

In [None]:
from matplotlib import pyplot
%matplotlib inline
%config InlineBackend.figure_format = 'svg'

In [None]:
sizes = None
costs = None
times1 = None
times2 = None

def plot(title):
    pyplot.figure(figsize=(8,5)).patch.set_color('white')
    pyplot.title('{} data structures'.format(title))
    for name, series in sorted(costs.items()):
        pyplot.plot(sizes, series, label=name)
    pyplot.xlabel('problem size')
    pyplot.xlim(0, max(sizes))
    pyplot.legend(loc='best')
    pyplot.tight_layout()

    pyplot.figure(figsize=(8,5)).patch.set_color('white')
    pyplot.title('{} run time'.format(title))
    pyplot.plot(sizes, times1, label='optim + compute')
    pyplot.plot(sizes, times2, label='compute')
    pyplot.xlim(0, max(sizes))
    pyplot.xlabel('problem size')
    pyplot.ylabel('time (sec)')
    pyplot.legend(loc='best')
    pyplot.tight_layout()

In [None]:
sizes = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50]
costs = {'einsum': [2, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48, 51, 54, 57, 60, 63, 66, 69, 72, 75, 78, 81, 84, 87, 90, 93, 96, 99, 102, 105, 108, 111, 114, 117, 120, 123, 126, 129, 132, 135, 138, 141, 144, 147], 'tensor': [2, 6, 10, 14, 18, 22, 26, 30, 34, 38, 42, 46, 50, 54, 58, 62, 66, 70, 74, 78, 82, 86, 90, 94, 98, 102, 106, 110, 114, 118, 122, 126, 130, 134, 138, 142, 146, 150, 154, 158, 162, 166, 170, 174, 178, 182, 186, 190, 194], 'tensordot': [1, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62, 64, 66, 68, 70, 72, 74, 76, 78, 80, 82, 84, 86, 88, 90, 92, 94, 96]}
times1 = [0.007970094680786133, 0.0070438385009765625, 0.008547067642211914, 0.014365911483764648, 0.016897916793823242, 0.023266077041625977, 0.021757841110229492, 0.027092933654785156, 0.0368039608001709, 0.03662919998168945, 0.0506591796875, 0.04990720748901367, 0.05901598930358887, 0.07366204261779785, 0.06887388229370117, 0.08624100685119629, 0.09469795227050781, 0.09825706481933594, 0.1133430004119873, 0.12122893333435059, 0.1850879192352295, 0.14140701293945312, 0.15995383262634277, 0.16268491744995117, 0.18590593338012695, 0.19794702529907227, 0.20801591873168945, 0.23165488243103027, 0.3067150115966797, 0.2564411163330078, 0.2590038776397705, 0.27591800689697266, 0.27986788749694824, 0.3199141025543213, 0.4127638339996338, 0.341094970703125, 0.36144495010375977, 0.3946821689605713, 0.3910231590270996, 0.4228341579437256, 0.4439270496368408, 0.4661390781402588, 0.46268701553344727, 0.6442289352416992, 0.5377390384674072, 0.5631959438323975, 0.5840370655059814, 0.7751121520996094, 0.6266789436340332]
times2 = [0.003628969192504883, 0.005155086517333984, 0.010263919830322266, 0.011459112167358398, 0.011826038360595703, 0.013258934020996094, 0.019121170043945312, 0.016191959381103516, 0.018826961517333984, 0.028467893600463867, 0.02474188804626465, 0.026965856552124023, 0.03708696365356445, 0.03941798210144043, 0.0377202033996582, 0.0455169677734375, 0.04416608810424805, 0.05702090263366699, 0.057134151458740234, 0.06262588500976562, 0.07068705558776855, 0.07330203056335449, 0.07079315185546875, 0.07384896278381348, 0.0794670581817627, 0.10179495811462402, 0.10263395309448242, 0.09877896308898926, 0.10967493057250977, 0.12395191192626953, 0.12283015251159668, 0.12361001968383789, 0.14192509651184082, 0.13588380813598633, 0.1673719882965088, 0.16988682746887207, 0.16839909553527832, 0.17453384399414062, 0.28158998489379883, 0.205488920211792, 0.20661401748657227, 0.22320795059204102, 0.23522114753723145, 0.23851394653320312, 0.236083984375, 0.2559959888458252, 0.26449108123779297, 0.2799370288848877, 0.2828099727630615]

In [None]:
plot('HMM')

In [None]:
sizes = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50]
costs = {'einsum': [10, 19, 26, 33, 40, 47, 54, 61, 68, 75, 82, 89, 96, 103, 110, 117, 124, 131, 138, 145, 152, 159, 166, 173, 180, 187, 194, 201, 208, 215, 222, 229, 236, 243, 250, 257, 264, 271, 278, 285, 292, 299, 306, 313, 320, 327, 334, 341, 348], 'tensor': [11, 20, 28, 36, 44, 52, 60, 68, 76, 84, 92, 100, 108, 116, 124, 132, 140, 148, 156, 164, 172, 180, 188, 196, 204, 212, 220, 228, 236, 244, 252, 260, 268, 276, 284, 292, 300, 308, 316, 324, 332, 340, 348, 356, 364, 372, 380, 388, 396], 'tensordot': [4, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45, 48, 51, 54, 57, 60, 63, 66, 69, 72, 75, 78, 81, 84, 87, 90, 93, 96, 99, 102, 105, 108, 111, 114, 117, 120, 123, 126, 129, 132, 135, 138, 141, 144, 147]}
times1 = [0.010128974914550781, 0.017036914825439453, 0.03534197807312012, 0.04351019859313965, 0.05243515968322754, 0.07219195365905762, 0.08379006385803223, 0.10441899299621582, 0.13198399543762207, 0.15781497955322266, 0.18526601791381836, 0.21451497077941895, 0.25508689880371094, 0.2816019058227539, 0.3124680519104004, 0.3756380081176758, 0.5669558048248291, 0.44170498847961426, 0.48691892623901367, 0.5495238304138184, 0.5782699584960938, 0.6426401138305664, 0.7075190544128418, 0.9607229232788086, 0.8614039421081543, 0.9977800846099854, 1.0618460178375244, 1.1694221496582031, 1.5327379703521729, 1.3507568836212158, 1.334341049194336, 1.4470839500427246, 1.5896799564361572, 2.0667009353637695, 1.7337899208068848, 1.942903995513916, 1.8648650646209717, 2.00423002243042, 2.1050820350646973, 2.2586281299591064, 2.3501861095428467, 2.427324056625366, 2.5920469760894775, 2.661595106124878, 2.869610071182251, 3.574591875076294, 3.116729974746704, 3.220271110534668, 3.3714301586151123]
times2 = [0.007364034652709961, 0.015048027038574219, 0.0211489200592041, 0.022801876068115234, 0.031123876571655273, 0.03901100158691406, 0.046463966369628906, 0.05262184143066406, 0.06025385856628418, 0.07826995849609375, 0.08720993995666504, 0.10081791877746582, 0.11187314987182617, 0.12726187705993652, 0.14103198051452637, 0.16651415824890137, 0.20389413833618164, 0.2030019760131836, 0.22921991348266602, 0.23041200637817383, 0.2771189212799072, 0.3134799003601074, 0.3136579990386963, 0.36713600158691406, 0.39543604850769043, 0.46870899200439453, 0.47461915016174316, 0.5353858470916748, 0.6657888889312744, 0.6057989597320557, 0.6413729190826416, 0.6303989887237549, 0.7721939086914062, 0.8170120716094971, 0.7907459735870361, 0.8479220867156982, 1.218614101409912, 0.9534728527069092, 0.9779610633850098, 1.025270938873291, 1.4383699893951416, 1.17112398147583, 1.2239770889282227, 1.2886719703674316, 1.3440167903900146, 1.4269680976867676, 1.4906339645385742, 1.5290870666503906, 1.6254239082336426]

In [None]:
plot('DBN')