diff --git a/tests/test_timeevol.py b/tests/test_timeevol.py index 676c3893..5dc3a227 100644 --- a/tests/test_timeevol.py +++ b/tests/test_timeevol.py @@ -203,6 +203,146 @@ def objective_function(params): print(objective_function(tc.backend.ones(4))) +def test_ode_evol_jit_grad(highp, jaxb): + try: + import diffrax # pylint: disable=unused-import + except ImportError: + pytest.skip("diffrax not installed, skipping test") + + zz_ham = tc.quantum.PauliStringSum2COO([[3, 3, 0, 0], [0, 3, 3, 0]], [1, 1]) + x_ham = tc.quantum.PauliStringSum2COO([[1, 0, 0, 0], [0, 1, 0, 0]], [1, 1]) + + c = tc.Circuit(4) + c.x([1, 3]) + psi0 = c.state() + + # Example with parameterized Hamiltonian and optimization + def parametrized_hamiltonian(t, *params): + # params = [J0, J1, h0, h1] - parameters to optimize + J_t = params[0] + params[1] * tc.backend.sin(2.0 * t) + h_t = params[2] + params[3] * tc.backend.cos(1.5 * t) + + return J_t * zz_ham + h_t * x_ham + + def zz_correlation(state): + n = int(np.log2(state.shape[0])) + circuit = tc.Circuit(n, inputs=state) + return circuit.expectation_ps(z=[0, 1]) + + @tc.backend.jit + @tc.backend.value_and_grad + def kv_ode_solver_(params): + states = tc.timeevol.ode_evol_global( + parametrized_hamiltonian, + psi0, + tc.backend.convert_to_tensor([0, 10.0]), + None, + *params, + atol=1.0e-15, + rtol=1.0e-15, + solver="Kvaerno5", + ode_backend="diffrax", + ) + return tc.backend.real(zz_correlation(states[-1])) + + @tc.backend.jit + @tc.backend.value_and_grad + def ts_ode_solver_(params): + states = tc.timeevol.ode_evol_global( + parametrized_hamiltonian, + psi0, + tc.backend.convert_to_tensor([0, 10.0]), + None, + *params, + ode_backend="diffrax", + atol=1.0e-13, + rtol=1.0e-13, + dt0=0.005, + ) + return tc.backend.real(zz_correlation(states[-1])) + + @tc.backend.jit + @tc.backend.value_and_grad + def do5_ode_solver_(params): + states = tc.timeevol.ode_evol_global( + parametrized_hamiltonian, + psi0, + tc.backend.convert_to_tensor([0, 10.0]), + None, + *params, + ) + return tc.backend.real(zz_correlation(states[-1])) + + paras = np.random.rand(4) + s1 = kv_ode_solver_(paras) + s2 = ts_ode_solver_(paras) + s3 = do5_ode_solver_(paras) + + v1, g1 = s1 + v2, g2 = s2 + v3, g3 = s3 + + np.testing.assert_allclose(g1, g3, atol=1e-8, rtol=0) + np.testing.assert_allclose(g1, g2, atol=1e-8, rtol=0) + np.testing.assert_allclose(v1, v3, atol=1e-8, rtol=0) + np.testing.assert_allclose(v1, v2, atol=1e-8, rtol=0) + + ###################################################################### + + def local_hamiltonian(t, Omega, phi): + angle = phi * t + coeff = Omega * tc.backend.cos(2.0 * t) # Amplitude modulation + + # Single-qubit Rabi Hamiltonian (2x2 matrix) + hx = coeff * tc.backend.cos(angle) * tc.gates.x().tensor + hy = coeff * tc.backend.sin(angle) * tc.gates.y().tensor + return hx + hy + + # Initial state: GHZ state |0000⟩ + |1111⟩ + c = tc.Circuit(4) + c.h(0) + for i in range(3): + c.cnot(i, i + 1) + psi0 = c.state() + + # Evolve with local Hamiltonian acting on qubit 1 + @tc.backend.jit + @tc.backend.value_and_grad + def ts_ode_solver_local(paras): + states = tc.timeevol.ode_evol_local( + local_hamiltonian, + psi0, + tc.backend.convert_to_tensor([0, 10.0]), + [2], # Apply to qubit 1 + None, + *paras, # Omega=1.0, phi=2.0 + ode_backend="diffrax", + ) + return tc.backend.real(zz_correlation(states[-1])) + + @tc.backend.jit + @tc.backend.value_and_grad + def do5_ode_solver_local(paras): + states = tc.timeevol.ode_evol_local( + local_hamiltonian, + psi0, + tc.backend.convert_to_tensor([0, 10.0]), + [2], # Apply to qubit 1 + None, + *paras, # Omega=1.0, phi=2.0 + ) + return tc.backend.real(zz_correlation(states[-1])) + + paras = np.random.rand(2) + s1 = ts_ode_solver_local(paras) + s2 = do5_ode_solver_local(paras) + v1, g1 = s1 + v2, g2 = s2 + + np.testing.assert_allclose(g1, g2, atol=1e-8, rtol=0) + np.testing.assert_allclose(v1, v2, atol=1e-8, rtol=0) + + @pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")]) def test_ed_evol(backend): n = 4