# Bell State Preparation experiment: Compute $\langle \bar{\Phi^+} | \bar{Z}\bar{Z} | \bar{\Phi^+}\rangle$ with color code

In [1]:
from typing import List, Dict
import itertools
import functools
import numpy as np
from tqdm import tqdm
import cirq
import stim
import stimcirq

from mitiq import PauliString

from encoded.diagonalize import get_stabilizer_matrix_from_paulis, get_measurement_circuit, get_paulis_from_stabilizer_matrix

## Set parameters

In [2]:
distance: int = 3                       # Color code distance.
nshots = 100_000                        # Number of samples/shots
depth = 0                               # Number of folded Bell state preparation circuits for added noise
k = 2                                   # Number of logical qubits.

noise_rate = 0.001
noise = cirq.depolarize
simulator = stimcirq.StimSampler()

## Helper functions

In [3]:
def int_to_bin_list(x, length):
    result = [0] * length
    for i in range(length):
        if x & (1 << i):
            result[length - 1 - i] = 1
    return result

def int_to_bin_str(x, length):
    return functools.reduce(lambda a, b: a + b, [str(b) for b in int_to_bin_list(x, length)])

# Expectation of pauli on bitstring measured in diagonal basis.
def compute_expectation(
    pauli: cirq.PauliString,
    counts: Dict[str, int],
) -> float:
    if pauli is cirq.PauliString():
        return 1.0

    expectation = 0.0

    indices = [q.x for q in pauli.qubits]
    for key, value in counts.items():
        key = list(map(int, list(key[::-1])))
        expectation += (-1) ** sum([key[i] for i in indices]) * value

    return expectation / sum(counts.values())

def measure_observable(
    pauli: cirq.PauliString,
    bitstring: str,
) -> float:
    if pauli is cirq.PauliString():
        return 1.0

    indices = [q.x for q in pauli.qubits]
    bitstring = list(map(int, list(bitstring[::-1])))
    return (-1) ** sum([bitstring[i] for i in indices])


def strs_to_paulis(pauli_strs : List[str]) -> List[cirq.PauliString]:
    stab_list = []
    for stab_str in pauli_strs:
        stab_list.append(PauliString(stab_str)._pauli)
    return stab_list

def generate_stabilizer_elements(generators: List[cirq.PauliString]) -> List[cirq.PauliString]:
    elements = []
    for string in itertools.chain.from_iterable(itertools.combinations(generators, r) for r in range(len(generators) + 1)):
        elements.append(
            functools.reduce(lambda a, b: a * b, string, cirq.PauliString())
        )
    return elements

def get_lst_ev(counts, observables, stabilizers):
    k = len(list(counts)[0]) // n  # TODO: Input n or define in function.
    nshots = sum(counts.values())

    numerator = 0
    for bitstring, count in counts.items():
        m = 1
        for i in range(k):
            qubit_ev = 0
            for ob in observables:
                qubit_ev += measure_observable(ob, bitstring[i*n : (i+1)*n]) / len(observables)
            m *= qubit_ev
        numerator += m * count / nshots

    denominator = 0
    for bitstring, count in counts.items():
        m = 1
        for i in range(k):
            qubit_ev = 0
            for stab in stabilizers:
                qubit_ev += measure_observable(stab, bitstring[i*n : (i+1)*n]) / len(stabilizers)
            m *= qubit_ev
        denominator += m * count / nshots
        
    return float(np.real_if_close(numerator / denominator))


In [4]:
import dataclasses


@dataclasses.dataclass
class Tile:
    qubits: list
    color: str


def make_color_code_tiles(*, base_data_width):
    if not (base_data_width % 2 == 1 and base_data_width >= 3):
        raise ValueError(f"{base_data_width=} wasn't an odd number at least as large as 3.")
    w = base_data_width * 2 - 1

    def is_in_bounds(q: complex) -> bool:

        if q.imag < 0:

            return False
        if q.imag * 2 > q.real * 3:

            return False
        if q.imag * 2 > (w - q.real) * 3:

            return False
        return True


    tiles = []
    hexagon_offsets = [-1, +1j, +1j + 1, +2, -1j + 1, -1j]
    for x in range(1, w, 2):
        for y in range((x // 2) % 2, w, 2):
            q = x + 1j * y

            tile = Tile(
                color=['red', 'green', 'blue'][y % 3],

            
                qubits=[
                    q + d
                    for d in hexagon_offsets
                    if is_in_bounds(q + d)
                ],
            )

            if len(tile.qubits) in [4, 6]:
                tiles.append(tile)

    return tiles

def get_stabilizer_generators(distance: int):
    tiles = make_color_code_tiles(base_data_width=distance)
    all_qubits = {q for tile in tiles for q in tile.qubits}

    # Only difference here is with the chromobius notebook that we rever
    sorted_qubits = reversed(sorted(all_qubits, key=lambda q: (q.imag, q.real)))
    q2i = {q: i for i, q in enumerate(sorted_qubits)}

    sorted_tiles = []
    for tile in tiles:
        sorted_tiles.append([q2i[q] for q in tile.qubits])

    stabilizers_x = []
    stabilizers_z = []
    for tile in sorted_tiles:
        stab_x = ""
        stab_z = ""
        for i in range(int((3*distance**2+1)/4)):
            if i in tile:
                stab_x+="X"
                stab_z+="Z"
            else:
                stab_x+="I"
                stab_z+="I"
        stabilizers_x.append(stab_x)
        stabilizers_z.append(stab_z)

    return stabilizers_x + stabilizers_z

def stabilizers_to_encoder(stabilizers) -> stim.Circuit:
    
    tableau = stim.Tableau.from_stabilizers(
        stabilizers,
        allow_underconstrained=True,
    )
    # Note: Look at https://github.com/quantumlib/Stim/blob/main/doc/python_api_reference_vDev.md
    # For the different method of encoding

    return tableau.to_circuit(method='graph_state')

### Run unmitigated experiment

In [12]:
qreg = cirq.LineQubit.range(k)

circuit = cirq.Circuit()
circuit.append(cirq.H.on(qreg[0]))
for i in range(len(qreg)-1):
    circuit.append(cirq.CNOT.on(qreg[i], qreg[i+1]))

circuit = circuit.with_noise(noise(noise_rate))
circuit.append(cirq.measure(*qreg, key="z"))
print(circuit)
counts = simulator.run(circuit, repetitions=nshots).histogram(key="z")
counts = {int_to_bin_str(key, k) : val for key, val in counts.items()}
print(counts)

ev = measure_observable(PauliString("ZZ")._pauli, counts)

print(ev)

0: ───H───D(0.001)[<virtual>]───@───D(0.001)[<virtual>]───M('z')───
                                │                         │
1: ───────D(0.001)[<virtual>]───X───D(0.001)[<virtual>]───M────────
{'00': 49780, '11': 49992, '10': 119, '01': 109}


TypeError: unhashable type: 'slice'

## Run encoded experiment

In [6]:
generator_strs = get_stabilizer_generators(distance)

n = len(generator_strs[0])

observable = PauliString("Z" * n)._pauli

qreg = cirq.LineQubit.range(n * k)

stabilizer_generators = strs_to_paulis(generator_strs)
stabilizer_matrix = get_stabilizer_matrix_from_paulis(stabilizer_generators, qreg[:n])
m_circuit, transformed_matrix = get_measurement_circuit(stabilizer_matrix)
measurement_circuit = cirq.Circuit.concat_ragged(
    m_circuit,
    m_circuit.transform_qubits(dict(zip(qreg[:n], qreg[n:])))
)

transformed_generators = get_paulis_from_stabilizer_matrix(transformed_matrix)
# print(transformed_generators)
stabilizer_elements = generate_stabilizer_elements(transformed_generators)

transformed_observable = observable.conjugated_by(m_circuit**-1)
observable_elements = [transformed_observable * stab for stab in stabilizer_elements]
# for ob in observable_elements: print(ob)
print(len(stabilizer_elements), len(observable_elements))

64 64


In [7]:
encoding_circuit = stimcirq.stim_circuit_to_cirq_circuit(stabilizers_to_encoder([stim.PauliString(s) for s in generator_strs]))
encoding = cirq.Circuit.concat_ragged(
    encoding_circuit,
    encoding_circuit.transform_qubits(dict(zip(qreg[:n], qreg[n:])))
)
encoding

In [8]:
encoding = cirq.Circuit.concat_ragged(
    encoding_circuit,
    encoding_circuit.transform_qubits({qreg[i]: qreg[i + n] for i in range(n)}),
)

# prepare Bell state
encoding.append(cirq.Moment(cirq.H.on_each(qreg[:n])))
encoding.append(cirq.Moment(cirq.CNOT.on_each([(qreg[i], qreg[i+n]) for i in range(n)])))


circuit = encoding + measurement_circuit
circuit = circuit.with_noise(noise(noise_rate))
circuit.append(cirq.measure(qreg, key="z"))
circuit

In [9]:
counts = simulator.run(circuit, repetitions=nshots).histogram(key="z")

In [10]:
counts = {int_to_bin_str(key, n * k) : val for key, val in counts.items()}

In [11]:
ev = get_lst_ev(counts, tqdm(observable_elements), tqdm(stabilizer_elements))
print(ev)

100%|██████████| 64/64 [00:00<00:00, 21793.90it/s]
100%|██████████| 64/64 [00:01<00:00, 56.69it/s]


0.9784160338034602
