Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion qupulse/expressions/sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ def _try_to_numeric(self) -> Optional[numbers.Number]:
return None
if isinstance(self._original_expression, ALLOWED_NUMERIC_SCALAR_TYPES):
return self._original_expression
expr = self._sympified_expression
expr = self._sympified_expression.doit()
if isinstance(expr, bool):
# sympify can return bool
return expr
Expand Down
82 changes: 80 additions & 2 deletions qupulse/program/linspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,19 @@ class LinSpaceNode:
def dependencies(self) -> Mapping[int, set]:
raise NotImplementedError

def reversed(self, offset: int, lengths: list):
"""Get the time reversed version of this linspace node. Since this is a non-local operation the arguments give
the context.

Args:
offset: Active iterations that are not reserved
lengths: Lengths of the currently active iterations that have to be reversed

Returns:
Time reversed version.
"""
raise NotImplementedError


@dataclass
class LinSpaceHold(LinSpaceNode):
Expand All @@ -60,13 +73,46 @@ def dependencies(self) -> Mapping[int, set]:
for idx, factors in enumerate(self.factors)
if factors}

def reversed(self, offset: int, lengths: list):
if not lengths:
return self
# If the iteration length is `n`, the starting point is shifted by `n - 1`
steps = [length - 1 for length in lengths]
bases = []
factors = []
for ch_base, ch_factors in zip(self.bases, self.factors):
if ch_factors is None or len(ch_factors) <= offset:
bases.append(ch_base)
factors.append(ch_factors)
else:
ch_reverse_base = ch_base + sum(step * factor
for factor, step in zip(ch_factors[offset:], steps))
reversed_factors = ch_factors[:offset] + tuple(-f for f in ch_factors[offset:])
bases.append(ch_reverse_base)
factors.append(reversed_factors)

if self.duration_factors is None or len(self.duration_factors) <= offset:
duration_factors = self.duration_factors
duration_base = self.duration_base
else:
duration_base = self.duration_base + sum((step * factor
for factor, step in zip(self.duration_factors[offset:], steps)), TimeType(0))
duration_factors = self.duration_factors[:offset] + tuple(-f for f in self.duration_factors[offset:])
return LinSpaceHold(tuple(bases), tuple(factors), duration_base=duration_base, duration_factors=duration_factors)


@dataclass
class LinSpaceArbitraryWaveform(LinSpaceNode):
"""This is just a wrapper to pipe arbitrary waveforms through the system."""
waveform: Waveform
channels: Tuple[ChannelID, ...]

def reversed(self, offset: int, lengths: list):
return LinSpaceArbitraryWaveform(
waveform=self.waveform.reversed(),
channels=self.channels,
)


@dataclass
class LinSpaceRepeat(LinSpaceNode):
Expand All @@ -81,6 +127,9 @@ def dependencies(self):
dependencies.setdefault(idx, set()).update(deps)
return dependencies

def reversed(self, offset: int, counts: list):
return LinSpaceRepeat(tuple(node.reversed(offset, counts) for node in reversed(self.body)), self.count)


@dataclass
class LinSpaceIter(LinSpaceNode):
Expand All @@ -100,6 +149,12 @@ def dependencies(self):
dependencies.setdefault(idx, set()).update(shortened)
return dependencies

def reversed(self, offset: int, lengths: list):
lengths.append(self.length)
reversed_iter = LinSpaceIter(tuple(node.reversed(offset, lengths) for node in reversed(self.body)), self.length)
lengths.pop()
return reversed_iter


class LinSpaceBuilder(ProgramBuilder):
"""This program builder supports efficient translation of pulse templates that use symbolic linearly
Expand Down Expand Up @@ -214,6 +269,14 @@ def with_iteration(self, index_name: str, rng: range,
if cmds:
self._stack[-1].append(LinSpaceIter(body=tuple(cmds), length=len(rng)))

@contextlib.contextmanager
def time_reversed(self) -> ContextManager['LinSpaceBuilder']:
self._stack.append([])
yield self
inner = self._stack.pop()
offset = len(self._ranges)
self._stack[-1].extend(node.reversed(offset, []) for node in reversed(inner))

def to_program(self) -> Optional[Sequence[LinSpaceNode]]:
if self._root():
return self._root()
Expand Down Expand Up @@ -414,8 +477,10 @@ def to_increment_commands(linspace_nodes: Sequence[LinSpaceNode]) -> List[Comman


class LinSpaceVM:
def __init__(self, channels: int):
def __init__(self, channels: int,
sample_resolution: TimeType = TimeType.from_fraction(1, 2)):
self.current_values = [np.nan] * channels
self.sample_resolution = sample_resolution
self.time = TimeType(0)
self.registers = tuple({} for _ in range(channels))

Expand All @@ -428,7 +493,20 @@ def __init__(self, channels: int):

def change_state(self, cmd: Union[Set, Increment, Wait, Play]):
if isinstance(cmd, Play):
raise NotImplementedError("TODO: Implement arbitrary waveform simulation")
dt = self.sample_resolution
t = TimeType(0)
total_duration = cmd.waveform.duration
while t <= total_duration and dt > 0:
sample_time = np.array([float(t)])
values = []
for (idx, ch) in enumerate(cmd.channels):
self.current_values[idx] = values.append(cmd.waveform.get_sampled(channel=ch, sample_times=sample_time)[0])
self.history.append(
(self.time, self.current_values.copy())
)
dt = min(total_duration - t, self.sample_resolution)
self.time += dt
t += dt
elif isinstance(cmd, Wait):
self.history.append(
(self.time, self.current_values.copy())
Expand Down
3 changes: 3 additions & 0 deletions qupulse/program/waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1277,3 +1277,6 @@ def compare_key(self) -> Hashable:

def reversed(self) -> 'Waveform':
return self._inner

def __repr__(self):
return f"ReversedWaveform(inner={self._inner!r})"
4 changes: 4 additions & 0 deletions tests/expressions/expression_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,10 @@ def test_special_function_numeric_evaluation(self):

np.testing.assert_allclose(expected, result)

def test_try_to_numeric(self):
expr = ExpressionScalar('Sum(9, (x, 0, 5), (y, 0, 7))')
self.assertEqual(expr._try_to_numeric(), 9*6*8)

def test_evaluate_with_exact_rationals(self):
expr = ExpressionScalar('1 / 3')
self.assertEqual(TimeType.from_fraction(1, 3), expr.evaluate_with_exact_rationals({}))
Expand Down
2 changes: 1 addition & 1 deletion tests/program/linspace_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def assert_vm_output_almost_equal(test: TestCase, expected, actual):
test.assertEqual(t_e, t_a, f"Differing times in {idx} element")
test.assertEqual(len(vals_e), len(vals_a), f"Differing channel count in {idx} element")
for ch, (val_e, val_a) in enumerate(zip(vals_e, vals_a)):
test.assertAlmostEqual(val_e, val_a, msg=f"Differing values in {idx} element channel {ch}")
test.assertAlmostEqual(val_e, val_a, msg=f"Differing values in {idx} of {len(expected)} element channel {ch}")


class SingleRampTest(TestCase):
Expand Down
41 changes: 36 additions & 5 deletions tests/pulses/time_reversal_pulse_template_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from qupulse.pulses.time_reversal_pulse_template import TimeReversalPulseTemplate
from qupulse.utils.types import TimeType
from qupulse.expressions import ExpressionScalar

from qupulse.program.loop import LoopBuilder
from qupulse.program.linspace import LinSpaceBuilder, LinSpaceVM, to_increment_commands
from tests.pulses.sequencing_dummies import DummyPulseTemplate
from tests.serialization_tests import SerializableTests

from tests.program.linspace_tests import assert_vm_output_almost_equal

class TimeReversalPulseTemplateTests(unittest.TestCase):
def test_simple_properties(self):
Expand All @@ -29,19 +30,49 @@ def test_simple_properties(self):

self.assertEqual(reversed_pt.identifier, 'reverse')

def test_time_reversal_program(self):
def test_time_reversal_loop(self):
inner = ConstantPT(4, {'a': 3}) @ FunctionPT('sin(t)', 5, channel='a')
manual_reverse = FunctionPT('sin(5 - t)', 5, channel='a') @ ConstantPT(4, {'a': 3})
time_reversed = TimeReversalPulseTemplate(inner)

program = time_reversed.create_program()
manual_program = manual_reverse.create_program()
program = time_reversed.create_program(program_builder=LoopBuilder())
manual_program = manual_reverse.create_program(program_builder=LoopBuilder())

t, data, _ = render(program, 9 / 10)
_, manual_data, _ = render(manual_program, 9 / 10)

np.testing.assert_allclose(data['a'], manual_data['a'])

def test_time_reversal_linspace(self):
constant_pt = ConstantPT(4, {'a': '3.0 + x * 1.0 + y * -0.3'})
function_pt = FunctionPT('sin(t)', 5, channel='a')
reversed_function_pt = function_pt.with_time_reversal()

inner = (constant_pt @ function_pt).with_iteration('x', 6)
inner_manual = (reversed_function_pt @ constant_pt).with_iteration('x', (5, -1, -1))

outer = inner.with_time_reversal().with_iteration('y', 8)
outer_man = inner_manual.with_iteration('y', 8)

self.assertEqual(outer.duration, outer_man.duration)

program = outer.create_program(program_builder=LinSpaceBuilder(channels=('a',)))
manual_program = outer_man.create_program(program_builder=LinSpaceBuilder(channels=('a',)))

commands = to_increment_commands(program)
manual_commands = to_increment_commands(manual_program)
self.assertEqual(commands, manual_commands)

manual_vm = LinSpaceVM(1)
manual_vm.set_commands(manual_commands)
manual_vm.run()

vm = LinSpaceVM(1)
vm.set_commands(commands)
vm.run()

assert_vm_output_almost_equal(self, manual_vm.history, vm.history)


class TimeReversalPulseTemplateSerializationTests(unittest.TestCase, SerializableTests):
@property
Expand Down
Loading