Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 12 additions & 47 deletions cirq-core/cirq/contrib/quimb/mps_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,19 +266,12 @@ def test_trial_result_str():
prng=value.parse_random_state(0),
simulation_options=ccq.mps_simulator.MPSOptions(),
)
assert (
str(
ccq.mps_simulator.MPSTrialResult(
params=cirq.ParamResolver({}),
measurements={'m': np.array([[1]])},
final_simulator_state=final_simulator_state,
)
)
== """measurements: m=1
output state: TensorNetwork([
Tensor(shape=(2,), inds=('i_0',), tags=oset([])),
])"""
result = ccq.mps_simulator.MPSTrialResult(
params=cirq.ParamResolver({}),
measurements={'m': np.array([[1]])},
final_simulator_state=final_simulator_state,
)
assert 'output state: TensorNetwork' in str(result)


def test_trial_result_repr_pretty():
Expand All @@ -293,40 +286,22 @@ def test_trial_result_repr_pretty():
measurements={'m': np.array([[1]])},
final_simulator_state=final_simulator_state,
)
cirq.testing.assert_repr_pretty(
result,
"""measurements: m=1
output state: TensorNetwork([
Tensor(shape=(2,), inds=('i_0',), tags=oset([])),
])""",
)
cirq.testing.assert_repr_pretty_contains(result, 'output state: TensorNetwork')
cirq.testing.assert_repr_pretty(result, "cirq.MPSTrialResult(...)", cycle=True)


def test_empty_step_result():
q0 = cirq.LineQubit(0)
sim = ccq.mps_simulator.MPSSimulator()
step_result = next(sim.simulate_moment_steps(cirq.Circuit(cirq.measure(q0))))
assert (
str(step_result)
== """q(0)=0
TensorNetwork([
Tensor(shape=(2,), inds=('i_0',), tags=oset([])),
])"""
)
assert 'TensorNetwork' in str(step_result)


def test_step_result_repr_pretty():
q0 = cirq.LineQubit(0)
sim = ccq.mps_simulator.MPSSimulator()
step_result = next(sim.simulate_moment_steps(cirq.Circuit(cirq.measure(q0))))
cirq.testing.assert_repr_pretty(
step_result,
"""q(0)=0
TensorNetwork([
Tensor(shape=(2,), inds=('i_0',), tags=oset([])),
])""",
)
cirq.testing.assert_repr_pretty_contains(step_result, 'TensorNetwork')
cirq.testing.assert_repr_pretty(step_result, "cirq.MPSSimulatorStepResult(...)", cycle=True)


Expand Down Expand Up @@ -391,13 +366,8 @@ def test_simulate_moment_steps_sample():
step._simulator_state().to_numpy(),
np.asarray([1.0 / math.sqrt(2), 0.0, 1.0 / math.sqrt(2), 0.0]),
)
assert (
str(step)
== """TensorNetwork([
Tensor(shape=(2,), inds=('i_0',), tags=oset([])),
Tensor(shape=(2,), inds=('i_1',), tags=oset([])),
])"""
)
# There are two "Tensor()" copies in the string.
assert len(str(step).split('Tensor(')) == 3
samples = step.sample([q0, q1], repetitions=10)
for sample in samples:
assert np.array_equal(sample, [True, False]) or np.array_equal(
Expand All @@ -412,13 +382,8 @@ def test_simulate_moment_steps_sample():
step._simulator_state().to_numpy(),
np.asarray([1.0 / math.sqrt(2), 0.0, 0.0, 1.0 / math.sqrt(2)]),
)
assert (
str(step)
== """TensorNetwork([
Tensor(shape=(2, 2), inds=('i_0', 'mu_0_1'), tags=oset([])),
Tensor(shape=(2, 2), inds=('mu_0_1', 'i_1'), tags=oset([])),
])"""
)
# There are two "Tensor()" copies in the string.
assert len(str(step).split('Tensor(')) == 3
samples = step.sample([q0, q1], repetitions=10)
for sample in samples:
assert np.array_equal(sample, [True, True]) or np.array_equal(
Expand Down
6 changes: 5 additions & 1 deletion cirq-core/cirq/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@
random_two_qubit_circuit_with_czs,
)

from cirq.testing.repr_pretty_tester import assert_repr_pretty, FakePrinter
from cirq.testing.repr_pretty_tester import (
assert_repr_pretty,
assert_repr_pretty_contains,
FakePrinter,
)

from cirq.testing.sample_circuits import nonoptimal_toffoli_circuit
19 changes: 19 additions & 0 deletions cirq-core/cirq/testing/repr_pretty_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,22 @@ def assert_repr_pretty(val: Any, text: str, cycle: bool = False):
p = FakePrinter()
val._repr_pretty_(p, cycle=cycle)
assert p.text_pretty == text, f"{p.text_pretty} != {text}"


def assert_repr_pretty_contains(val: Any, substr: str, cycle: bool = False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be nice to add a version that does a more general regex match as well:

def assert_repr_pretty_matches(val: Any, pattern: Union[str, re.Pattern], cycle: bool = False):
    p = FakePrinter()
    val._repr_pretty(p, cycle=cycle)
    assert re.match(pattern, p.text_pretty), ...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I appeal to YAGNI here - if someone needs it, they'll make it.

"""Assert that the given object has a `_repr_pretty_` output that contains the given text.

Args:
val: The object to test.
substr: The string that `_repr_pretty_` is expected to contain.
cycle: The value of `cycle` passed to `_repr_pretty_`. `cycle` represents whether
the call is made with a potential cycle. Typically one should handle the
`cycle` equals `True` case by returning text that does not recursively call
the `_repr_pretty_` to break this cycle.

Raises:
AssertionError: If `_repr_pretty_` does not pretty print the given text.
"""
p = FakePrinter()
val._repr_pretty_(p, cycle=cycle)
assert substr in p.text_pretty, f"{substr} not in {p.text_pretty}"
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,23 @@ def _repr_pretty_(self, p, cycle):

cirq.testing.assert_repr_pretty(TestClassMultipleTexts(), "I'm so pretty I am")
cirq.testing.assert_repr_pretty(TestClassMultipleTexts(), "TestClass", cycle=True)


def test_assert_repr_pretty_contains():
class TestClass:
def _repr_pretty_(self, p, cycle):
p.text("TestClass" if cycle else "I'm so pretty")

cirq.testing.assert_repr_pretty_contains(TestClass(), "pretty")
cirq.testing.assert_repr_pretty_contains(TestClass(), "Test", cycle=True)

class TestClassMultipleTexts:
def _repr_pretty_(self, p, cycle):
if cycle:
p.text("TestClass")
else:
p.text("I'm so pretty")
p.text(" I am")

cirq.testing.assert_repr_pretty_contains(TestClassMultipleTexts(), "I am")
cirq.testing.assert_repr_pretty_contains(TestClassMultipleTexts(), "Class", cycle=True)