# Bell State Preparation experiment: Compute $\langle \bar{\Phi^+} | \bar{Z}\bar{Z} | \bar{\Phi^+}\rangle$ with color code

In [1]:
from typing import List, Dict
import itertools
import functools
import numpy as np
from tqdm import tqdm
import cirq
import stim
import stimcirq

from mitiq import PauliString

from encoded.diagonalize import (
    get_stabilizer_matrix_from_paulis,
    get_measurement_circuit,
    get_measurement_circuit_triangular_color_code,
    get_paulis_from_stabilizer_matrix,
)

## Set parameters

In [54]:
distance: int = 7                       # Color code distance.
nshots = 100_000                        # Number of samples/shots
depth = 0                               # Number of folded Bell state preparation circuits for added noise
k = 2                                   # Number of logical qubits.

noise_rate = 0.001
noise = cirq.depolarize
simulator = stimcirq.StimSampler()

## Helper functions

In [55]:
def int_to_bin_list(x, length):
    result = [0] * length
    for i in range(length):
        if x & (1 << i):
            result[length - 1 - i] = 1
    return result

def int_to_bin_str(x, length):
    return functools.reduce(lambda a, b: a + b, [str(b) for b in int_to_bin_list(x, length)])

# Expectation of pauli on bitstring measured in diagonal basis.
def compute_expectation(
    pauli: cirq.PauliString,
    counts: Dict[str, int],
) -> float:
    if pauli is cirq.PauliString():
        return 1.0

    expectation = 0.0

    indices = [q.x for q in pauli.qubits]
    for key, value in counts.items():
        key = list(map(int, list(key[::-1])))
        expectation += (-1) ** sum([key[i] for i in indices]) * value

    return expectation / sum(counts.values())

def measure_observable(
    pauli: cirq.PauliString,
    bitstring: str,
) -> float:
    if pauli is cirq.PauliString():
        return 1.0

    indices = [q.x for q in pauli.qubits]
    bitstring = list(map(int, list(bitstring[::-1])))
    return (-1) ** sum([bitstring[i] for i in indices])


def strs_to_paulis(pauli_strs : List[str]) -> List[cirq.PauliString]:
    stab_list = []
    for stab_str in pauli_strs:
        stab_list.append(PauliString(stab_str)._pauli)
    return stab_list

def generate_stabilizer_elements(generators: List[cirq.PauliString]) -> List[cirq.PauliString]:
    elements = []
    for string in itertools.chain.from_iterable(itertools.combinations(generators, r) for r in range(len(generators) + 1)):
        elements.append(
            functools.reduce(lambda a, b: a * b, string, cirq.PauliString())
        )
    return elements


def get_lst_ev(counts, observables, stabilizers):
    k = len(list(counts)[0]) // n  # TODO: Input n or define in function.
    nshots = sum(counts.values())

    numerator = 0
    for bitstring, count in counts.items():
        m = 1
        for i in range(k):
            qubit_ev = 0
            for ob in observables:
                qubit_ev += measure_observable(ob, bitstring[i*n : (i+1)*n]) / len(observables)
            m *= qubit_ev
        numerator += m * count / nshots

    denominator = 0
    for bitstring, count in counts.items():
        m = 1
        for i in range(k):
            qubit_ev = 0
            for stab in stabilizers:
                qubit_ev += measure_observable(stab, bitstring[i*n : (i+1)*n]) / len(stabilizers)
            m *= qubit_ev
        denominator += m * count / nshots
        
    return float(np.real_if_close(numerator / denominator))


In [56]:
import dataclasses


@dataclasses.dataclass
class Tile:
    qubits: list
    color: str


def make_color_code_tiles(*, base_data_width):
    if not (base_data_width % 2 == 1 and base_data_width >= 3):
        raise ValueError(f"{base_data_width=} wasn't an odd number at least as large as 3.")
    w = base_data_width * 2 - 1

    def is_in_bounds(q: complex) -> bool:

        if q.imag < 0:

            return False
        if q.imag * 2 > q.real * 3:

            return False
        if q.imag * 2 > (w - q.real) * 3:

            return False
        return True


    tiles = []
    hexagon_offsets = [-1, +1j, +1j + 1, +2, -1j + 1, -1j]
    for x in range(1, w, 2):
        for y in range((x // 2) % 2, w, 2):
            q = x + 1j * y

            tile = Tile(
                color=['red', 'green', 'blue'][y % 3],

            
                qubits=[
                    q + d
                    for d in hexagon_offsets
                    if is_in_bounds(q + d)
                ],
            )

            if len(tile.qubits) in [4, 6]:
                tiles.append(tile)

    return tiles

def get_stabilizer_generators(distance: int):
    tiles = make_color_code_tiles(base_data_width=distance)
    all_qubits = {q for tile in tiles for q in tile.qubits}

    # Only difference here is with the chromobius notebook that we rever
    sorted_qubits = reversed(sorted(all_qubits, key=lambda q: (q.imag, q.real)))
    q2i = {q: i for i, q in enumerate(sorted_qubits)}

    sorted_tiles = []
    for tile in tiles:
        sorted_tiles.append([q2i[q] for q in tile.qubits])

    stabilizers_x = []
    stabilizers_z = []
    for tile in sorted_tiles:
        stab_x = ""
        stab_z = ""
        for i in range(int((3*distance**2+1)/4)):
            if i in tile:
                stab_x+="X"
                stab_z+="Z"
            else:
                stab_x+="I"
                stab_z+="I"
        stabilizers_x.append(stab_x)
        stabilizers_z.append(stab_z)

    return stabilizers_x + stabilizers_z

def stabilizers_to_encoder(stabilizers) -> stim.Circuit:
    
    tableau = stim.Tableau.from_stabilizers(
        stabilizers,
        allow_underconstrained=True,
    )
    # Note: Look at https://github.com/quantumlib/Stim/blob/main/doc/python_api_reference_vDev.md
    # For the different method of encoding

    return tableau.to_circuit(method='graph_state')

### Run unmitigated experiment

In [57]:
qreg = cirq.LineQubit.range(k)

circuit = cirq.Circuit()
circuit.append(cirq.H.on(qreg[0]))
for i in range(len(qreg)-1):
    circuit.append(cirq.CNOT.on(qreg[i], qreg[i+1]))

circuit = circuit.with_noise(noise(noise_rate))
circuit.append(cirq.measure(*qreg, key="z"))
print(circuit)
counts = simulator.run(circuit, repetitions=nshots).histogram(key="z")
counts = {int_to_bin_str(key, k) : val for key, val in counts.items()}
print(counts)

ev = compute_expectation(PauliString("ZZ")._pauli, counts)

print(ev)

0: ───H───D(0.001)[<virtual>]───@───D(0.001)[<virtual>]───M('z')───
                                │                         │
1: ───────D(0.001)[<virtual>]───X───D(0.001)[<virtual>]───M────────
{'00': 49666, '11': 50132, '10': 104, '01': 98}
0.99596


## Run encoded experiment

In [58]:
generator_strs = get_stabilizer_generators(distance)

In [59]:
n = len(generator_strs[0])

observable = PauliString("Z" * n)._pauli

qreg = cirq.LineQubit.range(n * k)

stabilizer_generators = strs_to_paulis(generator_strs)
stabilizer_matrix = get_stabilizer_matrix_from_paulis(stabilizer_generators, qreg[:n])

In [60]:
row_combination = tuple("XXXZXZXXXZXZXZXZXZXXXZXZXZXZXZXZZZZZZ")

In [61]:
from itertools import product

combs = []

numq = len(stabilizer_matrix) // 2 # number of qubits
nump = len(stabilizer_matrix[0]) # number of paulis
z_matrix = stabilizer_matrix.copy()[:numq]
x_matrix = stabilizer_matrix.copy()[numq:]

# Find a combination of rows to make X matrix have rank nump
for row_combination in [row_combination]:
    candidate_matrix = np.array([
        z_matrix[i] if c=="Z" else x_matrix[i] for i, c in enumerate(row_combination)
    ])
    if np.linalg.matrix_rank(candidate_matrix) == nump:
        combs.append(row_combination)

In [62]:
len(combs)

1

In [63]:
def get_measurement_circuit_tcc(stabilizer_matrix):
    numq = len(stabilizer_matrix) // 2 # number of qubits
    nump = len(stabilizer_matrix[0]) # number of paulis
    z_matrix = stabilizer_matrix.copy()[:numq]
    x_matrix = stabilizer_matrix.copy()[numq:]

    measurement_circuit = cirq.Circuit()
    qreg = cirq.LineQubit.range(numq)

    # Find a combination of rows to make X matrix have rank nump
    for row_combination in [combs[0]]:
        # ('X', 'X', 'X', 'Z', 'X', 'Z', 'X', 'X', 'X', 'Z', 'X', 'Z', 'X', 'Z', 'X', 'Z', 'Z', 'Z', 'Z')
        candidate_matrix = np.array([
            z_matrix[i] if c=="Z" else x_matrix[i] for i, c in enumerate(row_combination)
        ])

        # Apply Hadamards to swap X and Z rows to transform X matrix to have rank nump
        if np.linalg.matrix_rank(candidate_matrix) == nump:
            print("Row combination:", row_combination)
            for i, c in enumerate(row_combination):
                if c == "Z":
                    z_matrix[i] = x_matrix[i]
                    measurement_circuit.append(cirq.H.on(qreg[i]))
            x_matrix = candidate_matrix
            break
    
    # print("X matrix")
    # print(x_matrix)
    for j in range(nump):
        # print("j =", j)
        # print("x_matrix =")
        # print(x_matrix)
        if x_matrix[j,j] == 0:
            i = j + 1
            while True:
                if np.isclose(x_matrix[i,j], 0.0):
                    i += 1
                else:
                    break
                # print(f"x_matrix[i, j] = x_matrix[{i}, {j}] =", x_matrix[i, j])
                # i += 1

            x_row = x_matrix[i].copy()
            x_matrix[i] = x_matrix[j]
            x_matrix[j] = x_row

            z_row = z_matrix[i].copy()
            z_matrix[i] = z_matrix[j]
            z_matrix[j] = z_row

            measurement_circuit.append(cirq.SWAP.on(qreg[j], qreg[i]))

        for i in range(j + 1, numq):
            if x_matrix[i,j] == 1:
                x_matrix[i] = (x_matrix[i] + x_matrix[j]) % 2
                z_matrix[j] = (z_matrix[j] + z_matrix[i]) % 2

                measurement_circuit.append(cirq.CNOT.on(qreg[j], qreg[i]))

    for j in range(nump-1, 0, -1):
        for i in range(j):
            if x_matrix[i, j] == 1:
                x_matrix[i] = (x_matrix[i] + x_matrix[j]) % 2
                z_matrix[j] = (z_matrix[j] + z_matrix[i]) % 2

                measurement_circuit.append(cirq.CNOT.on(qreg[j], qreg[i]))

    for i in range(nump):
        if z_matrix[i,i] == 1:
            z_matrix[i,i] = 0
            measurement_circuit.append(cirq.S.on(qreg[i]))
        
        for j in range(i):
            if z_matrix[i,j] == 1:
                z_matrix[i,j] = 0
                z_matrix[j,i] = 0
                measurement_circuit.append(cirq.CZ.on(qreg[j], qreg[i]))

    for i in range(nump):
        row = x_matrix[i].copy()
        x_matrix[i] = z_matrix[i]
        z_matrix[i] = row

        measurement_circuit.append(cirq.H.on(qreg[i]))

    return measurement_circuit, np.concatenate((z_matrix, x_matrix))

In [64]:
m_circuit, transformed_matrix = get_measurement_circuit_tcc(stabilizer_matrix)

Row combination: ('X', 'X', 'X', 'Z', 'X', 'Z', 'X', 'X', 'X', 'Z', 'X', 'Z', 'X', 'Z', 'X', 'Z', 'X', 'Z', 'X', 'X', 'X', 'Z', 'X', 'Z', 'X', 'Z', 'X', 'Z', 'X', 'Z', 'X', 'Z', 'Z', 'Z', 'Z', 'Z', 'Z')


In [65]:
measurement_circuit = cirq.Circuit.concat_ragged(
    m_circuit,
    m_circuit.transform_qubits(dict(zip(qreg[:n], qreg[n:])))
)
measurement_circuit

In [66]:
transformed_generators = get_paulis_from_stabilizer_matrix(transformed_matrix)

In [67]:
transformed_observable = observable.conjugated_by(m_circuit**-1)

In [68]:
# print(transformed_generators)
# stabilizer_elements = generate_stabilizer_elements(transformed_generators)
# observable_elements = [transformed_observable * stab for stab in stabilizer_elements]
# for ob in observable_elements: print(ob)
# print(len(stabilizer_elements), len(observable_elements))

In [69]:
encoding_circuit = stimcirq.stim_circuit_to_cirq_circuit(stabilizers_to_encoder([stim.PauliString(s) for s in generator_strs]))
encoding = cirq.Circuit.concat_ragged(
    encoding_circuit,
    encoding_circuit.transform_qubits(dict(zip(qreg[:n], qreg[n:])))
)
encoding

In [70]:
encoding = cirq.Circuit.concat_ragged(
    encoding_circuit,
    encoding_circuit.transform_qubits({qreg[i]: qreg[i + n] for i in range(n)}),
)

# prepare Bell state
encoding.append(cirq.Moment(cirq.H.on_each(qreg[:n])))
encoding.append(cirq.Moment(cirq.CNOT.on_each([(qreg[i], qreg[i+n]) for i in range(n)])))


circuit = encoding + measurement_circuit
circuit = circuit.with_noise(noise(noise_rate))  # TODO: Consider also a perfect measurement circuit (no noise).
circuit.append(cirq.measure(qreg, key="z"))
circuit

In [71]:
counts = simulator.run(circuit, repetitions=nshots).histogram(key="z")

In [72]:
counts = {int_to_bin_str(key, n * k) : val for key, val in counts.items()}

In [74]:
# TODO: Save counts, observable elements, and stabilizer elements to disk.

## Direct serial calculation

In [75]:
# ev = get_lst_ev(counts, tqdm(observable_elements), tqdm(stabilizer_elements))
# print(ev)

## Sampling observables/stabilizers

In [101]:
nsample = 500_000

In [102]:
"""Sample generator powers to produce elements."""
sampled_obs_elements = []
sampled_stabilizer_elements = []

ngenerators = len(transformed_generators)
for _ in range(nsample):
    powers = np.random.choice([0, 1], ngenerators)
    stabilizer_element = functools.reduce(
        lambda a, b: a * b,
        [g if power == 1 else cirq.PauliString() for g, power in zip(transformed_generators, powers)]
    )
    sampled_obs_elements.append(stabilizer_element * transformed_observable)
    sampled_stabilizer_elements.append(stabilizer_element)

In [None]:
"""Sample directly from computed elements - can only be done if number of stabilizer elements is small."""
# indices = np.random.choice(
#     list(range(len(observable_elements))), size=nsample, replace=False
# )
# sampled_obs_elements = [observable_elements[i] for i in indices]

# indices_den = np.random.choice(
#     list(range(len(stabilizer_elements))), size=nsample, replace=False
# )
# sampled_stabilizer_elements = [stabilizer_elements[i] for i in indices]

'Sample directly from computed elements - can only be done if number of stabilizer elements is small.'

In [None]:
ev_sampled = get_lst_ev(counts, tqdm(sampled_obs_elements), tqdm(sampled_stabilizer_elements))
ev_sampled

100%|██████████| 1000/1000 [00:00<00:00, 58093.66it/s]
100%|██████████| 1000/1000 [26:48<00:00,  1.61s/it]   


0.11974859411180984

## Parallel calculation

In [103]:
from joblib import Parallel, delayed

In [104]:
njobs: int = 20

In [105]:
def process(bitstring, count, elements):
    n_elements = len(elements)

    value = 0.0
    m = 1
    for i in range(k):
        string = bitstring[i * n: (i + 1) * n]
        values = [measure_observable(pauli, string) for pauli in elements]
        qubit_ev = sum(values)
        m *= qubit_ev / n_elements
    value += m * count
    return value / nshots

In [106]:
numerator = Parallel(njobs)(
    delayed(process)(bitstring, count, sampled_obs_elements) for bitstring, count in counts.items()
)
np.sum(numerator)

KeyboardInterrupt: 

In [None]:
denominator = Parallel(njobs)(
    delayed(process)(bitstring, count, sampled_stabilizer_elements) for bitstring, count in counts.items()
)
np.sum(denominator)

1.1276462000000003e-07

In [None]:
len(sampled_stabilizer_elements) == len(sampled_obs_elements)

True

In [None]:
np.sum(numerator) / np.sum(denominator)
# Run 1: 1.3120205774766738 (100k samples)
# Run 2: 1.32812745699848 (100k samples)
# Run 3:   (500k samples)

1.32812745699848