In [None]:
from typing import Dict, List

import pymatching
import numpy as np

from soft_info import get_repcode_IQ_map, llh_ratio


def soft_reweight_pymatching(matching : pymatching.Matching,  d : int, T : int, IQ_data, 
                             kde_dict: Dict, layout : List[int], scaler_dict : Dict,
                             p_data : float = None, p_meas : float = None, common_measure = None):

    p_data = p_data if p_data is not None else 6.836e-3  # Sherbrooke median
    p_meas = p_meas if p_meas is not None else 0

    if layout is not None:
        qubit_mapping = get_repcode_IQ_map(layout, T)

    for edge in matching.edges():

        if edge[1]==None: # always second pose None
            #Boundary edge
            continue

        if edge[1]==edge[0]+1: # always first pos the smaller
            #Data edge
            edge[2]['weight'] = -np.log(p_data/(1-p_data)) # dict always in third pos

            if common_measure is not None:
                edge[2]['weight'] = round( edge[2]['weight']/ common_measure) * common_measure


            ##############################################################################
            continue  # next iteration because dont need to change the rest of the weights
            ##############################################################################

        if edge[1]==edge[0]+d:
            #Time edge
            edge[2]['weight'] = -np.log(p_meas/(1-p_meas))
        
        if edge[1]==edge[0]+d+1:
            #Time and data edge
            edge[2]['weight'] = -np.log(p_data/(1-p_data)) - np.log(p_meas/(1-p_meas))

        src_node, tgt_node = edge[0], edge[1]

        #Structure of IQ data = [link_0, link_1, link_3, link_0, link_1, .., code_qubit_1, ...]
        # equivalent to       = [node_0, node_1, node_3, node_4, node_5, .. ]
        # =>
        IQ_point = IQ_data[src_node]
        link_qubit_index = src_node % d # gives the inex of the link qubit register
        layout_qubit_idx = qubit_mapping[link_qubit_index]

        kde_0, kde_1 = kde_dict.get(layout_qubit_idx, (None, None))
        scaler = scaler_dict.get(layout_qubit_idx, None)

        weight = llh_ratio(IQ_point, kde_0, kde_1, scaler)

        edge[2]['weight'] += weight

        #Round the weights to common measure
        if common_measure is not None:
            edge[2]['weight'] = round( edge[2]['weight']/ common_measure) * common_measure






