# Decoding Quantum Toric Codes

In this experiment, we’ll use ``mdopt`` to compute the threshold of the toric code. Hereafter, we assume an independent noise model as well as perfect syndrome measurements.

In [1]:
import numpy as np
from tqdm import tqdm
import qecstruct as qc
import qecsim.paulitools as pt
import matplotlib.pyplot as plt
from scipy.sparse import hstack, kron, eye, csc_matrix, block_diag

from mdopt.mps.utils import marginalise, create_custom_product_state
from mdopt.contractor.contractor import mps_mpo_contract
from mdopt.optimiser.utils import (
    SWAP,
    COPY_LEFT,
    XOR_BULK,
    XOR_LEFT,
    XOR_RIGHT,
)
from examples.decoding.decoding import (
    css_code_checks,
    css_code_logicals,
    css_code_logicals_sites,
    css_code_constraint_sites,
)
from examples.decoding.decoding import (
    apply_constraints,
    apply_bitflip_bias,
)
from examples.decoding.decoding import (
    pauli_to_mps,
    decode_shor,
)

In [27]:
def repetition_code(n):
    """
    Parity check matrix of a repetition code with length n.
    """
    row_ind, col_ind = zip(*((i, j) for i in range(n) for j in (i, (i + 1) % n)))
    data = np.ones(2 * n, dtype=np.uint8)
    return csc_matrix((data, (row_ind, col_ind)))


def toric_code_x_checks(L):
    """
    Sparse check matrix for the X stabilisers of a toric code with
    lattice size L, constructed as the hypergraph product of
    two repetition codes.
    """
    Hr = repetition_code(L)
    H = hstack(
        [kron(Hr, eye(Hr.shape[1])), kron(eye(Hr.shape[0]), Hr.T)], dtype=np.uint8
    )
    H.data = H.data % 2
    H.eliminate_zeros()
    checks = csc_matrix(H).toarray()
    return [list(np.nonzero(check)[0]) for check in checks]


def toric_code_x_logicals(L):
    """
    Sparse binary matrix with each row corresponding to an X logical operator
    of a toric code with lattice size L. Constructed from the
    homology groups of the repetition codes using the Kunneth
    theorem.
    """
    H1 = csc_matrix(([1], ([0], [0])), shape=(1, L), dtype=np.uint8)
    H0 = csc_matrix(np.ones((1, L), dtype=np.uint8))
    x_logicals = block_diag([kron(H1, H0), kron(H0, H1)])
    x_logicals.data = x_logicals.data % 2
    x_logicals.eliminate_zeros()
    logicals = csc_matrix(x_logicals).toarray()
    return [list(np.nonzero(logical)[0]) for logical in logicals]


def toric_code_constraint_sites(L):
    stabilizers = toric_code_x_checks(L)
    sites_x = [np.nonzero(stabilizers[i])[0] for i in range(len(stabilizers))]
    sites_x = [2 * site + 3 for site in sites_x]

    constraints_strings_x = []

    for sites in sites_x:
        xor_left_sites_x = [sites[0]]
        xor_bulk_sites_x = [sites[i] for i in range(1, len(sites) - 1)]
        xor_right_sites_x = [sites[-1]]

        swap_sites_x = list(range(sites[0] + 1, sites[-1]))
        for k in range(1, len(sites) - 1):
            swap_sites_x.remove(sites[k])

        constraints_strings_x.append(
            [xor_left_sites_x, xor_bulk_sites_x, swap_sites_x, xor_right_sites_x]
        )

    return constraints_strings_x


def toric_code_logicals_sites(L):
    sites = toric_code_x_logicals(L)
    sites_x, sites_z = np.array(sites[0]), np.array(sites[1])
    sites_x = list(2 * sites_x + 3)
    sites_z = list(2 * sites_z + 3)

    copy_site_x = [0]
    copy_site_z = [1]

    xor_right_site_x = [sites_x[-1]]
    xor_right_site_z = [sites_z[-1]]

    xor_bulk_sites_x = [sites_x[i] for i in range(len(sites_x) - 1)]
    xor_bulk_sites_z = [sites_z[i] for i in range(len(sites_z) - 1)]

    swap_sites_x = list(range(copy_site_x[0] + 1, xor_right_site_x[0]))
    swap_sites_x = [site for site in swap_sites_x if site not in xor_bulk_sites_x]
    swap_sites_z = list(range(copy_site_z[0] + 1, xor_right_site_z[0]))
    swap_sites_z = [site for site in swap_sites_z if site not in xor_bulk_sites_z]

    string_x = [copy_site_x, xor_bulk_sites_x, swap_sites_x, xor_right_site_x]
    string_z = [copy_site_z, xor_bulk_sites_z, swap_sites_z, xor_right_site_z]

    return string_x, string_z

In [28]:
toric_code_constraint_sites(L=3)

[[[5], [7], [6, 8], [9]],
 [[3], [5, 7], [4, 6, 8], [9]],
 [[3], [5, 7], [4, 6, 8], [9]],
 [[3], [5, 7], [4, 6, 8], [9]],
 [[3], [5, 7], [4, 6, 8], [9]],
 [[3], [5, 7], [4, 6, 8], [9]],
 [[5], [7], [6, 8], [9]],
 [[3], [5, 7], [4, 6, 8], [9]],
 [[3], [5, 7], [4, 6, 8], [9]]]

In [17]:
sites = toric_code_x_logicals(L=3)
sites_x, sites_z = np.array(sites[0]), np.array(sites[1])
sites_x = list(2 * sites_x + 3)
sites_z = list(2 * sites_z + 3)

copy_site_x = [0]
copy_site_z = [1]

xor_right_site_x = [sites_x[-1]]
xor_right_site_z = [sites_z[-1]]

xor_bulk_sites_x = [sites_x[i] for i in range(len(sites_x) - 1)]
xor_bulk_sites_z = [sites_z[i] for i in range(len(sites_z) - 1)]

swap_sites_x = list(range(copy_site_x[0] + 1, xor_right_site_x[0]))
swap_sites_x = [site for site in swap_sites_x if site not in xor_bulk_sites_x]
swap_sites_z = list(range(copy_site_z[0] + 1, xor_right_site_z[0]))
swap_sites_z = [site for site in swap_sites_z if site not in xor_bulk_sites_z]

In [24]:
toric_code_constraint_sites(L=3)

[[[1], [2], [], [3]],
 [[0], [1, 2], [], [3]],
 [[0], [1, 2], [], [3]],
 [[0], [1, 2], [], [3]],
 [[0], [1, 2], [], [3]],
 [[0], [1, 2], [], [3]],
 [[1], [2], [], [3]],
 [[0], [1, 2], [], [3]],
 [[0], [1, 2], [], [3]]]

In [19]:
toric_code_x_logicals(L=3)

[[0, 1, 2], [9, 12, 15]]

In [18]:
sites_z

[21, 27, 33]

In [16]:
toric_code_logicals_sites(L=3)

([[0], [3, 5], [1, 2, 4, 6], [7]],
 [[1],
  [21, 27],
  [2,
   3,
   4,
   5,
   6,
   7,
   8,
   9,
   10,
   11,
   12,
   13,
   14,
   15,
   16,
   17,
   18,
   19,
   20,
   22,
   23,
   24,
   25,
   26,
   28,
   29,
   30,
   31,
   32],
  [33]])

In [None]:
toric_code_x_logicals(L=3)

In [None]:
[0, 3, 6] -> [3, 9, 15]
[0, 1, 2] -> [2, 4, 6]

In [None]:
code = shor_code()
log_matrix_x = code.z_logicals_binary()
array_x = np.zeros((log_matrix_x.num_rows(), log_matrix_x.num_columns()), dtype=int)
for row, cols in enumerate(log_matrix_x.rows()):
    for col in cols:
        array_x[row, col] = 1

In [None]:
array_x

In [None]:
from qecstruct import shor_code

In [None]:
shor_code().x_logicals_binary()

In [None]:
shor_code().z_logicals_binary()

In [None]:
css_code_logicals(shor_code())

In [None]:
toric_code_x_logicals(3)

In [None]:
def css_code_logicals(code: CssCode):
    """
    Returns the list of MPS sites where the logical constraints should be applied.

    Parameters
    ----------
    code : qec.CssCode
        The CSS code object.

    Returns
    -------
    logicals : Tuple[List[int]]
        List of logical operators, first X, then Z.
    """

    log_matrix_x = code.z_logicals_binary()
    array_x = np.zeros((log_matrix_x.num_rows(), log_matrix_x.num_columns()), dtype=int)
    for row, cols in enumerate(log_matrix_x.rows()):
        for col in cols:
            array_x[row, col] = 1

    log_matrix_z = code.x_logicals_binary()
    array_z = np.zeros((log_matrix_z.num_rows(), log_matrix_z.num_columns()), dtype=int)
    for row, cols in enumerate(log_matrix_z.rows()):
        for col in cols:
            array_z[row, col] = 1

    x_logicals = [
        2 * np.nonzero(row)[0] + code.num_x_logicals() + code.num_z_logicals() + 1
        for row in array_x
    ]
    x_logicals = [list(x_logical) for x_logical in x_logicals]
    z_logicals = [
        2 * np.nonzero(row)[0] + code.num_x_logicals() + code.num_z_logicals()
        for row in array_z
    ]
    z_logicals = [list(z_logical) for z_logical in z_logicals]

    return z_logicals[0], x_logicals[0]

In [None]:
[0, 1, 2] -> [2]

In [None]:
def decode_toric(
    lattice_size: int,
    error: str,
    renormalise: bool = True,
    silent: bool = False,
):
    """
    This function does error-based decoding of the toric code.
    It takes as input an error string and returns the most likely Pauli correction.

    Parameters
    ----------
    lattice_size : int
        The lattice size of the toric code.
    error : str
        The error in a string format.
        The way the decoder takes the error is as follows:
        "X_0 Z_0 X_1 Z_1 ..."
    renormalise : bool
        Whether to renormalise the singular values during contraction.
    silent : bool
        Whether to show the progress bars or not.

    Raises
    ------
    ValueError
        If the error string length does not correspond to the code.
    """

    num_logicals = 2
    num_sites = 2 * lattice_size**2 + num_logicals

    if len(error) != num_sites - num_logicals:
        raise ValueError(
            f"The error length is {len(error)}, expected {num_sites - num_logicals}."
        )

    logicals_state = "+" * num_logicals
    state_string = logicals_state + error
    error_mps = create_custom_product_state(string=state_string)

    constraints_tensors = [XOR_LEFT, XOR_BULK, SWAP, XOR_RIGHT]
    logicals_tensors = [COPY_LEFT, XOR_BULK, SWAP, XOR_RIGHT]

    constraints_sites = toric_code_constraint_sites(lattice_size)
    logicals_sites = toric_code_logicals_sites(lattice_size)
    sites_to_bias = list(range(num_logicals, num_sites))

    error_mps = apply_bitflip_bias(
        mps=error_mps, sites_to_bias=sites_to_bias, renormalise=renormalise
    )

    error_mps = apply_constraints(
        error_mps,
        constraints_sites[0],
        constraints_tensors,
        renormalise=renormalise,
        silent=silent,
    )
    error_mps = apply_constraints(
        error_mps,
        constraints_sites[1],
        constraints_tensors,
        renormalise=renormalise,
        silent=silent,
    )
    error_mps = apply_constraints(
        error_mps,
        logicals_sites,
        logicals_tensors,
        renormalise=renormalise,
        silent=silent,
    )

    sites_to_marginalise = list(range(num_logicals, len(error) + num_logicals))
    logical = marginalise(
        mps=error_mps,
        sites_to_marginalise=sites_to_marginalise,
    ).dense(flatten=True, renormalise=True, norm=1)

    if np.argmax(logical) == 0:
        return "I", logical
    if np.argmax(logical) == 1:
        return "X", logical
    if np.argmax(logical) == 2:
        return "Z", logical
    if np.argmax(logical) == 3:
        return "Y", logical

In [None]:
shor_code().z_logicals_binary()

In [None]:
toric_code_x_logicals(3)

In [None]:
def css_code_logicals(code: CssCode):
    """
    Returns the list of MPS sites where the logical constraints should be applied.

    Parameters
    ----------
    code : qec.CssCode
        The CSS code object.

    Returns
    -------
    logicals : Tuple[List[int]]
        List of logical operators, first X, then Z.
    """

    log_matrix_x = code.z_logicals_binary()
    array_x = np.zeros((log_matrix_x.num_rows(), log_matrix_x.num_columns()), dtype=int)
    for row, cols in enumerate(log_matrix_x.rows()):
        for col in cols:
            array_x[row, col] = 1

    log_matrix_z = code.x_logicals_binary()
    array_z = np.zeros((log_matrix_z.num_rows(), log_matrix_z.num_columns()), dtype=int)
    for row, cols in enumerate(log_matrix_z.rows()):
        for col in cols:
            array_z[row, col] = 1

    x_logicals = [
        2 * np.nonzero(row)[0] + code.num_x_logicals() + code.num_z_logicals() + 1
        for row in array_x
    ]
    x_logicals = [list(x_logical) for x_logical in x_logicals]
    z_logicals = [
        2 * np.nonzero(row)[0] + code.num_x_logicals() + code.num_z_logicals()
        for row in array_z
    ]
    z_logicals = [list(z_logical) for z_logical in z_logicals]

    return z_logicals[0], x_logicals[0]

In [None]:
css_code_logicals_sites(shor_code())[0]

In [None]:
toric_code_x_logicals(L=3)

In [None]:
stabilizers = toric_code_x_stabilisers(L=3)
logicals = toric_code_x_logicals(L=3)

In [None]:
stabilizers

In [None]:
logicals

In [None]:
from examples.decoding.decoding import (
    css_code_logicals,
    css_code_logicals_sites,
    css_code_checks,
)
from qecstruct import shor_code

In [None]:
css_code_checks(shor_code())

In [None]:
toric_code_x_logicals(3)

In [None]:
css_code_logicals(shor_code())

In [None]:
css_code_logicals_sites(shor_code())

In [None]:
toric_code_x_logicals(L=3)

In [None]:
toric_code_logicals_sites(L=3)