Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up parameter resolution for cirq.Duration #6270

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
105 changes: 68 additions & 37 deletions cirq-core/cirq/value/duration.py
Expand Up @@ -13,14 +13,14 @@
# limitations under the License.
"""A typed time delta that supports picosecond accuracy."""

from typing import AbstractSet, Any, Dict, Optional, Tuple, TYPE_CHECKING, Union
from typing import AbstractSet, Any, Dict, Optional, Tuple, TYPE_CHECKING, Union, List
import datetime

import sympy
import numpy as np

from cirq import protocols
from cirq._compat import proper_repr
from cirq._compat import proper_repr, cached_method
from cirq._doc import document

if TYPE_CHECKING:
Expand Down Expand Up @@ -79,48 +79,53 @@ def __init__(
>>> print(cirq.Duration(micros=1.5 * sympy.Symbol('t')))
(1500.0*t) ns
"""
self._time_vals: List[_NUMERIC_INPUT_TYPE] = [0, 0, 0, 0]
self._multipliers = [1, 1000, 1000_000, 1000_000_000]
if value is not None and value != 0:
if isinstance(value, datetime.timedelta):
# timedelta has microsecond resolution.
micros += int(value / datetime.timedelta(microseconds=1))
self._time_vals[2] = int(value / datetime.timedelta(microseconds=1))
elif isinstance(value, Duration):
picos += value._picos
self._time_vals = value._time_vals
else:
raise TypeError(f'Not a `cirq.DURATION_LIKE`: {repr(value)}.')

val = picos + nanos * 1000 + micros * 1000_000 + millis * 1000_000_000
self._picos: _NUMERIC_OUTPUT_TYPE = float(val) if isinstance(val, np.number) else val
input_vals = [picos, nanos, micros, millis]
self._time_vals = _add_time_vals(self._time_vals, input_vals)

def _is_parameterized_(self) -> bool:
return protocols.is_parameterized(self._picos)
return protocols.is_parameterized(self._time_vals)

def _parameter_names_(self) -> AbstractSet[str]:
return protocols.parameter_names(self._picos)
return protocols.parameter_names(self._time_vals)

def _resolve_parameters_(self, resolver: 'cirq.ParamResolver', recursive: bool) -> 'Duration':
return Duration(picos=protocols.resolve_parameters(self._picos, resolver, recursive))
return _duration_from_time_vals(
protocols.resolve_parameters(self._time_vals, resolver, recursive)
)

@cached_method
def total_picos(self) -> _NUMERIC_OUTPUT_TYPE:
"""Returns the number of picoseconds that the duration spans."""
return self._picos
val = sum(a * b for a, b in zip(self._time_vals, self._multipliers))
return float(val) if isinstance(val, np.number) else val

def total_nanos(self) -> _NUMERIC_OUTPUT_TYPE:
"""Returns the number of nanoseconds that the duration spans."""
return self._picos / 1000
return self.total_picos() / 1000

def total_micros(self) -> _NUMERIC_OUTPUT_TYPE:
"""Returns the number of microseconds that the duration spans."""
return self._picos / 1000_000
return self.total_picos() / 1000_000

def total_millis(self) -> _NUMERIC_OUTPUT_TYPE:
"""Returns the number of milliseconds that the duration spans."""
return self._picos / 1000_000_000
return self.total_picos() / 1000_000_000

def __add__(self, other) -> 'Duration':
other = _attempt_duration_like_to_duration(other)
if other is None:
return NotImplemented
return Duration(picos=self._picos + other._picos)
return _duration_from_time_vals(_add_time_vals(self._time_vals, other._time_vals))

def __radd__(self, other) -> 'Duration':
return self.__add__(other)
Expand All @@ -129,86 +134,94 @@ def __sub__(self, other) -> 'Duration':
other = _attempt_duration_like_to_duration(other)
if other is None:
return NotImplemented
return Duration(picos=self._picos - other._picos)
return _duration_from_time_vals(
_add_time_vals(self._time_vals, [-x for x in other._time_vals])
)

def __rsub__(self, other) -> 'Duration':
other = _attempt_duration_like_to_duration(other)
if other is None:
return NotImplemented
return Duration(picos=other._picos - self._picos)
return _duration_from_time_vals(
_add_time_vals(other._time_vals, [-x for x in self._time_vals])
)

def __mul__(self, other) -> 'Duration':
if not isinstance(other, (int, float, sympy.Expr)):
return NotImplemented
return Duration(picos=self._picos * other)
if other == 0:
return _duration_from_time_vals([0] * 4)
return _duration_from_time_vals([x * other for x in self._time_vals])

def __rmul__(self, other) -> 'Duration':
return self.__mul__(other)

def __truediv__(self, other) -> Union['Duration', float]:
if isinstance(other, (int, float, sympy.Expr)):
return Duration(picos=self._picos / other)
new_time_vals = [x / other for x in self._time_vals]
return _duration_from_time_vals(new_time_vals)

other_duration = _attempt_duration_like_to_duration(other)
if other_duration is not None:
return self._picos / other_duration._picos
return self.total_picos() / other_duration.total_picos()

return NotImplemented

def __eq__(self, other):
other = _attempt_duration_like_to_duration(other)
if other is None:
return NotImplemented
return self._picos == other._picos
return self.total_picos() == other.total_picos()

def __ne__(self, other):
other = _attempt_duration_like_to_duration(other)
if other is None:
return NotImplemented
return self._picos != other._picos
return self.total_picos() != other.total_picos()

def __gt__(self, other):
other = _attempt_duration_like_to_duration(other)
if other is None:
return NotImplemented
return self._picos > other._picos
return self.total_picos() > other.total_picos()

def __lt__(self, other):
other = _attempt_duration_like_to_duration(other)
if other is None:
return NotImplemented
return self._picos < other._picos
return self.total_picos() < other.total_picos()

def __ge__(self, other):
other = _attempt_duration_like_to_duration(other)
if other is None:
return NotImplemented
return self._picos >= other._picos
return self.total_picos() >= other.total_picos()

def __le__(self, other):
other = _attempt_duration_like_to_duration(other)
if other is None:
return NotImplemented
return self._picos <= other._picos
return self.total_picos() <= other.total_picos()

def __bool__(self):
return bool(self._picos)
return bool(self.total_picos())

def __hash__(self):
if isinstance(self._picos, (int, float)) and self._picos % 1000000 == 0:
return hash(datetime.timedelta(microseconds=self._picos / 1000000))
return hash((Duration, self._picos))
if isinstance(self.total_picos(), (int, float)) and self.total_picos() % 1000000 == 0:
return hash(datetime.timedelta(microseconds=self.total_picos() / 1000000))
return hash((Duration, self.total_picos()))

def _decompose_into_amount_unit_suffix(self) -> Tuple[int, str, str]:
picos = self.total_picos()
if (
isinstance(self._picos, sympy.Mul)
and len(self._picos.args) == 2
and isinstance(self._picos.args[0], (sympy.Integer, sympy.Float))
isinstance(picos, sympy.Mul)
and len(picos.args) == 2
and isinstance(picos.args[0], (sympy.Integer, sympy.Float))
):
scale = self._picos.args[0]
rest = self._picos.args[1]
scale = picos.args[0]
rest = picos.args[1]
else:
scale = self._picos
scale = picos
rest = 1

if scale % 1000_000_000 == 0:
Expand All @@ -234,7 +247,7 @@ def _decompose_into_amount_unit_suffix(self) -> Tuple[int, str, str]:
return amount * rest, unit, suffix

def __str__(self) -> str:
if self._picos == 0:
if self.total_picos() == 0:
return 'Duration(0)'
amount, _, suffix = self._decompose_into_amount_unit_suffix()
if not isinstance(amount, (int, float, sympy.Symbol)):
Expand All @@ -257,3 +270,21 @@ def _attempt_duration_like_to_duration(value: Any) -> Optional[Duration]:
if isinstance(value, (int, float)) and value == 0:
return Duration()
return None


def _add_time_vals(
val1: List[_NUMERIC_INPUT_TYPE], val2: List[_NUMERIC_INPUT_TYPE]
) -> List[_NUMERIC_INPUT_TYPE]:
ret: List[_NUMERIC_INPUT_TYPE] = []
for i in range(4):
if val1[i] and val2[i]:
ret.append(val1[i] + val2[i])
else:
ret.append(val1[i] or val2[i])
return ret


def _duration_from_time_vals(time_vals: List[_NUMERIC_INPUT_TYPE]):
ret = Duration()
ret._time_vals = time_vals
return ret
2 changes: 2 additions & 0 deletions cirq-core/cirq/value/duration_test.py
Expand Up @@ -168,9 +168,11 @@ def test_sub():
def test_mul():
assert Duration(picos=2) * 3 == Duration(picos=6)
assert 4 * Duration(picos=3) == Duration(picos=12)
assert 0 * Duration(picos=10) == Duration()

t = sympy.Symbol('t')
assert t * Duration(picos=3) == Duration(picos=3 * t)
assert 0 * Duration(picos=t) == Duration(picos=0)

with pytest.raises(TypeError):
_ = Duration() * Duration()
Expand Down