In [1]:
import random
import multiprocessing as mp
import time
import math


clebsch_gordon_conj = {} # complex conjugate of clebsch gordon coefficient.
for J in range(7):
    for j1 in range(7):
        for j2 in range(7):
            clebsch_gordon_conj[(J,j1,j2)] = [[[random.random() for _ in range(2*j2+1)] for _ in range(2*j1+1)] for _ in range(2*J+1)]


def tp_pathway(x1, x2, j1, j2, J, weight):
    result = [0 for _ in range(2*J+1)]
    Cstar = clebsch_gordon_conj[(J,j1,j2)] # (2*J+1, 2*j1+1, j*j2+1) array
    
    for M in range(2*J+1):
        for m1 in range(2*j1 + 1):
            for m2 in range(2*j2 + 1):
                result[M] += Cstar[M][m1][m2] * x1[m1] * x2[m2]
        
    return [x*weight for x in result] # multiply weight to the tp pathway

def parse_pathways(irreps_1, irreps_2, irreps_out):
    pathways = [] # list of pathway tuple (j1,j2,J,slice1,slice2,slice_out)
    
    cumul_idx_1 = 0
    for (mul_1, j1) in irreps_1:
        for _ in range(mul_1):
            cumul_idx_2 = 0
            for (mul_2, j2) in irreps_2:
                for _ in range(mul_2):
                    cumul_idx_out = 0
                    for (mul_out, J) in irreps_out:
                        for _ in range(mul_out):
                            if J <= j1+j2 and J >= abs(j1-j2):
                                pathways.append(
                                    (
                                        j1, 
                                        j2, 
                                        J, 
                                        slice(cumul_idx_1, cumul_idx_1 + (2*j1+1)),
                                        slice(cumul_idx_2, cumul_idx_2 + (2*j2+1)),
                                        slice(cumul_idx_out, cumul_idx_out + (2*J+1)),
                                    )
                                )
                            else:
                                pass
                            cumul_idx_out += (2*J+1)
                    cumul_idx_2 += (2*j2+1)
            cumul_idx_1 += (2*j1+1)
    return pathways

def count_ndim(irreps):
    ndim = 0
    for (mul, j) in irreps:
        ndim += mul * (2*j+1)
    return ndim

def _fctp(query):
    x1, x2, weighted_pathways = query
    
    pathway_outputs = []
    for pathway, weight in weighted_pathways:
        j1,j2,J,slice_1,slice_2,slice_out = pathway
        pathway_outputs.append(
            (tp_pathway(x1[slice_1],x2[slice_2],j1=j1,j2=j2,J=J,weight=weight), pathway)
        )
    return pathway_outputs
        
def fully_connected_tp(x1, x2, weighted_pathways, n_thread=1):
    
    n_pathways = int(math.ceil(len(weighted_pathways) / n_thread))
    queries = []
    for n in range(n_thread):
        queries.append((x1,x2,weighted_pathways[n*n_pathways:(n+1)*n_pathways]))

    with mp.Pool() as pool:
        outs = pool.map(_fctp, queries)
    
    pathway_outputs = []
    for sublist in outs:
        pathway_outputs.extend(sublist)
        
    return pathway_outputs

def collate_tp_output(pathway_outputs, irreps_out):
    x_out = [0 for _ in range(count_ndim(irreps_out))]
    
    for output, pathway in pathway_outputs:
        j1,j2,J,slice_1,slice_2,slice_out = pathway
        # assert len(output) == slice_out.stop - slice_out.start
        for val,i in zip(output, range(slice_out.start, slice_out.stop)):
            # print(i,val)
            x_out[i] += val
    return x_out

In [2]:
benchmarks = [
    dict(
        irreps_1 = [(30,0), (30,1), (30,2)], # (mul, j)
        irreps_2 = [(30,0), (30,1), (30,2)], # (mul, j)
        irreps_out = [(30,0), (30,1), (30,2)], # (mul, j)
    ),
    dict(
        irreps_1 = [(30,0), (30,1), (30,2)], # (mul, j)
        irreps_2 = [(30,0), (30,1), (30,2)], # (mul, j)
        irreps_out = [(30,0), (30,1), (30,2)], # (mul, j)
    ),
    dict(
        irreps_1 = [(30,0), (30,1), (30,2)], # (mul, j)
        irreps_2 = [(30,0), (30,1), (30,2)], # (mul, j)
        irreps_out = [(30,0), (30,1), (30,2)], # (mul, j)
    ),
    dict(
        irreps_1 = [(30,2), (30,3), (30,4)], # (mul, j)
        irreps_2 = [(30,2), (30,3), (30,4)], # (mul, j)
        irreps_out = [(30,2), (30,3), (30,4)], # (mul, j)
    ),
    dict(
        irreps_1 = [(30,2), (30,3), (30,4)], # (mul, j)
        irreps_2 = [(30,2), (30,3), (30,4)], # (mul, j)
        irreps_out = [(30,2), (30,3), (30,4)], # (mul, j)
    ),
    dict(
        irreps_1 = [(30,2), (30,3), (30,4)], # (mul, j)
        irreps_2 = [(30,2), (30,3), (30,4)], # (mul, j)
        irreps_out = [(30,2), (30,3), (30,4)], # (mul, j)
    ),
    dict(
        irreps_1 = [(70,0), (70,5)], # (mul, j)
        irreps_2 = [(70,0), (70,1)], # (mul, j)
        irreps_out = [(70,0), (70,4)], # (mul, j)
    ),
    dict(
        irreps_1 = [(70,0), (70,5)], # (mul, j)
        irreps_2 = [(70,0), (70,1)], # (mul, j)
        irreps_out = [(70,0), (70,4)], # (mul, j)
    ),
    dict(
        irreps_1 = [(70,0), (70,5)], # (mul, j)
        irreps_2 = [(70,0), (70,1)], # (mul, j)
        irreps_out = [(70,0), (70,4)], # (mul, j)
    ),
]

In [3]:
for n, benchmark in enumerate(benchmarks):
    irreps_1, irreps_2, irreps_out = benchmark['irreps_1'], benchmark['irreps_2'], benchmark['irreps_out']

    x1 = [random.random() for _ in range(count_ndim(irreps_1))]
    x2 = [random.random() for _ in range(count_ndim(irreps_2))]

    pathways = parse_pathways(irreps_1, irreps_2, irreps_out)
    pathway_weights = [random.random() for _ in range(len(pathways))]
    weighted_pathways = list(zip(pathways, pathway_weights))
    # random.shuffle(weighted_pathways)

    time_init = time.time()
    pathway_outputs = fully_connected_tp(x1,x2,weighted_pathways, n_thread=10)
    time_end = time.time()
    x_out = collate_tp_output(pathway_outputs,irreps_out)
    print(f"Benchmark #{n}: {time_end - time_init} sec")

Benchmark #0: 3.409900426864624 sec
Benchmark #1: 4.896377086639404 sec
Benchmark #2: 5.096034049987793 sec
Benchmark #3: 10.633584976196289 sec
Benchmark #4: 10.106495141983032 sec
Benchmark #5: 10.578734874725342 sec
Benchmark #6: 14.299248933792114 sec
Benchmark #7: 9.251941442489624 sec
Benchmark #8: 11.39824366569519 sec


In [4]:
for n, benchmark in enumerate(benchmarks):
    irreps_1, irreps_2, irreps_out = benchmark['irreps_1'], benchmark['irreps_2'], benchmark['irreps_out']

    x1 = [random.random() for _ in range(count_ndim(irreps_1))]
    x2 = [random.random() for _ in range(count_ndim(irreps_2))]

    pathways = parse_pathways(irreps_1, irreps_2, irreps_out)
    pathway_weights = [random.random() for _ in range(len(pathways))]
    weighted_pathways = list(zip(pathways, pathway_weights))
    random.shuffle(weighted_pathways)

    time_init = time.time()
    pathway_outputs = fully_connected_tp(x1,x2,weighted_pathways, n_thread=10)
    time_end = time.time()
    x_out = collate_tp_output(pathway_outputs,irreps_out)
    print(f"Benchmark #{n}: {time_end - time_init} sec")

Benchmark #0: 5.8359575271606445 sec
Benchmark #1: 4.9255945682525635 sec
Benchmark #2: 4.27514123916626 sec
Benchmark #3: 12.266233682632446 sec
Benchmark #4: 17.75212335586548 sec
Benchmark #5: 14.629468441009521 sec
Benchmark #6: 8.654037237167358 sec
Benchmark #7: 8.46125316619873 sec
Benchmark #8: 11.350107908248901 sec


In [5]:
for n, benchmark in enumerate(benchmarks):
    irreps_1, irreps_2, irreps_out = benchmark['irreps_1'], benchmark['irreps_2'], benchmark['irreps_out']

    x1 = [random.random() for _ in range(count_ndim(irreps_1))]
    x2 = [random.random() for _ in range(count_ndim(irreps_2))]

    pathways = parse_pathways(irreps_1, irreps_2, irreps_out)
    pathway_weights = [random.random() for _ in range(len(pathways))]
    weighted_pathways = list(zip(pathways, pathway_weights))

    time_init = time.time()
    pathway_outputs = fully_connected_tp(x1,x2,weighted_pathways, n_thread=1)
    time_end = time.time()
    x_out = collate_tp_output(pathway_outputs,irreps_out)
    print(f"Benchmark #{n}: {time_end - time_init} sec")

Benchmark #0: 17.01779270172119 sec
Benchmark #1: 10.721267461776733 sec
Benchmark #2: 13.802447080612183 sec
Benchmark #3: 47.55847477912903 sec
Benchmark #4: 46.29413414001465 sec
Benchmark #5: 48.74101972579956 sec
Benchmark #6: 32.651564598083496 sec
Benchmark #7: 32.034849643707275 sec
Benchmark #8: 31.428672552108765 sec
