In [89]:
import pyscf
import ffsim
import numpy as np
from opt_einsum import contract

In [90]:
# Build N2 molecule
mol = pyscf.gto.Mole()
mol.build(
    atom=[["N", (0, 0, 0)], ["N", (1.0, 0, 0)]],
    basis="sto-6g",
    # basis="6-31g",
    symmetry="Dooh",
)

# Define active space
n_frozen = 2
active_space = range(n_frozen, mol.nao_nr())

# Get molecular data and Hamiltonian
scf = pyscf.scf.RHF(mol).run()
mol_data = ffsim.MolecularData.from_scf(scf, active_space=active_space)
norb, nelec = mol_data.norb, mol_data.nelec
mol_hamiltonian = mol_data.hamiltonian
print(f"norb = {norb}")
print(f"nelec = {nelec}")

# Get CCSD t2 amplitudes for initializing the ansatz
ccsd = pyscf.cc.CCSD(
    scf, frozen=[i for i in range(mol.nao_nr()) if i not in active_space]
).run()

# Use 2 ansatz layers
n_reps = 2
# Use interactions implementable on a square lattice
pairs_aa = [(p, p + 1) for p in range(norb - 1)]
pairs_ab = [(p, p) for p in range(norb)]
ucj_op = ffsim.UCJOpSpinBalanced.from_t_amplitudes(
    ccsd.t2, t1=ccsd.t1, n_reps=n_reps, interaction_pairs=(pairs_aa, pairs_ab)
)

converged SCF energy = -108.464957764796
norb = 8
nelec = (5, 5)
E(CCSD) = -108.5933309085007  E_corr = -0.1283731437052351


In [91]:
nocc, _, nvrt, _ = ccsd.t2.shape
ccsd.t2.shape

(5, 5, 3, 3)

In [92]:
diag_coulomb_mats, orbital_rotations = ffsim.linalg.double_factorized_t2(ccsd.t2, tol=1e-8, max_vecs = 3)

In [93]:
diag_coulomb_mats.shape


(3, 2, 8, 8)

In [94]:
diag_coulomb_mats_reshape = diag_coulomb_mats.reshape(-1, norb, norb)
diag_coulomb_mats_reshape.shape

(6, 8, 8)

In [95]:
orbital_rotations.shape

(3, 2, 8, 8)

In [96]:
orbital_rotations_reshape = orbital_rotations.reshape(-1, norb, norb)
orbital_rotations_reshape.shape

(6, 8, 8)

In [121]:
n_reps = 6

In [122]:
reconstructed_ori = (
        1j
        * contract(
            "mkpq,mkap,mkip,mkbq,mkjq->ijab",
            diag_coulomb_mats[:n_reps],
            orbital_rotations[:n_reps],
            orbital_rotations[:n_reps].conj(),
            orbital_rotations[:n_reps],
            orbital_rotations[:n_reps].conj(),
            # optimize="greedy"
        )[:nocc, :nocc, nocc:, nocc:]
    )

reconstructed = (
        1j
        * contract(
            "mpq,map,mip,mbq,mjq->ijab",
            diag_coulomb_mats_reshape[:n_reps * 2],
            orbital_rotations_reshape[:n_reps * 2],
            orbital_rotations_reshape[:n_reps * 2].conj(),
            orbital_rotations_reshape[:n_reps * 2],
            orbital_rotations_reshape[:n_reps * 2].conj(),
            # optimize="greedy"
        )[:nocc, :nocc, nocc:, nocc:]
    )

In [123]:
diff = reconstructed - ccsd.t2
diff_ori = reconstructed_ori - ccsd.t2

In [120]:
0.5 * np.sum(np.abs(diff)**2)

np.float64(0.003367372975817081)

In [124]:
0.5 * np.sum(np.abs(diff_ori)**2)

np.float64(0.003367372975817081)

In [102]:
mask = np.zeros((norb, norb), dtype=bool)
rows, cols = zip(*pairs_aa)
mask[rows, cols] = True
mask[cols, rows] = True
mask

array([[False,  True, False, False, False, False, False, False],
       [ True, False,  True, False, False, False, False, False],
       [False,  True, False,  True, False, False, False, False],
       [False, False,  True, False,  True, False, False, False],
       [False, False, False,  True, False,  True, False, False],
       [False, False, False, False,  True, False,  True, False],
       [False, False, False, False, False,  True, False,  True],
       [False, False, False, False, False, False,  True, False]])

In [103]:
def double_factorized_t2_compressed(
    t2_amplitudes: np.ndarray, *, tol: float = 1e-8, max_vecs: int | None = None,
    interaction_pairs: tuple[
            list[tuple[int, int]] | None, list[tuple[int, int]] | None
        ]
        | None = None,
) -> tuple[np.ndarray, np.ndarray]:
    pass
    


In [104]:
theta = np.pi / 3
# U = np.array([
#     [np.exp(1j * theta), 0],
#     [0, np.exp(-1j * theta)]
# ], dtype=np.complex128)
U = orbital_rotations_reshape[1]
U

array([[-2.11507845e-18+0.00000000e+00j, -1.08839984e-18+0.00000000e+00j,
         5.27157390e-01+0.00000000e+00j, -1.68803619e-01+0.00000000e+00j,
         6.44760042e-01+0.00000000e+00j,  5.27157390e-01+0.00000000e+00j,
        -1.19593089e-17+0.00000000e+00j,  1.45042269e-17-0.00000000e+00j],
       [-2.24550520e-18-1.14936393e-16j,  1.22612234e-16+1.59592851e-16j,
        -1.77635684e-15-1.25728715e-32j, -9.67395058e-01+3.59338835e-18j,
        -2.53272189e-01+3.49459960e-17j, -1.66533454e-15-3.27785016e-32j,
         1.61268130e-16-9.56872020e-18j,  3.65673994e-17-7.88978835e-18j],
       [ 1.61427485e-01+1.75154458e-03j,  6.88423208e-01-3.40376215e-03j,
         0.00000000e+00+1.58741451e-33j, -1.94289029e-16+5.14733715e-17j,
         0.00000000e+00+7.14907938e-19j,  5.55111512e-17-1.90489741e-33j,
        -7.07007750e-01+2.58749102e-03j,  1.13487675e-02+2.13348871e-03j],
       [ 6.88430006e-01-1.49205258e-03j, -1.61410947e-01+2.89949349e-03j,
        -5.24900864e-17+1.82278203e

In [105]:
import scipy
tmp_H = scipy.linalg.logm(U)

In [106]:
eigs, vecs = np.linalg.eigh(-1j * tmp_H)

In [107]:
import scipy.linalg


new_U = np.einsum("ij,j,kj->ik", vecs, np.exp(1j * eigs), vecs.conj())
print(new_U[0])
print(U[0])
print(np.allclose(U, new_U))
new_U_expm = scipy.linalg.expm(tmp_H)
print(new_U_expm[0])
print(U[0])
print(np.allclose(U, new_U_expm))
print(np.allclose(new_U_expm, new_U, atol=1e-8))

[ 9.95731275e-16+8.69963823e-16j -1.17961196e-16-8.29197822e-16j
  5.27157390e-01+5.06539255e-16j -1.68803619e-01+3.05311332e-16j
  6.44760042e-01-5.72458747e-16j  5.27157390e-01-3.77302356e-16j
 -2.77555756e-16-4.37150316e-16j -3.92047506e-16-4.92661467e-16j]
[-2.11507845e-18+0.j -1.08839984e-18+0.j  5.27157390e-01+0.j
 -1.68803619e-01+0.j  6.44760042e-01+0.j  5.27157390e-01+0.j
 -1.19593089e-17+0.j  1.45042269e-17-0.j]
True
[ 1.44328993e-15+3.31332184e-16j -1.77809156e-16-8.11850587e-16j
  5.27157390e-01-3.58962094e-16j -1.68803619e-01-7.78949261e-17j
  6.44760042e-01-5.37235215e-16j  5.27157390e-01+2.32616104e-16j
 -1.90819582e-16-1.59594560e-16j -4.19030587e-16-7.21536546e-17j]
[-2.11507845e-18+0.j -1.08839984e-18+0.j  5.27157390e-01+0.j
 -1.68803619e-01+0.j  6.44760042e-01+0.j  5.27157390e-01+0.j
 -1.19593089e-17+0.j  1.45042269e-17-0.j]
True
True
