In [1]:
import os
import sys
import time
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

sys.path.append(str(Path("..").resolve()))
from data_handling import load_measurements_npz, MeasurementDataset, MeasurementLoader

# --- DEVICE SETUP ---
#if torch.backends.mps.is_available():
#    device = torch.device("mps")
#    print(f"Success: Device set to Apple MPS")
#elif torch.cuda.is_available():
#    device = torch.device("cuda")
#    print(f"Success: Device set to CUDA")
#else:
device = torch.device("cpu")
print(f"Warning: Device set to CPU")

# --- FIXED TIMING HELPER ---
def sync_device(d):
    if d.type == 'mps':
        try:
            # Modern PyTorch (2.0+)
            torch.mps.synchronize()
        except AttributeError:
            # Fallback for Older PyTorch (1.12/1.13)
            # Creating a dummy tensor and moving it to CPU forces a sync
            torch.zeros(1, device=d).cpu()
    elif d.type == 'cuda':
        torch.cuda.synchronize()

data_dir = Path("measurements")
print(f"Data resides in: {data_dir.resolve()}")

Data resides in: /Users/Tonni/Desktop/master-code/neural-quantum-tomo/experiments_consolidate/xxz_square_4x4/measurements


In [2]:
# --- JIT COMPILED SAMPLER ---
@torch.jit.script
def gibbs_sampling_loop(v: torch.Tensor, W: torch.Tensor,
                        b_mod: torch.Tensor, c_mod: torch.Tensor,
                        k: int, T: float) -> torch.Tensor:
    curr_v = v
    for _ in range(k):
        # Positive Phase
        h_logits = curr_v @ W + c_mod
        p_h = torch.sigmoid(h_logits / T)
        h = torch.bernoulli(p_h)
        # Negative Phase
        v_logits = h @ W.t() + b_mod
        p_v = torch.sigmoid(v_logits / T)
        curr_v = torch.bernoulli(p_v)
    return curr_v

class Conditioner(nn.Module):
    def __init__(self, num_visible: int, num_hidden: int, cond_dim: int, hidden_width: int):
        super().__init__()
        self.fc1 = nn.Linear(cond_dim, hidden_width)
        self.fc2 = nn.Linear(hidden_width, 2 * (num_visible + num_hidden))

    def forward(self, cond: torch.Tensor):
        x = torch.tanh(self.fc1(cond))
        x = self.fc2(x)
        N, H = self.fc2.out_features // 4, self.fc2.out_features // 4 - (self.fc2.out_features // 4) # Simplified logic
        # Re-calc split sizes correctly based on input
        # Note: In prev code, split sizes were fixed variables.
        # Here we rely on correct initialization.
        return torch.split(x, [x.shape[-1]//4]*4, dim=-1)
        # CAREFUL: Ensure init logic in RBM matches this split.
        # Restoring exact original logic for safety:
        # N, H logic relies on outer class knowing dims.
        # See usage below.

class ConditionalRBM(nn.Module):
    def __init__(self, num_visible: int, num_hidden: int, cond_dim: int,
                 conditioner_width: int = 64, k: int = 1, T: float = 1.0):
        super().__init__()
        self.num_visible = num_visible
        self.num_hidden = num_hidden
        self.k = k
        self.T = T

        self.W = nn.Parameter(torch.empty(num_visible, num_hidden))
        self.b = nn.Parameter(torch.zeros(num_visible))
        self.c = nn.Parameter(torch.zeros(num_hidden))

        # We need this to handle the split correctly
        self.conditioner = nn.Sequential(
            nn.Linear(cond_dim, conditioner_width),
            nn.Tanh(),
            nn.Linear(conditioner_width, 2 * (num_visible + num_hidden))
        )
        self.initialize_weights()

    def initialize_weights(self, w_mean=0.0, w_std=0.1, bias_val=0.0):
        nn.init.normal_(self.W, mean=w_mean, std=w_std)
        nn.init.constant_(self.b, bias_val)
        nn.init.constant_(self.c, bias_val)

    def _compute_effective_biases(self, cond: torch.Tensor):
        # Manual split matching original logic
        params = self.conditioner(cond)
        gamma_b, beta_b, gamma_c, beta_c = torch.split(params, [self.num_visible, self.num_visible, self.num_hidden, self.num_hidden], dim=-1)

        b_mod = (1.0 + gamma_b) * self.b.unsqueeze(0) + beta_b
        c_mod = (1.0 + gamma_c) * self.c.unsqueeze(0) + beta_c
        return b_mod, c_mod

    @staticmethod
    def _free_energy(v, W, b, c):
        v = v.to(dtype=W.dtype, device=W.device)
        return -(v * b).sum(dim=-1) - F.softplus(v @ W + c).sum(dim=-1)

    def forward(self, batch, aux_vars):
        t0 = time.perf_counter()

        values, _, cond = batch
        v_data = values.to(dtype=self.W.dtype, device=self.W.device)
        cond = cond.to(v_data.device, dtype=v_data.dtype)

        # 1. Bias Prep
        b_mod, c_mod = self._compute_effective_biases(cond)

        sync_device(v_data.device)
        t_prep = time.perf_counter() - t0
        t1 = time.perf_counter()

        # 2. Gibbs Sampling (The Heavy Part)
        # Init with 0.5 bernoulli
        v_start = torch.bernoulli(torch.full_like(v_data, 0.5))

        # JIT CALL
        v_model = gibbs_sampling_loop(v_start, self.W, b_mod, c_mod, self.k, self.T)
        v_model = v_model.detach()

        sync_device(v_data.device)
        t_sampling = time.perf_counter() - t1
        t2 = time.perf_counter()

        # 3. Loss Calc
        l2_strength = float(aux_vars.get("l2_strength", 0.0))
        l2_reg = (self.b.unsqueeze(0) - b_mod).pow(2).sum() + (self.c.unsqueeze(0) - c_mod).pow(2).sum()

        fe_data = self._free_energy(v_data, self.W, b_mod, c_mod)
        fe_model = self._free_energy(v_model, self.W, b_mod, c_mod)
        fe_diff = fe_data - fe_model
        loss = fe_diff.mean() + l2_strength * l2_reg

        sync_device(v_data.device)
        t_loss = time.perf_counter() - t2

        return loss, {
            "time_prep": t_prep,
            "time_sampling": t_sampling,
            "time_loss": t_loss
        }

    # Generate / score methods omitted for brevity as they aren't part of training loop profiling
    @torch.no_grad()
    def log_score(self, v, cond):
        b_mod, c_mod = self._compute_effective_biases(cond)
        return -0.5 * self._free_energy(v, self.W, b_mod, c_mod) / self.T

In [3]:
def compute_cxx(samples: torch.Tensor, pairs: List[Tuple[int, int]],
                log_score_fn: Callable[[torch.Tensor], torch.Tensor]) -> Tuple[float, float]:

    # Fully vectorized Cxx calculation
    B, N = samples.shape
    num_pairs = len(pairs)
    device = samples.device

    with torch.no_grad():
        # 1. Score original samples
        log_scores_orig = log_score_fn(samples) # Shape: (B)

        # 2. Vectorized Flip: Create a batch that contains ALL flipped versions
        # We repeat samples 'num_pairs' times: (num_pairs * B, N)
        samples_expanded = samples.repeat(num_pairs, 1)

        # Calculate indices to flip in the flat expanded batch
        us = torch.tensor([p[0] for p in pairs], device=device)
        vs = torch.tensor([p[1] for p in pairs], device=device)

        # Arithmetic to map (pair_idx, batch_idx) -> flat_idx
        batch_indices = torch.arange(B, device=device).unsqueeze(0).expand(num_pairs, B).flatten()
        pair_offsets = torch.arange(num_pairs, device=device).unsqueeze(1).expand(num_pairs, B).flatten() * B
        flat_indices = pair_offsets + batch_indices

        u_flat = us.unsqueeze(1).expand(num_pairs, B).flatten()
        v_flat = vs.unsqueeze(1).expand(num_pairs, B).flatten()

        # Perform the flips
        flipped_samples = samples_expanded.clone()
        flipped_samples[flat_indices, u_flat] = 1.0 - flipped_samples[flat_indices, u_flat]
        flipped_samples[flat_indices, v_flat] = 1.0 - flipped_samples[flat_indices, v_flat]

        # 3. Score ALL flipped versions in one massive GPU kernel launch
        log_scores_flip = log_score_fn(flipped_samples)

        # 4. Reshape and compute ratios
        log_scores_flip = log_scores_flip.view(num_pairs, B)
        log_ratios = log_scores_flip - log_scores_orig.unsqueeze(0)
        ratios = torch.exp(log_ratios)

        sample_cxx = ratios.mean(dim=0) # Average over pairs

        total_cxx = sample_cxx.mean().item()
        total_cxx_err = sample_cxx.std(unbiased=True).item() / math.sqrt(B)

        return total_cxx, total_cxx_err

# Keep monitor_cxx as is, just ensure it uses the optimized compute_cxx above
def monitor_cxx(model, ds, pair_indices, device, seed: int):
    model.eval()
    rng = torch.Generator(device='cpu').manual_seed(seed)

    # Sample a subset for monitoring
    num_samples = len(ds)
    n_monitor = min(1000, num_samples)
    indices = torch.randint(0, num_samples, (n_monitor,), generator=rng)

    samples = torch.as_tensor(ds.values[indices], device=device).to(dtype=torch.float32)
    cond = torch.as_tensor(ds.system_params[indices], device=device, dtype=torch.float32)

    scorer = lambda v: model.log_score(v, cond)

    cxx_val, _ = compute_cxx(samples, pair_indices, scorer)
    model.train()
    return cxx_val

In [4]:
def train_step(model, optimizer, batch, aux_vars):
    # Time Data Loading (implicit since we are here)
    t0 = time.perf_counter()

    optimizer.zero_grad(set_to_none=True)

    # Time Forward
    loss, aux = model(batch, aux_vars)

    # Time Backward
    sync_device(next(model.parameters()).device)
    t_fwd_end = time.perf_counter()

    loss.backward()
    sync_device(next(model.parameters()).device)
    t_bwd_end = time.perf_counter()

    optimizer.step()
    sync_device(next(model.parameters()).device)
    t_opt_end = time.perf_counter()

    # Pack timings
    timings = {
        "forward_total": t_fwd_end - t0,
        "backward": t_bwd_end - t_fwd_end,
        "optimizer": t_opt_end - t_bwd_end,
        "sampling_internal": aux.get("time_sampling", 0.0)
    }
    return loss.detach(), timings

def train_one_epoch_profile(model, optimizer, loader, epoch_idx, l2_strength):
    model.train()
    total_loss = 0.0

    # Storage for timings
    history = {"forward_total": [], "backward": [], "optimizer": [], "sampling_internal": []}

    print(f"\n{'Batch':<6} | {'Loss':<8} | {'Fwd (s)':<8} | {'Samp (s)':<8} | {'Bwd (s)':<8} | {'Opt (s)':<8}")
    print("-" * 65)

    start_epoch = time.perf_counter()

    for i, batch in enumerate(loader):
        # Stop after 20 batches for quick profiling
        if i >= 20:
            print(">> Stopping early for profiling.")
            break

        aux_vars = {"l2_strength": l2_strength}
        loss, times = train_step(model, optimizer, batch, aux_vars)

        total_loss += float(loss)

        for k, v in times.items():
            history[k].append(v)

        # Log every 5 batches
        if (i + 1) % 5 == 0:
            print(f"{i+1:<6} | {float(loss):.4f}   | "
                  f"{times['forward_total']:.4f}   | {times['sampling_internal']:.4f}   | "
                  f"{times['backward']:.4f}   | {times['optimizer']:.4f}")

    avg_loss = total_loss / (i + 1)

    # Stats
    print("\n--- TIMING BREAKDOWN (Avg per batch) ---")
    print(f"Sampling (JIT) : {sum(history['sampling_internal'])/len(history['sampling_internal']):.4f} sec")
    print(f"Forward Total  : {sum(history['forward_total'])/len(history['forward_total']):.4f} sec")
    print(f"Backward       : {sum(history['backward'])/len(history['backward']):.4f} sec")
    print(f"Optimizer      : {sum(history['optimizer'])/len(history['optimizer']):.4f} sec")
    print("----------------------------------------")

    return avg_loss

In [7]:
# DATA LOADING

SIDE_LENGTH = 4
FILE_SAMPLES = 5_000_000
TRAIN_SAMPLES = 50_000  # beyond 100k per file it gets slow

# construct filenames dynamically based on support deltas
delta_support = [0.40, 0.60, 0.80, 0.90, 0.95, 1.00, 1.05, 1.10, 1.40, 2.00]
file_names = [f"xxz_{SIDE_LENGTH}x{SIDE_LENGTH}_delta{d:.2f}_{FILE_SAMPLES}.npz" for d in delta_support]
file_paths = [data_dir / fn for fn in file_names]
samples_per_file = [TRAIN_SAMPLES] * len(file_paths)

diag_indices = [k * (SIDE_LENGTH + 1) for k in range(SIDE_LENGTH)]
corr_pairs = list(zip(diag_indices, diag_indices[1:]))

print(f"System Size       : {SIDE_LENGTH}x{SIDE_LENGTH} ({SIDE_LENGTH**2} qubits)")
print(f"Training Samples  : {TRAIN_SAMPLES} per file (from {FILE_SAMPLES} total)")
print(f"Support Deltas    : {delta_support}")
print(f"Correlation Pairs : {corr_pairs} (main diagonal neighbors)")

ds = MeasurementDataset(file_paths, load_fn=load_measurements_npz,
                        system_param_keys=["delta"], samples_per_file=samples_per_file)

print(f"Samples Shape     : {tuple(ds.values.shape)}")
print(f"Conditions Shape  : {tuple(ds.system_params.shape)}")

System Size       : 4x4 (16 qubits)
Training Samples  : 50000 per file (from 5000000 total)
Support Deltas    : [0.4, 0.6, 0.8, 0.9, 0.95, 1.0, 1.05, 1.1, 1.4, 2.0]
Correlation Pairs : [(0, 5), (5, 10), (10, 15)] (main diagonal neighbors)
Samples Shape     : (500000, 16)
Conditions Shape  : (500000, 1)


In [6]:
# --- CONFIGURATION ---
batch_size        = 1024
num_visible       = ds.num_qubits
num_hidden        = 64
conditioner_width = 64

# Start with minimal steps to verify speed
k_steps           = 25
l2_strength       = 1e-4
lr                = 1e-2

SEED = 42
torch.manual_seed(SEED)

print(f"Running on: {device}")
print(f"Gibbs Steps (k): {k_steps}")

# Data Loader
loader_rng = torch.Generator(device="cpu").manual_seed(SEED)
loader = MeasurementLoader(dataset=ds, batch_size=batch_size, shuffle=True, drop_last=False, rng=loader_rng)

# Model
model = ConditionalRBM(num_visible=num_visible, num_hidden=num_hidden, cond_dim=ds.system_params.shape[1],
                       conditioner_width=conditioner_width, k=k_steps, T=1.0)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# --- RUN PROFILING ---
print("Starting Profiling Run (Max 20 batches)...")
train_one_epoch_profile(model, optimizer, loader, epoch_idx=0, l2_strength=l2_strength)

Running on: cpu
Gibbs Steps (k): 25
Starting Profiling Run (Max 20 batches)...

Batch  | Loss     | Fwd (s)  | Samp (s) | Bwd (s)  | Opt (s) 
-----------------------------------------------------------------
5      | 0.5218   | 0.3588   | 0.3380   | 0.0304   | 0.0008
10     | -0.0008   | 0.2982   | 0.2832   | 0.0212   | 0.0004
15     | -0.3683   | 0.2374   | 0.2274   | 0.0320   | 0.0006
20     | -0.9718   | 0.4291   | 0.3936   | 0.0173   | 0.0004
>> Stopping early for profiling.

--- TIMING BREAKDOWN (Avg per batch) ---
Sampling (JIT) : 0.3022 sec
Forward Total  : 0.3191 sec
Backward       : 0.0248 sec
Optimizer      : 0.0009 sec
----------------------------------------


0.0014516938300359818