In [1]:
from matplotlib import pyplot as plt
import pymatching
import numpy as np
from copy import deepcopy
from circuit_generators import *
from sampling_functions import *
from bitpack import pack_bits, unpack_bits
from circuit_partition import *
from utilities_tf import *


# Number of worker nodes
n_worker_nodes = 8

# Surface code specifications
d = 4
r = 2
kernel_size = 3
p = 0.01
use_rotated_z = True

# Bit types
binary_t = np.int8 # Could use even less if numpy allowed
packed_t = np.int8 # Packed bit type
if d<=8:
  pass
elif d>8 and d<=16:
  packed_t = np.int16
elif d>16 and d<=32:
  packed_t = np.int32
elif d>32 and d<=64:
  packed_t = np.int64
elif d>64 and d<=128:
  packed_t = np.int128
elif d>128 and d<=256:
  packed_t = np.int256
else:
  raise RuntimeError("d is too large.")
time_t = np.int8

# Measurement index type
idx_t = np.int8
n_all_measurements = r*(d**2-1) + d**2
if n_all_measurements > np.iinfo(idx_t).max:
  idx_t = np.int16
if n_all_measurements > np.iinfo(idx_t).max:
  idx_t = np.int32
if n_all_measurements > np.iinfo(idx_t).max:
  idx_t = np.int64
if n_all_measurements > np.iinfo(idx_t).max:
  idx_t = np.int128
if n_all_measurements > np.iinfo(idx_t).max:
  idx_t = np.int256
if n_all_measurements > np.iinfo(idx_t).max:
  raise RuntimeError("idx_t is too small.")

# Call signature for circuit_partition::group_det_bits_kxk
call_group_det_bits_kxk = lambda det_bits_dxd, data_bits_dxd=None, d=d, r=r, k=kernel_size, use_rotated_z=use_rotated_z, binary_t=binary_t, idx_t=idx_t: group_det_bits_kxk(det_bits_dxd, d, r, k, use_rotated_z, data_bits_dxd, binary_t, idx_t)

# Call signature for bitpack::pack_bits
call_pack_bits = lambda bits, packed_t=packed_t: pack_bits(bits, bits.shape[0], packed_t=packed_t)




In [2]:
n_test = 10000000
n_train = 10000000
n_samples = n_test + n_train
decoders = ['pymatching']
test_circuit = get_builtin_circuit(
  "surface_code:rotated_memory_"+('z' if use_rotated_z else 'x'),
  distance=d,
  rounds=r,
  before_round_data_depolarization = p,
  after_reset_flip_probability = p,
  after_clifford_depolarization = p,
  before_measure_flip_probability = p
)

kernel_circuit = stim.Circuit(
  f"""
QUBIT_COORDS(1, 1) 1
QUBIT_COORDS(2, 0) 2
QUBIT_COORDS(3, 1) 3
QUBIT_COORDS(5, 1) 5
QUBIT_COORDS(1, 3) 8
QUBIT_COORDS(2, 2) 9
QUBIT_COORDS(3, 3) 10
QUBIT_COORDS(4, 2) 11
QUBIT_COORDS(5, 3) 12
QUBIT_COORDS(6, 2) 13
QUBIT_COORDS(0, 4) 14
QUBIT_COORDS(1, 5) 15
QUBIT_COORDS(2, 4) 16
QUBIT_COORDS(3, 5) 17
QUBIT_COORDS(4, 4) 18
QUBIT_COORDS(5, 5) 19
QUBIT_COORDS(4, 6) 25
R 1 3 5 8 10 12 15 17 19
X_ERROR(0.01) 1 3 5 8 10 12 15 17 19
R 2 9 11 13 14 16 18 25
X_ERROR(0.01) 2 9 11 13 14 16 18 25
TICK
DEPOLARIZE1(0.01) 1 3 5 8 10 12 15 17 19
H 2 11 16 25
DEPOLARIZE1(0.01) 2 11 16 25
TICK
CX 2 3 16 17 11 12 15 14 10 9 19 18
DEPOLARIZE2(0.01) 2 3 16 17 11 12 15 14 10 9 19 18
DEPOLARIZE1(0.01) 13 25
TICK
CX 2 1 16 15 11 10 8 14 3 9 12 18
DEPOLARIZE2(0.01) 2 1 16 15 11 10 8 14 3 9 12 18
DEPOLARIZE1(0.01) 5 13 17 19 25
TICK
CX 16 10 11 5 25 19 8 9 17 18 12 13
DEPOLARIZE2(0.01) 16 10 11 5 25 19 8 9 17 18 12 13
#DEPOLARIZE1(0.01)
TICK
CX 16 8 11 3 25 17 1 9 10 18 5 13
DEPOLARIZE2(0.01) 16 8 11 3 25 17 1 9 10 18 5 13
DEPOLARIZE1(0.01) 12 15 19
TICK
H 2 11 16 25
DEPOLARIZE1(0.01) 2 11 16 25
TICK
X_ERROR(0.01) 2 9 11 13 14 16 18 25
MR 2 9 11 13 14 16 18 25
X_ERROR(0.01) 2 9 11 13 14 16 18 25
DETECTOR(0, 4, 0) rec[-4]
DETECTOR(2, 2, 0) rec[-7]
DETECTOR(4, 4, 0) rec[-2]
DETECTOR(6, 2, 0) rec[-5]
REPEAT {r-1} {{
  TICK
  DEPOLARIZE1(0.01) 1 3 5 8 10 12 15 17 19
  H 2 11 16 25
  DEPOLARIZE1(0.01) 2 11 16 25
  TICK
  CX 2 3 16 17 11 12 15 14 10 9 19 18
  DEPOLARIZE2(0.01) 2 3 16 17 11 12 15 14 10 9 19 18
  DEPOLARIZE1(0.01) 13 25
  TICK
  CX 2 1 16 15 11 10 8 14 3 9 12 18
  DEPOLARIZE2(0.01) 2 1 16 15 11 10 8 14 3 9 12 18
  DEPOLARIZE1(0.01) 5 13 17 19 25
  TICK
  CX 16 10 11 5 25 19 8 9 17 18 12 13
  DEPOLARIZE2(0.01) 16 10 11 5 25 19 8 9 17 18 12 13
  #DEPOLARIZE1(0.01)
  TICK
  CX 16 8 11 3 25 17 1 9 10 18 5 13
  DEPOLARIZE2(0.01) 16 8 11 3 25 17 1 9 10 18 5 13
  DEPOLARIZE1(0.01) 12 15 19
  TICK
  H 2 11 16 25
  DEPOLARIZE1(0.01) 2 11 16 25
  TICK
  X_ERROR(0.01) 2 9 11 13 14 16 18 25
  MR 2 9 11 13 14 16 18 25
  X_ERROR(0.01) 2 9 11 13 14 16 18 25
  SHIFT_COORDS(0, 0, 1)
  DETECTOR(2, 0, 0) rec[-8] rec[-16]
  DETECTOR(2, 2, 0) rec[-7] rec[-15]
  DETECTOR(4, 2, 0) rec[-6] rec[-14]
  DETECTOR(6, 2, 0) rec[-5] rec[-13]
  DETECTOR(0, 4, 0) rec[-4] rec[-12]
  DETECTOR(2, 4, 0) rec[-3] rec[-11]
  DETECTOR(4, 4, 0) rec[-2] rec[-10]
  DETECTOR(4, 6, 0) rec[-1] rec[-9]
}}
X_ERROR(0.01) 1 3 5 8 10 12 15 17 19
M 1 3 5 8 10 12 15 17 19
DETECTOR(0, 4, 1) rec[-3] rec[-6] rec[-13]
DETECTOR(2, 2, 1) rec[-5] rec[-6] rec[-8] rec[-9] rec[-16]
DETECTOR(4, 4, 1) rec[-1] rec[-2] rec[-4] rec[-5] rec[-11]
DETECTOR(6, 2, 1) rec[-4] rec[-7] rec[-14]
OBSERVABLE_INCLUDE(0) rec[-7] rec[-8] rec[-9]
  """
)

# Sampling for the dxd circuit
m_sampler = test_circuit.compile_sampler(seed=12345)
d_sampler = test_circuit.compile_detector_sampler(seed=12345)
converter = test_circuit.compile_m2d_converter()
detector_error_model = test_circuit.detector_error_model(decompose_errors=True)

measurements = m_sampler.sample(n_samples, bit_packed=False)
det_evts, flips = converter.convert(measurements=measurements, separate_observables=True, bit_packed=False)
measurements = measurements.astype(binary_t)
det_evts = det_evts.astype(binary_t)
flips = flips.astype(binary_t)

avg_flips = np.sum(flips.reshape(-1,), dtype=np.float32)/flips.shape[0]
print(f"Average flip rate for the full circuit: {avg_flips}")

# Sampling for the kxk kernel
m_sampler_kernel = kernel_circuit.compile_sampler(seed=12345)
d_sampler_kernel = kernel_circuit.compile_detector_sampler(seed=12345)
converter_kernel = kernel_circuit.compile_m2d_converter()
detector_error_model_kernel = kernel_circuit.detector_error_model(decompose_errors=True)

measurements_kernel = m_sampler_kernel.sample(n_samples, bit_packed=False)
det_evts_kernel, flips_kernel = converter_kernel.convert(measurements=measurements_kernel, separate_observables=True, bit_packed=False)
measurements_kernel = measurements_kernel.astype(binary_t)
det_evts_kernel = det_evts_kernel.astype(binary_t)
flips_kernel = flips_kernel.astype(binary_t)

avg_flips_kernel = np.sum(flips_kernel.reshape(-1,), dtype=np.float32)/flips_kernel.shape[0]
print(f"Average flip rate for the kernel sample: {avg_flips_kernel}")

Average flip rate for the full circuit: 0.19171115
Average flip rate for the kernel sample: 0.15777895


In [3]:
print(test_circuit)

QUBIT_COORDS(1, 1) 1
QUBIT_COORDS(2, 0) 2
QUBIT_COORDS(3, 1) 3
QUBIT_COORDS(5, 1) 5
QUBIT_COORDS(6, 0) 6
QUBIT_COORDS(7, 1) 7
QUBIT_COORDS(1, 3) 10
QUBIT_COORDS(2, 2) 11
QUBIT_COORDS(3, 3) 12
QUBIT_COORDS(4, 2) 13
QUBIT_COORDS(5, 3) 14
QUBIT_COORDS(6, 2) 15
QUBIT_COORDS(7, 3) 16
QUBIT_COORDS(0, 4) 18
QUBIT_COORDS(1, 5) 19
QUBIT_COORDS(2, 4) 20
QUBIT_COORDS(3, 5) 21
QUBIT_COORDS(4, 4) 22
QUBIT_COORDS(5, 5) 23
QUBIT_COORDS(6, 4) 24
QUBIT_COORDS(7, 5) 25
QUBIT_COORDS(8, 4) 26
QUBIT_COORDS(1, 7) 28
QUBIT_COORDS(2, 6) 29
QUBIT_COORDS(3, 7) 30
QUBIT_COORDS(4, 6) 31
QUBIT_COORDS(5, 7) 32
QUBIT_COORDS(6, 6) 33
QUBIT_COORDS(7, 7) 34
QUBIT_COORDS(2, 8) 38
QUBIT_COORDS(6, 8) 42
R 1 3 5 7 10 12 14 16 19 21 23 25 28 30 32 34
X_ERROR(0.01) 1 3 5 7 10 12 14 16 19 21 23 25 28 30 32 34
R 2 6 11 13 15 18 20 22 24 26 29 31 33 38 42
X_ERROR(0.01) 2 6 11 13 15 18 20 22 24 26 29 31 33 38 42
TICK
DEPOLARIZE1(0.01) 1 3 5 7 10 12 14 16 19 21 23 25 28 30 32 34
H 2 6 13 20 24 31 38 42
DEPOLARIZE1(0.01) 2 6 13 20

In [4]:
print(kernel_circuit)

QUBIT_COORDS(1, 1) 1
QUBIT_COORDS(2, 0) 2
QUBIT_COORDS(3, 1) 3
QUBIT_COORDS(5, 1) 5
QUBIT_COORDS(1, 3) 8
QUBIT_COORDS(2, 2) 9
QUBIT_COORDS(3, 3) 10
QUBIT_COORDS(4, 2) 11
QUBIT_COORDS(5, 3) 12
QUBIT_COORDS(6, 2) 13
QUBIT_COORDS(0, 4) 14
QUBIT_COORDS(1, 5) 15
QUBIT_COORDS(2, 4) 16
QUBIT_COORDS(3, 5) 17
QUBIT_COORDS(4, 4) 18
QUBIT_COORDS(5, 5) 19
QUBIT_COORDS(4, 6) 25
R 1 3 5 8 10 12 15 17 19
X_ERROR(0.01) 1 3 5 8 10 12 15 17 19
R 2 9 11 13 14 16 18 25
X_ERROR(0.01) 2 9 11 13 14 16 18 25
TICK
DEPOLARIZE1(0.01) 1 3 5 8 10 12 15 17 19
H 2 11 16 25
DEPOLARIZE1(0.01) 2 11 16 25
TICK
CX 2 3 16 17 11 12 15 14 10 9 19 18
DEPOLARIZE2(0.01) 2 3 16 17 11 12 15 14 10 9 19 18
DEPOLARIZE1(0.01) 13 25
TICK
CX 2 1 16 15 11 10 8 14 3 9 12 18
DEPOLARIZE2(0.01) 2 1 16 15 11 10 8 14 3 9 12 18
DEPOLARIZE1(0.01) 5 13 17 19 25
TICK
CX 16 10 11 5 25 19 8 9 17 18 12 13
DEPOLARIZE2(0.01) 16 10 11 5 25 19 8 9 17 18 12 13
TICK
CX 16 8 11 3 25 17 1 9 10 18 5 13
DEPOLARIZE2(0.01) 16 8 11 3 25 17 1 9 10 18 5 13
DEPOLA

In [5]:
def split_measurements(measurements, d):
  n_measurements = idx_t(measurements.shape[1])
  # Measurements on data qubits come last
  exclude_indices = np.array([-x-1 for x in range(d**2)], dtype=idx_t)
  exclude_indices = exclude_indices + n_measurements
  # Out of all measurements on data qubits, the logical qubit measurements are those on the boundary of the lattice.
  # All other equivalent X_L/Z_L operators can be found through the combination of ancilla measurements and the chosen data qubits giving us the logical qubit.
  exclude_indices_obsL = np.array([-x-1 for x in range(d*(d-1), d**2)], dtype=idx_t)
  exclude_indices_obsL = exclude_indices_obsL + n_measurements
  # From obs_bits, we want to exclude all measurements except those listed in exclude_indices_obsL
  exclude_indices_obs = np.arange(0, n_measurements, 1, dtype=idx_t)
  exclude_indices_obs = np.delete(exclude_indices_obs, exclude_indices_obsL)

  det_bits = measurements
  det_bits = np.delete(det_bits, exclude_indices, axis=1)
  obs_bits = measurements
  obs_bits = np.delete(obs_bits, exclude_indices_obs, axis=1)

  data_bits = measurements[:, exclude_indices]

  # Reverse the order of data_bits because exclude_indices starts from the last data qubit measurement, not the first
  data_bits = np.flip(data_bits, axis=1)

  return det_bits, obs_bits, data_bits


n_measurements = idx_t(measurements.shape[1])
det_bits, obs_bits, data_bits = split_measurements(measurements, d)
print(obs_bits)

# Do the same for the kernels
det_bits_kernel, obs_bits_kernel, _ = split_measurements(measurements_kernel, kernel_size)
print(obs_bits_kernel)

[[0 1 0 0]
 [0 1 1 0]
 [0 1 1 0]
 ...
 [1 1 0 0]
 [0 0 1 0]
 [1 0 0 1]]
[[1 1 1]
 [1 1 0]
 [0 1 1]
 ...
 [0 1 1]
 [0 0 1]
 [0 0 0]]


In [6]:
det_bits_kxk_all, data_bits_kxk_all, obs_bits_kxk_all, kernel_result_translation_map = call_group_det_bits_kxk(det_bits, data_bits_dxd=data_bits)
print(det_bits_kxk_all.shape)
print(data_bits_kxk_all.shape)
print(obs_bits_kxk_all.shape)
print(kernel_result_translation_map.shape)

print(flips[0])
print(det_bits[0])
print(data_bits[0])

kernel_types = get_unique_kernel_types(kernel_size, d)
print(kernel_types)
n_kernels = det_bits_kxk_all.shape[0]
n_kernel_rows = int(np.sqrt(n_kernels))
for k in range(n_kernels):
  print(det_bits_kxk_all[k][0])
  print(data_bits_kxk_all[k][0])
  print(obs_bits_kxk_all[k][0])
  if k % n_kernel_rows == 0:
    print(kernel_result_translation_map[k//n_kernel_rows][0])

det_evts_kxk_all = []
flips_kxk_all = []
for k in range(n_kernels):
  measurements_kxk = np.concatenate((det_bits_kxk_all[k], data_bits_kxk_all[k]), axis=1).astype(np.bool_)
  det_evts_kxk, flips_kxk = converter_kernel.convert(measurements=measurements_kxk, separate_observables=True, bit_packed=False)
  det_evts_kxk_all.append(det_evts_kxk)
  flips_kxk_all.append(flips_kxk)
det_evts_kxk_all = np.array(det_evts_kxk_all, dtype=binary_t)
flips_kxk_all = np.array(flips_kxk_all, dtype=binary_t)
print(det_evts_kxk_all.shape)
print(flips_kxk_all.shape)

(4, 20000000, 16)
(4, 20000000, 9)
(4, 20000000, 3)
(2, 20000000, 2)
[1]
[1 1 1 0 0 0 1 0 0 0 0 1 0 0 0 0 1 0 0 1 0 1 0 0 0 0 1 0 0 1]
[0 1 0 0 1 0 0 1 1 0 0 1 0 1 1 0]
[[[1, 1], [0, 1, 2, 3]]]
[1 1 0 0 0 1 0 1 0 0 0 1 0 1 0 1]
[0 1 0 1 0 0 1 0 0]
[0 1 0]
[0 0]
[1 0 0 1 0 0 0 1 1 1 0 0 0 0 0 1]
[0 0 1 1 0 0 1 0 0]
[1 0 0]
[0 0 1 0 0 1 0 0 0 0 1 0 0 1 0 0]
[0 1 1 1 0 0 1 0 0]
[1 1 0]
[1 1]
[0 0 1 0 0 0 0 0 1 0 1 0 0 0 0 0]
[0 1 1 1 0 0 1 0 0]
[1 1 0]
(4, 20000000, 16)
(4, 20000000, 1)


In [7]:
# Prepare the input for training
class_bits = obs_bits
print(class_bits)

class_bits_packed = np.zeros(class_bits.shape[0], dtype=packed_t)
for iev in range(class_bits.shape[0]):
  class_bits_packed[iev] = call_pack_bits(class_bits[iev])
print(class_bits_packed)

# Make sure the data type is np.int32 below, not idx_t!
idxs_test, idxs_train = split_data(np.arange(n_samples, dtype=np.int32), test_size = n_test/n_samples, seed = 12345, shuffle = False)

features_det_bits = np.swapaxes(det_bits_kxk_all, 0, 1)
features_det_evts = np.swapaxes(det_evts_kxk_all, 0, 1)
features_translation_map = np.swapaxes(kernel_result_translation_map, 0, 1)[:,:,0]
features_final_det_evts = det_evts[:, -((d**2-1)//2):]

print(features_det_bits.shape)
print(features_det_evts.shape)
print(features_translation_map.shape)
print(features_final_det_evts.shape)

[[0 1 0 0]
 [0 1 1 0]
 [0 1 1 0]
 ...
 [1 1 0 0]
 [0 0 1 0]
 [1 0 0 1]]
[2 6 6 ... 3 4 9]
(20000000, 4, 16)
(20000000, 4, 16)
(20000000, 2)
(20000000, 7)


In [8]:
import tensorflow as tf
from tensorflow.keras.layers import Layer
from tensorflow.keras.models import Model

In [9]:
class CNNKernel(Layer):
  def __init__(self, constraint_label, kernel_distance, rounds, npol=1, do_all_data_qubits = False, include_det_evts = True, include_last_dets = False, **kwargs):
    super(CNNKernel, self).__init__()
    self.kernel_distance = kernel_distance
    self.rounds = rounds
    self.npol = npol
    self.include_det_evts = include_det_evts
    self.include_last_dets = include_last_dets
    self.do_all_data_qubits = do_all_data_qubits
    num_outputs = 1
    if self.do_all_data_qubits:
      num_outputs = self.kernel_distance**2

    self.n_ancillas = (self.kernel_distance**2-1)

    ndim1 = self.n_ancillas*rounds # Number of ancilla measurements
    if include_det_evts:
      ndim1 += self.n_ancillas//2 + self.n_ancillas*(rounds-1) # Number of detector event bits within each round
      if include_last_dets:
        ndim1 += self.n_ancillas//2 # One could also include the detector consistency bits after the measurement of data qubits
    self.ndims = []
    for _ in range(self.npol):
      self.ndims.append(ndim1)
    self.ndims.append(num_outputs)

    self.kernel_weights = self.add_weight(
      name=f"CNNkernel{self.kernel_distance}_{constraint_label}_w",
      shape=self.ndims,
      initializer='zeros',
      trainable=True
    )
    self.kernel_bias = self.add_weight(
      name=f"CNNkernel{self.kernel_distance}_{constraint_label}_b",
      shape=[num_outputs],
      initializer='zeros',
      trainable=True
    )

  def build(self, input_shape):
    pass

  def call(self, inputs):
    w = self.kernel_weights
    x = tf.cast(inputs, tf.float32)
    for _ in range(self.npol-1):
      w = tf.matmul(x, w)
    return tf.matmul(x, w) + self.kernel_bias
  

class FullCNNModel(Model):
  def __init__(self, code_distance, kernel_distance, rounds, npol = 1, do_all_data_qubits = False, extended_kernel_output = True, include_det_evts = True, include_last_dets = False, **kwargs):
    super(FullCNNModel, self).__init__()
    self.code_distance = code_distance
    self.kernel_distance = kernel_distance
    self.nshifts = self.code_distance - self.kernel_distance + 1
    self.rounds = rounds
    self.npol = npol
    self.do_all_data_qubits = do_all_data_qubits
    self.extended_kernel_output = extended_kernel_output
    self.include_det_evts = include_det_evts
    self.include_last_dets = include_last_dets

    self.cnn_kernels = []
    self.unique_kernel_types = get_unique_kernel_types(self.kernel_distance, code_distance)
    for kernel_type in self.unique_kernel_types:
      self.cnn_kernels.append(
        CNNKernel(
          f"{kernel_type[0][0]}_{kernel_type[0][1]}",
          self.kernel_distance,
          self.rounds,
          self.npol,
          self.do_all_data_qubits or self.extended_kernel_output,
          self.include_det_evts,
          self.include_last_dets
        )
      )
    
    self.hidden_layers = [
      Dense(100),
      tf.keras.layers.Activation('relu'),
      Dense(100),
      tf.keras.layers.Activation('relu'),
    ]

    self.predictors = [
      Dense(1 if not self.do_all_data_qubits else self.code_distance**2),
      tf.keras.layers.Activation('sigmoid')
    ]

  def call(self, all_inputs):
    det_bits = all_inputs[0]
    det_evts = all_inputs[1]
    translation_coefs = all_inputs[2]
    final_det_evts = all_inputs[3]
    predictor_inputs = []
    for i, cnn_kernel in enumerate(self.cnn_kernels):
      kernel_idxs = self.unique_kernel_types[i][1]
      for k in kernel_idxs:
        kernel_input = None
        det_bits_kernel = det_bits[:,k]
        det_evts_kernel = None
        if self.include_det_evts:
          if self.include_last_dets:
            det_evts_kernel = det_evts[:,k,:]
          else:
            det_evts_kernel = det_evts[:,k,0:-((self.kernel_distance**2-1)//2)]
          kernel_input = tf.concat([det_bits_kernel, det_evts_kernel], axis=1)
        else:
          kernel_input = det_bits_kernel
        kernel_output = cnn_kernel(kernel_input)
        predictor_inputs.append(kernel_output)
    predictor_inputs.append(tf.cast(translation_coefs, tf.float32))
    predictor_inputs.append(tf.cast(final_det_evts, tf.float32))
    predictor_inputs = tf.concat(predictor_inputs, axis=1)
    x = predictor_inputs
    for ll in self.hidden_layers:
      x = ll(x)
    for ll in self.predictors:
      x = ll(x)
    return x

In [10]:
class_bits_dxd = flips

model_dxd = FullCNNModel(d, kernel_size, r, do_all_data_qubits=False, extended_kernel_output=True, include_det_evts=True, include_last_dets=False)
model_dxd.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
tmpv = model_dxd([features_det_bits[0:1], features_det_evts[0:1], features_translation_map[0:1], features_final_det_evts[0:1]])
print(tmpv)
model_dxd.summary()

val_split = 0.2
n_epochs = 20
history = model_dxd.fit(
  x=[features_det_bits[idxs_train], features_det_evts[idxs_train], features_translation_map[idxs_train], features_final_det_evts[idxs_train]],
  y=class_bits_dxd[idxs_train,:],
  epochs=n_epochs, batch_size=10000, validation_split=val_split
)



tf.Tensor([[0.5015393]], shape=(1, 1), dtype=float32)
Model: "full_cnn_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 cnn_kernel (CNNKernel)      multiple                  261       
                                                                 
 dense (Dense)               multiple                  4600      
                                                                 
 activation (Activation)     multiple                  0         
                                                                 
 dense_1 (Dense)             multiple                  10100     
                                                                 
 activation_1 (Activation)   multiple                  0         
                                                                 
 dense_2 (Dense)             multiple                  101       
                                                              

In [11]:
flip_preds = model_dxd.predict([features_det_bits[idxs_test], features_det_evts[idxs_test], features_translation_map[idxs_test], features_final_det_evts[idxs_test]], batch_size=10000)



In [29]:
print(f"Inaccuracy of the final model on the test data: {1.-(flips[idxs_test]==(flip_preds>0.5).astype(binary_t)).astype(binary_t).sum()/idxs_test.shape[0]}")

Inaccuracy of the final model on the test data: 0.04115740000000001


In [28]:
pymatcher_dxd = pymatching.Matching.from_detector_error_model(detector_error_model)
predictions_dxd_pym = pymatcher_dxd.decode_batch(det_evts[idxs_test,:], bit_packed_predictions=False, bit_packed_shots=False).astype(packed_t).reshape(-1,1)
incorrect_matches_dxd = (predictions_dxd_pym != flips[idxs_test,:])
incorrect_rate_dxd_pym = np.sum(incorrect_matches_dxd)/incorrect_matches_dxd.shape[0]
print(f"PyMatching error rate for test data set of the full dxd code: {incorrect_rate_dxd_pym}")

PyMatching error rate for test data set of the full dxd code: 0.0437033
