In [4]:
import numpy as np
import os
import qec
import pymatching, stim
from BiasedErasure.main_code.LogicalCircuit import LogicalCircuit 
import random
from BiasedErasure.delayed_erasure_decoders.HeraldedCircuit_FREE_LD import HeraldedCircuit_FREE_LD
import math
import time
import sinter
import copy


In [23]:
def generate_ghz_state(d, order, num_logicals, QEC_cycles, entangling_gate_error_rate, entangling_gate_loss_rate, logical_basis='X', biased_pres_gates = False):
        
    assert logical_basis in ['X', 'Z']
    logical_qubits = [qec.surface_code.RotatedSurfaceCode(d, d) for _ in range(num_logicals)]
    
    lc = LogicalCircuit(logical_qubits, initialize_circuit=False,
                            loss_noise_scale_factor=1, spam_noise_scale_factor=0, measurement_error_rate=0,
                            gate_noise_scale_factor=1, idle_noise_scale_factor=0, reset_error_rate=0,
                            entangling_gate_error_rate=entangling_gate_error_rate, entangling_gate_loss_rate=entangling_gate_loss_rate)
    lc.loss_noise_scale_factor = 0; lc.gate_noise_scale_factor=0
    
    # First prepare all logical qubits in |0>
    lc.append(qec.surface_code.prepare_zero, list(range(0, len(logical_qubits))))

    # Rotate noiselessly
    lc.append(qec.surface_code.rotate_code, 0)

    # Hadamard all qubits
    lc.append(qec.surface_code.global_h, list(range(len(logical_qubits))), move_duration=0)

    # Entangle all qubits:
    for _ in range(len(order)):
        lc.append(qec.surface_code.global_cz, [order[_][0], order[_][1]], move_duration=0)
        lc.append(qec.surface_code.global_h, order[_][1], move_duration=0)
        
        # add QEC cycles after every gate:
        lc.loss_noise_scale_factor = 1; lc.gate_noise_scale_factor=1
        for cycle_num in range(QEC_cycles):
            lc.append_from_stim_program_text("""TICK""") # starting a QEC round
            lc.append(qec.surface_code.measure_stabilizers, list(range(len(logical_qubits))), order='fowler', with_cnot=biased_pres_gates, compare_with_previous=True) # append QEC rounds
            lc.append_from_stim_program_text("""TICK""") # starting a QEC round
        lc.loss_noise_scale_factor = 0; lc.gate_noise_scale_factor=0
        
    if logical_basis == 'X':
        lc.append(qec.surface_code.measure_x, list(range(len(logical_qubits))), observable_include=False)
        lc.append('MOVE_TO_NO_NOISE', lc.qubit_indices, 0)
        global_x = []
        for index in range(len(lc.logical_qubits)):
            global_x += lc.logical_qubits[index].logical_x_operator + [stim.GateTarget(stim.target_combiner())]
        lc.append('MPP', global_x[:-1])
        lc.append('OBSERVABLE_INCLUDE', [stim.target_rec(-1)], 0)

    elif logical_basis == 'Z':
        lc.append(qec.surface_code.measure_z, list(range(len(logical_qubits))), observable_include=False)
        lc.append('MOVE_TO_NO_NOISE', lc.qubit_indices, 0)

        for index in range(len(logical_qubits) - 1):
            lc.append('MPP', lc.logical_qubits[index].logical_z_operator + [stim.GateTarget(stim.target_combiner())] +
                    lc.logical_qubits[index + 1].logical_z_operator)
            lc.append('OBSERVABLE_INCLUDE', [stim.target_rec(-1)], lc.num_observables)

    return lc

In [25]:
phys_err = 10e-2
erasure_ratio = 1
bias_ratio = 0.5

# 1q loss and bias channels:
def set_1q_biased_erasure_channel(phys_err, bias_ratio, erasure_ratio):
    phys_err_1q = (1- np.sqrt(1 - phys_err))
    px_q1 = phys_err_1q / (2*(1+bias_ratio))
    py_q1 = px_q1
    pz_q1 = bias_ratio*(px_q1 + py_q1)
    entangling_gate_error_rate=(1-erasure_ratio)*np.array([px_q1, py_q1, pz_q1])
    entangling_gate_loss_rate=erasure_ratio*phys_err_1q
    return entangling_gate_error_rate, entangling_gate_loss_rate


fidelities = []
z_fidelities_1 = []
x_fidelities_1 = []
z_fidelities_2 = []
x_fidelities_2 = []


# logical_basis = 'Z'
distance = 3
num_logicals = 2
cycles = distance
loss_detection_freq = 1
biased_pres_gates = True
biased_erasure = False
SSR = True
printing = False
decoder = 'MWPM'

num_shots = 100

output_dir = "/Users/gefenbaranes/Documents/results"
n_vec = np.arange(2, 8, 1)
for n in [2]:
    order_1 = []
    order_2 = []
    for i in range(1, n):
        order_1.append((0, i))
        order_2.append((i - 1, i))
        
    orders = [order_1, order_2]
    if printing:
        print('GHZ state of {} logical qubits with d = {}'.format(n, distance))
    for (o, order) in enumerate(orders):
        for bias in [1/2]:
            for logical_basis in ['X', 'Z']:
                entangling_gate_error_rate, entangling_gate_loss_rate = set_1q_biased_erasure_channel(phys_err, bias_ratio, erasure_ratio)
                start_time = time.time()
                lc = generate_ghz_state(distance, order, num_logicals=n, QEC_cycles=cycles, entangling_gate_error_rate=entangling_gate_error_rate, entangling_gate_loss_rate=entangling_gate_loss_rate, logical_basis=logical_basis, biased_pres_gates = False)
                
                if printing:
                    print(f"\n {logical_basis} logical Circuit after noise: \n {lc}")
                    print(f"potential lost qubits: {lc.potential_lost_qubits}\n")
                    print(f"potential lost qubits loss probabilities: {lc.loss_probabilities}\n")

                if erasure_ratio > 0:
                    ancilla_qubits = [qubit for i in range(num_logicals) for qubit in lc.logical_qubits[i].measure_qubits]
                    data_qubits = [qubit for i in range(num_logicals) for qubit in lc.logical_qubits[i].data_qubits]
                    loss_sampling = 1 if erasure_ratio > 0 else  num_shots # how many times we will use the same loss sampling result
                    num_loss_shots = math.ceil(num_shots / loss_sampling)
                    loss_detection_events_all_shots = np.random.rand(num_loss_shots, len(lc.potential_lost_qubits)) < lc.loss_probabilities
                    loss_detection_class = HeraldedCircuit_FREE_LD(circuit=lc, biased_erasure=biased_erasure,
                                                                    basis = logical_basis, distance=distance, erasure_ratio = erasure_ratio, 
                                                                    phys_error = phys_err, ancilla_qubits=ancilla_qubits, data_qubits=data_qubits,
                                                                    SSR=SSR, cycles=cycles, printing=printing, loss_detection_freq = loss_detection_freq)
                    num_errors = 0
                    for shot in range(num_loss_shots):
                        loss_detection_events = loss_detection_events_all_shots[shot]
                        if printing:
                            print(f"loss detection events for this shot: {loss_detection_events} \n")
                        
                        new_circuit, new_lossless_circuit = loss_detection_class.heralded_new_circuit(loss_detection_events)
                        
                        measurements = new_circuit.compile_sampler().sample(loss_sampling)
                        converter = new_lossless_circuit.compile_m2d_converter()
                        new_detection_events, new_observable_flips = converter.convert(measurements=measurements, separate_observables=True)
                        if printing:
                            print("\n New Heralded Circuit (for measurements):")
                            print(new_circuit)
                            print("\n New Lossless Circuit:")
                            print(new_lossless_circuit)
                        
                        if decoder == "MLE":
                            detector_error_model = new_lossless_circuit.detector_error_model(decompose_errors=False, approximate_disjoint_errors=True, ignore_decomposition_failures=True, allow_gauge_detectors=False)
                            prediction = qec.correlated_decoders.mle.decode_gurobi_with_dem(dem=detector_error_model, detector_shots = new_detection_events)
                        else:
                            detector_error_model = new_lossless_circuit.detector_error_model(decompose_errors=True, approximate_disjoint_errors=True, ignore_decomposition_failures=True, allow_gauge_detectors=False)
                            prediction = sinter.predict_observables(
                                    dem=detector_error_model,
                                    dets=new_detection_events,
                                    decoder='pymatching',)
                        num_errors_loss_shot = np.sum(np.logical_xor(new_observable_flips, prediction))
                        num_errors += num_errors_loss_shot
                    
                else:
                    sampler = lc.compile_detector_sampler()
                    detection_events, observable_flips = sampler.sample(num_shots, separate_observables=True) 
                    detector_error_model = lc.detector_error_model(decompose_errors=True, approximate_disjoint_errors=True, ignore_decomposition_failures=True) 
                    # Run the decoder.
                    prediction = sinter.predict_observables(
                        dem=detector_error_model,
                        dets=detection_events,
                        decoder='pymatching',
                    )
                    num_errors = np.sum(np.logical_xor(observable_flips, prediction))
                    
                logical_error = num_errors / num_shots
                
                if logical_basis == 'X':
                    logical_error_X = logical_error
                elif logical_basis == 'Z':
                    logical_error_Z = logical_error
                
                print(f"For d={distance}, basis = {logical_basis}, physical error = {phys_err}, logical error is {logical_error:.2e}. Saving this result!")
    # f.write(f'{distance} {phys_err} {bias_ratio} {erasure_ratio} {num_errors} {num_shots} {time.time()-start_time}\n')
    # f.close()
# print(f"For phys error = {phys_err}, bias = {bias_ratio}, cycles = {cycles}: Logical bias ratio = {logical_error_Z / logical_error_X}")
            
# Problems with the XZZX:
# 1. detectors don't propagate backwards

KeyboardInterrupt: 