In [84]:
import sys
import math
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm
import matplotlib.pyplot as plt
import pandas as pd

sys.path.append(str(Path("..").resolve()))
from data_handling import load_state_npz  # if you need it elsewhere

device = "cpu"
models_dir = Path("models")

print(f"Running on: {device}")
print(f"Looking for models in: {models_dir.resolve()}")


Running on: cpu
Looking for models in: /Users/Tonni/Desktop/master-code/neural-quantum-tomo/experiments_polish/tfim_4x4/models


In [85]:
class Conditioner(nn.Module):
    def __init__(self, num_visible: int, num_hidden: int, cond_dim: int, hidden_width: int):
        super().__init__()
        self.fc1 = nn.Linear(cond_dim, hidden_width)
        self.fc2 = nn.Linear(hidden_width, 2 * (num_visible + num_hidden))
        self.num_visible = num_visible
        self.num_hidden = num_hidden

    def forward(self, cond: torch.Tensor):
        x = torch.tanh(self.fc1(cond))
        x = self.fc2(x)
        return torch.split(
            x,
            [self.num_visible, self.num_visible, self.num_hidden, self.num_hidden],
            dim=-1
        )


class ConditionalRBM(nn.Module):
    def __init__(self, num_visible: int, num_hidden: int, cond_dim: int,
                 conditioner_width: int = 64, gibbs_k: int = 10, T: float = 1.0):
        super().__init__()
        self.num_visible = num_visible
        self.num_hidden = num_hidden
        self.gibbs_k = gibbs_k
        self.T = T
        self.W = nn.Parameter(torch.empty(num_visible, num_hidden))
        self.b = nn.Parameter(torch.zeros(num_visible))
        self.c = nn.Parameter(torch.zeros(num_hidden))
        self.conditioner = Conditioner(num_visible, num_hidden, cond_dim, conditioner_width)

    def _free_energy(self, v: torch.Tensor, b_mod: torch.Tensor, c_mod: torch.Tensor) -> torch.Tensor:
        v = v.to(dtype=self.W.dtype, device=self.W.device)
        v_W = v @ self.W
        W_sum = self.W.sum(dim=0)

        linear_v = v_W + c_mod
        linear_flip = W_sum.unsqueeze(0) - v_W + c_mod

        term2_v = F.softplus(linear_v).sum(dim=-1)
        term2_f = F.softplus(linear_flip).sum(dim=-1)
        term1_v = -(v * b_mod).sum(dim=-1)
        term1_f = -((1.0 - v) * b_mod).sum(dim=-1)

        fe_v = term1_v - term2_v
        fe_flipped = term1_f - term2_f

        stacked = torch.stack([-fe_v, -fe_flipped], dim=-1)
        return -self.T * torch.logsumexp(stacked / self.T, dim=-1)

    def _compute_effective_biases(self, cond: torch.Tensor):
        gamma_b, beta_b, gamma_c, beta_c = self.conditioner(cond)
        if cond.dim() == 1:
            b_mod = (1.0 + gamma_b) * self.b + beta_b
            c_mod = (1.0 + gamma_c) * self.c + beta_c
        else:
            b_mod = (1.0 + gamma_b) * self.b.unsqueeze(0) + beta_b
            c_mod = (1.0 + gamma_c) * self.c.unsqueeze(0) + beta_c
        return b_mod, c_mod

    def _gibbs_step(self, v: torch.Tensor, b_mod: torch.Tensor,
                    c_mod: torch.Tensor, rng: torch.Generator):
        p_h = torch.sigmoid((v @ self.W + c_mod) / self.T)
        h = torch.bernoulli(p_h, generator=rng)
        p_v = torch.sigmoid((h @ self.W.t() + b_mod) / self.T)
        return torch.bernoulli(p_v, generator=rng)

    def log_score(self, v: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
        b_mod, c_mod = self._compute_effective_biases(cond)
        return -0.5 * self._free_energy(v, b_mod, c_mod) / self.T

    @torch.no_grad()
    def generate(self, cond: torch.Tensor, n_samples: int, rng: torch.Generator):
        cond = cond.to(device)
        if cond.dim() == 1:
            cond = cond.expand(n_samples, -1)

        b_mod, c_mod = self._compute_effective_biases(cond)

        v = torch.bernoulli(torch.full((n_samples, self.num_visible), 0.5, device=device), generator=rng)

        for _ in range(self.gibbs_k):
            v = self._gibbs_step(v, b_mod, c_mod, rng)

        return v


In [86]:
def load_model():
    latest_path = Path("./models/crbm_tfim_4x4_50000_suscept_20251211_161844.pt")
    #latest_path = Path("./models/crbm_tfim_4x4_100000_suscept_20251211_165059.pt")
    #latest_path = Path("./models/crbm_tfim_4x4_50000_suscept_20251212_202305.pt")

    print(f"Loading checkpoint: {latest_path.name}")
    checkpoint = torch.load(latest_path, map_location=device)
    config = checkpoint["config"]

    model = ConditionalRBM(
        num_visible=config["num_visible"],
        num_hidden=config["num_hidden"],
        cond_dim=1,
        gibbs_k=config.get("gibbs_k", 10),
    ).to(device)

    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()
    return model, config

model, config = load_model()
h_support = config.get("h_support", [])
GEN_SIDE_LENGTH = int(math.sqrt(model.num_visible))

print(f"Model Loaded. System Size: {GEN_SIDE_LENGTH}x{GEN_SIDE_LENGTH}")
print(f"Support Points: {h_support}")


Loading checkpoint: crbm_tfim_4x4_50000_suscept_20251211_161844.pt
Model Loaded. System Size: 4x4
Support Points: [1.0, 1.2, 1.4, 2.0, 2.5, 3.0, 4.0, 5.0, 6.0, 7.0]


In [87]:
def all_binary_states(num_visible: int, device="cpu", dtype=torch.float32):
    ints = torch.arange(2**num_visible, device=device, dtype=torch.int64)
    bits = ((ints[:, None] >> torch.arange(num_visible, device=device)) & 1)
    return bits.to(dtype=dtype)

dtype = next(model.parameters()).dtype
all_v = all_binary_states(model.num_visible, device=device, dtype=dtype)
print("All configs:", all_v.shape)  # (65536, 16) for 4x4


All configs: torch.Size([65536, 16])


In [88]:
@torch.no_grad()
def rbm_state_vector_batched(model, all_v, h_val: float, batch_size: int = 8192):
    n = all_v.shape[0]
    dtype = next(model.parameters()).dtype

    # 1st pass: compute max logpsi for stability
    max_lp = -1e30
    logpsi_chunks = []
    for start in range(0, n, batch_size):
        v = all_v[start:start+batch_size]
        cond = torch.full((v.shape[0], 1), float(h_val), device=v.device, dtype=dtype)
        lp = model.log_score(v, cond)  # (batch,)
        logpsi_chunks.append(lp.cpu())
        max_lp = max(max_lp, float(lp.max().cpu()))

    # 2nd pass: exp(logpsi-max) and normalize in L2
    norm2 = 0.0
    psi_chunks = []
    for lp in logpsi_chunks:
        x = torch.exp(lp - max_lp)
        norm2 += float(torch.sum(x * x))
        psi_chunks.append(x)

    norm = math.sqrt(norm2)
    psi = torch.cat(psi_chunks, dim=0) / norm
    return psi.numpy()  # real-positive wavefunction

# choose surface grid
h_values_surf = np.linspace(1.0, 4.0, 25)


In [89]:
psis = []
for h in tqdm(h_values_surf, desc="Building RBM states"):
    psis.append(rbm_state_vector_batched(model, all_v, h, batch_size=8192))

psis = np.stack(psis, axis=0)  # (Nh, 2^N)
print("psis:", psis.shape)


Building RBM states: 100%|██████████| 25/25 [00:01<00:00, 23.00it/s]

psis: (25, 65536)





In [90]:
# overlaps S_ij = <psi_i | psi_j>
S = psis @ psis.T
F = (np.abs(S) ** 2).astype(np.float64)

logF = F #np.log(F + 1e-12)


In [93]:
import numpy as np
import plotly.graph_objects as go

def _get(npz, *keys):
    for k in keys:
        if k in npz.files:
            return npz[k]
    raise KeyError(f"None of {keys} found. Available keys: {npz.files}")

# --- in-memory surface (your current fidelity values) ---
h_mem = np.asarray(h_values_surf if "h_values_surf" in globals() else h_values, dtype=float)
F_mem = np.asarray(logF, dtype=float)  # NOTE: using your variable name; values are fidelity

# --- file surface (reference) ---
D = np.load("tfim_4x4_fidelity_surface.npz")
h_file = np.asarray(_get(D, "h_values", "h"), dtype=float)
F_file = np.asarray(_get(D, "F", "logF"), dtype=float)

# require same grid for direct difference
if not (h_mem.shape == h_file.shape and np.allclose(h_mem, h_file) and F_mem.shape == F_file.shape):
    raise ValueError(
        f"Grids/shapes don't match.\n"
        f"h_mem: {h_mem.shape}, h_file: {h_file.shape}, F_mem: {F_mem.shape}, F_file: {F_file.shape}\n"
        f"If you want, I can add interpolation to a common grid."
    )

h = h_mem
H1, H2 = np.meshgrid(h, h)

# avoid z-fighting in overlay
z_min = float(min(np.min(F_mem), np.min(F_file)))
z_max = float(max(np.max(F_mem), np.max(F_file))) + 0.5
z_range = z_max - z_min
z_eps = 1e-3 * (z_range + 1e-12)

# difference carpet
dF = F_mem - F_file
absmax = float(np.max(np.abs(dF))) + 1e-12
z_floor = z_min

fig = go.Figure()

# file surface
fig.add_trace(go.Surface(
    x=H1, y=H2, z=F_file+0.5,
    colorscale="Viridis",
    showscale=False,
    opacity=0.90,
    name="file fidelity"
))

fig.add_trace(go.Surface(
    x=H1, y=H2, z=F_mem+0.5,
    colorscale="Magma",
    showscale=False,
    opacity=0.55,
    name="memory fidelity"
))

# difference "carpet" (flat z, colored by dF)
fig.add_trace(go.Surface(
    x=H1, y=H2, z=np.full_like(dF, z_floor),
    surfacecolor=dF,
    colorscale="RdBu",
    cmin=-absmax, cmax=absmax,
    showscale=True,
    opacity=0.95,
    name="ΔF (mem-file)",
    colorbar=dict(title="ΔF", len=0.6)
))

fig.update_layout(
    title="TFIM 4x4 - Fidelity Surfaces + ΔF Color Carpet",
    margin=dict(l=0, r=0, b=0, t=40),
    scene=dict(
        xaxis_title="h",
        yaxis_title="h'",
        zaxis_title="Fidelity F(h,h')",
        zaxis=dict(range=[z_floor, z_max])
    ),
)

fig.show()