In [1]:
from matplotlib import pyplot as plt
import pymatching
import numpy as np
from copy import deepcopy
from circuit_generators import *
from sampling_functions import *
from bitpack import pack_bits, unpack_bits
from circuit_partition import *
from utilities_tf import *

# Number of worker nodes
n_worker_nodes = 8

# Surface code specifications
d = 5
r = 2
kernel_size = 3

# Error probabilities
p = 0.01
before_round_data_depolarization = p
after_reset_flip_probability = p
after_clifford_depolarization = p
before_measure_flip_probability = p

use_rotated_z = True
observable_type = "ZL" if use_rotated_z else "XL"

# Bit types
binary_t = np.int8 # Could use even less if numpy allowed
packed_t = np.int8 # Packed bit type
if d<=8:
  pass
elif d>8 and d<=16:
  packed_t = np.int16
elif d>16 and d<=32:
  packed_t = np.int32
elif d>32 and d<=64:
  packed_t = np.int64
elif d>64 and d<=128:
  packed_t = np.int128
elif d>128 and d<=256:
  packed_t = np.int256
else:
  raise RuntimeError("d is too large.")
time_t = np.int8

# Measurement index type
idx_t = np.int8
n_all_measurements = r*(d**2-1) + d**2
if n_all_measurements > np.iinfo(idx_t).max:
  idx_t = np.int16
if n_all_measurements > np.iinfo(idx_t).max:
  idx_t = np.int32
if n_all_measurements > np.iinfo(idx_t).max:
  idx_t = np.int64
if n_all_measurements > np.iinfo(idx_t).max:
  idx_t = np.int128
if n_all_measurements > np.iinfo(idx_t).max:
  idx_t = np.int256
if n_all_measurements > np.iinfo(idx_t).max:
  raise RuntimeError("idx_t is too small.")

# Call signature for circuit_partition::group_det_bits_kxk
call_group_det_bits_kxk = lambda det_bits_dxd, data_bits_dxd=None, d=d, r=r, k=kernel_size, use_rotated_z=use_rotated_z, binary_t=binary_t, idx_t=idx_t: group_det_bits_kxk(det_bits_dxd, d, r, k, use_rotated_z, data_bits_dxd, binary_t, idx_t)




In [2]:
n_test = 5000000
n_train = 5000000
n_samples = n_test + n_train
decoders = ['pymatching']
test_circuit = get_builtin_circuit(
  "surface_code:rotated_memory_"+('z' if use_rotated_z else 'x'),
  distance=d,
  rounds=r,
  before_round_data_depolarization = before_round_data_depolarization,
  after_reset_flip_probability = after_reset_flip_probability,
  after_clifford_depolarization = after_clifford_depolarization,
  before_measure_flip_probability = before_measure_flip_probability
)

kernel_circuit_extra_depol1 = [
  [
    f"DEPOLARIZE1({after_clifford_depolarization}) 13 25",
    f"DEPOLARIZE1({after_clifford_depolarization}) 5 13 17 19 25",
    f"#DEPOLARIZE1({after_clifford_depolarization})",
    f"DEPOLARIZE1({after_clifford_depolarization}) 12 15 19",
  ], # parity = (1, 1)
  [
    f"DEPOLARIZE1({after_clifford_depolarization}) 8 13 25",
    f"DEPOLARIZE1({after_clifford_depolarization}) 5 13 17 19 25",
    f"DEPOLARIZE1({after_clifford_depolarization}) 1 14 15",
    f"DEPOLARIZE1({after_clifford_depolarization}) 12 14 15 19",
  ], # parity = (0, 1)
  [
    f"DEPOLARIZE1({after_clifford_depolarization}) 8 25",
    f"DEPOLARIZE1({after_clifford_depolarization}) 17 25",
    f"DEPOLARIZE1({after_clifford_depolarization}) 1 14 15",
    f"DEPOLARIZE1({after_clifford_depolarization}) 14 15",
  ], # parity = (-1, 1)
  [
    f"DEPOLARIZE1({after_clifford_depolarization}) 5 13 25",
    f"DEPOLARIZE1({after_clifford_depolarization}) 5 13 17 19 25",
    f"DEPOLARIZE1({after_clifford_depolarization}) 2 3",
    f"DEPOLARIZE1({after_clifford_depolarization}) 2 12 15 19",
  ], # parity = (1, 0)
  [
    f"DEPOLARIZE1({after_clifford_depolarization}) 1 5 8 13 25",
    f"DEPOLARIZE1({after_clifford_depolarization}) 5 13 17 19 25",
    f"DEPOLARIZE1({after_clifford_depolarization}) 1 2 3 14 15",
    f"DEPOLARIZE1({after_clifford_depolarization}) 2 12 14 15 19",
  ] # parity = (0, 0)
]
kernel_circuits = []
for replace_args in kernel_circuit_extra_depol1:
  kernel_circuit_template = \
  f"""
QUBIT_COORDS(1, 1) 1
QUBIT_COORDS(2, 0) 2
QUBIT_COORDS(3, 1) 3
QUBIT_COORDS(5, 1) 5
QUBIT_COORDS(1, 3) 8
QUBIT_COORDS(2, 2) 9
QUBIT_COORDS(3, 3) 10
QUBIT_COORDS(4, 2) 11
QUBIT_COORDS(5, 3) 12
QUBIT_COORDS(6, 2) 13
QUBIT_COORDS(0, 4) 14
QUBIT_COORDS(1, 5) 15
QUBIT_COORDS(2, 4) 16
QUBIT_COORDS(3, 5) 17
QUBIT_COORDS(4, 4) 18
QUBIT_COORDS(5, 5) 19
QUBIT_COORDS(4, 6) 25
R 1 3 5 8 10 12 15 17 19
X_ERROR({after_reset_flip_probability}) 1 3 5 8 10 12 15 17 19
R 2 9 11 13 14 16 18 25
X_ERROR({after_reset_flip_probability}) 2 9 11 13 14 16 18 25
TICK
DEPOLARIZE1({before_round_data_depolarization}) 1 3 5 8 10 12 15 17 19
H 2 11 16 25
DEPOLARIZE1({after_clifford_depolarization}) 2 11 16 25
TICK
CX 2 3 16 17 11 12 15 14 10 9 19 18
DEPOLARIZE2({after_clifford_depolarization}) 2 3 16 17 11 12 15 14 10 9 19 18
{replace_args[0]}
TICK
CX 2 1 16 15 11 10 8 14 3 9 12 18
DEPOLARIZE2({after_clifford_depolarization}) 2 1 16 15 11 10 8 14 3 9 12 18
{replace_args[1]}
TICK
CX 16 10 11 5 25 19 8 9 17 18 12 13
DEPOLARIZE2({after_clifford_depolarization}) 16 10 11 5 25 19 8 9 17 18 12 13
{replace_args[2]}
TICK
CX 16 8 11 3 25 17 1 9 10 18 5 13
DEPOLARIZE2({after_clifford_depolarization}) 16 8 11 3 25 17 1 9 10 18 5 13
{replace_args[3]}
TICK
H 2 11 16 25
DEPOLARIZE1({after_clifford_depolarization}) 2 11 16 25
TICK
X_ERROR({before_measure_flip_probability}) 2 9 11 13 14 16 18 25
MR 2 9 11 13 14 16 18 25
X_ERROR({after_reset_flip_probability}) 2 9 11 13 14 16 18 25
DETECTOR(0, 4, 0) rec[-4]
DETECTOR(2, 2, 0) rec[-7]
DETECTOR(4, 4, 0) rec[-2]
DETECTOR(6, 2, 0) rec[-5]
REPEAT {r-1} {{
  TICK
  DEPOLARIZE1({before_round_data_depolarization}) 1 3 5 8 10 12 15 17 19
  H 2 11 16 25
  DEPOLARIZE1({after_clifford_depolarization}) 2 11 16 25
  TICK
  CX 2 3 16 17 11 12 15 14 10 9 19 18
  DEPOLARIZE2({after_clifford_depolarization}) 2 3 16 17 11 12 15 14 10 9 19 18
  {replace_args[0]}
  TICK
  CX 2 1 16 15 11 10 8 14 3 9 12 18
  DEPOLARIZE2({after_clifford_depolarization}) 2 1 16 15 11 10 8 14 3 9 12 18
  {replace_args[1]}
  TICK
  CX 16 10 11 5 25 19 8 9 17 18 12 13
  DEPOLARIZE2({after_clifford_depolarization}) 16 10 11 5 25 19 8 9 17 18 12 13
  {replace_args[2]}
  TICK
  CX 16 8 11 3 25 17 1 9 10 18 5 13
  DEPOLARIZE2({after_clifford_depolarization}) 16 8 11 3 25 17 1 9 10 18 5 13
  {replace_args[3]}
  TICK
  H 2 11 16 25
  DEPOLARIZE1({after_clifford_depolarization}) 2 11 16 25
  TICK
  X_ERROR({before_measure_flip_probability}) 2 9 11 13 14 16 18 25
  MR 2 9 11 13 14 16 18 25
  X_ERROR({after_reset_flip_probability}) 2 9 11 13 14 16 18 25
  SHIFT_COORDS(0, 0, 1)
  DETECTOR(2, 0, 0) rec[-8] rec[-16]
  DETECTOR(2, 2, 0) rec[-7] rec[-15]
  DETECTOR(4, 2, 0) rec[-6] rec[-14]
  DETECTOR(6, 2, 0) rec[-5] rec[-13]
  DETECTOR(0, 4, 0) rec[-4] rec[-12]
  DETECTOR(2, 4, 0) rec[-3] rec[-11]
  DETECTOR(4, 4, 0) rec[-2] rec[-10]
  DETECTOR(4, 6, 0) rec[-1] rec[-9]
}}
X_ERROR({before_measure_flip_probability}) 1 3 5 8 10 12 15 17 19
M 1 3 5 8 10 12 15 17 19
DETECTOR(0, 4, 1) rec[-3] rec[-6] rec[-13]
DETECTOR(2, 2, 1) rec[-5] rec[-6] rec[-8] rec[-9] rec[-16]
DETECTOR(4, 4, 1) rec[-1] rec[-2] rec[-4] rec[-5] rec[-11]
DETECTOR(6, 2, 1) rec[-4] rec[-7] rec[-14]
OBSERVABLE_INCLUDE(0) rec[-7] rec[-8] rec[-9]
  """
  kernel_circuits.append(stim.Circuit(kernel_circuit_template))


# Sampling for the dxd circuit
m_sampler = test_circuit.compile_sampler(seed=12345)
d_sampler = test_circuit.compile_detector_sampler(seed=12345)
converter = test_circuit.compile_m2d_converter()
detector_error_model = test_circuit.detector_error_model(decompose_errors=True)

measurements = m_sampler.sample(n_samples, bit_packed=False)
det_evts, flips = converter.convert(measurements=measurements, separate_observables=True, bit_packed=False)
measurements = measurements.astype(binary_t)
det_evts = det_evts.astype(binary_t)
flips = flips.astype(binary_t)

avg_flips = np.sum(flips.reshape(-1,), dtype=np.float32)/flips.shape[0]
print(f"Average flip rate for the full circuit: {avg_flips}")

Average flip rate for the full circuit: 0.2297534


In [3]:
print(test_circuit)

QUBIT_COORDS(1, 1) 1
QUBIT_COORDS(2, 0) 2
QUBIT_COORDS(3, 1) 3
QUBIT_COORDS(5, 1) 5
QUBIT_COORDS(6, 0) 6
QUBIT_COORDS(7, 1) 7
QUBIT_COORDS(9, 1) 9
QUBIT_COORDS(1, 3) 12
QUBIT_COORDS(2, 2) 13
QUBIT_COORDS(3, 3) 14
QUBIT_COORDS(4, 2) 15
QUBIT_COORDS(5, 3) 16
QUBIT_COORDS(6, 2) 17
QUBIT_COORDS(7, 3) 18
QUBIT_COORDS(8, 2) 19
QUBIT_COORDS(9, 3) 20
QUBIT_COORDS(10, 2) 21
QUBIT_COORDS(0, 4) 22
QUBIT_COORDS(1, 5) 23
QUBIT_COORDS(2, 4) 24
QUBIT_COORDS(3, 5) 25
QUBIT_COORDS(4, 4) 26
QUBIT_COORDS(5, 5) 27
QUBIT_COORDS(6, 4) 28
QUBIT_COORDS(7, 5) 29
QUBIT_COORDS(8, 4) 30
QUBIT_COORDS(9, 5) 31
QUBIT_COORDS(1, 7) 34
QUBIT_COORDS(2, 6) 35
QUBIT_COORDS(3, 7) 36
QUBIT_COORDS(4, 6) 37
QUBIT_COORDS(5, 7) 38
QUBIT_COORDS(6, 6) 39
QUBIT_COORDS(7, 7) 40
QUBIT_COORDS(8, 6) 41
QUBIT_COORDS(9, 7) 42
QUBIT_COORDS(10, 6) 43
QUBIT_COORDS(0, 8) 44
QUBIT_COORDS(1, 9) 45
QUBIT_COORDS(2, 8) 46
QUBIT_COORDS(3, 9) 47
QUBIT_COORDS(4, 8) 48
QUBIT_COORDS(5, 9) 49
QUBIT_COORDS(6, 8) 50
QUBIT_COORDS(7, 9) 51
QUBIT_COORDS(8,

In [4]:
for i, kernel_circuit in enumerate(kernel_circuits):
  print(f"Kernel circuit {i}:")
  print(kernel_circuit)

Kernel circuit 0:
QUBIT_COORDS(1, 1) 1
QUBIT_COORDS(2, 0) 2
QUBIT_COORDS(3, 1) 3
QUBIT_COORDS(5, 1) 5
QUBIT_COORDS(1, 3) 8
QUBIT_COORDS(2, 2) 9
QUBIT_COORDS(3, 3) 10
QUBIT_COORDS(4, 2) 11
QUBIT_COORDS(5, 3) 12
QUBIT_COORDS(6, 2) 13
QUBIT_COORDS(0, 4) 14
QUBIT_COORDS(1, 5) 15
QUBIT_COORDS(2, 4) 16
QUBIT_COORDS(3, 5) 17
QUBIT_COORDS(4, 4) 18
QUBIT_COORDS(5, 5) 19
QUBIT_COORDS(4, 6) 25
R 1 3 5 8 10 12 15 17 19
X_ERROR(0.01) 1 3 5 8 10 12 15 17 19
R 2 9 11 13 14 16 18 25
X_ERROR(0.01) 2 9 11 13 14 16 18 25
TICK
DEPOLARIZE1(0.01) 1 3 5 8 10 12 15 17 19
H 2 11 16 25
DEPOLARIZE1(0.01) 2 11 16 25
TICK
CX 2 3 16 17 11 12 15 14 10 9 19 18
DEPOLARIZE2(0.01) 2 3 16 17 11 12 15 14 10 9 19 18
DEPOLARIZE1(0.01) 13 25
TICK
CX 2 1 16 15 11 10 8 14 3 9 12 18
DEPOLARIZE2(0.01) 2 1 16 15 11 10 8 14 3 9 12 18
DEPOLARIZE1(0.01) 5 13 17 19 25
TICK
CX 16 10 11 5 25 19 8 9 17 18 12 13
DEPOLARIZE2(0.01) 16 10 11 5 25 19 8 9 17 18 12 13
TICK
CX 16 8 11 3 25 17 1 9 10 18 5 13
DEPOLARIZE2(0.01) 16 8 11 3 25 17 1 9

In [5]:
def split_measurements(measurements, d):
  n_measurements = idx_t(measurements.shape[1])
  # Measurements on data qubits come last
  exclude_indices = np.array([-x-1 for x in range(d**2)], dtype=idx_t)
  exclude_indices = exclude_indices + n_measurements
  # Out of all measurements on data qubits, the logical qubit measurements are those on the boundary of the lattice.
  # All other equivalent X_L/Z_L operators can be found through the combination of ancilla measurements and the chosen data qubits giving us the logical qubit.
  exclude_indices_obsL = np.array([-x-1 for x in range(d*(d-1), d**2)], dtype=idx_t)
  exclude_indices_obsL = exclude_indices_obsL + n_measurements
  # From obs_bits, we want to exclude all measurements except those listed in exclude_indices_obsL
  exclude_indices_obs = np.arange(0, n_measurements, 1, dtype=idx_t)
  exclude_indices_obs = np.delete(exclude_indices_obs, exclude_indices_obsL)

  det_bits = measurements
  det_bits = np.delete(det_bits, exclude_indices, axis=1)
  obs_bits = measurements
  obs_bits = np.delete(obs_bits, exclude_indices_obs, axis=1)

  data_bits = measurements[:, exclude_indices]

  # Reverse the order of data_bits because exclude_indices starts from the last data qubit measurement, not the first
  data_bits = np.flip(data_bits, axis=1)

  return det_bits, obs_bits, data_bits


n_measurements = idx_t(measurements.shape[1])
det_bits, obs_bits, data_bits = split_measurements(measurements, d)
print(obs_bits)

[[1 0 0 1 0]
 [0 0 0 1 1]
 [0 1 1 1 1]
 ...
 [0 0 0 1 1]
 [0 0 1 1 0]
 [0 0 0 0 1]]


In [6]:
det_bits_kxk_all, data_bits_kxk_all, obs_bits_kxk_all, kernel_result_translation_map = call_group_det_bits_kxk(det_bits, data_bits_dxd=data_bits)
print(det_bits_kxk_all.shape)
print(data_bits_kxk_all.shape)
print(obs_bits_kxk_all.shape)
print(kernel_result_translation_map.shape)

print(flips[0])
print(det_bits[0])
print(data_bits[0])

kernel_types = get_unique_kernel_types(kernel_size, d)
print(kernel_types)
n_kernels = det_bits_kxk_all.shape[0]
n_kernel_rows = int(np.sqrt(n_kernels))
for k in range(n_kernels):
  print(det_bits_kxk_all[k][0])
  print(data_bits_kxk_all[k][0])
  print(obs_bits_kxk_all[k][0])
  if k % n_kernel_rows == 0:
    print(kernel_result_translation_map[k//n_kernel_rows][0])

(9, 10000000, 16)
(9, 10000000, 9)
(9, 10000000, 3)
(3, 10000000, 2)
[0]
[0 1 1 0 0 1 0 0 1 0 1 0 0 1 0 1 0 0 1 0 0 0 0 0 0 1 0 1 0 1 0 1 1 0 1 0 0
 1 0 1 0 1 0 0 0 0 0 0]
[1 0 0 1 0 0 1 0 1 0 1 1 0 0 1 0 0 1 1 1 1 1 0 1 1]
[[[1, 1], [0, 8]], [[0, 1], [1, 7]], [[-1, 1], [2, 6]], [[1, 0], [3, 5]], [[0, 0], [4]]]
[0 1 0 0 0 1 0 1 0 0 1 0 1 1 0 1]
[1 0 0 0 1 0 1 1 0]
[0 0 1]
[0 0]
[1 0 0 1 0 1 0 1 1 0 1 0 0 1 0 1]
[1 0 0 1 0 1 0 0 1]
[0 0 1]
[1 0 1 0 0 1 0 1 1 0 1 0 0 1 0 1]
[0 1 0 0 1 0 0 0 1]
[0 1 0]
[1 0 1 0 0 1 0 0 0 0 1 0 1 1 0 1]
[0 0 1 1 1 0 0 1 0]
[1 0 0]
[1 0]
[0 0 1 0 0 1 0 0 1 0 1 0 0 1 0 0]
[1 0 1 1 0 0 0 1 1]
[1 0 1]
[1 0 1 0 0 1 0 0 1 0 1 0 0 1 0 0]
[0 1 0 1 0 0 1 1 1]
[0 1 0]
[0 0 1 0 0 1 0 1 0 0 0 1 0 1 0 1]
[0 1 1 1 0 0 0 1 1]
[1 1 0]
[1 0]
[0 0 0 0 0 1 0 1 0 0 0 0 0 1 0 1]
[1 0 1 0 1 1 1 0 0]
[1 0 1]
[0 0 0 0 0 1 0 1 0 0 0 0 0 1 0 1]
[1 1 0 1 1 1 1 0 0]
[0 1 1]


In [7]:
kernel_result_translation_map_f = kernel_result_translation_map[:,:,1:]
kernel_result_translation_map_b = kernel_result_translation_map[:,:,0:-1]
print(kernel_result_translation_map_f.shape)
print(kernel_result_translation_map_b.shape)
kernel_result_translation_det_evts = (kernel_result_translation_map_f!=kernel_result_translation_map_b).astype(binary_t)
print(kernel_result_translation_det_evts.shape)
print(kernel_result_translation_det_evts[:,0,:])
kernel_result_translation_map = np.concatenate((kernel_result_translation_map, kernel_result_translation_det_evts), axis=2)

(3, 10000000, 1)
(3, 10000000, 1)
(3, 10000000, 1)
[[0]
 [1]
 [1]]


In [8]:
det_evts_kxk_all = []
flips_kxk_all = []
converters_kernel = []
for kernel_circuit in kernel_circuits:
  converters_kernel.append(kernel_circuit.compile_m2d_converter())
for k in range(n_kernels):
  measurements_kxk = np.concatenate((det_bits_kxk_all[k], data_bits_kxk_all[k]), axis=1).astype(np.bool_)
  ik = 0
  for i, kernel_type in enumerate(kernel_types):
    if k in kernel_type[1]:
      ik = i
      break
  det_evts_kxk, flips_kxk = converters_kernel[ik].convert(measurements=measurements_kxk, separate_observables=True, bit_packed=False)
  det_evts_kxk_all.append(det_evts_kxk)
  flips_kxk_all.append(flips_kxk)
det_evts_kxk_all = np.array(det_evts_kxk_all, dtype=binary_t)
flips_kxk_all = np.array(flips_kxk_all, dtype=binary_t)
print(det_evts_kxk_all.shape)
print(flips_kxk_all.shape)
del converters_kernel

(9, 10000000, 16)
(9, 10000000, 1)


In [9]:
# Make sure the data type is np.int32 below, not idx_t!
idxs_test, idxs_train = split_data(np.arange(n_samples, dtype=np.int32), test_size = n_test/n_samples, seed = 12345, shuffle = False)

class_bits = flips
features_det_bits = np.swapaxes(det_bits_kxk_all, 0, 1)
features_det_evts = np.swapaxes(det_evts_kxk_all, 0, 1)
features_translation_map = np.swapaxes(kernel_result_translation_map, 0, 1) # Dimensions go as [sample][kernel][cycle bits + detections (n_cycles-1)]
features_translation_map = np.reshape(features_translation_map, (features_translation_map.shape[0], -1))
features_final_det_evts = det_evts[:, -((d**2-1)//2):]

print(features_det_bits.shape)
print(features_det_evts.shape)
print(features_translation_map.shape)
print(features_final_det_evts.shape)

(10000000, 9, 16)
(10000000, 9, 16)
(10000000, 9)
(10000000, 12)


In [10]:
import tensorflow as tf
from tensorflow.keras.layers import Layer
from tensorflow.keras.models import Model
import tensorflow_probability as tfp





In [14]:
class CNNKernel(Layer):
  def __init__(
      self,
      kernel_type, kernel_distance, rounds,
      npol=1,
      do_all_data_qubits = False,
      include_det_evts = True,
      n_remove_last_det_evts = 0,
      **kwargs
    ):
    super(CNNKernel, self).__init__(**kwargs)
    self.kernel_distance = kernel_distance
    self.rounds = rounds
    self.npol = npol
    self.do_all_data_qubits = do_all_data_qubits
    self.include_det_evts = include_det_evts
    self.n_remove_last_det_evts = n_remove_last_det_evts
    self.n_ancillas = (self.kernel_distance**2 - 1)

    constraint_label = f"{kernel_type[0]}_{kernel_type[1]}"
    num_outputs = 1
    if self.do_all_data_qubits:
      if kernel_type[0]==0 and kernel_type[1]==0:
        num_outputs = self.kernel_distance*(self.kernel_distance + 1)//2
      else:
        num_outputs = self.kernel_distance**2

    ndim1 = self.n_ancillas*rounds # Number of ancilla measurements
    ndim2 = 0 # Number of detector events
    if self.include_det_evts:
      ndim2 += self.n_ancillas//2 + self.n_ancillas*(rounds-1) # Number of detector event bits within each round
      ndim2 += self.n_ancillas//2 - self.n_remove_last_det_evts
    if self.npol>1:
      ndim1 = ndim1*(ndim1+1)//2
      ndim2 = ndim2*(ndim2+1)//2

    self.ndims = [ [ ndim1, num_outputs ], [ ndim2, num_outputs ] ]
    self.kernel_weights_det_bits = self.add_weight(
      name=f"CNNkernel{self.kernel_distance}_{constraint_label}_w_det_bits",
      shape=self.ndims[0],
      initializer='zeros',
      trainable=True
    )
    self.kernel_weights_det_evts = None
    if self.include_det_evts:
      self.kernel_weights_det_evts = self.add_weight(
        name=f"CNNkernel{self.kernel_distance}_{constraint_label}_w_det_evts",
        shape=self.ndims[1],
        initializer='zeros',
        trainable=True
      )
    self.kernel_bias = None
    self.kernel_activation = None
    if self.npol<=1:
      self.kernel_bias = self.add_weight(
        name=f"CNNkernel{self.kernel_distance}_{constraint_label}_b",
        shape=[ num_outputs ],
        initializer='zeros',
        trainable=True
      )
    self.kernel_activation = tf.keras.activations.sigmoid

  def build(self, input_shape):
    pass

  def transform_inputs(self, x):
    res = None
    if self.npol==1:
      res = x
    else:
      res = tfp.math.fill_triangular_inverse(
        tf.matmul(
          tf.cast(tf.reshape((x+1), shape=(x.shape[0], x.shape[1], 1)), tf.int16),
          tf.cast(tf.reshape((x+1), shape=(x.shape[0], 1, x.shape[1])), tf.int16)
        )
      )
      res = (-res*res+res*9-14)/6 # (1, 2, 4) -> (-1, 0, 1)
    return tf.cast(res, tf.float32)

  def evaluate(self, bits, do_evts):
    w = None
    if not do_evts:
      w = self.kernel_weights_det_bits
    else:
      w = self.kernel_weights_det_evts
    return tf.matmul(self.transform_inputs(bits), w)

  def call(self, inputs):
    res = self.evaluate(inputs[0], False)
    if self.include_det_evts:
      res = res + self.evaluate(inputs[1], True)
    if self.kernel_bias is not None:
      res = res + self.kernel_bias
    if self.kernel_activation is not None:
      res = self.kernel_activation(res)
    return res
  

class FullCNNModel(Model):
  def __init__(
      self,
      obs_type, code_distance, kernel_distance, rounds,
      hidden_specs,
      npol = 1,
      do_all_data_qubits = False,
      extended_kernel_output = True,
      include_det_evts = True,
      include_last_kernel_dets = False,
      include_last_dets = True,
      has_nonuniform_response = False,
      **kwargs
    ):
    super(FullCNNModel, self).__init__(**kwargs)
    self.obs_type = obs_type
    self.code_distance = code_distance
    self.kernel_distance = kernel_distance
    self.kernel_half_distance = kernel_distance//2
    self.n_kernel_last_det_evts = (self.kernel_distance**2-1)//2
    self.nshifts = self.code_distance - self.kernel_distance + 1
    self.rounds = rounds
    self.npol = npol
    self.do_all_data_qubits = do_all_data_qubits
    self.extended_kernel_output = extended_kernel_output
    self.include_det_evts = include_det_evts
    self.include_last_kernel_dets = include_last_kernel_dets
    self.include_last_dets = include_last_dets
    self.has_nonuniform_response = has_nonuniform_response

    self.cnn_kernels = []
    self.unique_kernel_types = get_unique_kernel_types(self.kernel_distance, code_distance)
    for kernel_type in self.unique_kernel_types:
      n_remove_last_dets = 0
      kernel_parity = kernel_type[0]
      if self.include_det_evts:
        if self.include_last_kernel_dets:
          if self.obs_type=="ZL":
            if kernel_parity[0]==0:
              n_remove_last_dets = 2
            elif kernel_parity[0]==1 and self.code_distance>self.kernel_distance:
              n_remove_last_dets = 1
            elif kernel_parity[0]==-1:
              n_remove_last_dets = 1
          elif self.obs_type=="XL":
            if kernel_parity[1]==0:
              n_remove_last_dets = 2
            elif kernel_parity[1]==1 and self.code_distance>self.kernel_distance:
              n_remove_last_dets = 1
            elif kernel_parity[1]==-1:
              n_remove_last_dets = 1
        else:
          n_remove_last_dets = (self.kernel_distance**2-1)//2

      self.cnn_kernels.append(
        CNNKernel(
          kernel_parity,
          self.kernel_distance,
          self.rounds,
          self.npol,
          self.do_all_data_qubits or self.extended_kernel_output,
          self.include_det_evts,
          n_remove_last_dets
        )
      )
    
    self.translation_coef_transform = Dense(self.nshifts)
    self.translation_coef_transform_act = tf.keras.layers.Activation('sigmoid')

    dqubit_kernel_contribs = [ [] for _ in range(self.code_distance**2) ]
    for shifty in range(self.nshifts):
      for shiftx in range(self.nshifts):
        ikernel = shiftx+shifty*self.nshifts
        ktype = None
        is_symmetric = False
        for iktype, kernel_type in enumerate(self.unique_kernel_types):
          if ikernel in kernel_type[1]:
            ktype = iktype
            if kernel_type[0][0]==0 and kernel_type[0][1]==0:
              is_symmetric = True
            break
        _, _, flip_x, flip_y = get_kernel_parity_flips(self.nshifts, shiftx, shifty)
        for ky in range(-self.kernel_half_distance,self.kernel_half_distance+1):
          iy = ky if not flip_y else -ky
          jy = self.kernel_half_distance+iy
          if shifty+jy<0:
            continue
          for kx in range(-self.kernel_half_distance,self.kernel_half_distance+1):
            ix = kx if not flip_x else -kx
            jx = self.kernel_half_distance+ix
            if shiftx+jx<0:
              continue
            ox = kx
            oy = ky
            if is_symmetric and ox<oy:
              ox = -ox
              oy = -oy
            idx_kqubit = (oy+self.kernel_half_distance)*self.kernel_distance + (ox+self.kernel_half_distance)
            if is_symmetric:
              idx_kqubit = idx_kqubit - (oy+self.kernel_half_distance)*(oy+self.kernel_half_distance+1)//2
            idx_dqubit = (shiftx+jx) + (shifty+jy)*self.code_distance
            found = False
            for dqkcs in dqubit_kernel_contribs[idx_dqubit]:
              if dqkcs[0][0]==ktype and dqkcs[0][1]==idx_kqubit:
                dqkcs[1].append(ikernel)
                found = True
                break
            if not found:
              dqubit_kernel_contribs[idx_dqubit].append([[ ktype, idx_kqubit ], [ ikernel ]]) # Kernel type index, qubit index within kernel
    for dqkcs in dqubit_kernel_contribs:
      dqkcs.sort(key=lambda x: x[0][0]*self.kernel_distance**2 + x[0][1])
    self.unique_dqubit_kernel_contribs = []
    for idq, dqkcs in enumerate(dqubit_kernel_contribs):
      type_contribs = []
      kernel_idxs = []
      for ddd in dqkcs:
        type_contribs.append(ddd[0])
        kernel_idxs.append(ddd[1])
      found = False
      #print(f"Data qubit {idq} -> ktype={type_contribs}, kernel indices = {kernel_idxs}")
      dq_kernelidx_map = [idq, kernel_idxs]
      for udkc in self.unique_dqubit_kernel_contribs:
        if type_contribs==udkc[0]:
          found = True
          udkc[1].append(dq_kernelidx_map)
          break
      if not found:
        self.unique_dqubit_kernel_contribs.append([type_contribs, [ dq_kernelidx_map ]])

    total_nfracs = 0
    total_nphases = 0
    self.frac_params = []
    self.phase_params = []
    for iudkc, udkc in enumerate(self.unique_dqubit_kernel_contribs):
      #print(f"Kernel type = {udkc[0]}")
      #for uuu in udkc[1]:
      #  print(f"- Data qubit {uuu[0]} maps to kernels {uuu[1]}")
      np = len(udkc[0])
      nfr = np-1
      nph = np*(np-1)//2
      total_nfracs += nfr
      total_nphases += nph
      if nfr>0:
        udkc.append(
          self.add_weight(
            name=f"TranslationFrac_{iudkc}",
            shape=[ nfr ],
            initializer='zeros',
            trainable=True
          )
        )
      else:
        udkc.append(None)
      if nph>0:
        udkc.append(
          self.add_weight(
            name=f"TranslationPhase_{iudkc}",
            shape=[ nph ],
            initializer='zeros',
            trainable=True
          )
        )
      else:
        udkc.append(None)
    print(f"Total number of fractions: {total_nfracs}")
    print(f"Total number of phases: {total_nphases}")
    self.frac_activation = tf.keras.activations.sigmoid # We actually need cos(phi), so there is an activation function
    self.phase_activation = tf.keras.activations.tanh # We actually need cos(phi), so there is an activation function

    self.noutputs_final = 1 if not self.do_all_data_qubits else self.code_distance**2
    self.first_dense_nout = None
    self.upper_layers = []
    for hs in hidden_specs:
      if self.first_dense_nout is None:
        self.first_dense_nout = hs
      self.upper_layers.append(Dense(hs))
      self.upper_layers.append(tf.keras.layers.Activation('relu'))

    if self.first_dense_nout is None: # only happens if there are no hidden layers
      self.first_dense_nout = self.noutputs_final
    self.data_qubit_pred_eval_layer = Dense(self.first_dense_nout, use_bias=False)
    self.upper_layers.append(Dense(self.noutputs_final))
    self.upper_layers.append(tf.keras.layers.Activation('sigmoid'))

  def eval_final_data_qubit_pred_layer(self, data_qubit_final_preds):
    # We assume data_qubit_final_preds is flat along axis=1
    return self.data_qubit_pred_eval_layer(data_qubit_final_preds)


  def call(self, all_inputs):
    det_bits = all_inputs[0]
    det_evts = all_inputs[1]
    translation_coefs = all_inputs[2]
    final_det_evts = all_inputs[3]

    kernel_outputs = dict()
    predictor_inputs = []

    predictor_inputs.append(tf.cast(translation_coefs, tf.float32))
    if self.include_det_evts and self.include_last_dets:
      predictor_inputs.append(tf.cast(final_det_evts, tf.float32))
    predictor_inputs = tf.concat(predictor_inputs, axis=1)

    translation_coefs_transformed = self.translation_coef_transform_act(self.translation_coef_transform(translation_coefs))
    translation_coefs_transformed = tf.reshape(translation_coefs_transformed, (translation_coefs_transformed.shape[0], translation_coefs_transformed.shape[1], 1))
    if self.extended_kernel_output:
      translation_coefs_transformed = tf.repeat(translation_coefs_transformed, self.kernel_distance**2, axis=2)

    for i, cnn_kernel in enumerate(self.cnn_kernels):
      kernel_parity = self.unique_kernel_types[i][0]
      kernel_idxs = self.unique_kernel_types[i][1]
      for k in kernel_idxs:
        kernel_input = None
        det_bits_kernel = det_bits[:,k]
        if self.include_det_evts:
          det_evts_kernel = det_evts[:,k,0:-self.n_kernel_last_det_evts]
          if self.include_last_kernel_dets:
            det_evts_kernel_end = det_evts[:,k,-self.n_kernel_last_det_evts:]
            if self.obs_type=="ZL":
              if kernel_parity[0]==0:
                det_evts_kernel_end = det_evts_kernel_end[:,1:-1]
              elif kernel_parity[0]==1 and self.code_distance>self.kernel_distance:
                det_evts_kernel_end = det_evts_kernel_end[:,:-1]
              elif kernel_parity[0]==-1:
                det_evts_kernel_end = det_evts_kernel_end[:,1:]
            elif self.obs_type=="XL":
              if kernel_parity[1]==0:
                det_evts_kernel_end = det_evts_kernel_end[:,1:-1]
              elif kernel_parity[1]==1 and self.code_distance>self.kernel_distance:
                det_evts_kernel_end = det_evts_kernel_end[:,:-1]
              elif kernel_parity[1]==-1:
                det_evts_kernel_end = det_evts_kernel_end[:,1:]
            det_evts_kernel = tf.concat([det_evts_kernel, det_evts_kernel_end], axis=1)
          kernel_input = [ det_bits_kernel, det_evts_kernel ]
        else:
          kernel_input = [ det_bits_kernel ]
        kernel_output = cnn_kernel(kernel_input)
        if self.extended_kernel_output:
          shift_x = k % self.nshifts
          shift_y = k // self.nshifts
          k_shift = shift_y if self.obs_type=="ZL" else shift_x
          kernel_output = tf.math.pow(
              kernel_output,
              1.-translation_coefs_transformed[:,k_shift,0:kernel_output.shape[1]]
            )*tf.math.pow(
              1.-kernel_output,
              translation_coefs_transformed[:,k_shift,0:kernel_output.shape[1]]
            )
        kernel_outputs[k] = [i, kernel_output/(1.-kernel_output)] # [Kernel unique type index, transformed kernel output]
    
    data_qubit_idxs_preds = []
    for udkc in self.unique_dqubit_kernel_contribs:
      kernel_type_contribs = udkc[0]
      data_qubit_idxs = udkc[1]
      frac_params = udkc[2]
      phase_params = udkc[3]
      frac_values = None
      phase_values = None
      if frac_params is not None:
        frac_values = self.frac_activation(frac_params)
      if phase_params is not None:
        phase_values = self.phase_activation(phase_params)

      for idq_idkqs in data_qubit_idxs:
        idq = idq_idkqs[0]
        idkqs = idq_idkqs[1]
        sum_kouts = None
        sum_inputs = []
        for iktype, idkq in enumerate(idkqs):
          ktype = kernel_type_contribs[iktype]
          kout = None
          for ikq in idkq:
            if kout is None:
              kout = kernel_outputs[ikq][1][:,ktype[1]]
            else:
              kout = kout + kernel_outputs[ikq][1][:,ktype[1]]
            if frac_params is not None:
              frac = None
              for ifrac in range(min(frac_params.shape[0],iktype+1)):
                frac_tmp = frac_values[ifrac]
                if ifrac!=iktype:
                  frac_tmp = 1.-frac_tmp
                if frac is None:
                  frac = frac_tmp
                else:
                  frac = frac*frac_tmp
              kout = kout*frac
          if sum_kouts is None:
            sum_kouts = kout
          else:
            sum_kouts = sum_kouts + kout
          sum_inputs.append(kout)
        n_sum_inputs = len(sum_inputs)
        if phase_params is not None:
          if n_sum_inputs*(n_sum_inputs-1)//2!=phase_params.shape[0]:
            raise RuntimeError(f"Number of phase parameters {phase_params.shape[0]} does not match the number of inputs {n_sum_inputs}.")
          iphase = 0
          for idx_i1 in range(n_sum_inputs):
            for idx_i2 in range(idx_i1+1, n_sum_inputs):
              cos_phase = phase_values[iphase]
              iphase += 1
              sum_kouts = sum_kouts + 2*tf.sqrt(sum_inputs[idx_i1]*sum_inputs[idx_i2])*cos_phase
        #sum_kouts = sum_kouts/(1.+sum_kouts)
        sum_kouts = tf.math.log(sum_kouts)
        data_qubit_idxs_preds.append([idq, sum_kouts])
    data_qubit_idxs_preds.sort()

    data_qubit_final_preds = tf.concat(
      [ tf.reshape(dqp[1], shape=(dqp[1].shape[0],-1)) for dqp in data_qubit_idxs_preds ],
      axis=1
    )
    eval_dqubit_preds_layer = None

    x = predictor_inputs
    for ll in self.upper_layers:
      x = ll(x)
      if eval_dqubit_preds_layer is None:
        eval_dqubit_preds_layer = self.eval_final_data_qubit_pred_layer(data_qubit_final_preds)
        x = x + eval_dqubit_preds_layer
    return x

In [15]:
def learning_rate_scheduler(epoch, lr):
  if epoch < 10:
    return lr
  elif epoch < 20:
    return lr * 0.99
  elif epoch < 30:
    return lr * 0.95
  else:
    return lr * 0.9

In [None]:
for n_nodes in range(100, 150, 50):
  model_dxd = FullCNNModel(observable_type, d, kernel_size, r, [n_nodes for _ in range(2)], npol=2, do_all_data_qubits=False, extended_kernel_output=True, include_det_evts=True, include_last_kernel_dets=False, include_last_dets=True)
  model_dxd.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
  model_dxd([ features_det_bits[0:1], features_det_evts[0:1], features_translation_map[0:1], features_final_det_evts[0:1] ])
  model_dxd.summary()

  val_split = 0.2
  n_epochs = 50
  history = model_dxd.fit(
    x=[ features_det_bits[idxs_train], features_det_evts[idxs_train], features_translation_map[idxs_train], features_final_det_evts[idxs_train] ],
    y=class_bits[idxs_train,:],
    batch_size=10000,
    epochs=n_epochs, validation_split=val_split,
    callbacks=[
      tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
      tf.keras.callbacks.LearningRateScheduler(learning_rate_scheduler)
    ]
  )

  flips_pred = model_dxd.predict(
    [features_det_bits[idxs_test], features_det_evts[idxs_test], features_translation_map[idxs_test], features_final_det_evts[idxs_test]],
    batch_size=10000
  )
  print(f"Inaccuracy of the final model on the test data: {(flips[idxs_test]!=(flips_pred>0.5).astype(binary_t)).astype(binary_t).sum()/idxs_test.shape[0]}")

Total number of fractions: 31
Total number of phases: 68
Model: "full_cnn_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 cnn_kernel (CNNKernel)      multiple                  1926      
                                                                 
 cnn_kernel_1 (CNNKernel)    multiple                  1926      
                                                                 
 cnn_kernel_2 (CNNKernel)    multiple                  1926      
                                                                 
 cnn_kernel_3 (CNNKernel)    multiple                  1926      
                                                                 
 cnn_kernel_4 (CNNKernel)    multiple                  1284      
                                                                 
 dense (Dense)               multiple                  30        
                                                             

In [None]:
pymatcher = pymatching.Matching.from_detector_error_model(detector_error_model)
flips_pred_pym = pymatcher.decode_batch(det_evts[idxs_test,:], bit_packed_predictions=False, bit_packed_shots=False).astype(binary_t).reshape(-1,1)
print(f"PyMatching error rate for test data set: {np.sum((flips_pred_pym!=flips[idxs_test,:]))/idxs_test.shape[0]}")

PyMatching error rate for test data set: 0.0316026
