# TCC Memory: Scaling physical qubits

## Setup

In [1]:
import numpy as np
from typing import List, Dict
import functools
import itertools
import dataclasses

from joblib import Parallel, delayed
import matplotlib.pyplot as plt
import cirq
from mitiq import PauliString
import stim
import stimcirq

from encoded.diagonalize import get_stabilizer_matrix_from_paulis, get_measurement_circuit_tcc, get_paulis_from_stabilizer_matrix

## Set experiment parameters

In [85]:
distance = 7
nshots = 100_000 # number of repetitions to run circuit

simulator = stimcirq.StimSampler()

noise_rate = 0.0001
noise = cirq.depolarize

PAULI_O = "Z"

## Helper functions

In [86]:
def int_to_bin_list(x, length):
    result = [0] * length
    for i in range(length):
        if x & (1 << i):
            result[length - 1 - i] = 1
    return result

def int_to_bin_str(x, length):
    return functools.reduce(lambda a, b: a + b, [str(b) for b in int_to_bin_list(x, length)])

def compute_expectation(
    pauli: cirq.PauliString,
    counts: Dict[str, int],
) -> float:
    if pauli is cirq.PauliString():
        return 1.0

    expectation = 0.0

    indices = [q.x for q in pauli.qubits]
    for key, value in counts.items():
        key = list(map(int, list(key)))
        expectation += (-1) ** sum([key[i] for i in indices]) * value
    return expectation / sum(counts.values())

def noisy_identity(qreg, depth, noise) -> cirq.Circuit:
    circuit = cirq.Circuit()
    if depth == 0:
        circuit.append(cirq.I.on_each(qreg))
    for _ in range(depth):
        circuit.append(cirq.X.on_each(qreg))

    return circuit.with_noise(noise)

def strs_to_paulis(pauli_strs : List[str]) -> List[cirq.PauliString]:
    stab_list = []
    for stab_str in pauli_strs:
        stab_list.append(PauliString(stab_str)._pauli)
    return stab_list


def measure_observable(
    pauli: cirq.PauliString,
    bitstring: str,
) -> float:
    if pauli is cirq.PauliString():
        return 1.0

    indices = [q.x for q in pauli.qubits]
    bitstring = list(map(int, list(bitstring[::-1])))
    return (-1) ** sum([bitstring[i] for i in indices])


def generate_stabilizer_elements(generators: List[cirq.PauliString]) -> List[cirq.PauliString]:
    elements = []
    for string in itertools.chain.from_iterable(itertools.combinations(generators, r) for r in range(len(generators) + 1)):
        elements.append(
            functools.reduce(lambda a, b: a * b, string, cirq.PauliString())
        )
    return elements

### Logical state preparation / encoding

In [87]:
@dataclasses.dataclass
class Tile:
    qubits: list
    color: str


def make_color_code_tiles(*, base_data_width):
    if not (base_data_width % 2 == 1 and base_data_width >= 3):
        raise ValueError(f"{base_data_width=} wasn't an odd number at least as large as 3.")
    w = base_data_width * 2 - 1

    def is_in_bounds(q: complex) -> bool:

        if q.imag < 0:

            return False
        if q.imag * 2 > q.real * 3:

            return False
        if q.imag * 2 > (w - q.real) * 3:

            return False
        return True


    tiles = []
    hexagon_offsets = [-1, +1j, +1j + 1, +2, -1j + 1, -1j]
    for x in range(1, w, 2):
        for y in range((x // 2) % 2, w, 2):
            q = x + 1j * y

            tile = Tile(
                color=['red', 'green', 'blue'][y % 3],

            
                qubits=[
                    q + d
                    for d in hexagon_offsets
                    if is_in_bounds(q + d)
                ],
            )

            if len(tile.qubits) in [4, 6]:
                tiles.append(tile)

    return tiles

def get_stabilizer_generators(distance: int):
    tiles = make_color_code_tiles(base_data_width=distance)
    all_qubits = {q for tile in tiles for q in tile.qubits}

    # Only difference here is with the chromobius notebook that we rever
    sorted_qubits = reversed(sorted(all_qubits, key=lambda q: (q.imag, q.real)))
    q2i = {q: i for i, q in enumerate(sorted_qubits)}

    sorted_tiles = []
    for tile in tiles:
        sorted_tiles.append([q2i[q] for q in tile.qubits])

    stabilizers_x = []
    stabilizers_z = []
    for tile in sorted_tiles:
        stab_x = ""
        stab_z = ""
        for i in range(int((3*distance**2+1)/4)):
            if i in tile:
                stab_x+="X"
                stab_z+="Z"
            else:
                stab_x+="I"
                stab_z+="I"
        stabilizers_x.append(stab_x)
        stabilizers_z.append(stab_z)

    return stabilizers_x + stabilizers_z

def stabilizers_to_encoder(stabilizers) -> stim.Circuit:
    
    tableau = stim.Tableau.from_stabilizers(
        stabilizers,
        allow_underconstrained=True,
    )
    # Note: Look at https://github.com/quantumlib/Stim/blob/main/doc/python_api_reference_vDev.md
    # For the different method of encoding

    return tableau.to_circuit(method='graph_state')

## Run physical experiment

## Run encoded experiment

In [88]:
generator_strs = get_stabilizer_generators(distance)
n = len(generator_strs[0])

print(generator_strs)
observable = PauliString(PAULI_O*n)._pauli

['IIIIIIIIIIIIIIIIIIIIIIIIIIIIXXIIIIIXX', 'IIIIIIIIIIIIIIIIIIXIIIIXIIIIXXIIIIIII', 'IIIIIIIIIIIIIIIIIIIIIIXXIIIXXIIIIIXXI', 'IIIIIIIIIIIIXXIIIXXIIIXXIIIIIIIIIIIII', 'IIIIIIXIIXIIXXIIIIIIIIIIIIIIIIIIIIIII', 'IIIIIIIIIIIIIIIIIIIIIIIIIIXXIIIIIXXII', 'IIIIIIIIIIIIIIIIXXIIIXXIIIXXIIIIIIIII', 'IIIIIIIIXXIXXIIIXXIIIIIIIIIIIIIIIIIII', 'IIXXIXXIXXIIIIIIIIIIIIIIIIIIIIIIIIIII', 'XXXXIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII', 'IIIIIIIIIIIIIIIIIIIIXXIIIXXIIIIIXXIII', 'IIIIIIIIIIXXIIIXXIIIXXIIIIIIIIIIIIIII', 'IIIIXXIXXIXXIIIIIIIIIIIIIIIIIIIIIIIII', 'IXXIXXIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII', 'IIIIIIIIIIIIIIIIIIIIIIIIXXIIIIIXXIIII', 'IIIIIIIIIIIIIIXXIIIXXIIIXXIIIIIIIIIII', 'IIIIIIIXIIXIIIXXIIIIIIIIIIIIIIIIIIIII', 'IIIIIIIIIIIIIIIIIIIXIIIIXIIIIIXXIIIII', 'IIIIIIIIIIIIIIIIIIIIIIIIIIIIZZIIIIIZZ', 'IIIIIIIIIIIIIIIIIIZIIIIZIIIIZZIIIIIII', 'IIIIIIIIIIIIIIIIIIIIIIZZIIIZZIIIIIZZI', 'IIIIIIIIIIIIZZIIIZZIIIZZIIIIIIIIIIIII', 'IIIIIIZIIZIIZZIIIIIIIIIIIIIIIIIIIIIII', 'IIIIIIIIIIIIIIIIIIIIIIIIIIZZIIIIIZZII', 'IIIIIIIIIIIIII

In [89]:
qreg = cirq.LineQubit.range(n)
stabilizer_generators = strs_to_paulis(generator_strs)
stabilizer_matrix = get_stabilizer_matrix_from_paulis(stabilizer_generators, qreg)


encoding_circuit = stimcirq.stim_circuit_to_cirq_circuit(stabilizers_to_encoder([stim.PauliString(s) for s in generator_strs]))
print(encoding_circuit)

                                ┌──┐                       ┌──┐                   ┌──┐                   ┌──┐                       ┌──┐                           ┌──┐                               ┌──┐           ┌──┐                               ┌──┐                       ┌──┐                       ┌──┐                   ┌──┐                   ┌──┐                               ┌──┐                               ┌──┐       ┌──┐                   ┌──┐       ┌──┐               ┌──┐               ┌──┐       ┌──┐                       ┌──┐               ┌──┐       ┌──┐   ┌──┐       ┌──┐   ┌──┐   ┌──┐       ┌──┐           ┌──┐   ┌──┐
0: ────RX───@───@───@───@───@────@───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

In [90]:
measurement_circuit, transformed_matrix = get_measurement_circuit_tcc(stabilizer_matrix, distance)
transformed_generators = get_paulis_from_stabilizer_matrix(transformed_matrix)
# stabilizer_elements = generate_stabilizer_elements(transformed_generators)
# for stab in stabilizer_elements:
#     print(stab)

transformed_observable = observable.conjugated_by(measurement_circuit**-1)
# observable_elements = [transformed_observable * stab for stab in stabilizer_elements]
print(transformed_observable)

Z(q(18))*Z(q(21))*Z(q(23))*Z(q(26))*Z(q(29))*Z(q(32))*Z(q(36))


In [91]:
circuit = encoding_circuit + measurement_circuit
circuit = circuit.with_noise(noise(noise_rate))
circuit.append(cirq.measure(qreg, key="z"))
circuit

In [92]:
counts = simulator.run(circuit, repetitions=nshots).histogram(key="z")

In [93]:
counts = {int_to_bin_str(key, n) : val for key, val in counts.items()}

## Post-processing

In [94]:
nsample = 10_000

In [95]:
"""Sample generator powers to produce elements."""
sampled_obs_elements = []
sampled_stabilizer_elements = []

ngenerators = len(transformed_generators)
for _ in range(nsample):
    powers = np.random.choice([0, 1], ngenerators)
    stabilizer_element = functools.reduce(
        lambda a, b: a * b,
        [g if power == 1 else cirq.PauliString() for g, power in zip(transformed_generators, powers)]
    )
    sampled_obs_elements.append(stabilizer_element * transformed_observable)
    sampled_stabilizer_elements.append(stabilizer_element)

In [96]:
njobs: int = 16

In [97]:
def process(bitstring, count, elements):
    n_elements = len(elements)

    values = [measure_observable(pauli, bitstring) for pauli in elements]
    qubit_ev = sum(values)
    m = qubit_ev / n_elements
    return m * count / nshots

In [98]:
numerator = Parallel(njobs)(
    delayed(process)(bitstring, count, sampled_obs_elements) for bitstring, count in counts.items()
)
np.sum(numerator)

0.49710465

In [99]:
denominator = Parallel(njobs)(
    delayed(process)(bitstring, count, sampled_stabilizer_elements) for bitstring, count in counts.items()
)
np.sum(denominator)

0.49756916999999995

In [100]:
len(sampled_stabilizer_elements) == len(sampled_obs_elements)

True

In [101]:
np.sum(numerator) / np.sum(denominator)

0.9990664212575712