In [1]:
import sys
from datetime import datetime
from pathlib import Path

import numpy as np
import torch
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

sys.path.append(str(Path("..").resolve()))
from data_handling import load_measurements_npz, load_state_npz, MeasurementDataset, MeasurementLoader
from hyper_rbm import SymmetricHyperRBM, train_loop, get_sigmoid_curve, save_model
from wavefunction_overlap import generate_basis_states, calculate_exact_overlap, load_gt_wavefunction


data_dir = Path("measurements")
state_dir = Path("state_vectors")
models_dir = Path("models")
models_dir.mkdir(parents=True, exist_ok=True)

device = "cpu"
print(f"Running on: {device}")

SIDE_LENGTH = 3
FILE_SAMPLE_COUNT = 20_000
TRAIN_SAMPLE_COUNT = 10_000

h_support = [1.00, 1.40, 1.50, 2.00, 2.50, 3.00, 3.50, 4.00, 4.50]
h_novel = []
all_h_values = sorted(list(set(h_support + h_novel)))

file_names = [f"tfim_{SIDE_LENGTH}x{SIDE_LENGTH}_h{h:.2f}_{FILE_SAMPLE_COUNT}.npz" for h in h_support]
file_paths = [data_dir / fn for fn in file_names]

print(f"Support points: {h_support}")
print(f"Training on {TRAIN_SAMPLE_COUNT} samples per support point.")

samples_per_support = [TRAIN_SAMPLE_COUNT] * len(file_paths)
dataset = MeasurementDataset(file_paths, load_measurements_npz, ["h"], samples_per_support)

N_EPOCHS = 50
BATCH_SIZE = 1024
NUM_HIDDEN = 64
HYPER_NET_WIDTH = 64
K_STEPS = 10
GIBBS_NOISE_FRAC = 0.1
INIT_LR = 1e-2
FINAL_LR = 1e-4

all_seeds = np.arange(10) + 42  # 42..51


def train_and_eval_seed(seed: int) -> pd.DataFrame:
    seed = int(seed)
    print("\n" + "=" * 80)
    print(f"Training model with seed {seed}...")
    print("=" * 80)

    np.random.seed(seed)
    torch.manual_seed(seed)
    rng = torch.Generator().manual_seed(seed)

    loader = MeasurementLoader(
        dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=False, rng=rng
    )

    model = SymmetricHyperRBM(
        num_v=dataset.num_qubits,
        num_h=NUM_HIDDEN,
        hyper_dim=HYPER_NET_WIDTH,
        k=K_STEPS,
    ).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=INIT_LR)
    scheduler = get_sigmoid_curve(INIT_LR, FINAL_LR, N_EPOCHS * len(loader), 0.005)

    model = train_loop(
        model,
        optimizer,
        loader,
        num_epochs=N_EPOCHS,
        lr_schedule_fn=scheduler,
        noise_frac=GIBBS_NOISE_FRAC,
        rng=rng,
    )

    all_states = generate_basis_states(dataset.num_qubits, device)
    h_support_arr = np.asarray(h_support)

    overlap_rows = []
    for h_val in all_h_values:
        gt_path = state_dir / f"tfim_{SIDE_LENGTH}x{SIDE_LENGTH}_h{h_val:.2f}.npz"
        psi_true = load_gt_wavefunction(gt_path, device)

        overlap = calculate_exact_overlap(model, h_val, psi_true, all_states)
        split = "support" if np.isclose(h_support_arr, h_val, atol=1e-3).any() else "novel"

        overlap_rows.append(
            {"h": float(h_val), "overlap": float(overlap), "split": split, "seed": seed}
        )
        print(f"h={h_val:.2f} ({split:7}) | Overlap = {overlap:.5f}")

    overlap_df = pd.DataFrame(overlap_rows)

    # Save model with seed in filename
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"hyprbm_tfim_{SIDE_LENGTH}x{SIDE_LENGTH}_{TRAIN_SAMPLE_COUNT}_suscept_seed{seed}_{timestamp}.pt"
    save_path = models_dir / filename

    config = {
        "side_length": SIDE_LENGTH,
        "train_samples": TRAIN_SAMPLE_COUNT,
        "file_samples": FILE_SAMPLE_COUNT,
        "epochs": N_EPOCHS,
        "batch_size": BATCH_SIZE,
        "num_hidden": NUM_HIDDEN,
        "hyper_net_width": HYPER_NET_WIDTH,
        "k_steps": K_STEPS,
        "noise_frac": GIBBS_NOISE_FRAC,
        "init_lr": INIT_LR,
        "final_lr": FINAL_LR,
        "h_support": h_support,
        "h_novel": h_novel,
        "seed": seed,
        "device": device,
    }
    save_model(model, config, overlap_rows, save_path)

    return overlap_df


# Sequential execution over seeds (no parallel)
all_fidelities = []
for seed in tqdm(all_seeds, desc="Seeds"):
    df = train_and_eval_seed(seed)
    all_fidelities.append(df)


Running on: cpu
Support points: [1.0, 1.4, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5]
Training on 10000 samples per support point.


Seeds:   0%|          | 0/10 [00:00<?, ?it/s]


Training model with seed 42...
Epoch  | Loss       | LR        
-----------------------------------
1      | +0.0321     | 0.010000
10     | +0.0086     | 0.009987
20     | +0.0024     | 0.009017
30     | +0.0004     | 0.001092
40     | -0.0100     | 0.000114


Seeds:  10%|█         | 1/10 [01:47<16:04, 107.20s/it]

50     | -0.0060     | 0.000100
h=1.00 (support) | Overlap = 0.99982
h=1.40 (support) | Overlap = 0.99981
h=1.50 (support) | Overlap = 0.99977
h=2.00 (support) | Overlap = 0.99940
h=2.50 (support) | Overlap = 0.99934
h=3.00 (support) | Overlap = 0.99954
h=3.50 (support) | Overlap = 0.99972
h=4.00 (support) | Overlap = 0.99975
h=4.50 (support) | Overlap = 0.99973
Model saved to: models/hyprbm_tfim_3x3_10000_suscept_seed42_20260123_025126.pt

Training model with seed 43...
Epoch  | Loss       | LR        
-----------------------------------
1      | +0.0263     | 0.010000
10     | +0.0041     | 0.009987
20     | +0.0054     | 0.009017
30     | -0.0063     | 0.001092
40     | -0.0056     | 0.000114


Seeds:  20%|██        | 2/10 [03:34<14:18, 107.27s/it]

50     | -0.0094     | 0.000100
h=1.00 (support) | Overlap = 0.99977
h=1.40 (support) | Overlap = 0.99964
h=1.50 (support) | Overlap = 0.99953
h=2.00 (support) | Overlap = 0.99846
h=2.50 (support) | Overlap = 0.99733
h=3.00 (support) | Overlap = 0.99757
h=3.50 (support) | Overlap = 0.99826
h=4.00 (support) | Overlap = 0.99881
h=4.50 (support) | Overlap = 0.99906
Model saved to: models/hyprbm_tfim_3x3_10000_suscept_seed43_20260123_025314.pt

Training model with seed 44...
Epoch  | Loss       | LR        
-----------------------------------
1      | +0.0172     | 0.010000
10     | +0.0041     | 0.009987
20     | +0.0004     | 0.009017
30     | -0.0020     | 0.001092
40     | +0.0017     | 0.000114


Seeds:  30%|███       | 3/10 [05:14<12:08, 104.06s/it]

50     | -0.0007     | 0.000100
h=1.00 (support) | Overlap = 0.99978
h=1.40 (support) | Overlap = 0.99977
h=1.50 (support) | Overlap = 0.99971
h=2.00 (support) | Overlap = 0.99927
h=2.50 (support) | Overlap = 0.99899
h=3.00 (support) | Overlap = 0.99934
h=3.50 (support) | Overlap = 0.99971
h=4.00 (support) | Overlap = 0.99976
h=4.50 (support) | Overlap = 0.99966
Model saved to: models/hyprbm_tfim_3x3_10000_suscept_seed44_20260123_025454.pt

Training model with seed 45...
Epoch  | Loss       | LR        
-----------------------------------
1      | +0.0252     | 0.010000
10     | +0.0042     | 0.009987
20     | +0.0003     | 0.009017
30     | -0.0070     | 0.001092
40     | +0.0039     | 0.000114


Seeds:  40%|████      | 4/10 [06:51<10:08, 101.34s/it]

50     | -0.0063     | 0.000100
h=1.00 (support) | Overlap = 0.99981
h=1.40 (support) | Overlap = 0.99978
h=1.50 (support) | Overlap = 0.99973
h=2.00 (support) | Overlap = 0.99943
h=2.50 (support) | Overlap = 0.99939
h=3.00 (support) | Overlap = 0.99953
h=3.50 (support) | Overlap = 0.99964
h=4.00 (support) | Overlap = 0.99966
h=4.50 (support) | Overlap = 0.99973
Model saved to: models/hyprbm_tfim_3x3_10000_suscept_seed45_20260123_025631.pt

Training model with seed 46...
Epoch  | Loss       | LR        
-----------------------------------
1      | +0.0249     | 0.010000
10     | +0.0123     | 0.009987
20     | +0.0019     | 0.009017
30     | -0.0031     | 0.001092
40     | -0.0114     | 0.000114


Seeds:  50%|█████     | 5/10 [08:35<08:30, 102.20s/it]

50     | -0.0069     | 0.000100
h=1.00 (support) | Overlap = 0.99979
h=1.40 (support) | Overlap = 0.99965
h=1.50 (support) | Overlap = 0.99953
h=2.00 (support) | Overlap = 0.99836
h=2.50 (support) | Overlap = 0.99706
h=3.00 (support) | Overlap = 0.99709
h=3.50 (support) | Overlap = 0.99776
h=4.00 (support) | Overlap = 0.99815
h=4.50 (support) | Overlap = 0.99821
Model saved to: models/hyprbm_tfim_3x3_10000_suscept_seed46_20260123_025815.pt

Training model with seed 47...
Epoch  | Loss       | LR        
-----------------------------------
1      | +0.0346     | 0.010000
10     | +0.0022     | 0.009987
20     | +0.0045     | 0.009017
30     | -0.0088     | 0.001092
40     | -0.0081     | 0.000114


Seeds:  60%|██████    | 6/10 [10:14<06:44, 101.06s/it]

50     | -0.0066     | 0.000100
h=1.00 (support) | Overlap = 0.99974
h=1.40 (support) | Overlap = 0.99975
h=1.50 (support) | Overlap = 0.99966
h=2.00 (support) | Overlap = 0.99880
h=2.50 (support) | Overlap = 0.99773
h=3.00 (support) | Overlap = 0.99780
h=3.50 (support) | Overlap = 0.99867
h=4.00 (support) | Overlap = 0.99927
h=4.50 (support) | Overlap = 0.99939
Model saved to: models/hyprbm_tfim_3x3_10000_suscept_seed47_20260123_025954.pt

Training model with seed 48...
Epoch  | Loss       | LR        
-----------------------------------
1      | +0.0213     | 0.010000
10     | +0.0028     | 0.009987
20     | +0.0047     | 0.009017
30     | -0.0003     | 0.001092
40     | -0.0060     | 0.000114


Seeds:  70%|███████   | 7/10 [11:52<04:59, 99.99s/it] 

50     | +0.0030     | 0.000100
h=1.00 (support) | Overlap = 0.99983
h=1.40 (support) | Overlap = 0.99982
h=1.50 (support) | Overlap = 0.99979
h=2.00 (support) | Overlap = 0.99941
h=2.50 (support) | Overlap = 0.99918
h=3.00 (support) | Overlap = 0.99939
h=3.50 (support) | Overlap = 0.99964
h=4.00 (support) | Overlap = 0.99972
h=4.50 (support) | Overlap = 0.99966
Model saved to: models/hyprbm_tfim_3x3_10000_suscept_seed48_20260123_030131.pt

Training model with seed 49...
Epoch  | Loss       | LR        
-----------------------------------
1      | +0.0277     | 0.010000
10     | +0.0035     | 0.009987
20     | +0.0006     | 0.009017
30     | -0.0045     | 0.001092
40     | -0.0074     | 0.000114


Seeds:  80%|████████  | 8/10 [13:37<03:23, 101.62s/it]

50     | -0.0063     | 0.000100
h=1.00 (support) | Overlap = 0.99980
h=1.40 (support) | Overlap = 0.99980
h=1.50 (support) | Overlap = 0.99979
h=2.00 (support) | Overlap = 0.99947
h=2.50 (support) | Overlap = 0.99910
h=3.00 (support) | Overlap = 0.99927
h=3.50 (support) | Overlap = 0.99962
h=4.00 (support) | Overlap = 0.99970
h=4.50 (support) | Overlap = 0.99962
Model saved to: models/hyprbm_tfim_3x3_10000_suscept_seed49_20260123_030317.pt

Training model with seed 50...
Epoch  | Loss       | LR        
-----------------------------------
1      | +0.0206     | 0.010000
10     | +0.0057     | 0.009987
20     | +0.0039     | 0.009017
30     | -0.0096     | 0.001092
40     | -0.0155     | 0.000114


Seeds:  90%|█████████ | 9/10 [15:19<01:41, 101.66s/it]

50     | -0.0081     | 0.000100
h=1.00 (support) | Overlap = 0.99983
h=1.40 (support) | Overlap = 0.99975
h=1.50 (support) | Overlap = 0.99967
h=2.00 (support) | Overlap = 0.99896
h=2.50 (support) | Overlap = 0.99829
h=3.00 (support) | Overlap = 0.99851
h=3.50 (support) | Overlap = 0.99913
h=4.00 (support) | Overlap = 0.99950
h=4.50 (support) | Overlap = 0.99962
Model saved to: models/hyprbm_tfim_3x3_10000_suscept_seed50_20260123_030458.pt

Training model with seed 51...
Epoch  | Loss       | LR        
-----------------------------------
1      | +0.0434     | 0.010000
10     | +0.0007     | 0.009987
20     | +0.0064     | 0.009017
30     | -0.0073     | 0.001092
40     | -0.0049     | 0.000114


Seeds: 100%|██████████| 10/10 [16:51<00:00, 101.13s/it]

50     | -0.0111     | 0.000100
h=1.00 (support) | Overlap = 0.99979
h=1.40 (support) | Overlap = 0.99970
h=1.50 (support) | Overlap = 0.99960
h=2.00 (support) | Overlap = 0.99880
h=2.50 (support) | Overlap = 0.99799
h=3.00 (support) | Overlap = 0.99838
h=3.50 (support) | Overlap = 0.99902
h=4.00 (support) | Overlap = 0.99934
h=4.50 (support) | Overlap = 0.99942
Model saved to: models/hyprbm_tfim_3x3_10000_suscept_seed51_20260123_030630.pt





In [None]:
# --- Aggregate and plot results ---
all_fid_df = pd.concat(all_fidelities, ignore_index=True)

grouped = all_fid_df.groupby(["h", "split"])["overlap"]
mean_fid = grouped.mean().reset_index()
std_fid = grouped.std().reset_index()
std_fid["overlap"] = std_fid["overlap"].fillna(0.0)

plt.figure(figsize=(10, 6), dpi=100)

mean_fid_sorted = mean_fid.sort_values("h")
std_fid_sorted = std_fid.sort_values("h")

# plot per-point mean with errorbars; marker indicates split
seen_label = set()
for _, row in mean_fid_sorted.iterrows():
    h = row["h"]
    y = row["overlap"]
    split = row["split"]

    yerr = std_fid_sorted[
        (std_fid_sorted["h"] == h) & (std_fid_sorted["split"] == split)
        ]["overlap"].values
    yerr = float(yerr[0]) if len(yerr) else 0.0

    if split == "support":
        label = "Support" if "Support" not in seen_label else ""
        plt.errorbar(
            h, y, yerr=yerr, fmt="o",
            color="tab:blue",
            markersize=8, markerfacecolor="tab:blue", markeredgecolor="tab:blue",
            capsize=4, label=label
        )
        seen_label.add("Support")
    else:
        label = "Novel" if "Novel" not in seen_label else ""
        plt.errorbar(
            h, y, yerr=yerr, fmt="d",
            color="tab:blue",
            markersize=8, markerfacecolor="white", markeredgecolor="tab:blue",
            capsize=4, label=label
        )
        seen_label.add("Novel")

# line through all mean points (regardless of split)
plt.plot(
    mean_fid_sorted["h"], mean_fid_sorted["overlap"],
    "-", color="tab:blue", alpha=0.7, zorder=0, label="Mean"
)

plt.xlabel(r"Transverse Field", fontsize=12)
plt.ylabel(r"Overlap", fontsize=12)
plt.ylim(0.95, 1.002)
plt.grid(True, alpha=0.3)

handles, labels = plt.gca().get_legend_handles_labels()
by_label = dict(zip(labels, handles))
plt.legend(by_label.values(), by_label.keys(), loc="lower right", fontsize=10)

plt.tight_layout()
plt.show()


# --- Plot all seeds individually ---
plt.figure(figsize=(10, 6), dpi=100)

for seed in sorted(all_fid_df["seed"].unique()):
    df = all_fid_df[all_fid_df["seed"] == seed].sort_values("h")

    plt.plot(df["h"], df["overlap"], "-", alpha=0.4, label=f"Seed {seed}")

    support = df[df["split"] == "support"]
    novel = df[df["split"] == "novel"]

    if len(support):
        plt.plot(
            support["h"], support["overlap"], "o",
            color="tab:blue",
            markersize=6, markerfacecolor="tab:blue", markeredgecolor="tab:blue",
            alpha=0.7
        )
    if len(novel):
        plt.plot(
            novel["h"], novel["overlap"], "d",
            color="tab:blue",
            markersize=6, markerfacecolor="white", markeredgecolor="tab:blue",
            alpha=0.7
        )

plt.xlabel(r"Transverse Field", fontsize=12)
plt.ylabel(r"Overlap", fontsize=12)
plt.ylim(0.95, 1.002)
plt.grid(True, alpha=0.3)
plt.legend(loc="lower right", fontsize=10, ncol=2)
plt.tight_layout()
plt.show()
