# MLX Quantum Basics: Apple Silicon for Quantum Computing

**SIIEA Quantum Engineering Curriculum**
- **Curriculum Days:** Year 1, Semester 1A (Days 001-042)
- **License:** CC BY-NC-SA 4.0 | Siiea Innovations, LLC

---

In [None]:
# Hardware detection — adapts simulations to your machine
import sys, os
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath("__file__")), ".."))
try:
    from hardware_config import HARDWARE, get_max_qubits
    print(f"Hardware: {HARDWARE['chip']} | {HARDWARE['memory_gb']} GB | Profile: {HARDWARE['profile']}")
    print(f"Max qubits: {get_max_qubits('safe')} (safe) / {get_max_qubits('max')} (max)")
except ImportError:
    print("hardware_config.py not found — using defaults")
    print("Run setup.sh from the repo root to generate it")

## What is MLX and Why Does It Matter for Quantum Simulation?

**MLX** is Apple's machine-learning framework built specifically for Apple Silicon.
Unlike CUDA-based frameworks that require discrete GPUs, MLX operates on the
**unified memory architecture** of M-series chips --- the same physical memory
is shared by CPU, GPU, and Neural Engine with zero copy overhead.

### Why this matters for quantum computing

| Property | Implication |
|----------|-------------|
| **Unified memory** | A 128 GB MacBook Pro can hold a 30-qubit state vector (~17 GB) without CPU-GPU transfers |
| **Lazy evaluation** | MLX builds a compute graph and only materializes when needed --- saves peak memory |
| **Metal backend** | GPU-accelerated linear algebra on every Mac shipped since 2020 |
| **NumPy-like API** | Minimal learning curve for scientists already using NumPy |

In this notebook we will:
1. Confirm MLX is available and benchmark it against NumPy
2. Represent quantum states and gates as MLX arrays
3. Apply gates via matrix-vector products
4. Simulate measurements using the Born rule
5. Profile performance as qubit count grows

In [None]:
# --- MLX detection and environment setup ---
import time
import numpy as np

try:
    import mlx.core as mx
    HAS_MLX = True
    print("MLX available --- Apple Silicon acceleration enabled")
    print(f"MLX version: {mx.__version__ if hasattr(mx, '__version__') else 'unknown'}")
except ImportError:
    HAS_MLX = False
    print("MLX not available --- falling back to NumPy")
    print("Install MLX: pip install mlx  (requires Apple Silicon)")

# Helper: choose backend
def mx_array(data, dtype=None):
    """Create array on best available backend."""
    if HAS_MLX:
        return mx.array(np.asarray(data, dtype=np.complex128 if dtype is None else dtype))
    return np.asarray(data, dtype=np.complex128 if dtype is None else dtype)

def mx_matmul(a, b):
    """Matrix multiply on best available backend."""
    if HAS_MLX:
        return mx.matmul(a, b)
    return np.matmul(a, b)

def to_numpy(arr):
    """Convert any array to NumPy for display/plotting."""
    if HAS_MLX and isinstance(arr, mx.array):
        return np.array(arr)
    return np.asarray(arr)

print("\nBackend ready."  )

## MLX vs NumPy: Raw Matrix Performance

Before we touch quantum mechanics, let us see *how much faster* MLX's Metal
backend is for the core operation of quantum simulation: **large matrix-vector
and matrix-matrix products**.

We will multiply random complex matrices of increasing size and record wall-clock time.

In [None]:
# --- Benchmark: MLX vs NumPy matrix multiply ---
import time
import numpy as np

sizes = [128, 256, 512, 1024, 2048, 4096]
np_times = []
mlx_times = []

print(f"{'Size':>6} | {'NumPy (ms)':>12} | {'MLX (ms)':>12} | {'Speedup':>8}")
print("-" * 50)

for n in sizes:
    # NumPy
    a_np = np.random.randn(n, n).astype(np.float32)
    b_np = np.random.randn(n, n).astype(np.float32)
    t0 = time.perf_counter()
    _ = a_np @ b_np
    np_t = (time.perf_counter() - t0) * 1000
    np_times.append(np_t)

    # MLX
    if HAS_MLX:
        a_mx = mx.array(a_np)
        b_mx = mx.array(b_np)
        mx.eval(a_mx)  # ensure on device
        mx.eval(b_mx)
        t0 = time.perf_counter()
        c_mx = mx.matmul(a_mx, b_mx)
        mx.eval(c_mx)  # force computation (MLX is lazy)
        mlx_t = (time.perf_counter() - t0) * 1000
    else:
        mlx_t = np_t  # fallback: same as NumPy

    mlx_times.append(mlx_t)
    speedup = np_t / mlx_t if mlx_t > 0 else 0
    print(f"{n:>6} | {np_t:>12.2f} | {mlx_t:>12.2f} | {speedup:>7.2f}x")

In [None]:
# --- Visualize benchmark results ---
import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

ax1.plot(sizes, np_times, "o-", label="NumPy", color="#1f77b4", linewidth=2)
ax1.plot(sizes, mlx_times, "s-", label="MLX", color="#ff7f0e", linewidth=2)
ax1.set_xlabel("Matrix size (N x N)")
ax1.set_ylabel("Time (ms)")
ax1.set_title("Matrix Multiply: MLX vs NumPy")
ax1.legend()
ax1.set_yscale("log")
ax1.grid(True, alpha=0.3)

speedups = [n / m if m > 0 else 1 for n, m in zip(np_times, mlx_times)]
ax2.bar(range(len(sizes)), speedups, color="#2ca02c", alpha=0.8)
ax2.set_xticks(range(len(sizes)))
ax2.set_xticklabels(sizes)
ax2.set_xlabel("Matrix size (N x N)")
ax2.set_ylabel("Speedup (x)")
ax2.set_title("MLX Speedup over NumPy")
ax2.axhline(y=1, color="gray", linestyle="--", alpha=0.5)
ax2.grid(True, alpha=0.3, axis="y")

plt.tight_layout()
plt.savefig("mlx_labs/benchmark_matmul.png", dpi=150, bbox_inches="tight")
plt.show()
print("Saved: mlx_labs/benchmark_matmul.png")

## Quantum States as MLX Arrays

A single-qubit quantum state lives in $\mathbb{C}^2$:

$$|\psi\rangle = \alpha|0\rangle + \beta|1\rangle, \quad |\alpha|^2 + |\beta|^2 = 1$$

An $n$-qubit state lives in $\mathbb{C}^{2^n}$ --- a vector of $2^n$ complex amplitudes.
We store this as an MLX (or NumPy) array of `complex128` values.

### Standard single-qubit states

| State | Vector | Description |
|-------|--------|-------------|
| $\|0\rangle$ | $[1, 0]^T$ | Computational basis zero |
| $\|1\rangle$ | $[0, 1]^T$ | Computational basis one |
| $\|+\rangle$ | $\frac{1}{\sqrt{2}}[1, 1]^T$ | Hadamard state |
| $\|-\rangle$ | $\frac{1}{\sqrt{2}}[1, -1]^T$ | Hadamard minus state |

In [None]:
# --- Fundamental quantum states as MLX arrays ---
import numpy as np

inv_sqrt2 = 1.0 / np.sqrt(2.0)

# Single-qubit basis states
ket_0 = mx_array([1, 0])
ket_1 = mx_array([0, 1])
ket_plus = mx_array([inv_sqrt2, inv_sqrt2])
ket_minus = mx_array([inv_sqrt2, -inv_sqrt2])

states = {"| 0>": ket_0, "| 1>": ket_1, "| +>": ket_plus, "| ->": ket_minus}

print("Single-Qubit Quantum States")
print("=" * 50)
for name, state in states.items():
    s = to_numpy(state)
    probs = np.abs(s) ** 2
    print(f"  {name}  =  {s}")
    print(f"         P(0) = {probs[0]:.4f},  P(1) = {probs[1]:.4f}")
    print(f"         norm = {np.sum(probs):.6f}")
    print()

## Bell States: Maximally Entangled Two-Qubit States

The four Bell states form an orthonormal basis for $\mathbb{C}^4$:

$$|\Phi^+\rangle = \frac{1}{\sqrt{2}}(|00\rangle + |11\rangle)$$
$$|\Phi^-\rangle = \frac{1}{\sqrt{2}}(|00\rangle - |11\rangle)$$
$$|\Psi^+\rangle = \frac{1}{\sqrt{2}}(|01\rangle + |10\rangle)$$
$$|\Psi^-\rangle = \frac{1}{\sqrt{2}}(|01\rangle - |10\rangle)$$

These are the foundation of quantum teleportation, superdense coding, and
entanglement-based protocols.

In [None]:
# --- Bell states as 4-element vectors ---
inv_sqrt2 = 1.0 / np.sqrt(2.0)

bell_phi_plus  = mx_array([inv_sqrt2, 0, 0,  inv_sqrt2])   # |00> + |11>
bell_phi_minus = mx_array([inv_sqrt2, 0, 0, -inv_sqrt2])   # |00> - |11>
bell_psi_plus  = mx_array([0, inv_sqrt2,  inv_sqrt2, 0])   # |01> + |10>
bell_psi_minus = mx_array([0, inv_sqrt2, -inv_sqrt2, 0])   # |01> - |10>

bell_states = {
    "|Phi+>": bell_phi_plus,
    "|Phi->": bell_phi_minus,
    "|Psi+>": bell_psi_plus,
    "|Psi->": bell_psi_minus,
}

print("Bell States (2-qubit entangled states)")
print("=" * 60)
for name, state in bell_states.items():
    s = to_numpy(state)
    probs = np.abs(s) ** 2
    print(f"  {name}  amplitudes: {s}")
    print(f"         P(00)={probs[0]:.3f}  P(01)={probs[1]:.3f}  "
          f"P(10)={probs[2]:.3f}  P(11)={probs[3]:.3f}")
    print()

# Verify orthonormality
print("Orthonormality check (inner products):")
bell_list = list(bell_states.values())
bell_names = list(bell_states.keys())
for i in range(4):
    for j in range(i, 4):
        bi = to_numpy(bell_list[i])
        bj = to_numpy(bell_list[j])
        inner = np.abs(np.dot(bi.conj(), bj))
        expected = 1.0 if i == j else 0.0
        status = "pass" if abs(inner - expected) < 1e-10 else "FAIL"
        print(f"  <{bell_names[i]}|{bell_names[j]}> = {inner:.6f}  [{status}]")

## Quantum Gates as MLX Matrices

Quantum gates are **unitary matrices** ($U^\dagger U = I$). We store them as
MLX arrays for GPU-accelerated gate application.

### Single-qubit gates

$$X = \begin{pmatrix} 0 & 1 \\ 1 & 0 \end{pmatrix}, \quad
Y = \begin{pmatrix} 0 & -i \\ i & 0 \end{pmatrix}, \quad
Z = \begin{pmatrix} 1 & 0 \\ 0 & -1 \end{pmatrix}$$

$$H = \frac{1}{\sqrt{2}}\begin{pmatrix} 1 & 1 \\ 1 & -1 \end{pmatrix}$$

### Multi-qubit gates

$$\text{CNOT} = \begin{pmatrix} 1&0&0&0\\0&1&0&0\\0&0&0&1\\0&0&1&0 \end{pmatrix}, \quad
\text{Toffoli} \in \mathbb{C}^{8\times8}$$

In [None]:
# --- Quantum gates as MLX matrices ---
import numpy as np

inv_sqrt2 = 1.0 / np.sqrt(2.0)

# Pauli gates
I_gate = mx_array([[1, 0], [0, 1]])
X_gate = mx_array([[0, 1], [1, 0]])
Y_gate = mx_array([[0, -1j], [1j, 0]])
Z_gate = mx_array([[1, 0], [0, -1]])

# Hadamard gate
H_gate = mx_array([[inv_sqrt2, inv_sqrt2], [inv_sqrt2, -inv_sqrt2]])

# Phase gates
S_gate = mx_array([[1, 0], [0, 1j]])
T_gate = mx_array([[1, 0], [0, np.exp(1j * np.pi / 4)]])

# CNOT (2-qubit)
CNOT_gate = mx_array([
    [1, 0, 0, 0],
    [0, 1, 0, 0],
    [0, 0, 0, 1],
    [0, 0, 1, 0]
])

# Toffoli (3-qubit)
toffoli_np = np.eye(8, dtype=np.complex128)
toffoli_np[6, 6] = 0
toffoli_np[7, 7] = 0
toffoli_np[6, 7] = 1
toffoli_np[7, 6] = 1
Toffoli_gate = mx_array(toffoli_np)

gates = {
    "I": (I_gate, 1), "X": (X_gate, 1), "Y": (Y_gate, 1),
    "Z": (Z_gate, 1), "H": (H_gate, 1), "S": (S_gate, 1),
    "T": (T_gate, 1), "CNOT": (CNOT_gate, 2), "Toffoli": (Toffoli_gate, 3),
}

print("Quantum Gate Library")
print("=" * 60)
for name, (gate, n_qubits) in gates.items():
    g = to_numpy(gate)
    # Check unitarity: U^dag U = I
    product = g.conj().T @ g
    identity_check = np.allclose(product, np.eye(g.shape[0]), atol=1e-10)
    print(f"  {name:>8}  |  {n_qubits}-qubit  |  {g.shape[0]}x{g.shape[1]}  |  "
          f"Unitary: {'YES' if identity_check else 'NO'}")

## Applying Gates: Matrix-Vector Products

To apply gate $U$ to state $|\psi\rangle$:

$$|\psi'\rangle = U|\psi\rangle$$

This is simply a **matrix-vector product** --- the core operation MLX accelerates.

### Example: Creating a Bell State

Starting from $|00\rangle$, apply Hadamard to qubit 0 then CNOT:

$$|00\rangle \xrightarrow{H \otimes I} \frac{1}{\sqrt{2}}(|00\rangle + |10\rangle) \xrightarrow{\text{CNOT}} \frac{1}{\sqrt{2}}(|00\rangle + |11\rangle) = |\Phi^+\rangle$$

In [None]:
# --- Applying gates via matrix-vector product ---
import numpy as np

def apply_gate(gate, state):
    """Apply a quantum gate (matrix) to a state (vector) using MLX."""
    # Reshape state to column vector for matmul
    if HAS_MLX:
        s = mx.reshape(state, (-1, 1))
        result = mx.matmul(gate, s)
        return mx.reshape(result, (-1,))
    else:
        s = np.asarray(state).reshape(-1, 1)
        result = np.asarray(gate) @ s
        return result.flatten()

def tensor_product(a, b):
    """Kronecker product of two matrices/vectors."""
    a_np = to_numpy(a)
    b_np = to_numpy(b)
    return mx_array(np.kron(a_np, b_np))

# Start with |00>
psi = tensor_product(ket_0, ket_0)
print("Initial state |00>:", to_numpy(psi))

# Apply H tensor I
H_I = tensor_product(H_gate, I_gate)
psi = apply_gate(H_I, psi)
print("After (H x I):     ", to_numpy(psi))

# Apply CNOT
psi = apply_gate(CNOT_gate, psi)
print("After CNOT:         ", to_numpy(psi))

# Verify it is Bell |Phi+>
expected = to_numpy(bell_phi_plus)
actual = to_numpy(psi)
match = np.allclose(actual, expected, atol=1e-10)
print(f"\nIs this |Phi+>? {match}")
print(f"Probabilities: P(00)={np.abs(actual[0])**2:.4f}, "
      f"P(01)={np.abs(actual[1])**2:.4f}, "
      f"P(10)={np.abs(actual[2])**2:.4f}, "
      f"P(11)={np.abs(actual[3])**2:.4f}")

## Measurement Simulation: The Born Rule

When we measure a quantum state $|\psi\rangle = \sum_i \alpha_i |i\rangle$
in the computational basis, the probability of outcome $i$ is:

$$P(i) = |\alpha_i|^2$$

This is the **Born rule**. We simulate measurement by:
1. Computing probabilities from amplitudes
2. Sampling from the resulting distribution

In [None]:
# --- Measurement simulation using Born rule ---
import numpy as np

def measure(state, n_shots=1000, seed=42):
    """Simulate quantum measurement with n_shots samples.

    Returns:
        counts: dict mapping basis state index to count
        probs: array of theoretical probabilities
    """
    amplitudes = to_numpy(state)
    probs = np.abs(amplitudes) ** 2
    probs = probs / probs.sum()  # normalize (floating point safety)

    rng = np.random.default_rng(seed)
    outcomes = rng.choice(len(probs), size=n_shots, p=probs)

    counts = {}
    n_qubits = int(np.log2(len(probs)))
    for outcome in outcomes:
        label = format(outcome, f"0{n_qubits}b")
        counts[label] = counts.get(label, 0) + 1

    return counts, probs

# Measure the Bell state |Phi+>
print("Measuring Bell state |Phi+> (10,000 shots)")
print("=" * 50)
counts, probs = measure(bell_phi_plus, n_shots=10000)

print(f"\nTheoretical probabilities:")
for i, p in enumerate(probs):
    if p > 1e-10:
        print(f"  |{i:02b}> : {p:.4f}")

print(f"\nMeasurement results:")
for state_label in sorted(counts.keys()):
    freq = counts[state_label] / 10000
    print(f"  |{state_label}> : {counts[state_label]:>5} counts  "
          f"({freq:.4f} vs theoretical {probs[int(state_label, 2)]:.4f})")

# Measure |+> state (single qubit)
print("\n\nMeasuring |+> state (10,000 shots)")
print("=" * 50)
counts_plus, probs_plus = measure(ket_plus, n_shots=10000)
for state_label in sorted(counts_plus.keys()):
    freq = counts_plus[state_label] / 10000
    print(f"  |{state_label}> : {counts_plus[state_label]:>5} counts  ({freq:.4f})")

In [None]:
# --- Visualize measurement outcomes ---
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Bell state measurement
labels = sorted(counts.keys())
values = [counts[l] for l in labels]
theoretical = [probs[int(l, 2)] * 10000 for l in labels]
x_pos = range(len(labels))

axes[0].bar([x - 0.15 for x in x_pos], values, 0.3, label="Measured", color="#1f77b4")
axes[0].bar([x + 0.15 for x in x_pos], theoretical, 0.3, label="Theoretical", color="#ff7f0e", alpha=0.7)
axes[0].set_xticks(x_pos)
axes[0].set_xticklabels([f"|{l}>" for l in labels])
axes[0].set_title("Bell State |Phi+> Measurement (10k shots)")
axes[0].set_ylabel("Counts")
axes[0].legend()
axes[0].grid(True, alpha=0.3, axis="y")

# |+> state measurement
labels_p = sorted(counts_plus.keys())
values_p = [counts_plus[l] for l in labels_p]
axes[1].bar(range(len(labels_p)), values_p, color="#2ca02c", alpha=0.8)
axes[1].set_xticks(range(len(labels_p)))
axes[1].set_xticklabels([f"|{l}>" for l in labels_p])
axes[1].set_title("|+> State Measurement (10k shots)")
axes[1].set_ylabel("Counts")
axes[1].axhline(y=5000, color="red", linestyle="--", alpha=0.5, label="Expected")
axes[1].legend()
axes[1].grid(True, alpha=0.3, axis="y")

plt.tight_layout()
plt.savefig("mlx_labs/measurement_histogram.png", dpi=150, bbox_inches="tight")
plt.show()
print("Saved: mlx_labs/measurement_histogram.png")

## Scaling Benchmark: MLX vs NumPy for Quantum Simulation

The critical question: **how does MLX scale as qubit count grows?**

For $n$ qubits, the state vector has $2^n$ amplitudes and a single-qubit gate
applied to the full space requires a $2^n \times 2^n$ matrix multiply.

Memory requirements:
- Each `complex128` value = 16 bytes
- State vector: $2^n \times 16$ bytes
- Full gate matrix: $2^{2n} \times 16$ bytes (we will avoid this for large $n$)

| Qubits | State vector | Full matrix |
|--------|-------------|-------------|
| 10 | 16 KB | 16 MB |
| 15 | 512 KB | 16 GB |
| 20 | 16 MB | 16 TB |

For this benchmark we use the Hadamard on qubit 0 via Kronecker product
up to a manageable size, then switch to direct state-vector manipulation.

In [None]:
# --- Qubit scaling benchmark: MLX vs NumPy ---
import time
import numpy as np

# Load hardware limits
try:
    sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath("__file__")), ".."))
    from hardware_config import get_max_qubits
    max_bench_qubits = min(get_max_qubits("demo"), 20)
except ImportError:
    max_bench_qubits = 16

qubit_range = list(range(4, max_bench_qubits + 1, 2))
np_times_q = []
mlx_times_q = []
memory_mb = []

print(f"Benchmarking Hadamard-on-qubit-0 for {qubit_range[0]} to {qubit_range[-1]} qubits")
print(f"{'Qubits':>7} | {'Dim':>10} | {'Mem (MB)':>10} | {'NumPy (ms)':>12} | {'MLX (ms)':>12} | {'Speedup':>8}")
print("-" * 75)

for n in qubit_range:
    dim = 2 ** n
    mem = dim * 16 / (1024 ** 2)  # state vector memory in MB
    memory_mb.append(mem)

    # Build state |00...0> and Hadamard on qubit 0 via direct manipulation
    # Instead of full matrix, we apply H to qubit 0 by reshaping:
    #   state.reshape(2, 2^(n-1)) -> H @ state_reshaped -> flatten

    # NumPy version
    state_np = np.zeros(dim, dtype=np.complex128)
    state_np[0] = 1.0
    h_np = np.array([[1, 1], [1, -1]], dtype=np.complex128) / np.sqrt(2)

    t0 = time.perf_counter()
    reshaped = state_np.reshape(2, dim // 2)
    result_np = (h_np @ reshaped).flatten()
    np_t = (time.perf_counter() - t0) * 1000
    np_times_q.append(np_t)

    # MLX version
    if HAS_MLX:
        state_mx = mx.zeros((dim,), dtype=mx.complex64)
        # MLX doesn't support complex128 on all builds; use complex64 if needed
        try:
            state_mx = mx.array(state_np)
        except Exception:
            state_mx = mx.array(state_np.astype(np.complex64))
        h_mx = mx.array(h_np.astype(np.complex64)) if state_mx.dtype == mx.complex64 else mx.array(h_np)
        mx.eval(state_mx)
        mx.eval(h_mx)

        t0 = time.perf_counter()
        reshaped_mx = mx.reshape(state_mx, (2, dim // 2))
        result_mx = mx.matmul(h_mx, reshaped_mx)
        result_mx = mx.reshape(result_mx, (-1,))
        mx.eval(result_mx)
        mlx_t = (time.perf_counter() - t0) * 1000
    else:
        mlx_t = np_t

    mlx_times_q.append(mlx_t)
    speedup = np_t / mlx_t if mlx_t > 0 else 1.0
    print(f"{n:>7} | {dim:>10,} | {mem:>10.2f} | {np_t:>12.3f} | {mlx_t:>12.3f} | {speedup:>7.2f}x")

print(f"\nPeak state vector size: {memory_mb[-1]:.1f} MB for {qubit_range[-1]} qubits")

In [None]:
# --- Scaling visualization ---
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# Time comparison
axes[0].plot(qubit_range, np_times_q, "o-", label="NumPy", linewidth=2)
axes[0].plot(qubit_range, mlx_times_q, "s-", label="MLX", linewidth=2)
axes[0].set_xlabel("Number of Qubits")
axes[0].set_ylabel("Time (ms)")
axes[0].set_title("Gate Application Time")
axes[0].legend()
axes[0].set_yscale("log")
axes[0].grid(True, alpha=0.3)

# Speedup
speedups_q = [n / m if m > 0 else 1 for n, m in zip(np_times_q, mlx_times_q)]
axes[1].bar(range(len(qubit_range)), speedups_q, color="#2ca02c", alpha=0.8)
axes[1].set_xticks(range(len(qubit_range)))
axes[1].set_xticklabels(qubit_range)
axes[1].set_xlabel("Number of Qubits")
axes[1].set_ylabel("Speedup (x)")
axes[1].set_title("MLX Speedup Factor")
axes[1].axhline(y=1, color="gray", linestyle="--", alpha=0.5)
axes[1].grid(True, alpha=0.3, axis="y")

# Memory usage
axes[2].bar(range(len(qubit_range)), memory_mb, color="#9467bd", alpha=0.8)
axes[2].set_xticks(range(len(qubit_range)))
axes[2].set_xticklabels(qubit_range)
axes[2].set_xlabel("Number of Qubits")
axes[2].set_ylabel("State Vector Memory (MB)")
axes[2].set_title("Memory Requirements")
axes[2].set_yscale("log")
axes[2].grid(True, alpha=0.3, axis="y")

plt.tight_layout()
plt.savefig("mlx_labs/qubit_scaling.png", dpi=150, bbox_inches="tight")
plt.show()
print("Saved: mlx_labs/qubit_scaling.png")

## Summary and Key Takeaways

### What we learned

1. **MLX provides GPU acceleration** on Apple Silicon with a NumPy-like API
2. **Quantum states** are complex vectors; $n$ qubits need $2^n$ amplitudes
3. **Quantum gates** are unitary matrices applied via matrix-vector products
4. **Bell states** demonstrate entanglement in a 2-qubit system
5. **Born rule** gives measurement probabilities from amplitudes
6. **MLX scales well** for large qubit counts thanks to Metal GPU and unified memory

### Key formulas

| Concept | Formula |
|---------|---------|
| State normalization | $\sum_i |\alpha_i|^2 = 1$ |
| Gate application | $|\psi'\rangle = U|\psi\rangle$ |
| Born rule | $P(i) = |\langle i|\psi\rangle|^2 = |\alpha_i|^2$ |
| Memory (bytes) | $2^n \times 16$ for complex128 |
| Tensor product | $(A \otimes B)(|a\rangle \otimes |b\rangle) = A|a\rangle \otimes B|b\rangle$ |

### Next notebook

In **02_large_scale_simulation.ipynb**, we build a full circuit simulator class
using MLX and push to 25+ qubits with GHZ states and Quantum Fourier Transforms.