Skip to content

Commit

Permalink
Merge pull request #812 from qutech/issues/781_simple_expression
Browse files Browse the repository at this point in the history
Make SimpleExpression more expression conformant
  • Loading branch information
terrorfisch committed Mar 28, 2024
2 parents d35ac5f + 7d27495 commit 5c6fd63
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 12 deletions.
44 changes: 33 additions & 11 deletions qupulse/program/__init__.py
@@ -1,16 +1,17 @@
import contextlib
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, Union, Sequence, ContextManager, Mapping, Tuple, Generic, TypeVar, Iterable
from numbers import Real
from typing import Optional, Union, Sequence, ContextManager, Mapping, Tuple, Generic, TypeVar, Iterable, Dict
from numbers import Real, Number

import numpy as np

from qupulse._program.waveforms import Waveform
from qupulse.utils.types import MeasurementWindow, TimeType
from qupulse.utils.types import MeasurementWindow, TimeType, FrozenMapping
from qupulse._program.volatile import VolatileRepetitionCount
from qupulse.parameter_scope import Scope
from qupulse.expressions import sympy as sym_expr
from qupulse.expressions import sympy as sym_expr, Expression
from qupulse.utils.sympy import _lambdify_modules

from typing import Protocol, runtime_checkable

Expand All @@ -30,7 +31,7 @@ class SimpleExpression(Generic[NumVal]):
"""

base: NumVal
offsets: Sequence[Tuple[str, NumVal]]
offsets: Dict[str, NumVal]

def value(self, scope: Mapping[str, NumVal]) -> NumVal:
value = self.base
Expand All @@ -43,7 +44,10 @@ def __add__(self, other):
return SimpleExpression(self.base + other, self.offsets)

if type(other) == type(self):
return SimpleExpression(self.base + other.base, self.offsets + other.offsets)
offsets = self.offsets.copy()
for name, value in other.offsets.items():
offsets[name] = value + offsets.get(name, 0)
return SimpleExpression(self.base + other.base, offsets)

return NotImplemented

Expand All @@ -57,22 +61,40 @@ def __rsub__(self, other):
(-self).__add__(other)

def __neg__(self):
return SimpleExpression(-self.base, tuple((name, -value) for name, value in self.offsets))
return SimpleExpression(-self.base, {name: -value for name, value in self.offsets.items()})

def __mul__(self, other: NumVal):
if isinstance(other, SimpleExpression):
return NotImplemented
return SimpleExpression(self.base * other, tuple((name, value * other) for name, value in self.offsets))
if isinstance(other, (float, int, TimeType)):
return SimpleExpression(self.base * other, {name: other * value for name, value in self.offsets.items()})

return NotImplemented

def __rmul__(self, other):
return self.__mul__(other)

def evaluate_in_scope(self, *args, **kwargs):
def __truediv__(self, other):
inv = 1 / other
return self.__mul__(inv)

@property
def free_symbols(self):
return ()

def _sympy_(self):
return self

def replace(self, r, s):
return self

def evaluate_in_scope_(self, *args, **kwargs):
# TODO: remove. It is currently required to avoid nesting this class in an expression for the MappedScope
# We can maybe replace is with a HardwareScope or something along those lines
return self


_lambdify_modules.append({'SimpleExpression': SimpleExpression})


RepetitionCount = Union[int, VolatileRepetitionCount, SimpleExpression[int]]
HardwareTime = Union[TimeType, SimpleExpression[TimeType]]
HardwareVoltage = Union[float, SimpleExpression[float]]
Expand Down
2 changes: 1 addition & 1 deletion qupulse/program/linspace.py
Expand Up @@ -124,7 +124,7 @@ def inner_scope(self, scope: Scope) -> Scope:
process."""
if self._ranges:
name, _ = self._ranges[-1]
return MappedScope(scope, FrozenDict({name: SimpleExpression(base=0, offsets=[(name, 1)])}))
return scope.overwrite({name: SimpleExpression(base=0, offsets=[(name, 1)])})
else:
return scope

Expand Down

0 comments on commit 5c6fd63

Please sign in to comment.