In [40]:
import sympy as sp

def compute_final_expr(evals, A_list):
    d = len(evals)
    
    # Create a symbolic d x d matrix rho
    rho_symbols = [
        [sp.Symbol(f"rho_{i},{j}") for j in range(d)]
        for i in range(d)
    ]
    rho = sp.Matrix(rho_symbols)
    
    final_exprs = []
    
    for idx, A in enumerate(A_list):
        A = sp.Matrix(A)
        
        # Define gamma as a symbolic function
        gamma = sp.Function(f"gamma_{idx}")
        
        # Multiply each A_ij by gamma(evals[i] - evals[j])
        for i in range(d):
            for j in range(d):
                A[i, j] *= sp.sqrt(gamma(sp.simplify(-evals[i] + evals[j])))
        
        grouped_matrices = {}
        for i in range(d):
            for j in range(d):
                expr = A[i, j]
                # Look for gamma(...) in each element
                for f in expr.atoms(sp.Function):
                    if f.func == gamma:
                        arg = f.args[0]
                        # Factor out the gamma(...) to find the remaining coefficient
                        coeff = sp.simplify(expr)
                        if arg not in grouped_matrices:
                            grouped_matrices[arg] = sp.zeros(d)
                        grouped_matrices[arg][i, j] = coeff
        
        new_terms = {}
        for arg, A_mat in grouped_matrices.items():
            print(A_mat)
            A_mat_d = A_mat.T
            new_expr = A_mat * rho * A_mat_d - sp.Rational(1, 2) * (A_mat_d * A_mat * rho + rho * A_mat_d * A_mat)
            new_terms[arg] = new_expr
        
        # Sum all terms together, starting with a zero matrix instead of an integer
        final_expr = sum(new_terms.values(), sp.zeros(d, d))
        final_exprs.append(final_expr)
    
    return final_exprs,grouped_matrices

# Example usage
# d = 4
# evals = [
#     0 if i == 0 else sp.Symbol("epsilon_1") if i == 1 else sp.Symbol(f"epsilon_{i}")
#     for i in range(d)
# ]
# evals[3] = sp.Symbol(f"epsilon_{1}") + sp.Symbol(f"epsilon_{2}")
# import numpy as np
# A_list = [
#     np.kron(np.array([[0,1],[1,0]]),np.eye(2)),
# ]
# A_list = [
#     np.sqrt(1/2)*np.kron(np.array([[-1,0],[0,1]]),np.eye(2)),
# ]
d = 2
evals = [
    0 if i == 0 else sp.Symbol("epsilon1") if i == 1 else sp.Symbol(f"epsilon_{i}")
    for i in range(d)
]
A_list = [
    [[0, 1], [1, 0]],
    # [[-sp.sqrt(1/2), 0], [0, sp.sqrt(1/2)]]
]

final_exprs,matrices = compute_final_expr(evals, A_list)

Matrix([[0, sqrt(gamma_0(epsilon1))], [0, 0]])
Matrix([[0, 0], [sqrt(gamma_0(-epsilon1)), 0]])


In [45]:
from sympy.physics.quantum import TensorProduct
rho_symbols = [
        [sp.Symbol(f"s_{i},{j}") for j in range(2)]
        for i in range(2)
    ]
rho = sp.Matrix(rho_symbols)
TensorProduct(final_exprs[0],rho)

Matrix([
[    s_0,0*(-rho_0,0*gamma_0(-epsilon1) + rho_1,1*gamma_0(epsilon1)),     s_0,1*(-rho_0,0*gamma_0(-epsilon1) + rho_1,1*gamma_0(epsilon1)), s_0,0*(-rho_0,1*gamma_0(-epsilon1)/2 - rho_0,1*gamma_0(epsilon1)/2), s_0,1*(-rho_0,1*gamma_0(-epsilon1)/2 - rho_0,1*gamma_0(epsilon1)/2)],
[    s_1,0*(-rho_0,0*gamma_0(-epsilon1) + rho_1,1*gamma_0(epsilon1)),     s_1,1*(-rho_0,0*gamma_0(-epsilon1) + rho_1,1*gamma_0(epsilon1)), s_1,0*(-rho_0,1*gamma_0(-epsilon1)/2 - rho_0,1*gamma_0(epsilon1)/2), s_1,1*(-rho_0,1*gamma_0(-epsilon1)/2 - rho_0,1*gamma_0(epsilon1)/2)],
[s_0,0*(-rho_1,0*gamma_0(-epsilon1)/2 - rho_1,0*gamma_0(epsilon1)/2), s_0,1*(-rho_1,0*gamma_0(-epsilon1)/2 - rho_1,0*gamma_0(epsilon1)/2),      s_0,0*(rho_0,0*gamma_0(-epsilon1) - rho_1,1*gamma_0(epsilon1)),      s_0,1*(rho_0,0*gamma_0(-epsilon1) - rho_1,1*gamma_0(epsilon1))],
[s_1,0*(-rho_1,0*gamma_0(-epsilon1)/2 - rho_1,0*gamma_0(epsilon1)/2), s_1,1*(-rho_1,0*gamma_0(-epsilon1)/2 - rho_1,0*gamma_0(epsilon1)/2),      s_1,0*(rho_0,

In [36]:
matrices

{epsilon_2: Matrix([
 [0, 0, 1.0*sqrt(gamma_0(epsilon_2)),                            0],
 [0, 0,                            0, 1.0*sqrt(gamma_0(epsilon_2))],
 [0, 0,                            0,                            0],
 [0, 0,                            0,                            0]]),
 -epsilon_2: Matrix([
 [                            0,                             0, 0, 0],
 [                            0,                             0, 0, 0],
 [1.0*sqrt(gamma_0(-epsilon_2)),                             0, 0, 0],
 [                            0, 1.0*sqrt(gamma_0(-epsilon_2)), 0, 0]])}

In [39]:
np.kron(np.array([[0,1],[0,0]]),np.eye(2))

array([[0., 0., 1., 0.],
       [0., 0., 0., 1.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.]])

In [37]:
final_exprs[0]

Matrix([
[-1.0*rho_0,0*gamma_0(-epsilon_2) + 1.0*rho_2,2*gamma_0(epsilon_2), -1.0*rho_0,1*gamma_0(-epsilon_2) + 1.0*rho_2,3*gamma_0(epsilon_2), -0.5*rho_0,2*gamma_0(-epsilon_2) - 0.5*rho_0,2*gamma_0(epsilon_2), -0.5*rho_0,3*gamma_0(-epsilon_2) - 0.5*rho_0,3*gamma_0(epsilon_2)],
[-1.0*rho_1,0*gamma_0(-epsilon_2) + 1.0*rho_3,2*gamma_0(epsilon_2), -1.0*rho_1,1*gamma_0(-epsilon_2) + 1.0*rho_3,3*gamma_0(epsilon_2), -0.5*rho_1,2*gamma_0(-epsilon_2) - 0.5*rho_1,2*gamma_0(epsilon_2), -0.5*rho_1,3*gamma_0(-epsilon_2) - 0.5*rho_1,3*gamma_0(epsilon_2)],
[-0.5*rho_2,0*gamma_0(-epsilon_2) - 0.5*rho_2,0*gamma_0(epsilon_2), -0.5*rho_2,1*gamma_0(-epsilon_2) - 0.5*rho_2,1*gamma_0(epsilon_2),  1.0*rho_0,0*gamma_0(-epsilon_2) - 1.0*rho_2,2*gamma_0(epsilon_2),  1.0*rho_0,1*gamma_0(-epsilon_2) - 1.0*rho_2,3*gamma_0(epsilon_2)],
[-0.5*rho_3,0*gamma_0(-epsilon_2) - 0.5*rho_3,0*gamma_0(epsilon_2), -0.5*rho_3,1*gamma_0(-epsilon_2) - 0.5*rho_3,1*gamma_0(epsilon_2),  1.0*rho_1,0*gamma_0(-epsilon_2) - 1.0*rho_3,2

In [18]:
final_exprs[0]

Matrix([
[            -1.0*rho_0,0*gamma_0(-epsilon_2) + 1.0*rho_2,2*gamma_0(epsilon_2),            -0.5*rho_0,1*gamma_0(-epsilon_2) - 0.5*rho_0,1*gamma_0(epsilon1 - epsilon_3),            -0.5*rho_0,2*gamma_0(-epsilon_2) - 0.5*rho_0,2*gamma_0(epsilon_2),           -0.5*rho_0,3*gamma_0(-epsilon_2) - 0.5*rho_0,3*gamma_0(-epsilon1 + epsilon_3)],
[ -0.5*rho_1,0*gamma_0(-epsilon_2) - 0.5*rho_1,0*gamma_0(epsilon1 - epsilon_3), -1.0*rho_1,1*gamma_0(epsilon1 - epsilon_3) + 1.0*rho_3,3*gamma_0(-epsilon1 + epsilon_3),  -0.5*rho_1,2*gamma_0(epsilon_2) - 0.5*rho_1,2*gamma_0(epsilon1 - epsilon_3), -0.5*rho_1,3*gamma_0(-epsilon1 + epsilon_3) - 0.5*rho_1,3*gamma_0(epsilon1 - epsilon_3)],
[            -0.5*rho_2,0*gamma_0(-epsilon_2) - 0.5*rho_2,0*gamma_0(epsilon_2),             -0.5*rho_2,1*gamma_0(epsilon_2) - 0.5*rho_2,1*gamma_0(epsilon1 - epsilon_3),             1.0*rho_0,0*gamma_0(-epsilon_2) - 1.0*rho_2,2*gamma_0(epsilon_2),            -0.5*rho_2,3*gamma_0(epsilon_2) - 0.5*rho_2,3*gamma_0(-epsi

In [13]:
final_exprs[0]

Matrix([
[    -rho_0,0*gamma_0(-epsilon1) + rho_1,1*gamma_0(epsilon1), -rho_0,1*gamma_0(-epsilon1)/2 - rho_0,1*gamma_0(epsilon1)/2],
[-rho_1,0*gamma_0(-epsilon1)/2 - rho_1,0*gamma_0(epsilon1)/2,      rho_0,0*gamma_0(-epsilon1) - rho_1,1*gamma_0(epsilon1)]])

In [14]:
final_exprs[1]

Matrix([
[                      0, -1.0*rho_0,1*gamma_1(0)],
[-1.0*rho_1,0*gamma_1(0),                       0]])