In [146]:
import torch
import numpy as np
import stim
from typing import List
from dataclasses import dataclass

In [147]:
@dataclass
class DEM_Matrices:
    check_matrix: torch.Tensor
    logical_matrix: torch.Tensor
    priors: torch.Tensor

def DEM_to_matrices(DEM: stim.DetectorErrorModel) -> DEM_Matrices:

    priors = np.zeros(DEM.num_errors)
    check_matrix = np.zeros((DEM.num_detectors, DEM.num_errors))
    logical_matrix = np.zeros((DEM.num_observables, DEM.num_errors))
    
    e = 0
    
    for instruction in DEM.flattened():
        
        if instruction.type == "error":
            
            detectors: List[int] = []
            logicals: List[int] = []
            t: stim.DemTarget
            p = instruction.args_copy()[0]
            for t in instruction.targets_copy():
                if t.is_relative_detector_id():
                    detectors.append(t.val)
                elif t.is_logical_observable_id():
                    logicals.append(t.val)

            priors[e] = p
            check_matrix[detectors, e] = 1
            logical_matrix[logicals, e] = 1
            
            e += 1
            
        elif instruction.type == "detector":
            pass
        elif instruction.type == "logical_observable":
            pass
        else:
            raise NotImplementedError()
    
    check_matrix = torch.tensor(check_matrix, dtype=torch.int)
    logical_matrix = torch.tensor(logical_matrix, dtype=torch.int)
    priors = torch.tensor(priors, dtype=torch.float32)
    
    priors = torch.log((1 - priors) / priors)
    
    return DEM_Matrices(
        check_matrix=check_matrix,
        logical_matrix=logical_matrix,
        priors=priors
    )

In [148]:
circuit = stim.Circuit.generated(
                "repetition_code:memory",
                rounds=3,
                distance=3,
                after_clifford_depolarization=0.01,
                after_reset_flip_probability=0.01,
                before_measure_flip_probability=0.01,
                before_round_data_depolarization=0.01)

dem = circuit.detector_error_model(decompose_errors=True)
sampler = dem.compile_sampler()
syndromes, logical_flips, errors = sampler.sample(shots=2, return_errors=True)

syndromes = torch.from_numpy(syndromes).int()
syndromes = torch.reshape(syndromes, (len(syndromes), len(syndromes[0]), 1))
logical_flips = torch.from_numpy(logical_flips).int()
errors = torch.from_numpy(errors).int()

matrices = DEM_to_matrices(dem)

print('Check matrix')
print(matrices.check_matrix)
print('')
print('Logical matrix')
print(matrices.logical_matrix)
print('')
print('Priors')
print(matrices.priors)
print('')
print('Syndromes')
print(syndromes)
print('')
print('Logical flips')
print(logical_flips)
print('')
print('Errors')
print(errors)

Check matrix
tensor([[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0],
        [0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0],
        [0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
         0, 0, 0],
        [0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,
         0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,
         1, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0,
         0, 0, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1,
         1, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1,
         0, 1, 1]], dtype=torch.int32)

Logical matrix
tensor([[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
         0, 1, 1]], dtype=torch.int32)

Priors
tensor([3.9378, 3.9378,

In [149]:
weights_llr = []
weights_vc = []
marg_weights_llr = []
marg_weights_vc = []

layers = 2

for _ in range(layers):
    weights_vc.append(torch.ones_like(matrices.check_matrix, dtype=float))
    weights_llr.append(torch.ones_like(matrices.priors, dtype=float))
        
    marg_weights_vc.append(torch.ones_like(matrices.check_matrix, dtype=float))
    marg_weights_llr.append(torch.ones_like(matrices.priors, dtype=float))

rhos = torch.ones(layers, dtype=float)
residuals = torch.zeros(layers, dtype=float)

weights_llr = np.array(weights_llr)
weights_vc = np.array(weights_vc)
marg_weights_llr = np.array(marg_weights_llr)
marg_weights_vc = np.array(marg_weights_vc)

weights_llr = torch.from_numpy(np.array(weights_llr)).float()
weights_vc = torch.from_numpy(np.array(weights_vc)).float()
marg_weights_llr = torch.from_numpy(np.array(marg_weights_llr)).float()
marg_weights_vc = torch.from_numpy(np.array(marg_weights_vc)).float()

print('weights_llr')
print(weights_llr)
print('')
print('weights_vc')
print(weights_vc)
print('')
print('marg_weights_llr')
print(marg_weights_llr)
print('')
print('marg_weights_vc')
print(marg_weights_vc)

weights_llr
tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1.]])

weights_vc
tensor([[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
          1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
          1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
          1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
          1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
          1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.

In [150]:
batch_size = 2

loss_array = torch.zeros(batch_size, layers).float()

In [151]:
messages_en_to_dn = torch.zeros((batch_size, matrices.check_matrix.size()[0], matrices.check_matrix.size()[1]))
messages_dn_to_en = torch.zeros((batch_size, matrices.check_matrix.size()[0], matrices.check_matrix.size()[1]))

messages_en_to_dn += matrices.check_matrix * matrices.priors * weights_llr[0]
weighted_messages = messages_dn_to_en * weights_vc[0]
messages_en_to_dn += torch.sum(weighted_messages, dim=1, keepdim=True)
messages_en_to_dn *= matrices.check_matrix
messages_en_to_dn -= messages_dn_to_en

messages_en_to_dn

tensor([[[3.9378, 3.9378, 3.7775, 5.2284, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 5.9216, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000],
         [0.0000, 3.9378, 0.0000, 0.0000, 3.7775, 3.9378, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 5.9216, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 3.7775, 0.0000, 0.0000, 0.0000, 4.4168, 4.4168,
          3.7775, 5.2284, 5.9216, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 5.9216, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 5.2284, 3.7775, 0.0000, 0.0000, 4.4168,
          0.0000, 0.0000, 0.0000, 3.7775, 4.4168, 5.9216, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 5.9216, 0.0000, 0.0000,
       

In [152]:
divider = torch.tanh(messages_en_to_dn/2)
divider[divider==0] = 1
messages_dn_to_en = 2*torch.atanh(torch.prod(divider, dim=2, keepdim=True) / divider)
multiplicator = torch.pow(-1, syndromes)
multiplicator = multiplicator*matrices.check_matrix
multiplicator
# messages_dn_to_en *= multiplicator
# messages_dn_to_en

tensor([[[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0],
         [0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0],
         [0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,
          0, 0, 0, 0],
         [0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0,
          0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0,
          0, 1, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0,
          0, 0, 0, 1],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1,
          1, 1, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0,
          1, 0, 1, 1]],

        [[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0],
         [0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
 

In [153]:
weighted_messages = messages_dn_to_en * marg_weights_vc[0]
beliefs = torch.sum(weighted_messages, dim=1)
beliefs += matrices.priors*marg_weights_llr

In [154]:
beliefs

tensor([[25.5964, 25.9552, 25.8384, 26.7091, 25.8797, 25.6291, 25.9109, 26.0725,
         25.7755, 26.6981, 27.3268, 25.7755, 25.9109, 27.3301, 25.9109, 26.0725,
         26.0356, 26.7291, 27.3215, 25.9614, 25.9109, 27.3215, 25.9796, 26.2357,
         27.3416, 25.9484, 27.3363],
        [25.5964, 25.9552, 25.8384, 26.7091, 25.8797, 25.6291, 25.9109, 26.0725,
         25.7755, 26.6981, 27.3268, 25.7755, 25.9109, 27.3301, 25.9109, 26.0725,
         26.0356, 26.7291, 27.3215, 25.9614, 25.9109, 27.3215, 25.9796, 26.2357,
         27.3416, 25.9484, 27.3363]])

In [155]:
1 - errors

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1]], dtype=torch.int32)

In [156]:
(1-errors) * beliefs

tensor([[25.5964, 25.9552, 25.8384, 26.7091, 25.8797, 25.6291, 25.9109, 26.0725,
         25.7755, 26.6981, 27.3268, 25.7755, 25.9109, 27.3301, 25.9109, 26.0725,
         26.0356, 26.7291, 27.3215, 25.9614, 25.9109, 27.3215, 25.9796, 26.2357,
         27.3416, 25.9484, 27.3363],
        [25.5964, 25.9552, 25.8384, 26.7091, 25.8797, 25.6291, 25.9109, 26.0725,
         25.7755, 26.6981, 27.3268, 25.7755, 25.9109, 27.3301, 25.9109, 26.0725,
         26.0356, 26.7291, 27.3215, 25.9614, 25.9109, 27.3215, 25.9796, 26.2357,
         27.3416, 25.9484, 27.3363]])

In [157]:
softplus = torch.nn.Softplus(beta=1.0, threshold=50)
softplus(beliefs)

tensor([[25.5964, 25.9552, 25.8384, 26.7091, 25.8797, 25.6291, 25.9109, 26.0725,
         25.7755, 26.6981, 27.3268, 25.7755, 25.9109, 27.3301, 25.9109, 26.0725,
         26.0356, 26.7291, 27.3215, 25.9614, 25.9109, 27.3215, 25.9796, 26.2357,
         27.3416, 25.9484, 27.3363],
        [25.5964, 25.9552, 25.8384, 26.7091, 25.8797, 25.6291, 25.9109, 26.0725,
         25.7755, 26.6981, 27.3268, 25.7755, 25.9109, 27.3301, 25.9109, 26.0725,
         26.0356, 26.7291, 27.3215, 25.9614, 25.9109, 27.3215, 25.9796, 26.2357,
         27.3416, 25.9484, 27.3363]])

In [158]:
loss = (1 - errors) * beliefs
loss += softplus(beliefs)
loss = torch.sum(loss, dim=1)
loss 

tensor([1421.0270, 1421.0270])

In [159]:
loss_array[:, 0] = loss
loss_array[:, 1] = 1

In [160]:
loss_array

tensor([[1.4210e+03, 1.0000e+00],
        [1.4210e+03, 1.0000e+00]])

In [161]:
rhos = rhos / layers

In [162]:
loss_array * rhos

tensor([[7.1051e+02, 5.0000e-01],
        [7.1051e+02, 5.0000e-01]], dtype=torch.float64)

In [163]:
sigmoid = torch.nn.Sigmoid()
predictions = sigmoid(-beliefs)

In [164]:
beliefs

tensor([[25.5964, 25.9552, 25.8384, 26.7091, 25.8797, 25.6291, 25.9109, 26.0725,
         25.7755, 26.6981, 27.3268, 25.7755, 25.9109, 27.3301, 25.9109, 26.0725,
         26.0356, 26.7291, 27.3215, 25.9614, 25.9109, 27.3215, 25.9796, 26.2357,
         27.3416, 25.9484, 27.3363],
        [25.5964, 25.9552, 25.8384, 26.7091, 25.8797, 25.6291, 25.9109, 26.0725,
         25.7755, 26.6981, 27.3268, 25.7755, 25.9109, 27.3301, 25.9109, 26.0725,
         26.0356, 26.7291, 27.3215, 25.9614, 25.9109, 27.3215, 25.9796, 26.2357,
         27.3416, 25.9484, 27.3363]])

In [165]:
errors

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0]], dtype=torch.int32)

In [166]:
matrices.check_matrix

tensor([[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0],
        [0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0],
        [0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
         0, 0, 0],
        [0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,
         0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,
         1, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0,
         0, 0, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1,
         1, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1,
         0, 1, 1]], dtype=torch.int32)

In [167]:
errors.T.size()

torch.Size([27, 2])

In [168]:
matrices.check_matrix.size()

torch.Size([8, 27])

In [169]:
e  = matrices.check_matrix @ errors.T
e

tensor([[0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0]], dtype=torch.int32)

In [170]:
e = torch.sum(e, dim=0)
e

tensor([0, 0])

In [171]:
syndromes

tensor([[[0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0]],

        [[0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0],
         [0]]], dtype=torch.int32)

In [172]:
messages_dn_to_en

tensor([[[2.9882, 2.9882, 3.0577, 2.7417, 2.6621, 2.6621, 2.6621, 2.6621,
          2.6621, 2.6621, 2.7011, 2.6621, 2.6621, 2.6621, 2.6621, 2.6621,
          2.6621, 2.6621, 2.6621, 2.6621, 2.6621, 2.6621, 2.6621, 2.6621,
          2.6621, 2.6621, 2.6621],
         [2.7417, 3.1005, 2.7417, 2.7417, 3.1786, 3.1005, 2.7417, 2.7417,
          2.7417, 2.7417, 2.7417, 2.7417, 2.7417, 2.7840, 2.7417, 2.7417,
          2.7417, 2.7417, 2.7417, 2.7417, 2.7417, 2.7417, 2.7417, 2.7417,
          2.7417, 2.7417, 2.7417],
         [2.5199, 2.5199, 2.8527, 2.5199, 2.5199, 2.5199, 2.6815, 2.6815,
          2.8527, 2.5885, 2.5536, 2.5199, 2.5199, 2.5199, 2.5199, 2.5199,
          2.5199, 2.5199, 2.5536, 2.5199, 2.5199, 2.5199, 2.5199, 2.5199,
          2.5199, 2.5199, 2.5199],
         [2.5199, 2.5199, 2.5199, 2.5885, 2.8527, 2.5199, 2.5199, 2.6815,
          2.5199, 2.5199, 2.5199, 2.8527, 2.6815, 2.5536, 2.5199, 2.5199,
          2.5199, 2.5199, 2.5199, 2.5199, 2.5199, 2.5536, 2.5199, 2.5199,
       

In [173]:
residuals

tensor([0., 0.], dtype=torch.float64)