# Gaussian Boson Sampling (GBS) benchmark

Expanded notebook: exact probabilities via SF Fock backend, sampling with Gaussian and Fock backends, KL divergence and plots.

In [None]:
# Setup and imports
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter
import json
import strawberryfields as sf
from strawberryfields.ops import Sgate, Interferometer, MeasureFock
from thewalrus import hafnian
from scripts.quantum_photonics.run_gbs import build_interferometer, compute_exact_probs_fock, sample_gbs, kl_between_empirical_and_exact
print('numpy', np.__version__)
print('strawberryfields', sf.__version__)

In [None]:
# Configuration
modes = 4
squeezing = 0.6
shots = 200
cutoff = 6
rng_seed = 42
np.random.seed(rng_seed)
U = build_interferometer(modes, seed=rng_seed)
print('U shape', U.shape)

In [None]:
# Exact probabilities (Fock backend)
exact = compute_exact_probs_fock(modes=modes, squeezing=squeezing, U=U, cutoff=cutoff)
nonzero = sorted((k,v) for k,v in exact.items() if v>0)[:10]
print('example nonzero probs', nonzero[:5])

In [None]:
# Sampling and KL calculation
samples_gauss = sample_gbs(modes=modes, squeezing=squeezing, U=U, backend='gaussian', shots=shots)
samples_fock = sample_gbs(modes=modes, squeezing=squeezing, U=U, backend='fock', shots=shots, cutoff=cutoff)
kl_gauss = kl_between_empirical_and_exact(samples_gauss, exact, cutoff=cutoff)
kl_fock = kl_between_empirical_and_exact(samples_fock, exact, cutoff=cutoff)
print('KL(gaussian || exact)=', kl_gauss)
print('KL(fock || exact)=', kl_fock)

In [None]:
# Visualize total photon counts
def totals(samples):
    return [int(sum(s)) for s in samples]
plt.hist(totals(samples_gauss), bins=range(0, 1 + max(totals(samples_gauss) + totals(samples_fock))), alpha=0.6, label='gaussian')
plt.hist(totals(samples_fock), bins=range(0, 1 + max(totals(samples_gauss) + totals(samples_fock))), alpha=0.6, label='fock')
plt.legend(); plt.title('Total photon counts (gaussian vs fock)'); plt.show()

In [None]:
# Save results summary
out = {'modes':modes, 'squeezing':squeezing, 'shots':shots, 'kl_gauss':float(kl_gauss), 'kl_fock':float(kl_fock)}
open('bundles/v23_toe_finish/v23/gbs_benchmark_results.json','w').write(json.dumps(out, indent=2))
print('Saved gbs_benchmark_results.json')

# GBS threshold sweep results (extended)

This section loads the extended threshold-sweep results and displays the JS vs eta plot and a short summary.


In [None]:
import json
from pathlib import Path
import matplotlib.pyplot as plt
from IPython.display import Image, display

repo = Path.cwd()
png = repo / 'bundles' / 'v23_toe_finish' / 'v23' / 'gbs_threshold_js_vs_eta_extended.png'
json_f = repo / 'bundles' / 'v23_toe_finish' / 'v23' / 'gbs_threshold_sweep_extended.json'

if png.exists():
    display(Image(str(png)))
else:
    print('Plot not found yet:', png)

# robust JSON loading (avoid crashing when file is empty/incomplete)
if json_f.exists():
    try:
        with open(json_f) as f:
            data = json.load(f)
    except Exception as e:
        print('Could not load JSON results (yet):', json_f, 'error:', e)
        data = None
    if data:
        top = sorted(data, key=lambda r: r['js'], reverse=True)[:6]
        import pandas as pd
        df = pd.DataFrame(top)
        display(df)
    else:
        print('JSON present but empty or corrupt yet:', json_f)
else:
    print('JSON results not found yet:', json_f)


# --- TDA prototype: references and quick summary
# Topological Data Analysis (TDA) prototype for GBS samples
# References: Nicolau+Levine+Carlsson 2011; Rabadan & Blumberg (TDA for Genomics); Emmett+Rabadan 2014; Benjamin et al. 2022

# This cell computes small-sample persistence diagrams from photon-count threshold
# point clouds and compares Wasserstein distances between conditions to the
# Jensen-Shannon divergence used elsewhere in the notebook.


In [None]:
# TDA prototype: imports and helpers
try:
    from ripser import ripser
    from persim import plot_diagrams, wasserstein
except Exception as e:
    print('Missing TDA packages (ripser/persim). Install with `pip install ripser persim giotto-tda`')
    raise

import numpy as np
import matplotlib.pyplot as plt
from scripts.quantum_photonics.run_gbs import build_interferometer, sample_gbs, compute_threshold_probs
from collections import Counter


def persistence_from_pointcloud(X, maxdim=1):
    """Compute persistence diagrams (ripser) from point cloud X (n x d)."""
    r = ripser(X, maxdim=maxdim)
    return r['dgms']


def betti_curve_from_diagram(dgm, ts):
    """Compute Betti curve (counts of intervals alive at t) for 1D diagram."""
    births = np.array([b for b,d in dgm if np.isfinite(b)])
    deaths = np.array([d if np.isfinite(d) else np.inf for b,d in dgm])
    counts = np.array([np.sum((births <= t) & (deaths > t)) for t in ts])
    return counts



In [None]:
# Quick run: compute persistence for small (modes, eta) grid and compare to JS
modes_list = [2,3,4]
squeezing = 0.6
shots = 500
etas = [0.4, 0.6, 0.8, 1.0]
results_tda = []

for modes in modes_list:
    U = build_interferometer(modes, seed=42)
    for eta in etas:
        # sample gaussian GBS and thresholds
        samples_gauss = sample_gbs(modes=modes, squeezing=squeezing, U=U, backend='gaussian', shots=shots)
        thresholds = (np.array(samples_gauss) > 0).astype(int)
        # reduce duplicates by uniq patterns expanded (for ripser it's OK to include duplicates)
        X = thresholds.astype(float)
        # compute persistence diagrams
        dgms = persistence_from_pointcloud(X, maxdim=1)
        # compute simple Betti curves on H0/H1
        ts = np.linspace(0, np.sqrt(modes)+1, 50)
        betti0 = betti_curve_from_diagram(dgms[0], ts)
        betti1 = betti_curve_from_diagram(dgms[1], ts)
        # compute simple JS by re-using threshold-prob approach (emp vs theory)
        counts = Counter(tuple(row) for row in thresholds)
        th_probs = compute_threshold_probs(modes=modes, squeezings=[squeezing]*modes, U=U, eta=eta)
        all_patterns = sorted(th_probs.keys())
        p_emp = np.array([counts.get(p,0)/shots for p in all_patterns])
        p_th = np.array([th_probs.get(p,0.0) for p in all_patterns])
        m = 0.5*(p_emp + p_th)
        from scipy.stats import entropy
        js = 0.5*(entropy(np.maximum(p_emp,1e-12), np.maximum(m,1e-12)) + entropy(np.maximum(p_th,1e-12), np.maximum(m,1e-12)))
        results_tda.append({'modes': modes, 'eta': eta, 'js': float(js), 'dgms': dgms, 'betti0': betti0.tolist(), 'betti1': betti1.tolist()})
        print('modes', modes, 'eta', eta, 'js', js)

# Plot example Betti curves (H0/H1) for modes=3
import matplotlib.pyplot as plt
for r in [res for res in results_tda if res['modes']==3]:
    plt.plot(ts, r['betti0'], label=f"eta={r['eta']} H0")
    plt.plot(ts, r['betti1'], linestyle='--', label=f"eta={r['eta']} H1")
plt.legend(); plt.title('Betti curves (modes=3)'); plt.xlabel('filtration value (approx)'); plt.ylabel('Betti count')
plt.show()

# Compute pairwise Wasserstein distances between diagrams across etas for a fixed mode
from persim import wasserstein
mode=3
dgms_by_eta = {r['eta']: r['dgms'] for r in results_tda if r['modes']==mode}
ws = {}
etas_sorted = sorted(dgms_by_eta.keys())
for i in range(len(etas_sorted)-1):
    a = etas_sorted[i]
    b = etas_sorted[i+1]
    d = wasserstein(dgms_by_eta[a][1], dgms_by_eta[b][1], matching=False)
    ws[(a,b)] = d

print('Wasserstein distances (H1) between neighboring etas for modes=3:', ws)

# Scatter: Wasserstein (avg neighboring) vs JS (mean across etas)
import numpy as np
w_vals = []
js_vals = []
for m in modes_list:
    rs = [r for r in results_tda if r['modes']==m]
    # compute average Wasserstein between adjacent etas
    ds = []
    rs_sorted = sorted(rs, key=lambda r: r['eta'])
    for i in range(len(rs_sorted)-1):
        d = wasserstein(rs_sorted[i]['dgms'][1], rs_sorted[i+1]['dgms'][1], matching=False)
        ds.append(d)
    w_vals.append(np.mean(ds) if ds else 0.0)
    js_vals.append(np.mean([r['js'] for r in rs]))

plt.scatter(w_vals, js_vals)
for i,m in enumerate(modes_list):
    plt.annotate(f'modes={m}', (w_vals[i], js_vals[i]))
plt.xlabel('avg Wasserstein (H1)'); plt.ylabel('avg JS'); plt.title('TDA Wasserstein vs JS (small-sample prototype)')
plt.grid(True); plt.show()

# Save small results to bundles for review
import json
repo = Path.cwd()
out = repo / 'bundles' / 'v23_toe_finish' / 'v23' / 'gbs_threshold_tda_proto.json'
open(out,'w').write(json.dumps([{'modes':r['modes'],'eta':r['eta'],'js':r['js']} for r in results_tda], indent=2))
print('Wrote', out)
