### Virtual Distillation

This notebook provides an implementation of virtual distillation error mitigation, as described by: https://arxiv.org/pdf/2011.07064

Virtual distillation is an error mitigation which can leverage M copies of a state $\rho$ to surpress the error term. Virtual distillation describes the approximation of the error-free expectation value of an operator $O$ as:

$$<O>_{corrected} = \dfrac{Tr(O\rho^M)}{Tr(\rho^M)}$$

As described in the paper, we make use of the following equality:
$$Tr(O\rho^M) = Tr(O^{\textbf{i}}S^{(M)}\rho^{\otimes M})$$

This equation allows us to not calculate $\rho^M$, but instead use $M$ copies of $\rho$.
Hence we can implement the pseudocode as seen in algorithm 1 in the paper. Note that this notebook only provides an implementation for $M=2$

In [13]:
import cirq
import mitiq
import numpy as np

M = 2

## Bell state
The following code provides a bell state circuit as our $\rho$. However any N-qubit circuit works here. The circuits are in cirq package format

In [15]:
def bell_state():
    '''
    This function returns a circuit that prepares a Bell state in cirq circuit format.
    '''

    circuit = cirq.Circuit()
    qubits = cirq.LineQubit.range(3)
    circuit.append(cirq.H(qubits[0]))
    circuit.append(cirq.CNOT(qubits[0], qubits[1]))

    return circuit

## M copies of $\rho$
If no copies of $\rho$ are provided, the following function can be used.

In [16]:
def M_copies_of_rho(rho: cirq.Circuit, M: int=2):
    '''
    Given a circuit rho that acts on N qubits, this function returns a circuit that copies rho M times in parallel.
    This means the resulting circuit has N * M qubits.
    '''
    
    if M <= 1:
        print("warning: M_copies_of_rho is not needed for M <= 1")
        return rho

    N = len(rho.all_qubits())

    circuit = cirq.Circuit()
    qubits = cirq.LineQubit.range(N*M)

    for i in range(M):
        circuit += rho.transform_qubits(lambda q: qubits[q.x + N*i])

    return circuit

In [16]:
# Test the M_copies_of_rho function
circuit = bell_state()
print(M_copies_of_rho(circuit, 3))

0: ───H───@───────────────────
          │
1: ───────X───────────────────

2: ───────────H───@───────────
                  │
3: ───────────────X───────────

4: ───────────────────H───@───
                          │
5: ───────────────────────X───


## Applying swaps
A copy in this context is the specific copy of $\rho$ out of the M copies in the entire circuit. \
As can be seen in the paper, we need to allow easy access to coupling qubit n of copy 1 with qubit n of copy 2 for any n $\in$ [1,N] \
This access is done by performing a series of SWAP operations such that this pattern results in the circuit where qubit n of copy 1 is stacked above qubit n of copy 2.\
The SWAPs are returned as a list of tuples which store the indices of the qubits that have to be swapped in order. 

In [18]:
# This algorithm only works for M = 2
def generate_swaps(l: list) -> list[tuple]:

    if len(l) % 2 != 0:
        raise ValueError("The list must have an even number of elements, since M=2")

    N = len(l) // 2

    if sorted(l) != list(range(0,2*N)):
        raise ValueError("The list must contain all the integers from 0 to 2*N-1")

    correct_list = []
    for i in range(N):
        correct_list.append(i)
        correct_list.append(i+N)

    swaps = []
    for index, value in enumerate(correct_list):
        if l[index] != value:
            l_index = l.index(value)
            l[index], l[l_index] = l[l_index], l[index]
            swaps.append((index, l_index))


    return swaps

# applies swaps to check if the generate swaps algorithm works
def apply_swaps(swaps_list: list[tuple], list_to_permute: list[int]) -> list[int]:

    permuted_list = list_to_permute.copy()
    for swap in swaps_list:
        permuted_list[swap[0]], permuted_list[swap[1]] = permuted_list[swap[1]], permuted_list[swap[0]]

    return permuted_list

In [12]:
# Testing the function
# [0,1,2,3,4,5] should map to [0,3,1,4,2,5]
# [0,1,2,3,4,5,6,7] should map to [0,4,1,5,2,6,3,7]
swaps_1 = generate_swaps([0,1,2,3,4,5])
swaps_2 = generate_swaps([0,1,2,3,4,5,6,7])

print(apply_swaps(swaps_1, [0,1,2,3,4,5]))
print(apply_swaps(swaps_2, [0,1,2,3,4,5,6,7]))

[0, 3, 1, 4, 2, 5]
[0, 4, 1, 5, 2, 6, 3, 7]


## The algorithm
With everything prepared, we can apply the algorithm. \
This example uses the bell state as $\rho$. As operator $O$, the pauli Z is chosen. \
Every operator has a $B_i$ unitary that has to be applied, as can be seen in the paper, we define $B_i$ for the pauli Z operator

In [22]:
K = 100
print(f"We run {K} reps which means we need M*K = {M*K} copies of rho")

# let the circuit be 2 copies of bell state
N = len(bell_state().all_qubits())
rho = M_copies_of_rho(bell_state(), M)

# Bi corresponding to unitary operator O
Bi_gate = np.array([
        [1, 0, 0, 0],
        [0, np.sqrt(2)/2, -np.sqrt(2)/2, 0],
        [0, np.sqrt(2)/2, np.sqrt(2)/2, 0],
        [0, 0, 0, 1]
    ])

Ei = [0 for _ in range(N)]
D = 0
    
for _ in range(K):
    
    circuit = rho.copy()

    # 1) apply swaps
    swaps = generate_swaps(list(range(2*N)))
    for swap in swaps:
        circuit.append(cirq.SWAP(cirq.LineQubit(swap[0]), cirq.LineQubit(swap[1])))

    # 2) apply Bi^(2)
    unitary = Bi_gate
    B_gate = cirq.MatrixGate(unitary)
    for i in range(0,N+2,2):
        circuit.append(B_gate(cirq.LineQubit(i), cirq.LineQubit(i+1)))

    # 3) apply measurements
    for i in range(2*N):
        circuit.append(cirq.measure(cirq.LineQubit(i), key=f"{i}"))
    
    # print(circuit)

    # run the circuit
    simulator = cirq.Simulator()
    result = simulator.run(circuit, repetitions=1)
    

    # post processing measurements
    z1 = []
    z2 = []
    

    for i in range(2*N):
        if i % 2 == 0:
            z1.append(np.squeeze(result.records[str(i)]))
        else:
            z2.append(np.squeeze(result.records[str(i)]))


    for i in range(N):
        
        productE = 1
        for j in range(N):
            if i != j:
                productE *= ( 1 + z1[j] + z2[j] + z1[j]*z2[j] )

        Ei[i] += 1/2**N * (z1[i] + z2[i]) * productE

    productD = 1
    for j in range(N):
        productD *= ( 1 + z1[j] + z2[j] + z1[j]*z2[j] )
    D += 1/2**N * productD 
    
Z_i_corrected = [Ei[i] / D for i in range(N)]
print('Z_i_corrected: ', Z_i_corrected)


We run 100 reps which means we need M*K = 200 copies of rho
Z_i_corrected:  [0.48641304347826086, 0.48641304347826086]
