In this experiment, we decode Shor's nine-qubit quantum error correcting code.
We show decoding of the Shor's nine-qubit code using Dephasing DMRG, which is our own built-in DMRG-like optimisation algorithm to solve the main component problem which in its turn is the problem of finding a computational basis state cotributing the most to a given state.

In [199]:
import numpy as np
import qecstruct as qec
from mdopt.mps.utils import marginalise, create_custom_product_state
from mdopt.contractor.contractor import mps_mpo_contract
from mdopt.optimiser.utils import (
    COPY_RIGHT,
    SWAP,
    XOR_BULK,
    XOR_LEFT,
    XOR_RIGHT,
)
from examples.decoding.decoding import (
    linear_css_code_checks,
    css_code_logicals,
    get_css_code_constraint_sites,
    get_css_code_logicals_sites,
    apply_parity_constraints,
    apply_bias_channel,
)

In [1]:
from functools import reduce
from typing import Iterable
import numpy as np
import pytest
from opt_einsum import contract

from mdopt.mps.utils import (
    create_state_vector,
    mps_from_dense,
    inner_product,
    is_canonical,
    find_orth_centre,
    create_simple_product_state,
)
from mdopt.contractor.contractor import apply_one_site_operator, mps_mpo_contract
from mdopt.mps.canonical import CanonicalMPS
from mdopt.utils.utils import mpo_from_matrix

In [10]:
for _ in range(100):
    num_sites, phys_dim = 8, 2
    psi = create_state_vector(num_sites=num_sites, phys_dim=phys_dim)
    mps = mps_from_dense(psi, phys_dim=phys_dim, form="Right-canonical")
    mps_copy = mps.copy()

    sites_all = list(range(num_sites))
    sites_to_marginalise = []
    for site in sites_all:
        if np.random.uniform() < 1 / 2:
            sites_to_marginalise.append(site)
    sites_left = [site for site in sites_all if site not in sites_to_marginalise]

    mps_marginalised = mps_copy.marginal(sites_to_marginalise, canonicalise=False)
    mps_marginalised_canonical = mps.marginal(sites_to_marginalise, canonicalise=True)

    with pytest.raises(ValueError):
        mps.marginal([100, 200])

    if isinstance(mps_marginalised, CanonicalMPS):
        assert mps_marginalised.num_sites == len(sites_left)
        assert is_canonical(mps_marginalised_canonical)
    else:
        assert isinstance(mps_marginalised, np.complex128)

Let us first import the code from `qecstruct` and take a look at it.

In [200]:
code = qec.shor_code()
code

X stabilizers:
[0, 1, 2, 3, 4, 5]
[3, 4, 5, 6, 7, 8]
Z stabilizers:
[0, 1]
[1, 2]
[3, 4]
[4, 5]
[6, 7]
[7, 8]

This quantum error correcting code is defined on 9 physical qubits and has 2 logical operators. This means we will need $9*2 + 2 = 20$ sites in our MPS.

In [201]:
num_sites = 2 * len(code) + code.num_x_logicals() + code.num_z_logicals()
assert num_sites == 20

Now, let us define the initial state. First of all we will check that no error will constitute no correction. This means starting from the all-zeros state followed by decoding will return all-zeros state for the logical operators (the final logical operator will be identity operator).

In [202]:
string_state = "11" + "0" * (num_sites-4) + "++"
error_state = create_custom_product_state(string=string_state, form="Right-canonical")

Here, we get the sites where the checks will be applied. We will need to construct MPOs using this data.

In [203]:
checks_x, checks_z = linear_css_code_checks(code)
print("X checks:")
for check in checks_x:
    print(check)
print("Z checks:")
for check in checks_z:
    print(check)

X checks:
[ 0  2  4  6  8 10]
[ 6  8 10 12 14 16]
Z checks:
[1 3]
[3 5]
[7 9]
[ 9 11]
[13 15]
[15 17]


These lists mention only the sites where we will apply the XOR constraints. However, the MPOs will also consist of other tensors, such as SWAPs (wire crossings) and boundary XOR constraints. In what follows we define the list of these auxiliary tensors and the corresponding sites where they reside.

In [204]:
tensors_constraints = [XOR_LEFT, XOR_BULK, SWAP, XOR_RIGHT]
csscode_constraint_sites = get_css_code_constraint_sites(code)
print("Full X-check lists of sites:")
for string in csscode_constraint_sites[0]:
    print(string)
print("Full Z-check lists of sites:")
for string in csscode_constraint_sites[1]:
    print(string)

Full X-check lists of sites:
[[0], [2, 4, 6, 8], [1, 3, 5, 7, 9], [10]]
[[6], [8, 10, 12, 14], [7, 9, 11, 13, 15], [16]]
Full Z-check lists of sites:
[[1], [], [2], [3]]
[[3], [], [4], [5]]
[[7], [], [8], [9]]
[[9], [], [10], [11]]
[[13], [], [14], [15]]
[[15], [], [16], [17]]


Let us now take a look at the logical operators.

In [205]:
print(code.x_logicals_binary())
print(code.z_logicals_binary())

[0, 1, 2]

[0, 3, 6]



We need to again translate it to our MPO language by changing the indices since we add the logical-operator sites to the end of the MPS.

In [206]:
print(css_code_logicals(code)[0])
print(css_code_logicals(code)[1])

[ 1  7 13]
[0 2 4]


Now goes the same operation of adding sites where auxiliary tensors go.

In [207]:
print(get_css_code_logicals_sites(code)[0])
print(get_css_code_logicals_sites(code)[1])

[[1], [7, 13], [2, 3, 4, 5, 6, 8, 9, 10, 11, 12, 14, 15, 16, 17], [18]]
[[0], [2, 4], [1, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18], [19]]


In [208]:
tensors_logicals = [XOR_LEFT, XOR_BULK, SWAP, COPY_RIGHT]
constraint_sites_logicals = get_css_code_logicals_sites(code)

Now the fun part, contracting the logical MPOs.

In [209]:
renormalise=True
error_state = apply_bias_channel(basis_mps=error_state, codeword_string=string_state, prob_channel=0.2)
error_state = apply_parity_constraints(
    error_state, csscode_constraint_sites[0], tensors_constraints, renormalise=renormalise
)
error_state = apply_parity_constraints(
    error_state, csscode_constraint_sites[1], tensors_constraints, renormalise=renormalise
)
error_state = apply_parity_constraints(
    error_state, constraint_sites_logicals, tensors_logicals, renormalise=renormalise
)

100%|██████████| 2/2 [00:00<00:00, 309.78it/s]
100%|██████████| 6/6 [00:00<00:00, 1551.91it/s]
100%|██████████| 2/2 [00:00<00:00, 310.23it/s]


Marginalise over the message bits to get the logical operator MPS.

In [210]:
sites_to_marginalise = list(range(2 * len(code)))
logical = marginalise(mps=error_state, sites_to_marginalise=sites_to_marginalise)
print(logical.dense(flatten=True))

[1. 0. 0. 0.]
