In [1]:
import random
from multiprocessing import Process
import time


clebsch_gordon_conj = {} # complex conjugate of clebsch gordon coefficient.
for J in range(7):
    for j1 in range(7):
        for j2 in range(7):
            clebsch_gordon_conj[(J,j1,j2)] = [[[random.random() for _ in range(2*j2+1)] for _ in range(2*j1+1)] for _ in range(2*J+1)]


def tp_pathway(x1, x2, j1, j2, J, weight):
    result = [0 for _ in range(2*J+1)]
    Cstar = clebsch_gordon_conj[(J,j1,j2)] # (2*J+1, 2*j1+1, j*j2+1) array
    
    for M in range(2*J+1):
        for m1 in range(2*j1 + 1):
            for m2 in range(2*j2 + 1):
                result[M] += Cstar[M][m1][m2] * x1[m1] * x2[m2]
        
    return [x*weight for x in result] # multiply weight to the tp pathway

def parse_pathways(irreps_1, irreps_2, irreps_out):
    pathways = [] # list of pathway tuple (j1,j2,J,slice1,slice2,slice_out)
    
    cumul_idx_1 = 0
    for (mul_1, j1) in irreps_1:
        for _ in range(mul_1):
            cumul_idx_2 = 0
            for (mul_2, j2) in irreps_2:
                for _ in range(mul_2):
                    cumul_idx_out = 0
                    for (mul_out, J) in irreps_out:
                        for _ in range(mul_out):
                            if J <= j1+j2 and J >= abs(j1-j2):
                                pathways.append(
                                    (
                                        j1, 
                                        j2, 
                                        J, 
                                        slice(cumul_idx_1, cumul_idx_1 + (2*j1+1)),
                                        slice(cumul_idx_2, cumul_idx_2 + (2*j2+1)),
                                        slice(cumul_idx_out, cumul_idx_out + (2*J+1)),
                                    )
                                )
                            else:
                                pass
                            cumul_idx_out += (2*J+1)
                    cumul_idx_2 += (2*j2+1)
            cumul_idx_1 += (2*j1+1)
    return pathways

def count_ndim(irreps):
    ndim = 0
    for (mul, j) in irreps:
        ndim += mul * (2*j+1)
    return ndim
        
def fully_connected_tp(x1, x2, pathways, pathway_weights):
    assert len(pathways) == len(pathway_weights)
    
    pathway_outputs = []
    for pathway, weight in zip(pathways, pathway_weights):
        j1,j2,J,slice_1,slice_2,slice_out = pathway
        pathway_outputs.append(
            tp_pathway(x1[slice_1],x2[slice_2],j1=j1,j2=j2,J=J,weight=weight)
        )
        
    return pathway_outputs

def collate_tp_output(pathway_outputs, pathways, irreps_out):
    x_out = [0 for _ in range(count_ndim(irreps_out))]
    for output, pathway in zip(pathway_outputs,pathways):
        j1,j2,J,slice_1,slice_2,slice_out = pathway
        # assert len(output) == slice_out.stop - slice_out.start
        for val,i in zip(output, range(slice_out.start, slice_out.stop)):
            # print(i,val)
            x_out[i] += val
    return x_out

# Prepare inputs

In [2]:
benchmarks = [
    dict(
        irreps_1 = [(30,0), (30,1), (30,2)], # (mul, j)
        irreps_2 = [(30,0), (30,1), (30,2)], # (mul, j)
        irreps_out = [(30,0), (30,1), (30,2)], # (mul, j)
    ),
    dict(
        irreps_1 = [(30,0), (30,1), (30,2)], # (mul, j)
        irreps_2 = [(30,0), (30,1), (30,2)], # (mul, j)
        irreps_out = [(30,0), (30,1), (30,2)], # (mul, j)
    ),
    dict(
        irreps_1 = [(30,0), (30,1), (30,2)], # (mul, j)
        irreps_2 = [(30,0), (30,1), (30,2)], # (mul, j)
        irreps_out = [(30,0), (30,1), (30,2)], # (mul, j)
    ),
    dict(
        irreps_1 = [(30,2), (30,3), (30,4)], # (mul, j)
        irreps_2 = [(30,2), (30,3), (30,4)], # (mul, j)
        irreps_out = [(30,2), (30,3), (30,4)], # (mul, j)
    ),
    dict(
        irreps_1 = [(30,2), (30,3), (30,4)], # (mul, j)
        irreps_2 = [(30,2), (30,3), (30,4)], # (mul, j)
        irreps_out = [(30,2), (30,3), (30,4)], # (mul, j)
    ),
    dict(
        irreps_1 = [(30,2), (30,3), (30,4)], # (mul, j)
        irreps_2 = [(30,2), (30,3), (30,4)], # (mul, j)
        irreps_out = [(30,2), (30,3), (30,4)], # (mul, j)
    ),
]

In [3]:
for n, benchmark in enumerate(benchmarks):
    irreps_1, irreps_2, irreps_out = benchmark['irreps_1'], benchmark['irreps_2'], benchmark['irreps_out']

    x1 = [random.random() for _ in range(count_ndim(irreps_1))]
    x2 = [random.random() for _ in range(count_ndim(irreps_2))]

    pathways = parse_pathways(irreps_1, irreps_2, irreps_out)
    pathway_weights = [random.random() for _ in range(len(pathways))]
    
    time_init = time.time()
    pathway_outputs = fully_connected_tp(x1,x2,pathways,pathway_weights)
    time_end = time.time()
    x_out = collate_tp_output(pathway_outputs,pathways,irreps_out)
    print(f"Benchmark #{n}: {time_end - time_init} sec")

Benchmark #0: 3.2630627155303955 sec
Benchmark #1: 3.3104305267333984 sec
Benchmark #2: 2.9817121028900146 sec
Benchmark #3: 31.366021156311035 sec
Benchmark #4: 30.879258155822754 sec
Benchmark #5: 30.766552448272705 sec
