# Grover's Algorithm for Markov Chain Next-Character Search

This notebook demonstrates genuine quantum computing on Markov chain data:

1. Build a classical bigram transition matrix from Shakespeare
2. Reduce to 8 most frequent characters (3 qubits) for step-by-step visualization
3. Use **Grover's algorithm** to find the most likely next character in O(sqrt(N)) instead of O(N)
4. Visualize amplitude amplification step-by-step
5. Compare quantum measurement distribution vs classical probabilities
6. Use Superstaq to optimize the circuit and show depth/gate reduction
7. Scale to the **full 65-character vocabulary** using 7 qubits

In [None]:
!pip install cirq cirq-superstaq numpy matplotlib seaborn -q

import sys
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import cirq

sys.path.insert(0, '.')
from model import CharVocab

print(f"Cirq version: {cirq.__version__}")

## Section 2: Classical Bigram Markov Chain

In [None]:
with open('data/shakespeare.txt', 'r', encoding='utf-8') as f:
    text = f.read()

vocab = CharVocab(text)
print(f"Corpus size: {len(text):,} characters")
print(f"Vocabulary size: {vocab.vocab_size}")

In [None]:
# Build full bigram count matrix
V = vocab.vocab_size
bigram_counts = np.zeros((V, V), dtype=np.int64)

encoded = vocab.encode(text)
for i in range(len(encoded) - 1):
    bigram_counts[encoded[i], encoded[i + 1]] += 1

# Normalize to transition probabilities
row_sums = bigram_counts.sum(axis=1, keepdims=True)
row_sums[row_sums == 0] = 1  # avoid division by zero
bigram_probs = bigram_counts / row_sums

print(f"Bigram count matrix shape: {bigram_counts.shape}")
print(f"Total bigrams: {bigram_counts.sum():,}")

In [None]:
# Identify top 8 most frequent characters
char_freq = np.zeros(V, dtype=np.int64)
for idx in encoded:
    char_freq[idx] += 1

top8_indices = np.argsort(char_freq)[::-1][:8]
top8_chars = [vocab.idx_to_char[i] for i in top8_indices]

def display_label(ch):
    if ch == '\n':
        return '\\n'
    elif ch == ' ':
        return 'SPC'
    elif ch == '\t':
        return '\\t'
    return ch

top8_labels = [display_label(ch) for ch in top8_chars]

# Map to 3-bit states
print("Top 8 characters (3-qubit encoding):")
print("-" * 40)
for i, (ch, label) in enumerate(zip(top8_chars, top8_labels)):
    bits = format(i, '03b')
    print(f"  |{bits}> = '{label}'  (freq: {char_freq[top8_indices[i]]:,})")

In [None]:
# Extract 8x8 sub-matrix and renormalize
sub_counts = bigram_counts[np.ix_(top8_indices, top8_indices)]
sub_row_sums = sub_counts.sum(axis=1, keepdims=True)
sub_row_sums[sub_row_sums == 0] = 1
sub_probs = sub_counts / sub_row_sums

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

sns.heatmap(sub_counts, ax=axes[0], annot=True, fmt='d',
            xticklabels=top8_labels, yticklabels=top8_labels, cmap='Blues')
axes[0].set_title('Bigram Counts (Top 8 Characters)')
axes[0].set_xlabel('Next character')
axes[0].set_ylabel('Current character')

sns.heatmap(sub_probs, ax=axes[1], annot=True, fmt='.3f',
            xticklabels=top8_labels, yticklabels=top8_labels, cmap='Reds')
axes[1].set_title('Transition Probabilities (Renormalized)')
axes[1].set_xlabel('Next character')
axes[1].set_ylabel('Current character')

plt.tight_layout()
plt.show()

## Section 3: Quantum Search Problem Setup

In [None]:
# Choose query character
query_char = 't'
query_idx_in_top8 = top8_chars.index(query_char)

transition_row = sub_probs[query_idx_in_top8]
target_idx = int(np.argmax(transition_row))
target_char = top8_chars[target_idx]
target_bits = format(target_idx, '03b')

print(f"Query character: '{display_label(query_char)}' (index {query_idx_in_top8})")
print(f"\nTransition probabilities from '{display_label(query_char)}':")
print("-" * 40)
for i, (label, prob) in enumerate(zip(top8_labels, transition_row)):
    marker = ' <-- TARGET (argmax)' if i == target_idx else ''
    print(f"  |{format(i, '03b')}> '{label}': {prob:.4f}{marker}")

print(f"\nGrover will amplify: |{target_bits}> = '{display_label(target_char)}'")
print(f"Classical probability: {transition_row[target_idx]:.4f}")

In [None]:
# Visualize the search problem
fig, ax = plt.subplots(figsize=(8, 4))
colors = ['#d32f2f' if i == target_idx else '#1976d2' for i in range(8)]
bit_labels = [f"|{format(i, '03b')}>\n'{l}'" for i, l in enumerate(top8_labels)]
ax.bar(bit_labels, transition_row, color=colors)
ax.set_ylabel('Transition probability')
ax.set_title(f"Classical bigram probabilities from '{display_label(query_char)}'\n"
             f"Target: |{target_bits}> = '{display_label(target_char)}'")
plt.tight_layout()
plt.show()

## Section 4: Build Grover Circuit in Cirq

In [None]:
qubits = cirq.LineQubit.range(3)
q0, q1, q2 = qubits

def make_oracle(qubits, target_bits):
    """Phase oracle: flip phase of |target_bits> using X-bracket + CCZ.
    CCZ = H on last qubit, then Toffoli, then H on last qubit."""
    q0, q1, q2 = qubits
    ops = []
    # X gates on qubits where target bit is 0
    for i, bit in enumerate(target_bits):
        if bit == '0':
            ops.append(cirq.X(qubits[i]))
    # CCZ via H-Toffoli-H
    ops.append(cirq.H(q2))
    ops.append(cirq.TOFFOLI(q0, q1, q2))
    ops.append(cirq.H(q2))
    # Undo X gates
    for i, bit in enumerate(target_bits):
        if bit == '0':
            ops.append(cirq.X(qubits[i]))
    return ops

def make_diffusion(qubits):
    """Diffusion operator: reflect about |+>^n (uniform superposition)."""
    q0, q1, q2 = qubits
    ops = []
    # H all
    ops.extend([cirq.H(q) for q in qubits])
    # X all
    ops.extend([cirq.X(q) for q in qubits])
    # CCZ via H-Toffoli-H
    ops.append(cirq.H(q2))
    ops.append(cirq.TOFFOLI(q0, q1, q2))
    ops.append(cirq.H(q2))
    # X all
    ops.extend([cirq.X(q) for q in qubits])
    # H all
    ops.extend([cirq.H(q) for q in qubits])
    return ops

print("Oracle ops for target", target_bits)
for op in make_oracle(qubits, target_bits):
    print(" ", op)
print("\nDiffusion ops:")
for op in make_diffusion(qubits):
    print(" ", op)

In [None]:
# Build full Grover circuit with 2 iterations
# Optimal iterations for N=8, k=1: round(pi/4 * sqrt(8)) = 2
n_iterations = round(np.pi / 4 * np.sqrt(8))
print(f"Optimal Grover iterations for N=8: {n_iterations}")

def build_grover_circuit(qubits, target_bits, iterations, measure=True):
    circuit = cirq.Circuit()
    # Initialize: H on all qubits
    circuit.append(cirq.H.on_each(*qubits))
    # Grover iterations
    for _ in range(iterations):
        circuit.append(make_oracle(qubits, target_bits))
        circuit.append(make_diffusion(qubits))
    if measure:
        circuit.append(cirq.measure(*qubits, key='result'))
    return circuit

grover_circuit = build_grover_circuit(qubits, target_bits, n_iterations)
print(f"\nGrover circuit ({n_iterations} iterations):")
print(grover_circuit)

In [None]:
# Circuit statistics
all_ops = list(grover_circuit.all_operations())
gate_count = len([op for op in all_ops if not isinstance(op.gate, cirq.MeasurementGate)])
depth = len(grover_circuit) - 1  # exclude measurement moment

print(f"Circuit depth (excluding measurement): {depth}")
print(f"Total gates (excluding measurement): {gate_count}")
print(f"Qubits: {len(qubits)}")

In [None]:
# Build circuits for step-by-step amplitude tracking
# We'll simulate state vectors after init, 1 iter, 2 iters, and 3 iters (over-rotation)
simulator = cirq.Simulator()

amplitude_snapshots = []
labels_for_plot = [f"|{format(i, '03b')}>" for i in range(8)]

for n_iter in range(4):  # 0, 1, 2, 3 iterations
    circuit_sv = build_grover_circuit(qubits, target_bits, n_iter, measure=False)
    result = simulator.simulate(circuit_sv)
    probs = np.abs(result.final_state_vector) ** 2
    amplitude_snapshots.append(probs)
    target_prob = probs[target_idx]
    print(f"After {n_iter} iteration(s): P(target |{target_bits}>) = {target_prob:.4f}")

## Section 5: Simulate and Analyze

In [None]:
# Run with 10,000 shots
result = simulator.run(grover_circuit, repetitions=10000)
counts = result.histogram(key='result')

print(f"Measurement results ({n_iterations} Grover iterations, 10,000 shots):")
print("-" * 50)
for state in sorted(counts.keys()):
    bits = format(state, '03b')
    label = display_label(top8_chars[state]) if state < len(top8_chars) else '?'
    marker = ' <-- TARGET' if state == target_idx else ''
    print(f"  |{bits}> '{label}': {counts[state]:,} ({counts[state]/100:.1f}%){marker}")

most_common = max(counts, key=counts.get)
print(f"\nMost common measurement: |{format(most_common, '03b')}> = '{display_label(top8_chars[most_common])}'")
print(f"Classical argmax:        |{target_bits}> = '{display_label(target_char)}'")
print(f"Match: {most_common == target_idx}")

In [None]:
# Amplitude evolution plot
fig, axes = plt.subplots(1, 4, figsize=(18, 4), sharey=True)
titles = ['Initialization\n(0 iterations)',
          'After 1 iteration',
          'After 2 iterations\n(optimal)',
          'After 3 iterations\n(over-rotation)']

for ax, probs, title in zip(axes, amplitude_snapshots, titles):
    colors = ['#d32f2f' if i == target_idx else '#1976d2' for i in range(8)]
    ax.bar(labels_for_plot, probs, color=colors)
    ax.set_title(title)
    ax.set_ylim(0, 1.05)
    ax.axhline(y=1/8, color='gray', linestyle='--', alpha=0.5, label='1/N')
    ax.tick_params(axis='x', rotation=45)
    target_p = probs[target_idx]
    ax.text(target_idx, target_p + 0.03, f'{target_p:.2f}',
            ha='center', fontsize=9, fontweight='bold', color='#d32f2f')

axes[0].set_ylabel('Probability')
fig.suptitle(f"Grover Amplitude Amplification: target |{target_bits}> = '{display_label(target_char)}'",
             fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
# Quantum vs Classical comparison
measurement_probs = np.zeros(8)
total_shots = sum(counts.values())
for state, count in counts.items():
    if state < 8:
        measurement_probs[state] = count / total_shots

x = np.arange(8)
width = 0.35

fig, ax = plt.subplots(figsize=(10, 5))
bars1 = ax.bar(x - width/2, transition_row, width, label='Classical bigram P(next|current)',
               color='#1976d2', alpha=0.8)
bars2 = ax.bar(x + width/2, measurement_probs, width, label='Grover measurement frequency',
               color='#d32f2f', alpha=0.8)

ax.set_xticks(x)
ax.set_xticklabels([f"|{format(i, '03b')}>\n'{l}'" for i, l in enumerate(top8_labels)])
ax.set_ylabel('Probability')
ax.set_title(f"Quantum vs Classical: next character after '{display_label(query_char)}'\n"
             f"Grover amplifies argmax to near-certainty")
ax.legend()
plt.tight_layout()
plt.show()

## Section 6: Superstaq Circuit Optimization

In [None]:
# Try Superstaq optimization, fall back to cirq gate merging if unavailable
optimized_circuit = None
optimization_method = None

try:
    import cirq_superstaq as css
    service = css.Service()
    # Build circuit without measurement for optimization
    circuit_to_optimize = build_grover_circuit(qubits, target_bits, n_iterations, measure=False)
    compiled = service.compile(
        circuit_to_optimize,
        target="ss_unconstrained_simulator"
    )
    optimized_circuit = compiled.circuit
    optimization_method = "Superstaq"
    print("Superstaq optimization successful!")
except Exception as e:
    print(f"Superstaq unavailable ({type(e).__name__}: {e})")
    print("Falling back to cirq gate merging optimization...")
    circuit_to_optimize = build_grover_circuit(qubits, target_bits, n_iterations, measure=False)
    optimized_circuit = cirq.merge_single_qubit_gates_to_phxz(circuit_to_optimize)
    cirq.drop_negligible_operations(optimized_circuit)
    cirq.drop_empty_moments(optimized_circuit)
    optimization_method = "cirq gate merging"
    print("Gate merging optimization complete.")

print(f"\nOptimization method: {optimization_method}")

In [None]:
# Compare original vs optimized
original_no_meas = build_grover_circuit(qubits, target_bits, n_iterations, measure=False)
orig_ops = list(original_no_meas.all_operations())
opt_ops = list(optimized_circuit.all_operations())

orig_depth = len(original_no_meas)
opt_depth = len(optimized_circuit)
orig_gates = len(orig_ops)
opt_gates = len(opt_ops)

print(f"Original circuit:  depth={orig_depth}, gates={orig_gates}")
print(f"Optimized circuit: depth={opt_depth}, gates={opt_gates}")
print(f"Depth reduction:   {orig_depth - opt_depth} ({(1 - opt_depth/orig_depth)*100:.1f}%)")
print(f"Gate reduction:    {orig_gates - opt_gates} ({(1 - opt_gates/orig_gates)*100:.1f}%)")

In [None]:
# Visualization: depth and gate count comparison
fig, axes = plt.subplots(1, 2, figsize=(10, 4))

categories = ['Original', f'Optimized\n({optimization_method})']
bar_colors = ['#1976d2', '#388e3c']

axes[0].bar(categories, [orig_depth, opt_depth], color=bar_colors)
axes[0].set_ylabel('Circuit Depth')
axes[0].set_title('Circuit Depth Comparison')
for i, v in enumerate([orig_depth, opt_depth]):
    axes[0].text(i, v + 0.3, str(v), ha='center', fontweight='bold')

axes[1].bar(categories, [orig_gates, opt_gates], color=bar_colors)
axes[1].set_ylabel('Gate Count')
axes[1].set_title('Gate Count Comparison')
for i, v in enumerate([orig_gates, opt_gates]):
    axes[1].text(i, v + 0.3, str(v), ha='center', fontweight='bold')

plt.tight_layout()
plt.show()

## Section 7: 7-Qubit Full Vocabulary (65 Characters)

In [None]:
# Scale to full 65-character vocabulary with 7 qubits (2^7 = 128 >= 65)
n_qubits_full = 7
N_full = 2 ** n_qubits_full  # 128
V_full = vocab.vocab_size     # 65

# Look up query character in the full vocabulary
query_full_idx = vocab.char_to_idx[query_char]
transition_row_full = bigram_probs[query_full_idx]
target_idx_full = int(np.argmax(transition_row_full))
target_char_full = vocab.idx_to_char[target_idx_full]
target_bits_full = format(target_idx_full, f'0{n_qubits_full}b')

print(f"7-qubit full vocabulary: {V_full} characters in {N_full}-dim Hilbert space")
print(f"Query: '{display_label(query_char)}' (idx {query_full_idx})")
print(f"Target: |{target_bits_full}> = '{display_label(target_char_full)}' (idx {target_idx_full})")
print(f"Target probability: {transition_row_full[target_idx_full]:.4f}")

# Print all 65 character labels with their qubit states
print(f"\nFull vocabulary encoding ({V_full} chars):")
print("-" * 50)
for i in range(V_full):
    ch = vocab.idx_to_char[i]
    bits = format(i, f'0{n_qubits_full}b')
    marker = ' <-- TARGET' if i == target_idx_full else ''
    prob_str = f'{transition_row_full[i]:.4f}' if i < len(transition_row_full) else ''
    print(f"  |{bits}> = '{display_label(ch)}'  P={prob_str}{marker}")
print(f"  (states {V_full}-{N_full-1} are unused padding)")

# Generalized oracle and diffusion for any qubit count
def make_oracle_nq(qubits, target_bits):
    """Phase oracle for n qubits: flip phase of |target_bits>."""
    ops = []
    for i, bit in enumerate(target_bits):
        if bit == '0':
            ops.append(cirq.X(qubits[i]))
    ops.append(cirq.Z(qubits[-1]).controlled_by(*qubits[:-1]))
    for i, bit in enumerate(target_bits):
        if bit == '0':
            ops.append(cirq.X(qubits[i]))
    return ops

def make_diffusion_nq(qubits):
    """Diffusion operator for n qubits: reflect about |+>^n."""
    ops = []
    ops.extend([cirq.H(q) for q in qubits])
    ops.extend([cirq.X(q) for q in qubits])
    ops.append(cirq.Z(qubits[-1]).controlled_by(*qubits[:-1]))
    ops.extend([cirq.X(q) for q in qubits])
    ops.extend([cirq.H(q) for q in qubits])
    return ops

# Optimal iterations: round(pi/4 * sqrt(N)) where N=128
n_iter_full = round(np.pi / 4 * np.sqrt(N_full))
print(f"\nOptimal Grover iterations for N={N_full}: {n_iter_full}")

# Build 7-qubit Grover circuit
qubits_full = cirq.LineQubit.range(n_qubits_full)
circuit_full = cirq.Circuit()
circuit_full.append(cirq.H.on_each(*qubits_full))
for _ in range(n_iter_full):
    circuit_full.append(make_oracle_nq(qubits_full, target_bits_full))
    circuit_full.append(make_diffusion_nq(qubits_full))
circuit_full.append(cirq.measure(*qubits_full, key='result'))

# Simulate with 10,000 shots
result_full = simulator.run(circuit_full, repetitions=10000)
counts_full = result_full.histogram(key='result')

# Show top-10 results
sorted_counts_full = sorted(counts_full.items(), key=lambda x: x[1], reverse=True)
print(f"\nTop-10 measurement results (7-qubit, {n_iter_full} iterations, 10,000 shots):")
print("-" * 60)
for state, count in sorted_counts_full[:10]:
    bits = format(state, f'0{n_qubits_full}b')
    if state < V_full:
        label = display_label(vocab.idx_to_char[state])
    else:
        label = f'pad({state})'
    marker = ' <-- TARGET' if state == target_idx_full else ''
    print(f"  |{bits}> '{label}': {count:,} ({count/100:.1f}%){marker}")

most_common_full = max(counts_full, key=counts_full.get)
print(f"\nGrover result: |{format(most_common_full, f'0{n_qubits_full}b')}> = '{display_label(vocab.idx_to_char[most_common_full]) if most_common_full < V_full else 'pad'}'")
print(f"Classical argmax: |{target_bits_full}> = '{display_label(target_char_full)}'")
print(f"Match: {most_common_full == target_idx_full}")

# Verify success probability via state vector simulation
circuit_sv_full = cirq.Circuit()
circuit_sv_full.append(cirq.H.on_each(*qubits_full))
for _ in range(n_iter_full):
    circuit_sv_full.append(make_oracle_nq(qubits_full, target_bits_full))
    circuit_sv_full.append(make_diffusion_nq(qubits_full))
sv_result_full = simulator.simulate(circuit_sv_full)
probs_full = np.abs(sv_result_full.final_state_vector) ** 2
print(f"Theoretical success probability: {probs_full[target_idx_full]:.4%}")

## Section 8: Summary

In [None]:
grover_result_3q = display_label(top8_chars[most_common])
classical_answer = display_label(target_char)
grover_result_7q = display_label(vocab.idx_to_char[most_common_full]) if most_common_full < V_full else f'pad({most_common_full})'
classical_answer_7q = display_label(target_char_full)

print("=" * 60)
print("  GROVER'S ALGORITHM FOR MARKOV CHAIN SEARCH")
print("=" * 60)
print(f"  Corpus:              {len(text):,} characters (Shakespeare)")
print(f"  Vocabulary:          {V_full} characters")
print(f"  Query character:     '{display_label(query_char)}'")
print(f"")
print(f"  --- 3-Qubit (8 characters) ---")
print(f"  Classical argmax:    '{classical_answer}'")
print(f"  Grover result:       '{grover_result_3q}'")
print(f"  Grover iterations:   {n_iterations} (vs 8 classical comparisons)")
print(f"  Circuit depth:       {depth}")
print(f"  Success probability: {amplitude_snapshots[n_iterations][target_idx]:.1%}")
print(f"")
print(f"  --- 7-Qubit (full {V_full}-char vocabulary) ---")
print(f"  Hilbert space:       {N_full} states ({V_full} chars + {N_full - V_full} padding)")
print(f"  Classical argmax:    '{classical_answer_7q}'")
print(f"  Grover result:       '{grover_result_7q}'")
print(f"  Grover iterations:   {n_iter_full} (vs {N_full} classical comparisons)")
print(f"  Success probability: {probs_full[target_idx_full]:.1%}")
print(f"")
print(f"  Speedup: O(sqrt(N)) vs O(N) -- quadratic advantage")
print("=" * 60)