diff --git a/cirq-core/cirq/contrib/quimb/mps_simulator_test.py b/cirq-core/cirq/contrib/quimb/mps_simulator_test.py index 8c664880b2e..9c520e00d9b 100644 --- a/cirq-core/cirq/contrib/quimb/mps_simulator_test.py +++ b/cirq-core/cirq/contrib/quimb/mps_simulator_test.py @@ -266,19 +266,12 @@ def test_trial_result_str(): prng=value.parse_random_state(0), simulation_options=ccq.mps_simulator.MPSOptions(), ) - assert ( - str( - ccq.mps_simulator.MPSTrialResult( - params=cirq.ParamResolver({}), - measurements={'m': np.array([[1]])}, - final_simulator_state=final_simulator_state, - ) - ) - == """measurements: m=1 -output state: TensorNetwork([ - Tensor(shape=(2,), inds=('i_0',), tags=oset([])), -])""" + result = ccq.mps_simulator.MPSTrialResult( + params=cirq.ParamResolver({}), + measurements={'m': np.array([[1]])}, + final_simulator_state=final_simulator_state, ) + assert 'output state: TensorNetwork' in str(result) def test_trial_result_repr_pretty(): @@ -293,13 +286,7 @@ def test_trial_result_repr_pretty(): measurements={'m': np.array([[1]])}, final_simulator_state=final_simulator_state, ) - cirq.testing.assert_repr_pretty( - result, - """measurements: m=1 -output state: TensorNetwork([ - Tensor(shape=(2,), inds=('i_0',), tags=oset([])), -])""", - ) + cirq.testing.assert_repr_pretty_contains(result, 'output state: TensorNetwork') cirq.testing.assert_repr_pretty(result, "cirq.MPSTrialResult(...)", cycle=True) @@ -307,26 +294,14 @@ def test_empty_step_result(): q0 = cirq.LineQubit(0) sim = ccq.mps_simulator.MPSSimulator() step_result = next(sim.simulate_moment_steps(cirq.Circuit(cirq.measure(q0)))) - assert ( - str(step_result) - == """q(0)=0 -TensorNetwork([ - Tensor(shape=(2,), inds=('i_0',), tags=oset([])), -])""" - ) + assert 'TensorNetwork' in str(step_result) def test_step_result_repr_pretty(): q0 = cirq.LineQubit(0) sim = ccq.mps_simulator.MPSSimulator() step_result = next(sim.simulate_moment_steps(cirq.Circuit(cirq.measure(q0)))) - cirq.testing.assert_repr_pretty( - step_result, - """q(0)=0 -TensorNetwork([ - Tensor(shape=(2,), inds=('i_0',), tags=oset([])), -])""", - ) + cirq.testing.assert_repr_pretty_contains(step_result, 'TensorNetwork') cirq.testing.assert_repr_pretty(step_result, "cirq.MPSSimulatorStepResult(...)", cycle=True) @@ -391,13 +366,8 @@ def test_simulate_moment_steps_sample(): step._simulator_state().to_numpy(), np.asarray([1.0 / math.sqrt(2), 0.0, 1.0 / math.sqrt(2), 0.0]), ) - assert ( - str(step) - == """TensorNetwork([ - Tensor(shape=(2,), inds=('i_0',), tags=oset([])), - Tensor(shape=(2,), inds=('i_1',), tags=oset([])), -])""" - ) + # There are two "Tensor()" copies in the string. + assert len(str(step).split('Tensor(')) == 3 samples = step.sample([q0, q1], repetitions=10) for sample in samples: assert np.array_equal(sample, [True, False]) or np.array_equal( @@ -412,13 +382,8 @@ def test_simulate_moment_steps_sample(): step._simulator_state().to_numpy(), np.asarray([1.0 / math.sqrt(2), 0.0, 0.0, 1.0 / math.sqrt(2)]), ) - assert ( - str(step) - == """TensorNetwork([ - Tensor(shape=(2, 2), inds=('i_0', 'mu_0_1'), tags=oset([])), - Tensor(shape=(2, 2), inds=('mu_0_1', 'i_1'), tags=oset([])), -])""" - ) + # There are two "Tensor()" copies in the string. + assert len(str(step).split('Tensor(')) == 3 samples = step.sample([q0, q1], repetitions=10) for sample in samples: assert np.array_equal(sample, [True, True]) or np.array_equal( diff --git a/cirq-core/cirq/testing/__init__.py b/cirq-core/cirq/testing/__init__.py index 9e0165f0c7f..81a38d020ca 100644 --- a/cirq-core/cirq/testing/__init__.py +++ b/cirq-core/cirq/testing/__init__.py @@ -92,6 +92,10 @@ random_two_qubit_circuit_with_czs, ) -from cirq.testing.repr_pretty_tester import assert_repr_pretty, FakePrinter +from cirq.testing.repr_pretty_tester import ( + assert_repr_pretty, + assert_repr_pretty_contains, + FakePrinter, +) from cirq.testing.sample_circuits import nonoptimal_toffoli_circuit diff --git a/cirq-core/cirq/testing/repr_pretty_tester.py b/cirq-core/cirq/testing/repr_pretty_tester.py index a55c3e27b3d..7b25cf058ff 100644 --- a/cirq-core/cirq/testing/repr_pretty_tester.py +++ b/cirq-core/cirq/testing/repr_pretty_tester.py @@ -53,3 +53,22 @@ def assert_repr_pretty(val: Any, text: str, cycle: bool = False): p = FakePrinter() val._repr_pretty_(p, cycle=cycle) assert p.text_pretty == text, f"{p.text_pretty} != {text}" + + +def assert_repr_pretty_contains(val: Any, substr: str, cycle: bool = False): + """Assert that the given object has a `_repr_pretty_` output that contains the given text. + + Args: + val: The object to test. + substr: The string that `_repr_pretty_` is expected to contain. + cycle: The value of `cycle` passed to `_repr_pretty_`. `cycle` represents whether + the call is made with a potential cycle. Typically one should handle the + `cycle` equals `True` case by returning text that does not recursively call + the `_repr_pretty_` to break this cycle. + + Raises: + AssertionError: If `_repr_pretty_` does not pretty print the given text. + """ + p = FakePrinter() + val._repr_pretty_(p, cycle=cycle) + assert substr in p.text_pretty, f"{substr} not in {p.text_pretty}" diff --git a/cirq-core/cirq/testing/repr_pertty_tester_test.py b/cirq-core/cirq/testing/repr_pretty_tester_test.py similarity index 67% rename from cirq-core/cirq/testing/repr_pertty_tester_test.py rename to cirq-core/cirq/testing/repr_pretty_tester_test.py index 2641bb4fc5b..ac8d7a369df 100644 --- a/cirq-core/cirq/testing/repr_pertty_tester_test.py +++ b/cirq-core/cirq/testing/repr_pretty_tester_test.py @@ -42,3 +42,23 @@ def _repr_pretty_(self, p, cycle): cirq.testing.assert_repr_pretty(TestClassMultipleTexts(), "I'm so pretty I am") cirq.testing.assert_repr_pretty(TestClassMultipleTexts(), "TestClass", cycle=True) + + +def test_assert_repr_pretty_contains(): + class TestClass: + def _repr_pretty_(self, p, cycle): + p.text("TestClass" if cycle else "I'm so pretty") + + cirq.testing.assert_repr_pretty_contains(TestClass(), "pretty") + cirq.testing.assert_repr_pretty_contains(TestClass(), "Test", cycle=True) + + class TestClassMultipleTexts: + def _repr_pretty_(self, p, cycle): + if cycle: + p.text("TestClass") + else: + p.text("I'm so pretty") + p.text(" I am") + + cirq.testing.assert_repr_pretty_contains(TestClassMultipleTexts(), "I am") + cirq.testing.assert_repr_pretty_contains(TestClassMultipleTexts(), "Class", cycle=True)