diff --git a/cirq-core/cirq/work/pauli_sum_collector.py b/cirq-core/cirq/work/pauli_sum_collector.py index cbb408fcf46..45b8e9c564e 100644 --- a/cirq-core/cirq/work/pauli_sum_collector.py +++ b/cirq-core/cirq/work/pauli_sum_collector.py @@ -92,9 +92,9 @@ def estimated_energy(self) -> Union[float, complex]: if a + b: energy += coef * (a - b) / (a + b) energy = complex(energy) + energy += self._identity_offset if energy.imag == 0: energy = energy.real - energy += self._identity_offset return energy diff --git a/cirq-core/cirq/work/pauli_sum_collector_test.py b/cirq-core/cirq/work/pauli_sum_collector_test.py index 75727304c29..cb6994455f6 100644 --- a/cirq-core/cirq/work/pauli_sum_collector_test.py +++ b/cirq-core/cirq/work/pauli_sum_collector_test.py @@ -22,12 +22,16 @@ async def test_pauli_string_sample_collector(): a, b = cirq.LineQubit.range(2) p = cirq.PauliSumCollector( circuit=cirq.Circuit(cirq.H(a), cirq.CNOT(a, b), cirq.X(a), cirq.Z(b)), - observable=cirq.X(a) * cirq.X(b) - 16 * cirq.Y(a) * cirq.Y(b) + 4 * cirq.Z(a) * cirq.Z(b), + observable=(1 + 0j) * cirq.X(a) * cirq.X(b) + - 16 * cirq.Y(a) * cirq.Y(b) + + 4 * cirq.Z(a) * cirq.Z(b) + + (1 - 0j), samples_per_term=100, ) completion = p.collect_async(sampler=cirq.Simulator()) assert await completion is None - assert p.estimated_energy() == 11 + energy = p.estimated_energy() + assert isinstance(energy, float) and energy == 12 @pytest.mark.asyncio