In [46]:
# reload magic
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [47]:
import numpy as np
import torch
from mldec.models import baselines
from mldec.utils import evaluation
from mldec.datasets import reps_toric_code_data
import stim


In [48]:

    # data_val, triv_val, stim_data_val, observable_flips_val = reps_toric_code_data.sample_dataset(n_data, validation_dataset_config, device)
    # TIMINGS for repetitions=9:
    #  - 1e5 data: 16s
    #  - 1e6 data: 160s (2:40)
    # TIMINGS for repetitions=3:
    #  - 1e5 data: 9s
    #  - 1e6 data: 90s (1:30)

In [49]:
import pymatching

In [50]:
def count_logical_errors(circuit: stim.Circuit, num_shots: int) -> int:
    # Sample the circuit.
    sampler = circuit.compile_detector_sampler()
    detection_events, observable_flips = sampler.sample(num_shots, separate_observables=True)

    # Configure a decoder using the circuit.
    detector_error_model = circuit.detector_error_model(decompose_errors=True)
    matcher = pymatching.Matching.from_detector_error_model(detector_error_model)

    # Run the decoder.
    predictions = matcher.decode_batch(detection_events)

    # Count the mistakes.
    num_errors = 0
    for shot in range(num_shots):
        actual_for_shot = observable_flips[shot]
        predicted_for_shot = predictions[shot]
        if not np.array_equal(actual_for_shot, predicted_for_shot):
            num_errors += 1
    return num_errors

In [52]:
validation_dataset_config = {
    'p': 0.004,
    'repetitions': 3,
    'code_size': 3,
    'beta': 1, 
}
for n_test in [1e4, 1e5, 1e6, 1e7]:
    circuit = stim.Circuit.generated(
            "surface_code:rotated_memory_z",
            rounds = validation_dataset_config.get("repetitions")   ,
            distance = validation_dataset_config.get("code_size"),
            after_clifford_depolarization = validation_dataset_config.get("p"),
            after_reset_flip_probability = validation_dataset_config.get("p"),
            before_measure_flip_probability = validation_dataset_config.get("p"),
            before_round_data_depolarization = validation_dataset_config.get("p"))
    n_errs = count_logical_errors(circuit, int(n_test))
    print(n_errs / n_test)

0.0093
0.01114
0.011199
0.0113531


In [45]:
validation_dataset_config = {
    'p': 0.001,
    'repetitions': 3,
    'code_size': 3,
    'beta': 1, 
}
device = torch.device("cpu")
sampler, detector_coordinates, detector_error_model = reps_toric_code_data.make_sampler(validation_dataset_config)
for n_test in [1e4, 1e5, 1e6,]:
    print(n_test)
    n_test = int(n_test)

    # sample detection events and observable flips
    stim_data, observable_flips = sampler.sample(shots=n_test, separate_observables=True)
    non_empty_indices = (np.sum(stim_data, axis = 1) != 0)
    triv_val = len(observable_flips[~ non_empty_indices])
    stim_data_val = stim_data[non_empty_indices, :]
    observable_flips_val = observable_flips[non_empty_indices]

    # Configure a decoder using the circuit.
    detector_error_model = circuit.detector_error_model(decompose_errors=True)
    matcher = pymatching.Matching.from_detector_error_model(detector_error_model)

    # Run the decoder.
    predictions = matcher.decode_batch(stim_data_val)

    # Count the mistakes.
    num_errors = 0
    for shot in range(len(stim_data_val)):
        actual_for_shot = observable_flips_val[shot]
        predicted_for_shot = predictions[shot]
        if not np.array_equal(actual_for_shot, predicted_for_shot):
            num_errors += 1
    print(num_errors / n_test)

    mwpm_decoder = baselines.CyclesMinimumWeightPerfectMatching(detector_error_model)
    minimum_weight_correct = evaluation.evaluate_mwpm(stim_data_val, observable_flips_val, mwpm_decoder)
    print(1 - (minimum_weight_correct + triv_val) / n_test)

    # minimum_weight_val_acc = (minimum_weight_val_acc_nontrivial + triv_val) / n_test
    # print("mwpm logical err rate: {}".format(1 - minimum_weight_val_acc))
    # print("nontrivial mwpm err rate: {}".format(1 - minimum_weight_val_acc_nontrivial / (n_test - triv_val)))
    # print()
    # print("nontrivial mwpm acc: {}".format(minimum_weight_val_acc_nontrivial / (n_test - triv_val)))
    # print("mwpm logical acc: {}".format(minimum_weight_val_acc))
    # print()


10000.0
0.0006
[0.0006]
100000.0
0.00063
[0.00063]
1000000.0
0.00081
[0.00081]


In [42]:
len(stim_data_val)

157488

In [44]:
predictions = mwpm_decoder.predict(stim_data_val)
print(sum(observable_flips_val != predictions))
j = 0
for i in range(len(stim_data_val)):
    if observable_flips_val[i] != predictions[i]:
        j += 1
print(j )



[789]
789


mwpm logical err rate: 0.00039999999999995595
nontrivial mwpm err rate: 0.00039999999999995595

nontrivial mwpm acc: 0.9996
mwpm logical acc: 0.9996


In [35]:
type(minimum_weight_val_acc)

float

In [10]:
np.random.seed(222)
p = 0.01
var = 0.01
n = 9
p_samp = np.random.normal(p, var, size=n)
print(p_samp)

[0.02963425 0.0127577  0.01458658 0.02001265 0.00236165 0.01721928
 0.0009453  0.02001873 0.00479258]


In [21]:
from mldec.datasets import toric_code_data
# sample virtual XY
config = {'n': 9, 'var': 0.03, 'p': 0.05, 'beta': 1.75}
X, Y, probs = toric_code_data.create_dataset_training(n, config, cache=True)
Xb, Yb, weightsb, histb = toric_code_data.sample_virtual_XY(probs.numpy(), 1994, n, config, cache=True)


  Xb_tensor = torch.tensor(X_full, dtype=torch.float32)
  Yb_tensor = torch.tensor(Y_full, dtype=torch.float32)


In [22]:
len(Xb)

1024

In [None]:

from mldec.datasets import toy_problem_data
import torch
from mldec.utils import evaluation

In [3]:
import numpy as np
observable_flips_list = [0.1*np.ones(4)] * 5 + [0.2*np.ones(4)] * 5 + [0.3*np.ones(4)] * 5
observable_flips_list = [val for tup in zip(*observable_flips_list) for val in tup]
print(observable_flips_list)


[0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.2, 0.3, 0.3, 0.3, 0.3, 0.3, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.2, 0.3, 0.3, 0.3, 0.3, 0.3, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.2, 0.3, 0.3, 0.3, 0.3, 0.3, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.2, 0.3, 0.3, 0.3, 0.3, 0.3]


In [142]:
n = 8
batch_size = 10000
dataset_config = {
    'p': 0.15,
    'alpha': 0.33,
    'pcm': toy_problem_data.repetition_pcm(n),
    "sos_eos": (0, 0),
}
X, Y, weights = toy_problem_data.create_dataset_training(n, dataset_config)
weights_np = weights.numpy()

Xb, Yb, weightsb, downsampled_weights = toy_problem_data.sample_virtual_XY(weights_np, batch_size, n, dataset_config)
downsampled_weights_tensor = torch.tensor(downsampled_weights, dtype=torch.float32)

Xgood, Ygood, weightsgood = toy_problem_data.uniform_over_good_examples(n, dataset_config)

In [None]:
from mldec.models.baselines import RepetitionCodeLookupTable, RepetitionCodeMinimumWeight



mld = RepetitionCodeLookupTable(n)
mld.train_on_histogram(Xgood, Ygood, weightsgood)

minimum_weight_decoder = RepetitionCodeMinimumWeight(n)
minimum_weight_decoder.make_decoder(X, Y)

In [153]:
print(Xb.shape)
print(X.shape, Y.shape, Ypred.shape)

torch.Size([120, 7])
torch.Size([256, 7]) torch.Size([256, 10]) torch.Size([120, 10])


In [162]:
lookup = RepetitionCodeLookupTable(n)
lookup.train_on_histogram(X, Y, downsampled_weights)


In [163]:


train_acc = evaluation.weighted_accuracy(lookup, X, Y, downsampled_weights_tensor) # training accuracy is evaluated on the same data from this epoch.
val_acc = evaluation.weighted_accuracy(lookup, X, Y, weights) # validation accuracy is evaluated on the full dataset
opt_val_acc = evaluation.weighted_accuracy(mld, X, Y, weights) # optimal validation accuracy is evaluated on the full dataset
minimum_weight_val_acc = evaluation.weighted_accuracy(minimum_weight_decoder, X, Y, weights) # optimal validation accuracy is evaluated on the full dataset
print("lookup train:", train_acc)
print("lookup test:", val_acc)
print("minimum weight:", minimum_weight_val_acc)
print("optimal:", opt_val_acc)


lookup train: 0.9994999170303345
lookup test: 0.9970260858535767
minimum weight: 0.997134804725647
optimal: 0.9988806843757629
