In [1]:
import torch
from emu_base.math.double_krylov import double_krylov, double_krylov_2
from emu_sv.hamiltonian import RydbergHamiltonian
import sys
sys.path.append('./test/utils_testing/')
from utils_dense_hamiltonians import dense_rydberg_hamiltonian

In [2]:
torch.manual_seed(1337)

dtype = torch.complex128
dtype_params = torch.float64
tolerance = 1e-5
N = 6
dt = 0.5

omegas = torch.randn(N, dtype=dtype_params)
deltas = torch.randn(N, dtype=dtype_params)
phis = torch.randn(N, dtype=dtype_params)
interactions = torch.zeros(N, N, dtype=dtype_params)
for i in range(N):
    for j in range(i + 1, N):
        interactions[i, j] = 1 / abs(j - i)

state = torch.randn(2**N, dtype=dtype)
state = state / state.norm()

grad = torch.randn(2**N, dtype=dtype)

ham = RydbergHamiltonian(
    omegas=omegas,
    deltas=deltas,
    phis=phis,
    interaction_matrix=interactions,
    device=state.device,
)

op = lambda x: -1j * dt * (ham * x)

In [3]:
iteration_count = min(2**N, 80)


lanczos_vectors_even, odd_block, eT = double_krylov(
    op, grad, state, iteration_count, tolerance
)

even_block = torch.stack(lanczos_vectors_even)
Hess_L = eT[1 : 2 * odd_block.shape[0] : 2, : 2 * even_block.shape[0] : 2]
# L = V_odd @ Hess_L @ V_even*
print("odd:", odd_block.shape)
print("eT:", eT.shape, "Hess_L", Hess_L.shape)
print("even:", even_block.shape)

L = odd_block.mT @ Hess_L @ even_block.conj()

print(L)

odd: torch.Size([65, 64])
eT: torch.Size([132, 132]) Hess_L torch.Size([65, 11])
even: torch.Size([11, 64])
tensor([[ 0.0351+0.0376j, -0.0534-0.0205j,  0.1043-0.0616j,  ...,
          0.0273-0.0592j, -0.0576-0.0014j, -0.0324+0.0662j],
        [ 0.0753-0.0537j, -0.0491+0.0815j, -0.0661-0.1832j,  ...,
         -0.0797-0.0724j, -0.0263+0.0942j,  0.0807+0.0821j],
        [ 0.0276-0.0354j, -0.0012+0.0345j, -0.0520-0.0322j,  ...,
         -0.0161-0.0134j, -0.0006+0.0281j,  0.0076+0.0087j],
        ...,
        [-0.0445+0.0193j,  0.0417-0.0262j, -0.0233+0.0966j,  ...,
          0.0208+0.0755j,  0.0422-0.0463j, -0.0070-0.0907j],
        [-0.0041+0.0435j,  0.0007-0.0531j,  0.0778+0.0846j,  ...,
          0.0926+0.0129j, -0.0261-0.0651j, -0.1097-0.0358j],
        [-0.0055+0.0342j, -0.0051-0.0238j,  0.0385+0.0158j,  ...,
          0.0187+0.0220j,  0.0001-0.0277j, -0.0031-0.0277j]],
       dtype=torch.complex128)


In [4]:
Hsv = dense_rydberg_hamiltonian(omegas, deltas, phis, interactions)
E = state.unsqueeze(-1) @ grad.conj().unsqueeze(0)
big_mat = torch.block_diag(-1j * dt * Hsv, -1j * dt * Hsv)
sizeH = Hsv.shape[0]
big_mat[:sizeH, sizeH:] = E
big_exp = torch.linalg.matrix_exp(big_mat)
expected_L = big_exp[:sizeH, sizeH:]

print(expected_L)

tensor([[ 0.0351+0.0376j, -0.0534-0.0205j,  0.1043-0.0616j,  ...,
          0.0273-0.0592j, -0.0576-0.0014j, -0.0324+0.0662j],
        [ 0.0753-0.0537j, -0.0491+0.0815j, -0.0661-0.1832j,  ...,
         -0.0797-0.0724j, -0.0263+0.0942j,  0.0807+0.0821j],
        [ 0.0276-0.0354j, -0.0012+0.0345j, -0.0520-0.0322j,  ...,
         -0.0161-0.0134j, -0.0006+0.0281j,  0.0076+0.0087j],
        ...,
        [-0.0445+0.0193j,  0.0417-0.0262j, -0.0233+0.0966j,  ...,
          0.0208+0.0755j,  0.0422-0.0463j, -0.0070-0.0907j],
        [-0.0041+0.0435j,  0.0007-0.0531j,  0.0778+0.0846j,  ...,
          0.0926+0.0129j, -0.0261-0.0651j, -0.1097-0.0358j],
        [-0.0055+0.0342j, -0.0051-0.0238j,  0.0385+0.0158j,  ...,
          0.0187+0.0220j,  0.0001-0.0277j, -0.0031-0.0277j]],
       dtype=torch.complex128)


In [5]:
assert torch.allclose(L, expected_L, atol=tolerance)
print(torch.dist(L,expected_L))

tensor(6.2744e-07, dtype=torch.float64)


In [6]:
op = lambda x: -1j * dt * (ham * x)
lanczos_vectors_even, lanczos_vectors_odd, eT = double_krylov_2(
    op, grad, state, tolerance
)

even_block = torch.stack(lanczos_vectors_even)
odd_block = torch.stack(lanczos_vectors_odd)

Hess_L = eT[1 : 2 * odd_block.shape[0] : 2, : 2 * even_block.shape[0] : 2]
# L = V_odd @ Hess_L @ V_even*
print("odd:", odd_block.shape)
print("eT:", eT.shape, "Hess_L", Hess_L.shape)
print("even:", even_block.shape)

L = odd_block.mT @ Hess_L @ even_block.conj()

assert torch.allclose(L, expected_L, atol=tolerance)
print(torch.dist(L,expected_L))
print(L)

odd: torch.Size([101, 64])
eT: torch.Size([204, 204]) Hess_L torch.Size([101, 11])
even: torch.Size([11, 64])
tensor(6.2744e-07, dtype=torch.float64)
tensor([[ 0.0351+0.0376j, -0.0534-0.0205j,  0.1043-0.0616j,  ...,
          0.0273-0.0592j, -0.0576-0.0014j, -0.0324+0.0662j],
        [ 0.0753-0.0537j, -0.0491+0.0815j, -0.0661-0.1832j,  ...,
         -0.0797-0.0724j, -0.0263+0.0942j,  0.0807+0.0821j],
        [ 0.0276-0.0354j, -0.0012+0.0345j, -0.0520-0.0322j,  ...,
         -0.0161-0.0134j, -0.0006+0.0281j,  0.0076+0.0087j],
        ...,
        [-0.0445+0.0193j,  0.0417-0.0262j, -0.0233+0.0966j,  ...,
          0.0208+0.0755j,  0.0422-0.0463j, -0.0070-0.0907j],
        [-0.0041+0.0435j,  0.0007-0.0531j,  0.0778+0.0846j,  ...,
          0.0926+0.0129j, -0.0261-0.0651j, -0.1097-0.0358j],
        [-0.0055+0.0342j, -0.0051-0.0238j,  0.0385+0.0158j,  ...,
          0.0187+0.0220j,  0.0001-0.0277j, -0.0031-0.0277j]],
       dtype=torch.complex128)


In [7]:
print(Hess_L)

tensor([[  3.6892e+00-4.2706e+00j,  -1.5690e+00+2.6163e+00j,
           8.5068e-01-1.4703e+00j,  ...,
           1.4332e-04-4.5069e-04j,  -1.8892e-05+5.2614e-05j,
           2.0398e-06-5.4976e-06j],
        [  1.9872e+00-2.3117e+00j,  -6.2293e-01+9.7552e-01j,
           2.5071e-01-4.1571e-01j,  ...,
           1.7000e-05-4.9196e-05j,  -1.9861e-06+5.2104e-06j,
           1.9465e-07-4.9840e-07j],
        [  5.9526e-01-1.4460e+00j,  -1.5248e-01+4.3774e-01j,
           5.3928e-02-1.4786e-01j,  ...,
           1.9184e-06-7.4003e-06j,  -2.1348e-07+7.1844e-07j,
           1.9716e-08-6.3329e-08j],
        ...,
        [ 1.9501e-169-3.8074e-169j, -2.1709e-171+4.2666e-171j,
          3.5050e-173-6.8942e-173j,  ...,
          3.5576e-184-7.2919e-184j, -3.8228e-186+7.7887e-186j,
          3.9794e-188-8.1038e-188j],
        [ 3.2585e-171-6.2263e-171j, -3.5913e-173+6.9085e-173j,
          5.7407e-175-1.1054e-174j,  ...,
          5.5032e-186-1.1045e-185j, -5.8575e-188+1.1690e-187j,
          6.0409e