In [2]:
import torch
import torch.nn as nn
from torch import tensor
import stim
import numpy as np
from typing import List
from dataclasses import dataclass

In [18]:
@dataclass
class DEM_Matrices:
    check_matrix: torch.Tensor
    logical_matrix: torch.Tensor
    priors: torch.Tensor

def DEM_to_matrices(DEM: stim.DetectorErrorModel) -> DEM_Matrices:

    priors = np.zeros(DEM.num_errors)
    check_matrix = np.zeros((DEM.num_detectors, DEM.num_errors))
    logical_matrix = np.zeros((DEM.num_observables, DEM.num_errors))
    
    e = 0
    
    for instruction in DEM.flattened():
        
        if instruction.type == "error":
            
            detectors: List[int] = []
            logicals: List[int] = []
            t: stim.DemTarget
            p = instruction.args_copy()[0]
            for t in instruction.targets_copy():
                if t.is_relative_detector_id():
                    detectors.append(t.val)
                elif t.is_logical_observable_id():
                    logicals.append(t.val)

            priors[e] = p
            check_matrix[detectors, e] = 1
            logical_matrix[logicals, e] = 1
            
            e += 1
            
        elif instruction.type == "detector":
            pass
        elif instruction.type == "logical_observable":
            pass
        else:
            raise NotImplementedError()
    
    check_matrix = torch.tensor(check_matrix, dtype=torch.int)
    logical_matrix = torch.tensor(logical_matrix, dtype=torch.int)
    priors = torch.tensor(priors, dtype=torch.float32)
    
    priors = tensor.log( (1 - priors) / priors )
    
    return DEM_Matrices(
        check_matrix=check_matrix,
        logical_matrix=logical_matrix,
        priors=priors
    )

In [20]:
def generate_M(n):
    
    M = np.zeros((2*n, 2*n), dtype=int)
    M[:n, n:] = np.eye(n, dtype=int)
    M[n:, :n] = np.eye(n, dtype=int)

    M = torch.tensor(M, dtype=torch.int)
    
    return M

In [None]:
class Nbp_decoder(nn.Module):
    
    def __init__(self,
                 circuit: stim.Circuit,
                 layers: int = 20,
                 weights: str = None,
                 batch_size: int = 1):
        
        super().__init__()
        
        self.device = 'cpu'
        
        self.matrices = DEM_to_matrices( circuit.detector_error_model(decompose_errors=False) )
        
        self.H = self.matrices.check_matrix
        self.L = self.matrices.logical_matrix
        self.p = self.matrices.priors
        
        # self.M = generate_M(len(self.H[0])/2)
        # self.H_perp = self.H
        
        self.layers = layers
        self.batch_size = batch_size
        
        if not weights:
            self.ini_weight_as_one()
        else:
            self.load_weights(weights, self.device)
            
    def load_weights(self, path):
        pass
    
    def save_weights(self, path):
        pass
        
    def ini_weight_as_one(self):
        
        self.weights_llr = []
        self.weights_vc = []
        self.residual_weights = []
        self.rhos = []
        
        for _ in range(self.layers):
            self.weights_vc.append(torch.ones_like(self.H, requires_grad=True, device=self.device))
            self.weights_llr.append(torch.ones_like(self.priors, requires_grad=True, device=self.device))
        
        self.residual_weights.append(torch.ones_like(self.layers, requires_grad=True, device=self.device))
        self.rhos.append(torch.ones_like(self.layers, requires_grad=True, device=self.device))    
        self.weights_vc.append(torch.ones_like(self.check_matrix, requires_grad=True, device=self.device))
        self.weights_llr.append(torch.ones_like(self.priors, requires_grad=True, device=self.device))
    
    def unsqueeze_batches(self, tensor: torch.Tensor) -> torch.Tensor:
        
        if tensor.dim() == 3:
            return tensor
        elif tensor.dim() == 2:
            return torch.unsqueeze(tensor, dim=0)
        
    def mu_vc(self, mu_cv, weights_vc, weights_llr):
        
        llr = self.priors*weights_llr
        
        vc_msg = self.check_matrix * llr
        vc_msg *= weights_vc
        vc_msg += torch.sum(mu_cv, dim=0, keepdim=True)
        
        vc_msg = vc_msg * self.check_matrix
        vc_msg -= mu_cv*weights_vc
        
        return vc_msg
    
    def mu_cv(self, mu_vc, syndrome):
        
        divide = torch.tanh(mu_vc/2)
        divide[divide==0] = 1
        cv_msg = 2*torch.atanh(torch.prod(divide, dim=1, keepdim=True) / divide)
        cv_msg = cv_msg * self.check_matrix
        cv_msg = cv_msg * torch.pow(-1, syndrome)
        
        return cv_msg
    
    def log1pexp(self, x):
        m = nn.Softplus(beta=1, threshold=50)
        return m(x)
    
    def marginalise(self, mu_cv, marg_weights_llr, marg_weights_vc):
        
        mu_cv *= marg_weights_vc
        belief = torch.sum(mu_cv, dim=0) + self.priors*marg_weights_llr
        
        return belief
    
    def loss_function(self, prediction, error):
        
        recovery = ( prediction + error ) % 2
        loss = torch.sum(torch.matmul(torch.matmul(self.orthogonal_check_matrix, self.M), recovery))
        loss = torch.abs(torch.sin(np.pi / 2 * loss))
        loss = torch.sum(loss)
        
        return loss
    
    def decode(self, error, syndrome, batch_size = 1):
        
        loss_array = torch.zeros(self.batch_size, self.layers).float().to(self.device)
        
        vc_msg = self.check_matrix * self.priors * self.weights_llr[0]
        
        for i in range(1, self.layers):
            
            cv_msg = self.mu_cv(vc_msg, syndrome)
            vc_msg = self.mu_vc(cv_msg, self.weights_vc[i], self.weights_llr[i])
            
            belief = self.marginalise(cv_msg)
            
            loss_array[:, i] = self.loss_function( (1 / self.log1pexp(belief)), error ) 
            
        loss = 1/self.layers * torch.sum(loss_array)
        
        return loss
            