# Load functions

In [8]:
import pymatching
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

from soft_info import get_repcode_IQ_map, llh_ratio

def process_string(input_str, verbose=False):
    # Step 1: Invert the order of the string
    reversed_str = input_str[::-1]
    if verbose:
        print("Reversed str:", reversed_str)

    # Step 2: Separate the last part of the string
    last_part = reversed_str.split(" ")[-1]
    if verbose:
        print("Count str:", last_part)
    
    # Step 3: Perform XOR operations on the last part
    xor_result = ''.join([str((int(last_part[i]) + int(last_part[i + 1])) % 2) for i in range(len(last_part) - 1)])
    if verbose:
        print("XOR result:", xor_result)
    
    # Step 4: Remove the remaining spaces in the first part of the string
    first_part = ''.join(reversed_str.split(" ")[:-1])
    if verbose:
        print("First part:", first_part)
    
    # Step 5: Separate each bit of the string into a list of a NumPy array
    numpy_list = np.array([int(bit) for bit in first_part]+ [int(bit) for bit in xor_result])
    if verbose:
        print("Numpy list:", numpy_list)
    
    return numpy_list


def reweight_edges_to_one(matching: pymatching.Matching):
    for edge in matching.edges():
        src_node, tgt_node, edge_data = edge
        fault_ids = edge_data.get('fault_ids', set())
        error_probability = edge_data.get('error_probability', -1.0)
        
        if tgt_node is None:
            matching.add_boundary_edge(src_node, weight=1, fault_ids=fault_ids, 
                              error_probability=error_probability, merge_strategy="replace")
        else:
            matching.add_edge(src_node, tgt_node, weight=1, fault_ids=fault_ids, 
                          error_probability=error_probability, merge_strategy="replace")


def soft_reweight_pymatching(matching : pymatching.Matching,  d : int, T : int, IQ_data, 
                             kde_dict: dict, layout : list, scaler_dict : dict,
                             p_data : float = None, p_meas : float = None, common_measure = None,
                             verbose : bool = False):

    p_data = p_data if p_data is not None else 6.836e-3  # Sherbrooke median
    p_meas = p_meas if p_meas is not None else 0

    if layout is not None:
        qubit_mapping = get_repcode_IQ_map(layout, T)

    for edge in matching.edges():
        src_node, tgt_node, edge_data = edge
        if verbose:
            print("\nEdge:", (src_node, tgt_node))
        fault_ids = edge_data.get('fault_ids', set())
        error_probability = edge_data.get('error_probability', -1.0)
        
        if tgt_node is None:  # always second pose None
            # Boundary edge (logical on it)
            new_weight = -np.log(p_data / (1 - p_data))

            if common_measure is not None:
                new_weight = round(new_weight / common_measure) * common_measure
                
            matching.add_boundary_edge(src_node, weight=new_weight, fault_ids=fault_ids, 
                              error_probability=error_probability, merge_strategy="replace")
            if verbose:
                print("Boundary edge weight: ", new_weight)

            _has_time_component = False
            continue
        elif tgt_node == src_node + 1:  # always first pos the smaller
            # Data edge
            new_weight = -np.log(p_data / (1 - p_data))
            if common_measure is not None:
                new_weight = round(new_weight / common_measure) * common_measure
            if verbose:
                print("Data edge weight: ", new_weight)
        elif tgt_node == src_node + (d-1):
            # Time edge
            #TODO implement adding a new edge for hard meas flip
            new_weight = 0 #-np.log(p_meas / (1 - p_meas))
            _has_time_component = True
            if verbose:
                print("Time edge weight: ", new_weight)
        elif tgt_node == src_node + (d-1) + 1:
            # mixed edge
            # TODO implement adding a new DIAG edge for hard meas flip
            new_weight = -np.log(p_data / (1 - p_data))# - np.log(p_meas / (1 - p_meas))
            _has_time_component = True
            if verbose:
                print("Mixed edge weight: ", new_weight)

        if _has_time_component: 
            #Structure of IQ data = [link_0, link_1, link_3, link_0, link_1, .., code_qubit_1, ...]
            # equivalent to       = [node_0, node_1, node_3, node_4, node_5, .. ]
            # =>
            IQ_point = IQ_data[src_node]
            layout_qubit_idx = qubit_mapping[src_node]
            kde_0, kde_1 = kde_dict.get(layout_qubit_idx, (None, None))
            scaler = scaler_dict.get(layout_qubit_idx, None)
            llh_weight = llh_ratio(IQ_point, kde_0, kde_1, scaler)
            
            if verbose:
                print("LLH weight: ", llh_weight)
        
            new_weight += llh_weight

            # Round the weights to common measure
            if common_measure is not None:
                new_weight = round(new_weight / common_measure) * common_measure

        # Update the edge weight
        matching.add_edge(src_node, tgt_node, weight=new_weight, fault_ids=fault_ids, 
                          error_probability=error_probability, merge_strategy="replace")



def draw_matching_graph(matching, d, T):
    G = nx.Graph()
    pos = {}
    edge_colors = []
    
    for edge in matching.edges():
        src_node, tgt_node, edge_data = edge
        if tgt_node is not None:
            G.add_edge(src_node, tgt_node, weight=edge_data['weight'])
            if edge_data.get('fault_ids'):
                edge_colors.append('r')
            else:
                edge_colors.append('k')
        
        x_src = src_node % (d-1)
        y_src = src_node // (d-1)
        pos[src_node] = (x_src, -y_src)
    
    nx.draw(G, pos, with_labels=True, node_color='white', edge_color=edge_colors, font_weight='bold', node_size=700, font_size=18)
    
    edge_weights = nx.get_edge_attributes(G, 'weight')
    labels = {k: f"{v:.2f}" for k, v in edge_weights.items()}
    
    nx.draw_networkx_edge_labels(G, pos, edge_labels=labels)

    for edge in matching.edges():
        src_node, tgt_node, edge_data = edge
        if tgt_node is None:
            x_src = src_node % (d-1)
            y_src = src_node // (d-1)
            color = 'r' if edge_data.get('fault_ids') == set() else 'k'
            weight_text = f"{edge_data.get('weight'):.2f}"
            if x_src == 0:
                plt.plot([x_src, x_src - 0.5], [-y_src, -y_src], color=color)
                plt.text(x_src - 0.3, -y_src + 0.05, weight_text)
            elif x_src == d - 2:
                plt.plot([x_src, x_src + 0.5], [-y_src, -y_src], color=color)
                plt.text(x_src + 0.2, -y_src + 0.05, weight_text)

    nx.draw_networkx_nodes(G, pos, node_color='skyblue', node_size=700)
    
    plt.show()      


# Load data

In [9]:
%reload_ext autoreload
%autoreload 2

In [10]:
from result_saver import SaverProvider
provider = SaverProvider()

In [16]:
import numpy as np
from Scratch import metadata_loader

md = metadata_loader(_extract=True, _drop_inutile=True)
md = md[md["job_status"] == "JobStatus.DONE"]
md = md[md["notebook_name"] == "bigger_rep_codes"]
max_distance = int(max(md.distance))
max_distance = 30
md = md[md["distance"] == max_distance]
md = md.sort_values(by='backend_name', ascending=False)

md = md[:2]

md

Unnamed: 0,creation_date,notebook_name,backend_name,job_id,tags,shots,tags_xp,sampled_state,num_qubits,job_status,extra,optimization_level,code,distance,rounds,logical,layout,descr
76,2023-10-29 14:47:58.814875+01:00,bigger_rep_codes,ibm_sherbrooke,cmz653m3r3vg008wf9j0,[],1111.0,,,,JobStatus.DONE,,,RepetitionCodeCircuit,30.0,30,1,_is_hex=True,Run bigger Repetition codes
75,2023-10-29 14:47:43.903639+01:00,bigger_rep_codes,ibm_sherbrooke,cmz64zvvcq70008qdxcg,[],1111.0,,,,JobStatus.DONE,,,RepetitionCodeCircuit,30.0,30,0,_is_hex=True,Run bigger Repetition codes


In [17]:
memories = {}
for job_id, logical in zip(md.job_id, md.logical):
    mmr_name = f"mmr_log_{logical}"
    memories[mmr_name] = provider.retrieve_job(job_id).result().get_memory()

memories

{'mmr_log_1': array([[ -9827335.+1.37223950e+07j, -53977778.+7.99818970e+07j,
         -79071209.-9.83478680e+07j, ...,  -8432334.-7.54241200e+06j,
          -5175264.-8.50117500e+06j,   6540639.-5.34872300e+06j],
        [-10489210.+2.02035840e+07j, -83693508.+7.02712640e+07j,
         -68739961.-8.00257260e+07j, ...,  -6971682.-7.36254800e+06j,
          -4029180.-5.24985200e+06j,  12069718.-1.06493800e+07j],
        [-14341855.+1.95327160e+07j, -72534674.+6.55941800e+07j,
         -67845350.-1.14758427e+08j, ...,  -6694234.-5.57350800e+06j,
         -10807815.-1.09097600e+07j,   7684249.-1.06278620e+07j],
        ...,
        [ -8225117.+9.95315900e+06j, -71635045.+6.52791880e+07j,
         -62033555.-8.20048330e+07j, ...,   6253861.-8.49423500e+06j,
           5556042.-8.86801300e+06j,   8736877.-1.35674410e+07j],
        [-16130944.+1.36266200e+07j, -57588956.+5.27301930e+07j,
         -58264550.-6.76114050e+07j, ...,  -7844011.-9.03883700e+06j,
          -8234982.-9.29126300e+06j

# Decode

In [18]:
import numpy as np
import stim
import pymatching

from soft_info import get_repcode_layout, get_KDEs

# Code parameters
d=max_distance
T=max_distance
layout = get_repcode_layout(distance=max_distance, backend=provider.get_backend("ibm_sherbrooke"), _is_hex=True)

kde_dict, scaler_dict = get_KDEs(provider, 'ibm_sherbrooke', layout, bandwidths=0.2, plot=False)

circuit = stim.Circuit.generated("repetition_code:memory",
                                 distance=d,
                                 rounds=T,
                                 after_clifford_depolarization=0.1)

model = circuit.detector_error_model(decompose_errors=True)
matching = pymatching.Matching.from_detector_error_model(model)

memory = memories['mmr_log_0']
#draw_matching_graph(matching, d, T)

In [19]:
from tqdm import tqdm

from soft_info import get_counts

VERBOSE = False

actual_observables = np.array([[False]]) # hardcoded, can be retrieved
num_errors = 0

i = 0
w_idx_lst = []
for shot in tqdm(range(len(memory))[:]):
    i += 1
    IQ_data = memory[shot]

    counts = get_counts([IQ_data], kde_dict, scaler_dict, layout, T, verbose=False)
    count_key = next(iter(counts.keys()))
    
    #soft_reweight_pymatching(matching, d, T, IQ_data, kde_dict, layout, scaler_dict, common_measure=0.01, verbose=False)  
    reweight_edges_to_one(matching)

    array_processed_string = process_string(count_key, verbose=False)

    predicted_observables = matching.decode(array_processed_string)

    if predicted_observables == [0]:
        continue
    
    #print(f"Wrong decoding at index {i}")
    w_idx_lst.append(i)

    
    if VERBOSE:
        print("Count key:", count_key)
        print("process_string:", array_processed_string)

    if VERBOSE:
        draw_matching_graph(matching, d, T)

    if VERBOSE:
        matched_edges = matching.decode_to_edges_array(array_processed_string)
        print("matched_edges: ", matched_edges)
        print("Estimated flip:", predicted_observables)

    num_errors += not np.array_equal(actual_observables[0, :], predicted_observables) # 0 can be changed to i if multiple observables and multiple syndromes per ovbservable

print("Num errors:", num_errors)

100%|██████████| 1111/1111 [12:02<00:00,  1.54it/s]

Num errors: 534





In [20]:
from tqdm import tqdm

from soft_info import get_counts

matching = pymatching.Matching.from_detector_error_model(model)

VERBOSE = False

actual_observables = np.array([[False]]) # hardcoded, can be retrieved
num_errors = 0

i = 0
w_idx_lst = []
for shot in tqdm(range(len(memory))[:]):
    i += 1
    IQ_data = memory[shot]

    counts = get_counts([IQ_data], kde_dict, scaler_dict, layout, T, verbose=False)
    count_key = next(iter(counts.keys()))
    
    #soft_reweight_pymatching(matching, d, T, IQ_data, kde_dict, layout, scaler_dict, common_measure=0.01, verbose=False)  
    #reweight_edges_to_one(matching)

    array_processed_string = process_string(count_key, verbose=False)

    predicted_observables = matching.decode(array_processed_string)

    if predicted_observables == [0]:
        continue
    
    #print(f"Wrong decoding at index {i}")
    w_idx_lst.append(i)

    
    if VERBOSE:
        print("Count key:", count_key)
        print("process_string:", array_processed_string)

    if VERBOSE:
        draw_matching_graph(matching, d, T)

    if VERBOSE:
        matched_edges = matching.decode_to_edges_array(array_processed_string)
        print("matched_edges: ", matched_edges)
        print("Estimated flip:", predicted_observables)

    num_errors += not np.array_equal(actual_observables[0, :], predicted_observables) # 0 can be changed to i if multiple observables and multiple syndromes per ovbservable

print("Num errors:", num_errors)

100%|██████████| 1111/1111 [11:39<00:00,  1.59it/s]

Num errors: 550



