Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create tests for functions in qutip #64

Merged
merged 11 commits into from
Aug 26, 2024
Merged

Conversation

rochisha0
Copy link
Contributor

@rochisha0 rochisha0 commented Aug 7, 2024

In this task, we aim to create comprehensive tests for the functions within the QuTiP library to check their compatibility with jax.grad and jax.jit. These tests will ensure the correctness and robustness of the implemented functions and will cover a wide range of scenarios.

@rochisha0 rochisha0 marked this pull request as draft August 7, 2024 15:56
Comment on lines 19 to 23

with qutip.CoreOptions(default_dtype="jax"):
X = qutip.sigmax()
I = qutip.qeye(2)
CNOT = qutip.tensor(qutip.basis(2, 0) * qutip.basis(2, 0).dag(), I) + qutip.tensor(qutip.basis(2, 1) * qutip.basis(2, 1).dag(), X)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CNOT is not used, let's remove it.

Comment on lines 35 to 36
print(f"{name} (original):", result)
print(f"{name} (JIT):", result_jit)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No print in tests.

Suggested change
print(f"{name} (original):", result)
print(f"{name} (JIT):", result_jit)

with qutip.CoreOptions(default_dtype="jax"):
basis_0 = qutip.basis(2, 0)
basis_1 = qutip.basis(2, 1)
bell_state = (qutip.tensor(basis_0, basis_1) + qutip.tensor(basis_1, basis_0)).unit()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
bell_state = (qutip.tensor(basis_0, basis_1) + qutip.tensor(basis_1, basis_0)).unit()
bell_state = qutip.bell_state("10")

Comment on lines 16 to 17
density_matrix = bell_state * bell_state.dag()
dm = qutip.rand_dm([5, 5], distribution="pure")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bell_dm and rand_dm?

Comment on lines 19 to 22
with qutip.CoreOptions(default_dtype="jax"):
X = qutip.sigmax()
I = qutip.qeye(2)
CNOT = qutip.tensor(qutip.basis(2, 0) * qutip.basis(2, 0).dag(), I) + qutip.tensor(qutip.basis(2, 1) * qutip.basis(2, 1).dag(), X)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not used. We have it as qutip.gates.cnot instead of building it yourself.

@rochisha0 rochisha0 changed the title Create tests for metrics and entropy Create tests for functions in qutip Aug 14, 2024
# Pytest test case for gradient computation
@pytest.mark.parametrize("omega_val", [1.0, 2.0, 3.0])
def test_gradient_mcsolve(omega_val):
H, state, tlist, c_ops, e_ops = setup_system(size=2)
Copy link
Contributor

@Sampreet Sampreet Aug 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the system has a continuous variable harmonic oscillator model (defined by the ladder operator a inside setup_system), would it be better to use a Hilbert space size (defined by the parameter size here) greater than or equal to 10?


# Test setup for gradient calculation
def setup_system(size=2):
a = qt.destroy(size).to("jax")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the Hamiltonian contains two sub-systems, I think the operators would be tensor products with corresponding identities for the other operators:

a = qt.tensor(qt.destroy(size), qt.qeye(2)).to('jaxdia')

Or alternatively,

a = qt.destroy(size).to('jaxdia') & qt.qeye(2).to('jaxdia')

Same goes for sm:

sm = qt.qeye(size).to('jaxdia') & qt.sigmax().to('jaxdia')

Accordingly, the initial state would be:

state = qt.basis(size, size - 1).to('jax') & qutip.basis(2, 1).to('jax')

Or alternatively,

state = qt.basis([size, 2], [size - 1, 1]).to('jax')

@rochisha0 rochisha0 marked this pull request as ready for review August 22, 2024 19:53
@coveralls
Copy link

coveralls commented Aug 23, 2024

Pull Request Test Coverage Report for Build 10554916907

Warning: This coverage report may be inaccurate.

This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.

Details

  • 0 of 0 changed or added relevant lines in 0 files are covered.
  • 6 unchanged lines in 2 files lost coverage.
  • Overall coverage decreased (-0.2%) to 90.303%

Files with Coverage Reduction New Missed Lines %
src/qutip_jax/binops.py 1 97.0%
src/qutip_jax/settings.py 5 70.59%
Totals Coverage Status
Change from base Build 9842191187: -0.2%
Covered Lines: 1220
Relevant Lines: 1351

💛 - Coveralls

Copy link
Member

@Ericgig Ericgig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall look good.
Let's just make the mcsolve test faster.

return result.expect[0][-1].real

# Pytest test case for gradient computation
@pytest.mark.parametrize("omega_val", [1.0, 2.0, 3.0])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is the slowest and grad functionality is not affected by the value of omega, so let's run it only once.
We could loop over options (improved_sampling, store_states, keep_runs_results) or mixed state input later.

Suggested change
@pytest.mark.parametrize("omega_val", [1.0, 2.0, 3.0])
@pytest.mark.parametrize("omega_val", [2.0])

e_ops = [a.dag() * a, sm.dag() * sm]

# Time list
tlist = jnp.linspace(0.0, 10.0, 101)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's speed up the tests, the actual range does not affect the validity.

Suggested change
tlist = jnp.linspace(0.0, 10.0, 101)
tlist = jnp.linspace(0.0, 1.0, 101)

Comment on lines 41 to 43
H[1][1] = qt.coefficient(H_1_coeff, args={"omega": omega})

result = mcsolve(H, state, tlist, c_ops, e_ops, ntraj=10, options={"method": "diffrax"})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing args to a solver should overwrite the existing values.

Suggested change
H[1][1] = qt.coefficient(H_1_coeff, args={"omega": omega})
result = mcsolve(H, state, tlist, c_ops, e_ops, ntraj=10, options={"method": "diffrax"})
result = mcsolve(
H, state, tlist, c_ops, e_ops, ntraj=10,
args={"omega": omega},
options={"method": "diffrax"}
)

@Ericgig Ericgig merged commit 0cc6256 into qutip:master Aug 26, 2024
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants