# And

To do classical logic with a reversible circuit (a pre-requisite for a quantum circuit), we use a three (qu)bit operation called a Toffli gate that takes `[a, b, c]` to `[a, b, c ^ (a & b)]`. If we take `c` to be zero, this is an And gate taking `[a, b]` to `[a, b, a & b]`.

In [None]:
import itertools
for a, b, in itertools.product([0, 1], repeat=2):
    print(a, b, '->', a & b)

## Quantum operation

We provide a quantum operation for performing quantum And. Specifically, it assumes the third qubit (i.e. the target) is initialized to the `|0>` state.

In [None]:
import cirq
from cirq.contrib.svg import SVGCircuit
from cirq_qubitization.and_gate import And

gate = And()
r = gate.registers
quregs = r.get_named_qubits()
operation = gate.on_registers(**quregs)
circuit = cirq.Circuit(operation)
SVGCircuit(circuit)

## Testing classical operations


In fault tolerant algorithms, the beefy part of an algorithm usually boils down to 
devising unitary (reversible) operations that implement a classical operation. In math, 
this looks like
$$
U_f |x\rangle|0\rangle = |x\rangle|f(x)\rangle
$$
for all $x$.

The encompassing quantum algorithm will then run this operation over a superposition
and cleverly interfere the resulting state to read out an answer of interest. But we
can test the definition of $U_f$ by testing its action on computational basis states.
A classical operation is (by definition!) one that takes a computational basis state
to one-and-only-one other computational basis state, so we can efficiently track its
state.

### Input generation

We can use `itertools.product([0, 1], repeat=n)` to generate all the possible strings
of 0s and 1s to use as test input. For gates with named registers, we provide helper methods
to generate classical inputs that vary over some registers but keep others fixed to try
to control the number of test cases.

We operate on 2-dimensional numpy arrays of bits where the second axis is over (qu)bits and the
first axis is over all the states we want to check. Numpy vectorized operations are
faster and we want to check as many input/output states as feasible!

In [None]:
import cirq_qubitization.testing as cq_testing

test_inputs = cq_testing.get_classical_inputs(
    variable_registers=[r['control']], 
    fixed_registers={
        r['target']: 0, 
        r['ancilla']: 0
    }
)
test_inputs

### Classical implementations

Gates can implement `_apply_classical_from_registers` to provide an efficient, numpy-vectorized
version of their classical operation. The wrapping method `apply_classical` manages input validation
and output munging. 

In [None]:
test_outputs = gate.apply_classical(test_inputs)
test_outputs

In [None]:
import numpy as np

# Control vals should remain unmodified.
np.testing.assert_array_equal(test_inputs['control'], test_outputs['control'])

In [None]:
# not using any ancilla here
assert test_inputs['ancilla'].shape[1] == 0
assert test_outputs['ancilla'].shape[1] == 0

## Testing with Tensor Networks

We can use Quimb tensor networks to test input/ouput pairs. We include state initialization as 0 or 1 kets and the expected state likewise as 0 or 1 bras. Optimal contraction ordering may enable testing on large-qubit but small-depth circuits that would be out of reach to a naive statevector simulator. 

In [None]:
from matplotlib import pyplot as plt
fig, ax = plt.subplots(2, 2, figsize=(7, 5))

tn_generator = cq_testing.yield_test_tensor_networks(gate, test_inputs, test_outputs)
for tn, ax in zip(tn_generator, ax.reshape(-1)):
    tn.tn.draw(fix=tn.fix, color=['0', '1'], ax=ax, show_tags=False)
    ax.axis('off')
    
fig.tight_layout()

### Contraction
The tensor network will contract to `1` for a given input/output pair if the circuit is correct.

In [None]:
tn_generator = cq_testing.yield_test_tensor_networks(gate, test_inputs, test_outputs)
for tn in tn_generator:
    amp = tn.tn.contract()
    correct = np.isclose(amp, 1, atol=1e-8)
    print(tn.input_str, '->', tn.output_str, 'Check ✓' if correct else 'FAIL')   

## Efficient decomposition


The `And` specialization of the Toffli gate permits a specialized decomposition that minimizes the `T`-gate count.

In [None]:
c2 = cirq.Circuit(cirq.decompose_once(operation))
SVGCircuit(c2)

In [None]:
input_states = [(a, b, 0) for a, b in itertools.product([0, 1], repeat=2)]
output_states = [(a, b, a & b) for a, b, _ in input_states]


for inp, out in zip(input_states, output_states):
    result = cirq.Simulator().simulate(c2, initial_state=inp)
    print(inp, '->', result.dirac_notation())
    assert result.dirac_notation()[1:-1] == "".join(str(x) for x in out)

## Multi-Control

Using a recursive definition, we can implement an And on > 2 control qubits.

In [None]:
mc_gate = And(cv=(1,1,1,1))
mc_r = mc_gate.registers
mc_quregs = mc_r.get_named_qubits()
mc_operation = mc_gate.on_registers(**mc_quregs)
mc_circuit = cirq.Circuit(mc_operation)
SVGCircuit(mc_circuit)

In [None]:
SVGCircuit(cirq.Circuit(cirq.decompose_once(mc_operation)))

In [None]:
mc_test_inputs = cq_testing.get_classical_inputs(
    variable_registers=[mc_r['control']], 
    fixed_registers={
        mc_r['target']: 0, 
        mc_r['ancilla']: 0
    }
)

mc_test_outputs = mc_gate.apply_classical(mc_test_inputs)
mc_test_outputs

In [None]:
tn_generator = cq_testing.yield_test_tensor_networks(mc_gate, mc_test_inputs, mc_test_outputs)
for tn in tn_generator:
    amp = tn.tn.contract()
    correct = np.isclose(amp, 1, atol=1e-8)
    print(tn.input_str, '->', tn.output_str, 'Check ✓' if correct else 'FAIL')   

In [None]:
from matplotlib import pyplot as plt
fig, ax = plt.subplots(2, 2, figsize=(7, 5))

tn_generator = cq_testing.yield_test_tensor_networks(mc_gate, mc_test_inputs, mc_test_outputs)
for tn, ax in zip(tn_generator, ax.reshape(-1)):
    tn.tn.draw(fix=tn.fix, color=['0', '1'], ax=ax, show_tags=False)
    ax.axis('off')
    
fig.tight_layout()

## Uncompute

We can save even more `T` gates when "uncomputing" an And operation, i.e. performing the adjoint operation by using classical control.

In [None]:
inv_operation = operation ** -1
inv_circuit = cirq.Circuit(inv_operation)
SVGCircuit(inv_circuit)

We reset our target using measurement and fix up phases depending on the result of that measurement:

In [None]:
inv_c2 = cirq.Circuit(cirq.decompose_once(inv_operation))
inv_c2

## Test Adjoint

In [None]:
input_states = [(a, b, a & b) for a, b in itertools.product([0, 1], repeat=2)]
output_states = [(a, b, 0) for a, b, _ in input_states]

for inp, out in zip(input_states, output_states):
    result = cirq.Simulator().simulate(inv_circuit, initial_state=inp)
    print(inp, '->', result.dirac_notation())
    assert result.dirac_notation()[1:-1] == "".join(str(x) for x in out)