# Verifying Measurement-Based Uncomputation

Quantum information cannot be destroyed, but during a computation, we may produce intermediate values that we wish to discard. We can "uncompute" these values by running the computation in reverse. The ordinary uncomputation strategy requires paying the cost of the computation twice, but [*Halving the cost of quantum addition.* Gidney 2017](https://arxiv.org/abs/1709.06648) shows how measurement in the X basis can effectively discard a bit without expensive uncomputation. The consequence is that the remaining states of the system will pick up phases depending on the random measurement result. [*Verifying Measurement Based Uncomputation.* Gidney 2019](https://algassert.com/post/1903) provides more detail about these phases. It also describes a proceedure for using a phased-classical simulator to "fuzz test" measurement-based uncomputation circuits.

Here, we show how Qualtran can be used to verify measurement based uncomputation circuits following Gidney's proposal.

## Uncomputing $\mathrm{And}$

As a warm-up, we can use the reference classical action of `And(uncompute=True)` to verify the truth table of the operation. First, we check the bloq over valid inputs.

In [None]:
import itertools
from qualtran.bloqs.mcmt import And

and_dag = And(uncompute=True)
for q1, q2 in itertools.product(range(2), repeat=2):
    target = int(q1==1 and q2 == 1)
    print(f'{q1=}, {q2=}, {target=}', end='  ')
    (q1_out, q2_out), = and_dag.call_classically(ctrl=[q1,q2], target=target)
    assert q1_out == q1
    assert q2_out == q2
    print('✓')

In a quantum computer, there is no error handling; but the classical simulation will helpfully inform you if you supply invalid inputs to the bloq. Here, there is an error because the `target` register does not contain the result of a (forwards) computation of $\mathrm{And}$.

In [None]:
try:
    and_dag.call_classically(ctrl=[1,1], target=0)
except ValueError as e:
    print(e)

## Naive attempt at $\mathrm{And}^\dagger$

What happens if we just measure the target bit in the X basis and throw it away? We'll build this simple circuit below so we can use the phased-classical simulator to find out.

In [None]:
from qualtran import BloqBuilder, Register, QBit, Side, Controlled, CtrlSpec, CBit
from qualtran.bloqs.basic_gates import MeasX, Discard, CZ

bb = BloqBuilder()
q1 = bb.add_register('q1', 1)
q2 = bb.add_register('q2', 1)
target = bb.add_register(Register('target', QBit(), side=Side.LEFT))

ctarget= bb.add(MeasX(), q=target)
bb.add(Discard(), c=ctarget)

throw_out_target = bb.finalize(q1=q1, q2=q2)
from qualtran.drawing import show_bloq
show_bloq(throw_out_target, 'musical_score')

## Fuzz testing measurement circuits

Given a computational basis state input, the X-basis measurement operation returns a random outcome. We explicitly supply a random number generator to the phased classical simulation function to support these circuits.

Since our simulation is now stochastic, we run it 10 times and see if we get the right answer.

In [None]:
import numpy as np
from qualtran.simulation.classical_sim import do_phased_classical_simulation

rng = np.random.default_rng(seed=123)
in_vals = {'q1': 1, 'q2': 1, 'target': 1}
for _ in range(10):
    out_vals, phase = do_phased_classical_simulation(throw_out_target, in_vals, rng=rng)
    assert out_vals['q1'] == 1
    assert out_vals['q2'] == 1
    assert 'target' not in out_vals
    if phase == 1:
        print("✓", end=' ')
    else:
        print(f"Bad phase: {phase}")

A phase on our computational basis state will result in *relative phases amongst* the computational basis states when this operation is called on a register in superposition, so these spurious phases must be fixed.

## MBUC circuit for $\mathrm{And}^\dagger$

So simply measuring the bit in an orthogonal basis and throwing it away hasn't worked. The fix here is straightforward: a phase is encountered when the target bit is `1` and the random measurement outcome is also `1`, so we can flip it back. We flip the phase conditioned on 1) the two control qubits being `1` and 2) the classical measurement result being `1`. The first condition can be achieved with a `CZ`. We use a classically-controlled `CZ` to implement conditions (1) *and* (2) with only a Clifford operation.

In [None]:
bb = BloqBuilder()
q1 = bb.add_register('q1', 1)
q2 = bb.add_register('q2', 1)
target = bb.add_register(Register('target', QBit(), side=Side.LEFT))

ctarget = bb.add(MeasX(), q=target)
classically_controlled_cz = CZ().controlled(CtrlSpec(qdtypes=[CBit()]))
ctarget, q1, q2 = bb.add(
    classically_controlled_cz,
    **{'ctrl': ctarget,
       'q1': q1,
       'q2': q2
      }
)
bb.add(Discard(), c=ctarget)

mbuc_target = bb.finalize(q1=q1, q2=q2)
show_bloq(mbuc_target, 'musical_score')

## Fuzz testing MBUC

We can continue to use random measurement results in simulation to "fuzz test" our construction. Here, all ten runs pass our check.

In [None]:
rng = np.random.default_rng(seed=123)
in_vals = {'q1': 1, 'q2': 1, 'target': 1}
for _ in range(10):
    out_vals, phase = do_phased_classical_simulation(mbuc_target, in_vals, rng=rng)
    assert out_vals['q1'] == 1
    assert out_vals['q2'] == 1
    assert 'target' not in out_vals
    if phase == 1:
        print("✓", end=' ')
    else:
        print(f"Bad phase: {phase}")

## Exhaustive testing of MBUC

With some additional work, we can inject particular patterns of measurement results to check all possible cases. For circuits with a small number of `MeasX` bloqs, this can be more valuable than fuzz testing. The exhaustive number of cases grows exponentially in the number of measured bits.

#### Preparation: find the bloq index of our measurement operation

In [None]:
# Prep work: find the bloq instance indices of measurement operations.
# Here, there's only one; but this code snippet will work for MBUC circuits
# with additional MeasX bloqs
cbloq = mbuc_target
meas_binst_is = [binst.i for binst in cbloq.bloq_instances if binst.bloq_is(MeasX)]
assert len(meas_binst_is) == 1, 'this circuit only has one'
meas_binst_i = meas_binst_is[0]
meas_binst_i

### Loop over inputs *and* measurement results

In [None]:
from qualtran.simulation.classical_sim import PhasedClassicalSimState
import itertools

for q1, q2 in itertools.product(range(2), repeat=2):
    target = int(q1==1 and q2 == 1)
    print(f'{q1=}, {q2=}, {target=}')

    for meas_result in [0, 1]:
        print(f'  meas {meas_result}', end=' ')
        fixed_rnd_vals = {meas_binst_i: meas_result}
        sim = PhasedClassicalSimState.from_cbloq(
            cbloq, 
            vals={'q1': q1, 'q2': q2, 'target': target},
            fixed_random_vals={meas_binst_i: meas_result}
        )
        out_vals = sim.simulate()
        
        assert out_vals['q1'] == q1
        assert out_vals['q2'] == q2
        assert 'target' not in out_vals
        assert phase == 1.0
        print(' ✓')

### Inspecting the phase during simulation

For visibility into the progress of the simulation, we extend the `step` method of the simulator to print out the current phase of the system. We've also modified the exhaustive loop to use `itertools.product` so this code snippet can handle circuits with multiple `MeasX` gates (with exponential scaling). 

In [None]:
class DebugPhasedClassicalSim(PhasedClassicalSimState):
    """Phased-classical simulator that prints debug information."""
    
    def step(self):
        """At each step, print a brief representation of the current phase."""
        super().step()
        if sim.phase == 1.0:
            print('+', end='')
        elif sim.phase == -1.0:
            print('-', end='')
        else:
            print('?', end='')

In [None]:
meas_binst_is = [binst.i for binst in cbloq.bloq_instances if binst.bloq_is(MeasX)]

for q1, q2 in itertools.product(range(2), repeat=2):
    target = int(q1==1 and q2 == 1)
    print(f'{q1=}, {q2=}, {target=}')

    for meas_result in itertools.product(range(2), repeat=len(meas_binst_is)):
        print(f'  meas {meas_result}', end=' ')
        fixed_rnd_vals = {binst_i: meas_result[j] for j, binst_i in enumerate(meas_binst_is)}

        sim = DebugPhasedClassicalSim.from_cbloq(
            cbloq,
            {'q1': q1, 'q2': q2, 'target': target},
            fixed_random_vals=fixed_rnd_vals
        )
        out_vals = sim.simulate()
        
        assert out_vals['q1'] == q1
        assert out_vals['q2'] == q2
        assert 'target' not in out_vals
        assert phase == 1.0
        print(' ✓')

Note that the phase is unaffected for all cases except when the `target` bit is `1` *and* the measurement result is `1`. Note that it is immediately fixed by the classically-controlled CZ.