Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 125 additions & 108 deletions qiskit_experiments/database_service/db_experiment_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,43 +172,21 @@ def _set_service_from_backend(self, backend: Union[Backend, BaseBackend]) -> Non
def add_data(
self,
data: Union[Result, List[Result], Job, List[Job], Dict, List[Dict]],
post_processing_callback: Optional[Callable] = None,
timeout: Optional[float] = None,
**kwargs: Any,
) -> None:
"""Add experiment data.

Note:
This method is not thread safe and should not be called by the
`post_processing_callback` function.

Note:
If `data` is a ``Job``, this method waits for the job to finish
and calls the `post_processing_callback` function asynchronously.

Args:
data: Experiment data to add.
Several types are accepted for convenience:

* Result: Add data from this ``Result`` object.
* List[Result]: Add data from the ``Result`` objects.
* Job: Add data from the job result.
* List[Job]: Add data from the job results.
* Dict: Add this data.
* List[Dict]: Add this list of data.

post_processing_callback: Callback function invoked when data is
added. If `data` is a ``Job``, the callback is only invoked when
the job finishes successfully.
The following positional arguments are provided to the callback function:

* This ``DbExperimentData`` object.
* Additional keyword arguments passed to this method.

timeout: Timeout waiting for job to finish, if `data` is a ``Job``.

**kwargs: Keyword arguments to be passed to the callback function.

Raises:
TypeError: If the input data type is invalid.
"""
Expand All @@ -218,88 +196,115 @@ def add_data(
"Not all post-processing has finished. Adding new data "
"may create unexpected analysis results."
)
if not isinstance(data, list):
data = [data]

# Extract job data and directly add non-job data
job_data = []
for datum in data:
if isinstance(datum, (Job, BaseJob)):
job_data.append(datum)
elif isinstance(datum, dict):
self._add_single_data(datum)
elif isinstance(datum, Result):
self._add_result_data(datum)
else:
raise TypeError(f"Invalid data type {type(datum)}.")

if isinstance(data, (Job, BaseJob)):
if self.backend and self.backend.name() != data.backend().name():
# Add futures for job data
for job in job_data:
if self.backend and self.backend.name() != job.backend().name():
LOG.warning(
"Adding a job from a backend (%s) that is different "
"than the current backend (%s). "
"The new backend will be used, but "
"service is not changed if one already exists.",
data.backend(),
job.backend(),
self.backend,
)
self._backend = data.backend()
self._backend = job.backend()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe

Suggested change
self._backend = job.backend()
if not self._backend:
self._backend = job.backend()

?

if not self._service:
self._set_service_from_backend(self._backend)

self._jobs[data.job_id()] = data
job_kwargs = {
"job": data,
"job_done_callback": post_processing_callback,
"timeout": timeout,
**kwargs,
}
self._job_futures.append(
(
job_kwargs,
self._executor.submit(self._wait_for_job, **job_kwargs),
)
self._jobs[job.job_id()] = job

job_kwargs = {
"jobs": job_data,
"timeout": timeout,
}
self._job_futures.append(
(
job_kwargs,
self._executor.submit(self._add_jobs_result, **job_kwargs),
)
if self.auto_save:
self.save_metadata()
return
)
if self.auto_save:
self.save_metadata()

if isinstance(data, dict):
self._add_single_data(data)
elif isinstance(data, Result):
self._add_result_data(data)
elif isinstance(data, list):
for dat in data:
self.add_data(dat)
def add_processing_callback(
self,
callback: Callable,
**kwargs: Any,
):
"""Add processing callback for after experiment job has run.

This method waits for the last set of jobs to finish and calls
the `callback` function asynchronously.

Note:
This method is not thread safe and should not be called by the
` callback` function.

Args:
callback: Callback function invoked when job finishes successfully.
The following positional arguments are provided to the callback function:
* This ``DbExperimentData`` object.
* Additional keyword arguments passed to this method.
**kwargs: Keyword arguments to be passed to the callback function.
"""
# Check if there are no futures to wait on
if not self._job_futures:
callback(self, **kwargs)
else:
raise TypeError(f"Invalid data type {type(data)}.")
# Get the last added future and add a done callback
_, future = self._job_futures[-1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens in self._job_futures[:-1]? Submission of jobs and their completion are always in the same order? I guess this is not guaranteed because job is scheduled according to some formula and it prioritizes jobs with cost, i.e. number of circuits, shots, etc...


def future_callback(fut):
if not fut.cancelled():
callback(self, **kwargs)

if post_processing_callback is not None:
post_processing_callback(self, **kwargs)
future.add_done_callback(future_callback)

def _wait_for_job(
if self.auto_save:
self.save_metadata()

def _add_jobs_result(
self,
job: Union[Job, BaseJob],
job_done_callback: Optional[Callable] = None,
jobs: List[Union[Job, BaseJob]],
timeout: Optional[float] = None,
**kwargs: Any,
) -> None:
"""Wait for a job to finish.
"""Wait for a job to finish and add job result data.

Args:
job: Job to wait for.
job_done_callback: Callback function to invoke when job finishes.
jobs: Jobs to wait for.
timeout: Timeout waiting for job to finish.
**kwargs: Keyword arguments to be passed to the callback function.

Raises:
Exception: If post processing failed.
Exception: If any of the jobs failed.
"""
LOG.debug("Waiting for job %s to finish.", job.job_id())
try:
for job in jobs:
LOG.debug("Waiting for job %s to finish.", job.job_id())
try:
job_result = job.result(timeout=timeout)
except TypeError: # Not all jobs take timeout.
job_result = job.result()
with self._data.lock:
# Hold the lock so we add the block of results together.
self._add_result_data(job_result)
except Exception: # pylint: disable=broad-except
LOG.warning("Job %s failed:\n%s", job.job_id(), traceback.format_exc())
raise

try:
if job_done_callback:
job_done_callback(self, **kwargs)
except Exception: # pylint: disable=broad-except
LOG.warning("Post processing function failed:\n%s", traceback.format_exc())
raise
try:
job_result = job.result(timeout=timeout)
except TypeError: # Not all jobs take timeout.
job_result = job.result()
with self._data.lock:
# Hold the lock so we add the block of results together.
self._add_result_data(job_result)
except Exception: # pylint: disable=broad-except
LOG.warning("Job %s failed:\n%s", job.job_id(), traceback.format_exc())
raise

def _add_result_data(self, result: Result) -> None:
"""Add data from a Result object
Expand Down Expand Up @@ -811,12 +816,13 @@ def load(cls, experiment_id: str, service: DatabaseServiceV1) -> "DbExperimentDa
def cancel_jobs(self) -> None:
"""Cancel any running jobs."""
for kwargs, fut in self._job_futures.copy():
job = kwargs["job"]
if not fut.done() and job.status() not in JOB_FINAL_STATES:
try:
job.cancel()
except Exception as err: # pylint: disable=broad-except
LOG.info("Unable to cancel job %s: %s", job.job_id(), err)
if not fut.done():
for job in kwargs["jobs"]:
if job.status() not in JOB_FINAL_STATES:
try:
job.cancel()
except Exception as err: # pylint: disable=broad-except
LOG.info("Unable to cancel job %s: %s", job.job_id(), err)

def block_for_results(self, timeout: Optional[float] = None) -> "DbExperimentDataV1":
"""Block until all pending jobs and their post processing finish.
Expand All @@ -827,11 +833,21 @@ def block_for_results(self, timeout: Optional[float] = None) -> "DbExperimentDat
Returns:
The experiment data with finished jobs and post-processing.
"""
job_ids = []
job_futures = []
for kwargs, fut in self._job_futures.copy():
job = kwargs["job"]
LOG.info("Waiting for job %s and its post processing to finish.", job.job_id())
job_ids += [job.job_id() for job in kwargs["jobs"]]
job_futures.append(fut)
LOG.info("Waiting for jobs %s and its post processing to finish.", job_ids)

# The Python concurrency module does not invoke the future done callback
# functions when calling Future.result(), so we force them to be run before
# returning by explicitly calling Future.set_result() which does invoke
# the done callback functions
for fut in job_futures:
with contextlib.suppress(Exception):
fut.result(timeout)
fut_res = fut.result(timeout=timeout)
fut.set_result(fut_res)
return self

def status(self) -> str:
Expand Down Expand Up @@ -862,24 +878,26 @@ def status(self) -> str:
with self._job_futures.lock:
for idx, item in enumerate(self._job_futures):
kwargs, fut = item
job = kwargs["job"]
job_status = job.status()
statuses.add(job_status)
if job_status == JobStatus.ERROR:
job_err = "."
if hasattr(job, "error_message"):
job_err = ": " + job.error_message()
self._errors.append(f"Job {job.job_id()} failed{job_err}")

if fut.done():
self._job_futures[idx] = None
ex = fut.exception()
if ex:
self._errors.append(
f"Post processing for job {job.job_id()} failed: \n"
+ "".join(traceback.format_exception(type(ex), ex, ex.__traceback__))
)
statuses.add(JobStatus.ERROR)
for job in kwargs["jobs"]:
job_status = job.status()
statuses.add(job_status)
if job_status == JobStatus.ERROR:
job_err = "."
if hasattr(job, "error_message"):
job_err = ": " + job.error_message()
self._errors.append(f"Job {job.job_id()} failed{job_err}")

if fut.done():
self._job_futures[idx] = None
ex = fut.exception()
if ex:
self._errors.append(
f"Post processing for job {job.job_id()} failed: \n"
+ "".join(
traceback.format_exception(type(ex), ex, ex.__traceback__)
)
)
statuses.add(JobStatus.ERROR)

self._job_futures = ThreadSafeList(list(filter(None, self._job_futures)))

Expand Down Expand Up @@ -946,12 +964,11 @@ def _copy_metadata(
# inherits an abstract class.
extra_kwargs = {}
for key, val in orig_kwargs.items():
if key not in ["job", "job_done_callback", "timeout"]:
if key not in ["jobs", "timeout"]:
extra_kwargs[key] = val

new_instance.add_data(
data=orig_kwargs["job"],
post_processing_callback=orig_kwargs["job_done_callback"],
data=orig_kwargs["jobs"],
timeout=orig_kwargs["timeout"],
**extra_kwargs,
)
Expand Down
Loading