# Introduction to AI for Quantum Error Correction

## Overview
This notebook demonstrates quantum error correction using the 3-qubit bit flip code with realistic noise and then explores various avenues where AI can be useful. 

Please note that this notebook is a work-in-progress. You can check it periodically over the next 2 weeks. You should familiarize yourself with the basics of Quantum Error Correction - some good resources are qBook (qbook.qbraid.com) or Coursera. You can also check out NVIDIA's QEC work [here](https://github.com/NVIDIA/cuda-q-academic/tree/main/qec101). 

Now, let's start with a hands-on example with the 3-qubit bit flip code.

## The 3-Qubit Bit Flip Code
- **Logical states**: |0⟩ → |000⟩, |1⟩ → |111⟩  
- **Stabilizers**: g₁ = ZZI, g₂ = IZZ
- **Error correction**: Can fix any single bit flip error

## Syndrome Table
| Syndrome | Error Location | Correction |
|----------|----------------|------------|
| 00       | None           | Do nothing |
| 10       | Qubit 0        | Apply X₀   |
| 11       | Qubit 1        | Apply X₁   |
| 01       | Qubit 2        | Apply X₂   |




In [None]:
import numpy as np
from qiskit import QuantumCircuit, QuantumRegister, ClassicalRegister
from qiskit_aer import Aer
from qiskit_aer.noise import NoiseModel, depolarizing_error
from qiskit.transpiler.preset_passmanagers import generate_preset_pass_manager
import matplotlib.pyplot as plt

# Decoder Implementation
# =====================================================

class BitFlipDecoder:
    """Simple lookup table decoder for 3-qubit bit flip code."""
    
    def __init__(self):
        self.syndrome_table = {
            '00': None,  # No correction needed
            '10': 0,     # Correct qubit 0
            '11': 1,     # Correct qubit 1  
            '01': 2      # Correct qubit 2
        }
    
    def decode(self, syndrome_string):
        """Return which qubit to correct (None if no correction needed)."""
        return self.syndrome_table.get(syndrome_string, None)
    
    def get_error_description(self, syndrome_string):
        """Get human-readable error description."""
        corrections = {
            '00': 'No error detected',
            '10': 'Error on qubit 0', 
            '11': 'Error on qubit 1',
            '01': 'Error on qubit 2'
        }
        return corrections.get(syndrome_string, 'Unknown syndrome')

# =====================================================
# Circuit Building Functions
# =====================================================

def create_error_correction_circuit(logical_state=0, introduce_error=None):
    """
    Create a complete error correction circuit.
    
    Args:
        logical_state (int): 0 or 1 for logical |0⟩ or |1⟩
        introduce_error (int or None): Which qubit to flip (for testing)
    
    Returns:
        QuantumCircuit: Complete circuit with encoding, error, syndrome measurement
    """
    # Registers
    data = QuantumRegister(3, 'data')
    ancilla = QuantumRegister(2, 'ancilla')
    syndrome = ClassicalRegister(2, 'syndrome')
    
    qc = QuantumCircuit(data, ancilla, syndrome)
    
    # Encode logical state
    if logical_state == 1:
        qc.x(data[0])
        qc.x(data[1])
        qc.x(data[2])
    
    # Introduce test error
    if introduce_error is not None:
        qc.x(data[introduce_error])
    
    # Measure stabilizers
    # g1 = ZZI
    qc.cx(data[0], ancilla[0])
    qc.cx(data[1], ancilla[0])
    
    # g2 = IZZ
    qc.cx(data[1], ancilla[1])
    qc.cx(data[2], ancilla[1])
    
    qc.measure(ancilla, syndrome)
    
    return qc

def create_noisy_circuit(logical_state=0, noise_prob=0.01):
    """
    Create circuit with realistic depolarizing noise.
    
    Args:
        logical_state (int): 0 or 1
        noise_prob (float): Probability of depolarizing error per gate
    
    Returns:
        tuple: (circuit, noise_model)
    """
    data = QuantumRegister(3, 'data')
    ancilla = QuantumRegister(2, 'ancilla')
    syndrome = ClassicalRegister(2, 'syndrome')
    correction = ClassicalRegister(3, 'correction')
    
    qc = QuantumCircuit(data, ancilla, syndrome, correction)
    
    # Encode logical state
    if logical_state == 1:
        qc.x(data[0])
        qc.x(data[1]) 
        qc.x(data[2])
    
    # Add barriers to separate encoding from syndrome measurement
    qc.barrier()
    
    # Measure stabilizers
    qc.cx(data[0], ancilla[0])
    qc.cx(data[1], ancilla[0])
    qc.cx(data[1], ancilla[1])
    qc.cx(data[2], ancilla[1])
    qc.measure(ancilla, syndrome)
    
    # Add correction gates (controlled by classical bits in real implementation)
    # For demo, we'll apply all possible corrections and use post-processing
    qc.barrier()
    
    # Create noise model
    noise_model = NoiseModel()
    
    # Add depolarizing error to all single-qubit gates
    error_1q = depolarizing_error(noise_prob, 1)
    noise_model.add_all_qubit_quantum_error(error_1q, ['x'])
    
    # Add depolarizing error to CNOT gates  
    error_2q = depolarizing_error(noise_prob, 2)
    noise_model.add_all_qubit_quantum_error(error_2q, ['cx'])
    
    return qc, noise_model

# =====================================================
# Testing and Analysis Functions
# =====================================================

def test_decoder_without_noise():
    """Test decoder on all possible single-bit errors without noise."""
    decoder = BitFlipDecoder()
    
    print("Testing decoder without noise:")
    print("=" * 40)
    
    results = {}
    
    for error_pos in [None, 0, 1, 2]:
        qc = create_error_correction_circuit(logical_state=0, introduce_error=error_pos)
        
        # Run simulation
        simulator = Aer.get_backend('qasm_simulator')
        pm = generate_preset_pass_manager(backend=simulator, optimization_level=1)
        transpiled_qc = pm.run(qc)
        job = simulator.run(transpiled_qc, shots=1000)
        result = job.result()
        counts = result.get_counts()
        
        # Get most common syndrome
        syndrome = max(counts.keys(), key=counts.get)
        correction = decoder.decode(syndrome)
        description = decoder.get_error_description(syndrome)
        
        results[error_pos] = {
            'syndrome': syndrome,
            'correction': correction,
            'description': description,
            'counts': counts
        }
        
        if error_pos is None:
            print(f"No error:     syndrome {syndrome} → {description}")
        else:
            print(f"Error on q{error_pos}: syndrome {syndrome} → {description}")
    
    print("\n All single errors detected correctly!\n")
    return results

def test_decoder_with_noise(noise_probs=[0.001, 0.01, 0.05, 0.1]):
    """Test decoder performance under different noise levels."""
    decoder = BitFlipDecoder()
    
    print("Testing decoder with depolarizing noise:")
    print("=" * 50)
    
    results = {}
    
    for noise_prob in noise_probs:
        print(f"\nNoise probability: {noise_prob}")
        print("-" * 30)
        
        # Test both logical states
        for logical_state in [0, 1]:
            qc, noise_model = create_noisy_circuit(logical_state, noise_prob)
            
            # Run noisy simulation
            simulator = Aer.get_backend('qasm_simulator')
            pm = generate_preset_pass_manager(backend=simulator, optimization_level=1)
            transpiled_qc = pm.run(qc)
            job = simulator.run(transpiled_qc, shots=1000, noise_model=noise_model)
            result = job.result()
            counts = result.get_counts()
            
            # Analyze syndrome distribution
            syndrome_counts = {}
            for bitstring, count in counts.items():
                syndrome = bitstring[-2:]  # Last 2 bits are syndrome
                syndrome_counts[syndrome] = syndrome_counts.get(syndrome, 0) + count
            
            # Calculate error rates
            total_shots = sum(syndrome_counts.values())
            no_error_rate = syndrome_counts.get('00', 0) / total_shots
            error_rate = 1 - no_error_rate
            
            print(f"  Logical |{logical_state}⟩: {error_rate:.1%} error rate")
            
            # Store detailed results
            key = (noise_prob, logical_state)
            results[key] = {
                'syndrome_counts': syndrome_counts,
                'error_rate': error_rate,
                'total_shots': total_shots
            }
    
    return results

def analyze_noise_performance(results):
    """Analyze and plot decoder performance vs noise level."""
    print("\nNoise Performance Analysis:")
    print("=" * 40)
    
    noise_levels = [0.001, 0.01, 0.05, 0.1]
    error_rates_0 = []
    error_rates_1 = []
    
    for noise_prob in noise_levels:
        rate_0 = results[(noise_prob, 0)]['error_rate']
        rate_1 = results[(noise_prob, 1)]['error_rate']
        error_rates_0.append(rate_0)
        error_rates_1.append(rate_1)
        
        print(f"Noise {noise_prob:5.1%}: |0⟩ errors {rate_0:.1%}, |1⟩ errors {rate_1:.1%}")
    
    # Create simple text-based visualization
    print("\nError Rate vs Noise Level:")
    print("(Each * represents ~5% error rate)")
    for i, noise_prob in enumerate(noise_levels):
        stars = int(error_rates_0[i] * 20)  # Scale for visualization
        print(f"{noise_prob:5.1%} |{'*' * stars}")
    
    print("\n💡 Observations:")
    print("- Error rates increase roughly linearly with noise")
    print("- Both logical states show similar error rates")
    print("- Code provides protection until noise becomes too high")

# =====================================================
# Main Execution
# =====================================================

if __name__ == "__main__":
    # Test decoder without noise
    clean_results = test_decoder_without_noise()
    
    # Test decoder with noise
    noisy_results = test_decoder_with_noise()
    
    # Analyze performance
    analyze_noise_performance(noisy_results)

So with this example:

✅ **WHAT WE LEARNED:**

• Built a working stabilizer decoder using lookup tables

• Tested error detection without noise (perfect performance)

• Simulated realistic depolarizing noise on quantum gates

• Measured how decoder performance degrades with noise

🔬 **KEY INSIGHTS:**

• Stabilizer codes work well for low noise rates

• Performance degrades gracefully as noise increases  

• Both logical |0⟩ and |1⟩ show similar error rates

• Real quantum computers need much better codes for fault tolerance

🚀 **NEXT STEPS:**

• Try the 7-qubit Steane code for better error correction

• Implement maximum likelihood decoding for better performance

• Study surface codes for scalable quantum computing

## Why AI for Decoding?

Now if this worked so well, why do we want to explore AI based decoders? Traditional lookup table decoders work great for small codes like our 3-qubit example, but they hit fundamental limits:

1. Exponential scaling: 
A lookup table for an n-qubit stabilizer code needs 2^(syndrome length) entries. For the surface code with distance d=17 (needed for useful quantum computing), that's millions of possible syndromes.

2. Complex noise: Real quantum hardware has correlated errors, crosstalk, and time-dependent noise that simple models can't capture.

3. Real-time constraints: Quantum states decohere quickly, so decoders need to run in microseconds, not milliseconds.

## Neural Network Decoders
### Feedforward Networks
The simplest approach treats decoding as a classification problem:
Input: Syndrome vector (e.g., [1,0,1,1,0])
Output: Most likely error pattern or correction to apply

In [None]:
# Conceptual example
import torch.nn as nn

class SyndromeDecoder(nn.Module):
    def __init__(self, syndrome_length, num_qubits):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(syndrome_length, 256),
            nn.ReLU(),
            nn.Linear(256, 128), 
            nn.ReLU(),
            nn.Linear(128, num_qubits)  # Output: probability each qubit has error
        )
    
    def forward(self, syndrome):
        return torch.sigmoid(self.network(syndrome))


This has its advantages: fast inference, learns from data, can handle complex noise. But it also has its challenges: we need lots of training data and it may not generalize to new noise patterns.


### Recurrent Neural Networks (RNNs)
For surface codes and other topological codes, errors often have spatial correlations. RNNs can capture these patterns:

In [None]:
# Process syndrome as a sequence over the 2D surface
class TopologicalDecoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.lstm = nn.LSTM(input_size=4, hidden_size=64, num_layers=2)
        self.classifier = nn.Linear(64, 2)  # Error or no error
    
    def forward(self, syndrome_sequence):
        output, _ = self.lstm(syndrome_sequence)
        return self.classifier(output[-1])

### Graph Neural Networks (GNNs)

This is where things get really interesting. Stabilizer codes have natural graph structure. Code graph: Qubits as nodes, stabilizers as constraints and then a Syndrome graph: Active syndrome measurements form patterns

In [None]:
import torch_geometric.nn as pyg

class GraphDecoder(torch.nn.Module):
    def __init__(self, num_features):
        super().__init__()
        self.conv1 = pyg.GCNConv(num_features, 64)
        self.conv2 = pyg.GCNConv(64, 32)
        self.classifier = torch.nn.Linear(32, 1)
    
    def forward(self, x, edge_index):
        # x: node features (syndrome values)
        # edge_index: graph connectivity
        x = torch.relu(self.conv1(x, edge_index))
        x = torch.relu(self.conv2(x, edge_index))
        return torch.sigmoid(self.classifier(x))

This GNN approach has its advantages too! It can naturally handle the irregular geometry of quantum codes, can process variable-size syndrome patterns, and can learn local error correlations while maintaining global consistency.

### Reinforcement Learning Decoders

RL treats decoding as a sequential decision problem:

Agent: The decoder ->  Environment: The quantum code with syndrome ->  Actions: Apply corrections to specific qubits ->  Reward: Success when logical information is preserved

In [None]:
class RLDecoder:
    def __init__(self, code):
        self.q_network = DQN(state_size=code.syndrome_size, 
                           action_size=code.num_qubits)
    
    def decode(self, syndrome):
        state = syndrome
        corrections = []
        
        while not self.is_decoded(state):
            action = self.q_network.choose_action(state)
            corrections.append(action)
            state = self.apply_correction(state, action)
            
        return corrections

The RL approach learns optimal correction strategies, potentially handling partial information and uncertainty as it adapts to changing noise conditions. 

### Transformer Architectures

Recent work applies attention mechanisms to quantum error correction:

In [None]:
class TransformerDecoder(nn.Module):
    def __init__(self, syndrome_length, d_model=256):
        super().__init__()
        self.embedding = nn.Linear(1, d_model)
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model, nhead=8),
            num_layers=6
        )
        self.output = nn.Linear(d_model, 1)
    
    def forward(self, syndrome):
        # Treat syndrome as sequence
        x = self.embedding(syndrome.unsqueeze(-1))
        x = self.transformer(x)
        return self.output(x).squeeze(-1)

This attention-based approach would naturally capture long-range correlations in error patterns, which is crucial for large quantum codes.


### Hybrid Classical-Quantum Approaches

Some cutting-edge research uses quantum neural networks for decoding:

1. Quantum autoencoders: Learn compressed representations of error syndromes
2. Variational quantum eigensolvers: Find optimal corrections by minimizing energy functions
3. Quantum approximate optimization: Frame decoding as combinatorial optimization


Recent work from Google shows that AI decoders can beat classical decoders, because they can learn correlations that analytical decoders miss and adapt in real-time. They also have the advantage of 
scaling efficiently. It's not easy though! Having sufficient, high-quality training data remains a challenge. Furthermore, models trained on one noise model may fail on real hardware, and there are also issues with latency - deep networks can be slower than needed for real-time correction. 


## So What's Next?

Now armed with all these potentials, explore how you can use various existing AI architectures to improve quantum error correction. You could start with a simple depolarizing noise channel, explore resources from Google & NVIDIA, and keep an eye out on this notebook for updates in the near future!