Skip to content

Commit

Permalink
Merge pull request #2150 from Ericgig/misc.progressbar
Browse files Browse the repository at this point in the history
Add a test for progress bar
  • Loading branch information
Ericgig committed Apr 12, 2023
2 parents df08d96 + 041441c commit eb06fc4
Show file tree
Hide file tree
Showing 11 changed files with 125 additions and 87 deletions.
1 change: 1 addition & 0 deletions doc/changes/2150.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add a test for progress_bar
11 changes: 5 additions & 6 deletions qutip/ipynbtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,16 +272,15 @@ def parallel_map(task, values, task_args=None, task_kwargs=None,
view.wait(ar_list)
else:
if progress_bar is True:
progress_bar = HTMLProgressBar()

n = len(ar_list)
progress_bar.start(n)
progress_bar = HTMLProgressBar(len(ar_list))
prev_finished = 0
while True:
n_finished = sum([ar.progress for ar in ar_list])
progress_bar.update(n_finished)
for _ in range(prev_finished, n_finished):
progress_bar.update()
prev_finished = n_finished

if view.wait(ar_list, timeout=0.5):
progress_bar.update(n)
break
progress_bar.finished()

Expand Down
14 changes: 6 additions & 8 deletions qutip/solve/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,13 @@ def serial_map(task, values, task_args=tuple(), task_kwargs={}, **kwargs):
try:
progress_bar = kwargs['progress_bar']
if progress_bar is True:
progress_bar = TextProgressBar()
progress_bar = TextProgressBar(len(values))
except:
progress_bar = BaseProgressBar()
progress_bar = BaseProgressBar(len(values))

progress_bar.start(len(values))
results = []
for n, value in enumerate(values):
progress_bar.update(n)
progress_bar.update()
result = task(value, *task_args, **task_kwargs)
results.append(result)
progress_bar.finished()
Expand Down Expand Up @@ -199,16 +198,15 @@ def parallel_map(task, values, task_args=tuple(), task_kwargs={}, **kwargs):
try:
progress_bar = kwargs['progress_bar']
if progress_bar is True:
progress_bar = TextProgressBar()
progress_bar = TextProgressBar(len(values))
except:
progress_bar = BaseProgressBar()
progress_bar = BaseProgressBar(len(values))

progress_bar.start(len(values))
nfinished = [0]

def _update_progress_bar(x):
nfinished[0] += 1
progress_bar.update(nfinished[0])
progress_bar.update()

try:
pool = Pool(processes=kw['num_cpus'])
Expand Down
9 changes: 3 additions & 6 deletions qutip/solve/pdpsolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,6 @@ def __init__(self, H=None, state0=None, times=None, c_ops=[], sc_ops=[],
if options is None:
options = SolverOptions()

if progress_bar is None:
progress_bar = TextProgressBar()

self.H = H
self.d1 = d1
self.d2 = d2
Expand Down Expand Up @@ -200,6 +197,9 @@ def __init__(self, H=None, state0=None, times=None, c_ops=[], sc_ops=[],
else:
self.map_func = serial_map

if progress_bar is None:
self.progress_bar = TextProgressBar(self.ntraj)

self.map_kwargs = map_kwargs if map_kwargs is not None else {}

# Does any operator depend on time?
Expand Down Expand Up @@ -352,7 +352,6 @@ def _ssepdpsolve_generic(sso, options, progress_bar):
for c in sso.c_ops:
Heff += -0.5j * c.dag() * c

progress_bar.start(sso.ntraj)
for n in range(sso.ntraj):
progress_bar.update(n)
psi_t = _data.dense.fast_from_numpy(sso.state0.full().ravel())
Expand Down Expand Up @@ -483,8 +482,6 @@ def _smepdpsolve_generic(sso, options, progress_bar):
# needs to be modified for TD systems
L = liouvillian(sso.H, sso.c_ops)

progress_bar.start(sso.ntraj)

for n in range(sso.ntraj):
progress_bar.update(n)
rho_t = _data.dense.fast_from_numpy(sso.rho0.full())
Expand Down
5 changes: 2 additions & 3 deletions qutip/solve/stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,9 +300,6 @@ def __init__(self, me, H=None, c_ops=[], sc_ops=[], state0=None,
if options is None:
options = SolverOptions()

if progress_bar is None:
progress_bar = TextProgressBar()

# System
# Cast to QobjEvo so the code has only one version for both the
# constant and time-dependent case.
Expand Down Expand Up @@ -420,6 +417,8 @@ def __init__(self, me, H=None, c_ops=[], sc_ops=[], state0=None,
self.noise_type = 0

# Map
if progress_bar is None:
progress_bar = TextProgressBar(self.ntraj)
self.progress_bar = progress_bar
if self.ntraj > 1 and map_func:
self.map_func = map_func
Expand Down
7 changes: 4 additions & 3 deletions qutip/solver/correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .heom.bofin_solvers import HEOMSolver

from .steadystate import steadystate
from ..ui.progressbar import progess_bars
from ..ui.progressbar import progress_bars

# -----------------------------------------------------------------------------
# PUBLIC API
Expand Down Expand Up @@ -495,8 +495,9 @@ def _correlation_3op_dm(solver, state0, tlist, taulist, A, B, C):
solver.options["normalize_output"] = False
solver.options["progress_bar"] = False

progress_bar = progess_bars[old_opt['progress_bar']]()
progress_bar.start(len(taulist) + 1, **old_opt['progress_kwargs'])
progress_bar = progress_bars[old_opt['progress_bar']](
len(taulist) + 1, **old_opt['progress_kwargs']
)
rho_t = solver.run(state0, tlist).states
corr_mat = np.zeros([np.size(tlist), np.size(taulist)], dtype=complex)
progress_bar.update()
Expand Down
7 changes: 4 additions & 3 deletions qutip/solver/floquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .integrator import Integrator
from .result import Result
from time import time
from ..ui.progressbar import progess_bars
from ..ui.progressbar import progress_bars


class FloquetBasis:
Expand Down Expand Up @@ -918,8 +918,9 @@ def run(self, state0, tlist, *, floquet=False, args=None, e_ops=None):
results.add(tlist[0], self._restore_state(_data0, copy=False))
stats["preparation time"] += time() - _time_start

progress_bar = progess_bars[self.options["progress_bar"]]()
progress_bar.start(len(tlist) - 1, **self.options["progress_kwargs"])
progress_bar = progress_bars[self.options["progress_bar"]](
len(tlist) - 1, **self.options["progress_kwargs"]
)
for t, state in self._integrator.run(tlist):
progress_bar.update()
results.add(t, self._restore_state(state, copy=False))
Expand Down
19 changes: 11 additions & 8 deletions qutip/solver/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import time
import threading
import concurrent.futures
from qutip.ui.progressbar import progess_bars
from qutip.ui.progressbar import progress_bars
from qutip.settings import available_cpu_count

if sys.platform == 'darwin':
Expand Down Expand Up @@ -94,8 +94,9 @@ def serial_map(task, values, task_args=None, task_kwargs=None,
if task_kwargs is None:
task_kwargs = {}
map_kw = _read_map_kw(map_kw)
progress_bar = progess_bars[progress_bar]()
progress_bar.start(len(values), **progress_bar_kwargs)
progress_bar = progress_bars[progress_bar](
len(values), **progress_bar_kwargs
)
end_time = map_kw['timeout'] + time.time()
results = None
if reduce_func is None:
Expand All @@ -104,7 +105,7 @@ def serial_map(task, values, task_args=None, task_kwargs=None,
for n, value in enumerate(values):
if time.time() > end_time:
break
progress_bar.update(n)
progress_bar.update()
try:
result = task(value, *task_args, **task_kwargs)
except Exception as err:
Expand Down Expand Up @@ -177,8 +178,9 @@ def parallel_map(task, values, task_args=None, task_kwargs=None,
end_time = map_kw['timeout'] + time.time()
job_time = map_kw['job_timeout']

progress_bar = progess_bars[progress_bar]()
progress_bar.start(len(values), **progress_bar_kwargs)
progress_bar = progress_bars[progress_bar](
len(values), **progress_bar_kwargs
)

errors = {}
if reduce_func is not None:
Expand Down Expand Up @@ -311,8 +313,9 @@ def loky_pmap(task, values, task_args=None, task_kwargs=None,
os.environ['QUTIP_IN_PARALLEL'] = 'TRUE'
from loky import get_reusable_executor, TimeoutError

progress_bar = progess_bars[progress_bar]()
progress_bar.start(len(values), **progress_bar_kwargs)
progress_bar = progress_bars[progress_bar](
len(values), **progress_bar_kwargs
)

executor = get_reusable_executor(max_workers=map_kw['num_cpus'])
end_time = map_kw['timeout'] + time.time()
Expand Down
7 changes: 4 additions & 3 deletions qutip/solver/solver_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ..core import stack_columns, unstack_columns
from .result import Result
from .integrator import Integrator
from ..ui.progressbar import progess_bars
from ..ui.progressbar import progress_bars
from time import time


Expand Down Expand Up @@ -150,8 +150,9 @@ def run(self, state0, tlist, *, args=None, e_ops=None):
results.add(tlist[0], self._restore_state(_data0, copy=False))
stats['preparation time'] += time() - _time_start

progress_bar = progess_bars[self.options['progress_bar']]()
progress_bar.start(len(tlist)-1, **self.options['progress_kwargs'])
progress_bar = progress_bars[self.options['progress_bar']](
len(tlist)-1, **self.options['progress_kwargs']
)
for t, state in self._integrator.run(tlist):
progress_bar.update()
results.add(t, self._restore_state(state, copy=False))
Expand Down
44 changes: 44 additions & 0 deletions qutip/tests/test_progressbar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from qutip.ui.progressbar import progress_bars
import pytest
import time


bars = ["base", "text", "Enhanced"]

try:
import tqdm
bars.append("tqdm")
except ImportError:
bars.append(
pytest.param("tqdm", marks=pytest.mark.skip("module not installed"))
)

try:
import IPython
bars.append("html")
except ImportError:
bars.append(
pytest.param("html", marks=pytest.mark.skip("module not installed"))
)


@pytest.mark.parametrize("pbar", bars)
def test_progressbar(pbar):
N = 5
bar = progress_bars[pbar](N)
assert bar.total_time() < 0
for _ in range(N):
time.sleep(0.25)
bar.update()
bar.finished()
assert bar.total_time() > 0


@pytest.mark.parametrize("pbar", bars[1:])
def test_progressbar_has_print(pbar, capsys):
N = 2
bar = progress_bars[pbar](N)
bar.update()
bar.finished()
out, err = capsys.readouterr()
assert out + err != ""

0 comments on commit eb06fc4

Please sign in to comment.