Skip to content

Commit

Permalink
Merge pull request #2382 from magzpavz/master
Browse files Browse the repository at this point in the history
fix end condition on mcsolve when using target tolerance
  • Loading branch information
Ericgig committed Apr 8, 2024
2 parents a9bbac4 + 8eb3c98 commit 8035590
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 1 deletion.
1 change: 1 addition & 0 deletions doc/changes/2382.bugfix
@@ -0,0 +1 @@
Ensure that end_condition of mcsolve result doesn't say target tolerance reached when it hasn't
5 changes: 4 additions & 1 deletion qutip/solver/result.py
Expand Up @@ -657,7 +657,10 @@ def _target_tolerance_end(self):

self._estimated_ntraj = min(target_ntraj, self._target_ntraj)
if (self._estimated_ntraj - self.num_trajectories) <= 0:
self.stats["end_condition"] = "target tolerance reached"
if (self._estimated_ntraj - self._target_ntraj) < 0:
self.stats["end_condition"] = "target tolerance reached"
else:
self.stats["end_condition"] = "ntraj reached"
return self._estimated_ntraj - self.num_trajectories

def _post_init(self):
Expand Down
22 changes: 22 additions & 0 deletions qutip/tests/solver/test_mcsolve.py
Expand Up @@ -391,6 +391,28 @@ def test_timeout(improved_sampling):
timeout=1e-6)
assert res.stats['end_condition'] == 'timeout'

@pytest.mark.parametrize("improved_sampling", [True, False])
def test_target_tol(improved_sampling):
size = 10
ntraj = 100
a = qutip.destroy(size)
H = qutip.num(size)
state = qutip.basis(size, size-1)
times = np.linspace(0, 1.0, 100)
coupling = 0.5
n_th = 0.05
c_ops = np.sqrt(coupling * (n_th + 1)) * a
e_ops = [qutip.num(size)]

options = {'map': 'serial', "improved_sampling": improved_sampling}

res = mcsolve(H, state, times, c_ops, e_ops, ntraj=ntraj, options=options,
target_tol = 0.5)
assert res.stats['end_condition'] == 'target tolerance reached'

res = mcsolve(H, state, times, c_ops, e_ops, ntraj=ntraj, options=options,
target_tol = 1e-6)
assert res.stats['end_condition'] == 'ntraj reached'

@pytest.mark.parametrize("improved_sampling", [True, False])
def test_super_H(improved_sampling):
Expand Down

0 comments on commit 8035590

Please sign in to comment.