In [1]:
import numpy as np
import itertools
from tqdm.notebook import trange, tqdm

In [2]:
c = 3

In [3]:
# Generate some data
np.random.seed(42)
lambda1 = np.random.normal(size=(c, c))
lambda2 = np.random.normal(size=(c, c))
lambda3 = np.random.normal(size=(c, c))
G1 = np.random.normal(size=(c, c, c))
G2 = np.random.normal(size=(c, c, c))
U = np.random.normal(size=(c, c, c, c))

In [4]:
def Z_naive(lambda1, lambda2, lambda3, G1, G2, U):
    c = lambda1.shape[0]
    Z = np.zeros(shape=(c, c, c, c))
    for a, b, c, d, e, f, g, h, i, j in itertools.product(*([range(c)]*10)):
        Z[a, h, i, j] += lambda1[a, b]*lambda2[d, e]*lambda3[g, h]*G1[c, b, d]*G2[f, e, g]*U[i, j, c, f]
    return Z

In [5]:
%%timeit
Z = Z_naive(lambda1, lambda2, lambda3, G1, G2, U)

48.3 ms ± 258 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [6]:
Z = Z_naive(lambda1, lambda2, lambda3, G1, G2, U)
Z.shape

(3, 3, 3, 3)

In [7]:
# via np.einsum
%%timeit
Z=np.einsum('ab, cbd, de, feg, gh, ijcf -> ahij',
            lambda1, G1, lambda2, G2, lambda3, U,
            optimize = True)
    

139 µs ± 3.27 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [49]:
information=np.einsum_path('ab, cbd, de, feg, gh, ijcf -> ahij',
                         lambda1, G1, lambda2, G2, lambda3, U,
                         optimize = True)
print(information[1])

  Complete contraction:  ab,cbd,de,feg,gh,ijcf->ahij
         Naive scaling:  10
     Optimized scaling:  6
      Naive FLOP count:  3.543e+05
  Optimized FLOP count:  2.431e+03
   Theoretical speedup:  145.740
  Largest intermediate:  8.100e+01 elements
--------------------------------------------------------------------------
scaling                  current                                remaining
--------------------------------------------------------------------------
   4                 cbd,ab->acd                 de,feg,gh,ijcf,acd->ahij
   4                 feg,de->dfg                    gh,ijcf,acd,dfg->ahij
   4                 dfg,gh->dfh                       ijcf,acd,dfh->ahij
   5               dfh,acd->acfh                          ijcf,acfh->ahij
   6             acfh,ijcf->ahij                               ahij->ahij


In [68]:
%time
acd = np.tensordot(G1, lambda1, axes=([1],[1]))
dfg = np.tensordot(G2, lambda2, axes=([1],[1]))
dfh = np.tensordot(dfg, lambda3, axes=([1],[0]))
acfh = np.tensordot(dfh, acd, axes=([1], [1]))
ahij = np.tensordot(acfh, U, axes=([0, 2], [3, 2])) # Z

print(Z.shape == ahij.shape)
print(round(Z[0][0][0][0], 10) == round(ahij[0][0][0][0], 10) )

# even faster then np.einsum! (I think it just does not spend time on parsing 'ab, cbd, de, feg, gh, ijcf -> ahij')

CPU times: user 2 µs, sys: 0 ns, total: 2 µs
Wall time: 6.91 µs
True
True
