# Complexity of large contractions

This notebook assesses computational complexity of `einsum` on large contractions. To generate data (pasted in the cells below), run:
```sh
GROWTH_SIZE=50 pytest -s tests/infer/test_enum.py -k growth
```

In [None]:
from matplotlib import pyplot
%matplotlib inline
%config InlineBackend.figure_format = 'svg'

In [None]:
sizes = None
costs = None
times1 = None
times2 = None

def plot(title):
    pyplot.figure(figsize=(8,5)).patch.set_color('white')
    pyplot.title('{} data structures'.format(title))
    for name, series in sorted(costs.items()):
        pyplot.plot(sizes, series, label=name)
    pyplot.xlabel('problem size')
    pyplot.xlim(0, max(sizes))
    pyplot.legend(loc='best')
    pyplot.tight_layout()

    pyplot.figure(figsize=(8,5)).patch.set_color('white')
    pyplot.title('{} run time'.format(title))
    pyplot.plot(sizes, times1, label='optim + compute')
    pyplot.plot(sizes, times2, label='compute')
    pyplot.xlim(0, max(sizes))
    pyplot.xlabel('problem size')
    pyplot.ylabel('time (sec)')
    pyplot.legend(loc='best')
    pyplot.tight_layout()

In [None]:
sizes = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50]
costs = {'einsum': [2, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 43, 46, 49, 52, 55, 58, 61, 64, 67, 70, 73, 76, 79, 82, 85, 88, 91, 94, 97, 100, 103, 106, 109, 112, 115, 118, 121, 124, 127, 130, 133, 136, 139, 142, 145, 148], 'tensor': [2, 8, 13, 18, 23, 30, 36, 43, 51, 61, 63, 79, 85, 93, 105, 118, 129, 140, 165, 176, 167, 169, 179, 192, 201, 211, 221, 237, 244, 253, 265, 295, 301, 308, 312, 344, 351, 358, 421, 348, 362, 376, 382, 396, 412, 426, 441, 453, 464], 'tensordot': [1, 3, 6, 9, 12, 17, 21, 26, 32, 40, 40, 54, 58, 64, 74, 85, 94, 103, 126, 135, 124, 124, 132, 143, 150, 158, 166, 180, 185, 192, 202, 230, 234, 239, 241, 271, 276, 281, 342, 267, 279, 291, 295, 307, 321, 333, 346, 356, 365]}
times1 = [0.008372783660888672, 0.009359121322631836, 0.011464118957519531, 0.015782833099365234, 0.016071796417236328, 0.01944112777709961, 0.023526906967163086, 0.030776023864746094, 0.03602004051208496, 0.04520106315612793, 0.04832887649536133, 0.062474966049194336, 0.07393193244934082, 0.0805349349975586, 0.08371806144714355, 0.0923609733581543, 0.11249184608459473, 0.11324906349182129, 0.14364004135131836, 0.14569592475891113, 0.2268989086151123, 0.18666291236877441, 0.17470383644104004, 0.19195914268493652, 0.20164704322814941, 0.2327430248260498, 0.22559595108032227, 0.25400304794311523, 0.3351759910583496, 0.27175092697143555, 0.29911088943481445, 0.34028100967407227, 0.352888822555542, 0.4202589988708496, 0.38995909690856934, 0.423145055770874, 0.40044283866882324, 0.4354410171508789, 0.5694520473480225, 0.5093300342559814, 0.5235121250152588, 0.5504801273345947, 0.6867399215698242, 0.6036691665649414, 0.5951530933380127, 0.6306290626525879, 0.7909841537475586, 0.6858761310577393, 0.7134449481964111]
times2 = [0.005914211273193359, 0.007597923278808594, 0.009119987487792969, 0.00985407829284668, 0.010882139205932617, 0.013438940048217773, 0.01553797721862793, 0.021770000457763672, 0.0218658447265625, 0.024719953536987305, 0.032356977462768555, 0.03748202323913574, 0.0403439998626709, 0.04552507400512695, 0.04369306564331055, 0.057772159576416016, 0.06459307670593262, 0.06221294403076172, 0.07624697685241699, 0.08031320571899414, 0.09099507331848145, 0.09505510330200195, 0.08849000930786133, 0.0975649356842041, 0.10746312141418457, 0.10596489906311035, 0.10487890243530273, 0.11904406547546387, 0.12820792198181152, 0.1451740264892578, 0.1548779010772705, 0.16187405586242676, 0.17342615127563477, 0.17133617401123047, 0.18398308753967285, 0.2108469009399414, 0.19859814643859863, 0.21503710746765137, 0.22825002670288086, 0.22601985931396484, 0.25327301025390625, 0.2293708324432373, 0.2724459171295166, 0.2701280117034912, 0.2770829200744629, 0.28898096084594727, 0.2914009094238281, 0.333981990814209, 0.3231480121612549]

In [None]:
plot('HMM')

In [None]:
sizes = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50]
costs = {'einsum': [14, 26, 39, 51, 64, 79, 96, 115, 136, 115, 128, 142, 149, 164, 178, 189, 194, 209, 226, 241, 243, 253, 265, 274, 313, 321, 332, 344, 353, 1088, 1101, 1110, 1125, 1146, 1169, 1194, 1221, 1250, 1281, 1303, 1326, 772, 805, 840, 877, 916, 957, 1000, 1028], 'tensor': [16, 37, 51, 64, 82, 103, 129, 158, 191, 149, 169, 190, 192, 214, 233, 247, 249, 270, 295, 317, 317, 331, 348, 356, 409, 421, 435, 457, 464, 1867, 1886, 1896, 1919, 1953, 1990, 2031, 2077, 2126, 2179, 2214, 2253, 1235, 1296, 1361, 1431, 1505, 1582, 1664, 1713], 'tensordot': [5, 16, 19, 22, 29, 37, 48, 60, 74, 55, 64, 73, 70, 79, 86, 91, 90, 98, 108, 117, 117, 123, 130, 131, 147, 153, 158, 170, 170, 840, 848, 851, 861, 876, 892, 910, 931, 953, 977, 992, 1010, 548, 578, 610, 645, 682, 720, 761, 784]}
times1 = [0.009994029998779297, 0.02237081527709961, 0.033674001693725586, 0.04198598861694336, 0.0538630485534668, 0.07793188095092773, 0.09697389602661133, 0.12314987182617188, 0.1406879425048828, 0.17415308952331543, 0.20070600509643555, 0.23083710670471191, 0.2576451301574707, 0.30806684494018555, 0.34606099128723145, 0.38559699058532715, 0.4183800220489502, 0.4967918395996094, 0.5417690277099609, 0.6058049201965332, 0.6377570629119873, 0.716437816619873, 0.7564048767089844, 0.8115310668945312, 0.9168391227722168, 1.0167579650878906, 1.0937302112579346, 1.156134843826294, 1.2684600353240967, 1.8358500003814697, 1.5880110263824463, 1.749824047088623, 1.8276331424713135, 2.1938741207122803, 2.119447946548462, 2.164933919906616, 2.2648820877075195, 2.7475831508636475, 2.4922289848327637, 2.6320111751556396, 2.9630181789398193, 2.8099000453948975, 2.8417391777038574, 3.023730993270874, 3.2213869094848633, 3.393570899963379, 3.5461208820343018, 3.6010329723358154, 3.695289134979248]
times2 = [0.007719993591308594, 0.015820980072021484, 0.021553993225097656, 0.028441190719604492, 0.03739118576049805, 0.04699110984802246, 0.056149959564208984, 0.06093597412109375, 0.0740509033203125, 0.08436298370361328, 0.10250306129455566, 0.1187129020690918, 0.1287248134613037, 0.14644718170166016, 0.15980291366577148, 0.17705607414245605, 0.2018280029296875, 0.22140002250671387, 0.4134490489959717, 0.2622520923614502, 0.2707970142364502, 0.30434203147888184, 0.33132004737854004, 0.35626697540283203, 0.6417820453643799, 0.4756011962890625, 0.5306429862976074, 0.5519900321960449, 0.5825059413909912, 0.8147339820861816, 0.8448278903961182, 0.8757419586181641, 0.9480810165405273, 1.0090808868408203, 1.081186056137085, 1.156675100326538, 1.1782598495483398, 1.233062982559204, 1.3140740394592285, 1.3824069499969482, 1.8663380146026611, 1.3181798458099365, 1.3746099472045898, 1.5419700145721436, 1.9973039627075195, 1.6606531143188477, 1.7595770359039307, 1.77606201171875, 2.3110439777374268]

In [None]:
plot('DBN')