In [1]:
import os
from matplotlib import pyplot as plt
import pymatching
import numpy as np
from copy import deepcopy
from circuit_generators import *
from sampling_functions import *
from bitpack import pack_bits, unpack_bits
from circuit_partition import *
from utilities_tf import *
from CNNModel import *


# Number of worker nodes
n_worker_nodes = 8

# Surface code specifications
d = 5
r = 5
kernel_size = 3

# Error probabilities
p = 0.01
before_round_data_depolarization = p
after_reset_flip_probability = p
after_clifford_depolarization = p
before_measure_flip_probability = p

use_rotated_z = True
observable_type = "ZL" if use_rotated_z else "XL"

# Bit types
binary_t = np.int8 # Could use even less if numpy allowed
packed_t = np.int8 # Packed bit type
if d<=8:
  pass
elif d>8 and d<=16:
  packed_t = np.int16
elif d>16 and d<=32:
  packed_t = np.int32
elif d>32 and d<=64:
  packed_t = np.int64
elif d>64 and d<=128:
  packed_t = np.int128
elif d>128 and d<=256:
  packed_t = np.int256
else:
  raise RuntimeError("d is too large.")
time_t = np.int8

# Measurement index type
idx_t = np.int8
n_all_measurements = r*(d**2-1) + d**2
if n_all_measurements > np.iinfo(idx_t).max:
  idx_t = np.int16
if n_all_measurements > np.iinfo(idx_t).max:
  idx_t = np.int32
if n_all_measurements > np.iinfo(idx_t).max:
  idx_t = np.int64
if n_all_measurements > np.iinfo(idx_t).max:
  idx_t = np.int128
if n_all_measurements > np.iinfo(idx_t).max:
  idx_t = np.int256
if n_all_measurements > np.iinfo(idx_t).max:
  raise RuntimeError("idx_t is too small.")

# Call signature for circuit_partition::group_det_bits_kxk
call_group_det_bits_kxk = lambda det_bits_dxd, data_bits_dxd=None, d=d, r=r, k=kernel_size, use_rotated_z=use_rotated_z, binary_t=binary_t, idx_t=idx_t: group_det_bits_kxk(det_bits_dxd, d, r, k, use_rotated_z, data_bits_dxd, binary_t, idx_t)




Configure, check for any saved data, and make it if needed.

In [3]:
n_test = 5000000
n_train = 5000000
stim_sampler_seed = 12345
rnd_seed = 12345
n_samples = n_test + n_train

config_name = f"surface_code_d{d}_r{r}_k{kernel_size}_RCNN"
saved_data_dir = f"saved_data/{config_name}/obs_{observable_type}_prob_{str(p).replace('.','p')}"
saved_data_dir = f"{saved_data_dir}_N_{n_samples}_seed_{stim_sampler_seed}"
has_saved_data = os.path.exists(saved_data_dir)
os.makedirs(saved_data_dir, exist_ok=True)

decoders = ['pymatching']
test_circuit = get_builtin_circuit(
  "surface_code:rotated_memory_"+('z' if use_rotated_z else 'x'),
  distance=d,
  rounds=r,
  before_round_data_depolarization = before_round_data_depolarization,
  after_reset_flip_probability = after_reset_flip_probability,
  after_clifford_depolarization = after_clifford_depolarization,
  before_measure_flip_probability = before_measure_flip_probability
)
kernel_circuit = get_builtin_circuit(
  "surface_code:rotated_memory_"+('z' if use_rotated_z else 'x'),
  distance=kernel_size,
  rounds=r,
  before_round_data_depolarization = before_round_data_depolarization,
  after_reset_flip_probability = after_reset_flip_probability,
  after_clifford_depolarization = after_clifford_depolarization,
  before_measure_flip_probability = before_measure_flip_probability
)

# Sampling for the dxd circuit
m_sampler = test_circuit.compile_sampler(seed=stim_sampler_seed)
d_sampler = test_circuit.compile_detector_sampler(seed=stim_sampler_seed)
converter = test_circuit.compile_m2d_converter()
detector_error_model = test_circuit.detector_error_model(decompose_errors=True)

if not has_saved_data:
  measurements = m_sampler.sample(n_samples, bit_packed=False)
  det_evts, flips = converter.convert(measurements=measurements, separate_observables=True, bit_packed=False)
  measurements = measurements.astype(binary_t)
  det_evts = det_evts.astype(binary_t)
  flips = flips.astype(binary_t)

  np.save(f"{saved_data_dir}/measurements.npy", measurements)
  np.save(f"{saved_data_dir}/det_evts.npy", det_evts)
  np.save(f"{saved_data_dir}/flips.npy", flips)
else:
  measurements = np.load(f"{saved_data_dir}/measurements.npy")
  det_evts = np.load(f"{saved_data_dir}/det_evts.npy")
  flips = np.load(f"{saved_data_dir}/flips.npy")

avg_flips = np.sum(flips.reshape(-1,), dtype=np.float32)/flips.shape[0]
print(f"Average flip rate for the full circuit: {avg_flips}")

Average flip rate for the full circuit: 0.3546183


In [3]:
print(test_circuit)

QUBIT_COORDS(1, 1) 1
QUBIT_COORDS(2, 0) 2
QUBIT_COORDS(3, 1) 3
QUBIT_COORDS(5, 1) 5
QUBIT_COORDS(6, 0) 6
QUBIT_COORDS(7, 1) 7
QUBIT_COORDS(9, 1) 9
QUBIT_COORDS(1, 3) 12
QUBIT_COORDS(2, 2) 13
QUBIT_COORDS(3, 3) 14
QUBIT_COORDS(4, 2) 15
QUBIT_COORDS(5, 3) 16
QUBIT_COORDS(6, 2) 17
QUBIT_COORDS(7, 3) 18
QUBIT_COORDS(8, 2) 19
QUBIT_COORDS(9, 3) 20
QUBIT_COORDS(10, 2) 21
QUBIT_COORDS(0, 4) 22
QUBIT_COORDS(1, 5) 23
QUBIT_COORDS(2, 4) 24
QUBIT_COORDS(3, 5) 25
QUBIT_COORDS(4, 4) 26
QUBIT_COORDS(5, 5) 27
QUBIT_COORDS(6, 4) 28
QUBIT_COORDS(7, 5) 29
QUBIT_COORDS(8, 4) 30
QUBIT_COORDS(9, 5) 31
QUBIT_COORDS(1, 7) 34
QUBIT_COORDS(2, 6) 35
QUBIT_COORDS(3, 7) 36
QUBIT_COORDS(4, 6) 37
QUBIT_COORDS(5, 7) 38
QUBIT_COORDS(6, 6) 39
QUBIT_COORDS(7, 7) 40
QUBIT_COORDS(8, 6) 41
QUBIT_COORDS(9, 7) 42
QUBIT_COORDS(10, 6) 43
QUBIT_COORDS(0, 8) 44
QUBIT_COORDS(1, 9) 45
QUBIT_COORDS(2, 8) 46
QUBIT_COORDS(3, 9) 47
QUBIT_COORDS(4, 8) 48
QUBIT_COORDS(5, 9) 49
QUBIT_COORDS(6, 8) 50
QUBIT_COORDS(7, 9) 51
QUBIT_COORDS(8,

In [4]:
print(kernel_circuit)

QUBIT_COORDS(1, 1) 1
QUBIT_COORDS(2, 0) 2
QUBIT_COORDS(3, 1) 3
QUBIT_COORDS(5, 1) 5
QUBIT_COORDS(1, 3) 8
QUBIT_COORDS(2, 2) 9
QUBIT_COORDS(3, 3) 10
QUBIT_COORDS(4, 2) 11
QUBIT_COORDS(5, 3) 12
QUBIT_COORDS(6, 2) 13
QUBIT_COORDS(0, 4) 14
QUBIT_COORDS(1, 5) 15
QUBIT_COORDS(2, 4) 16
QUBIT_COORDS(3, 5) 17
QUBIT_COORDS(4, 4) 18
QUBIT_COORDS(5, 5) 19
QUBIT_COORDS(4, 6) 25
R 1 3 5 8 10 12 15 17 19
X_ERROR(0.01) 1 3 5 8 10 12 15 17 19
R 2 9 11 13 14 16 18 25
X_ERROR(0.01) 2 9 11 13 14 16 18 25
TICK
DEPOLARIZE1(0.01) 1 3 5 8 10 12 15 17 19
H 2 11 16 25
DEPOLARIZE1(0.01) 2 11 16 25
TICK
CX 2 3 16 17 11 12 15 14 10 9 19 18
DEPOLARIZE2(0.01) 2 3 16 17 11 12 15 14 10 9 19 18
TICK
CX 2 1 16 15 11 10 8 14 3 9 12 18
DEPOLARIZE2(0.01) 2 1 16 15 11 10 8 14 3 9 12 18
TICK
CX 16 10 11 5 25 19 8 9 17 18 12 13
DEPOLARIZE2(0.01) 16 10 11 5 25 19 8 9 17 18 12 13
TICK
CX 16 8 11 3 25 17 1 9 10 18 5 13
DEPOLARIZE2(0.01) 16 8 11 3 25 17 1 9 10 18 5 13
TICK
H 2 11 16 25
DEPOLARIZE1(0.01) 2 11 16 25
TICK
X_ERROR(0.

In [5]:
def split_measurements(measurements, d):
  n_measurements = idx_t(measurements.shape[1])
  # Measurements on data qubits come last
  exclude_indices = np.array([-x-1 for x in range(d**2)], dtype=idx_t)
  exclude_indices = exclude_indices + n_measurements
  # Out of all measurements on data qubits, the logical qubit measurements are those on the boundary of the lattice.
  # All other equivalent X_L/Z_L operators can be found through the combination of ancilla measurements and the chosen data qubits giving us the logical qubit.
  exclude_indices_obsL = np.array([-x-1 for x in range(d*(d-1), d**2)], dtype=idx_t)
  exclude_indices_obsL = exclude_indices_obsL + n_measurements
  # From obs_bits, we want to exclude all measurements except those listed in exclude_indices_obsL
  exclude_indices_obs = np.arange(0, n_measurements, 1, dtype=idx_t)
  exclude_indices_obs = np.delete(exclude_indices_obs, exclude_indices_obsL)

  det_bits = measurements
  det_bits = np.delete(det_bits, exclude_indices, axis=1)
  obs_bits = measurements
  obs_bits = np.delete(obs_bits, exclude_indices_obs, axis=1)

  data_bits = measurements[:, exclude_indices]

  # Reverse the order of data_bits because exclude_indices starts from the last data qubit measurement, not the first
  data_bits = np.flip(data_bits, axis=1)

  return det_bits, obs_bits, data_bits


n_measurements = idx_t(measurements.shape[1])
det_bits, obs_bits, data_bits = split_measurements(measurements, d)
print(obs_bits)

[[0 1 0 0 1]
 [0 1 1 1 1]
 [1 0 0 0 1]
 ...
 [0 0 1 0 1]
 [0 1 0 1 0]
 [0 0 0 1 0]]


In [6]:
det_bits_kxk_all, data_bits_kxk_all, _, _ = call_group_det_bits_kxk(det_bits, data_bits_dxd=data_bits)
kernel_types = get_unique_kernel_types(kernel_size, d)
final_det_evts = det_evts[:, -((d**2-1)//2):]
det_evts_kxk_all_translated = translate_det_bits_to_det_evts(observable_type, kernel_size, det_bits_kxk_all, final_det_evts)

In [7]:
# Make sure the data type is np.int32 below, not idx_t!
idxs_test, idxs_train = split_data(np.arange(n_samples, dtype=np.int32), test_size = n_test/n_samples, seed = rnd_seed, shuffle = False)

class_bits = flips
features_det_bits = np.swapaxes(det_bits_kxk_all, 0, 1)
features_det_evts = np.swapaxes(det_evts_kxk_all_translated, 0, 1)

print(class_bits.shape)
print(features_det_bits.shape)
print(features_det_evts.shape)

(10000000, 1)
(10000000, 9, 40)
(10000000, 9, 40)


In [11]:
def learning_rate_scheduler(epoch, lr):
  if epoch < 10:
    return 0.001*(10-epoch)
  elif epoch < 20:
    return lr * 0.9
  elif epoch < 30:
    return lr * 0.8
  else:
    return lr * 0.65

In [12]:
n_nodes = 100
model_dxd = FullRCNNModel(
  observable_type, d, kernel_size, r,
  [n_nodes for _ in range(2)],
  npol=2,
  stop_round = None,
  has_nonuniform_response = False,
  do_all_data_qubits = False,
  return_all_rounds = False
)
model_dxd.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
model_dxd([ features_det_bits[0:1], features_det_evts[0:1] ])
model_dxd.summary()


Number of unique contributions: 13
Total number of fractions: 28
Total number of phases: 62
Number of unique contributions: 13
Total number of fractions: 28
Total number of phases: 62
Number of unique contributions: 13
Total number of fractions: 28
Total number of phases: 62
Number of unique contributions: 13
Total number of fractions: 28
Total number of phases: 62

Model: "full_rcnn_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 rcnn_initial_state_kernel_  multiple                  8848      
 combiner (RCNNInitialState                                      
 KernelCombiner)                                                 
                                                                 
 rcnn_lead_in_kernel_combin  multiple                  11091     
 er (RCNNLeadInKernelCombin                                      
 er)                                                             
           

In [13]:
val_split = 0.2
n_epochs = 50
history = model_dxd.fit(
  x=[ features_det_bits[idxs_train], features_det_evts[idxs_train] ],
  y=class_bits[idxs_train,:],
  batch_size=10000,
  epochs=n_epochs, validation_split=val_split,
  callbacks=[
    tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
    tf.keras.callbacks.LearningRateScheduler(learning_rate_scheduler)
  ]
)

flips_pred = model_dxd.predict(
  [ features_det_bits[idxs_test], features_det_evts[idxs_test] ],
  batch_size=10000
)
print(f"Inaccuracy of the final model on the test data: {(flips[idxs_test]!=(flips_pred>0.5).astype(binary_t)).astype(binary_t).sum()/idxs_test.shape[0]}")

Epoch 1/50




In [None]:
pymatcher = pymatching.Matching.from_detector_error_model(detector_error_model)
flips_pred_pym = pymatcher.decode_batch(det_evts, bit_packed_predictions=False, bit_packed_shots=False).astype(binary_t).reshape(-1,1)
print(f"PyMatching error rate for the full data set: {np.sum((flips_pred_pym!=flips))/flips_pred_pym.shape[0]}")
flips_pred_pym = pymatcher.decode_batch(det_evts[idxs_test,:], bit_packed_predictions=False, bit_packed_shots=False).astype(binary_t).reshape(-1,1)
print(f"PyMatching error rate for test data set: {np.sum((flips_pred_pym!=flips[idxs_test,:]))/flips_pred_pym.shape[0]}")

PyMatching error rate for the full data set: 0.0836916
PyMatching error rate for test data set: 0.0836274
