In [2]:
import sys
from pathlib import Path

# Add parent directory (case_studies/) to sys.path to discover restricted_boltzmann_machine
sys.path.append(str(Path.cwd().parent))

In [3]:
from pathlib import Path
import pickle, itertools, math, collections
import numpy as np
import jax, jax.numpy as jnp
import optax
from jax.random import PRNGKey

from rbm.rbm import RBM
from rbm.pcd_trainer import RBMTrainState, train_rbm
from rbm.cosine_annealing_sampler import get_cosine_schedule

DATA_DIR   = Path("./data")
MODELS_DIR = Path("./models"); MODELS_DIR.mkdir(exist_ok=True)

FILE = DATA_DIR / "w_vanilla_20_20000.txt"
assert FILE.exists(), f"{FILE} missing – run the generator first."

print("JAX devices:", jax.devices())

JAX devices: [CpuDevice(id=0)]


In [4]:
# --- helper to map 'Z','z'  -> 1,0 -------------------------------------------
def to_binary(line: str) -> np.ndarray:
    return np.fromiter((c == "Z" for c in line.strip()), dtype=np.float32, count=20)

with open(FILE) as f:
    data = np.stack([to_binary(l) for l in f])

N_TOTAL, N_VIS = data.shape
print(f"Loaded {N_TOTAL} samples of length {N_VIS}.")

# quick inspection: show first 5 strings and site-occupancy histogram ----------
print("First 5 samples (as 0/1 arrays):\n", data[:5])

occ = data.sum(0)            # how often each spin is 'up'
print("Site-occupancy counts (should all be ≈ N_TOTAL / N_VIS):")
print(occ.astype(int))

Loaded 20000 samples of length 20.
First 5 samples (as 0/1 arrays):
 [[0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
Site-occupancy counts (should all be ≈ N_TOTAL / N_VIS):
[ 975  989 1003 1021  991 1017  997  985 1021 1005  998 1000  989 1029
  985  965 1026 1030  972 1002]


In [10]:
def make_loader(arr: np.ndarray, batch_size: int, rng_key):
    """Yield shuffled mini‐batches of shape (batch_size, N_VIS) as (data, None)."""
    idx = jax.random.permutation(rng_key, len(arr))
    arr = arr[idx]
    for i in range(0, len(arr), batch_size):
        batch = arr[i:i+batch_size]
        yield (batch, None)

In [11]:
# ----- global hyper-parameters ------------------------------------------------
N_HID        = 20          # M = N
K_STEPS      = 1           # PCD-1
BATCH_SIZE   = 128
LR           = 1e-3
N_EPOCHS     = 40
PCD_RESET    = 75
WEIGHT_DECAY = 1e-5
LR_DECAY     = 0.95

def train_rbm_subset(subset_size: int, rng_key):
    """Return (state, rng_key) after training on a random subset."""
    sk1, sk2 = jax.random.split(rng_key)
    subset   = jax.random.choice(sk1, data, (subset_size,), replace=False)
    loader   = list(make_loader(subset, BATCH_SIZE, sk2))

    # ---- scheduler & optimiser
    lr_sched = optax.exponential_decay(LR, len(loader),
                                       decay_rate=LR_DECAY, staircase=True)

    # optax.adam() has no weight_decay kwarg → apply L2 via add_decayed_weights
    opt      = optax.chain(
        optax.add_decayed_weights(WEIGHT_DECAY),
        optax.adam(lr_sched)
    )

    rbm      = RBM(n_visible=N_VIS, n_hidden=N_HID, k=K_STEPS)
    dummy    = jnp.ones((BATCH_SIZE, N_VIS), dtype=jnp.float32)
    rng_key, sk = jax.random.split(rng_key)
    v_pers   = jax.random.bernoulli(sk, 0.5, shape=dummy.shape).astype(jnp.float32)
    params   = rbm.init(rng_key, dummy, v_pers, rng_key)["params"]

    state    = RBMTrainState.create(apply_fn=rbm.apply, params=params, tx=opt)

    state, metrics, rng_key = train_rbm(state,
                                        loader,
                                        N_EPOCHS,
                                        rng_key,
                                        pcd_reset=PCD_RESET,
                                        scheduler=lr_sched)
    return state, rng_key

In [12]:
rng = PRNGKey(0)

states = {}
for ns in (50, 1_000, 20_000):
    print(f"\n=== Training on N_s = {ns} ===")
    state, rng = train_rbm_subset(ns, rng)
    states[ns] = state
print("\nFinished all trainings.")


=== Training on N_s = 50 ===
Epoch [1/40] – FE-loss: 0.0196
Epoch [2/40] – FE-loss: -0.0958
Epoch [3/40] – FE-loss: -0.1614
Epoch [4/40] – FE-loss: -0.2887
Epoch [5/40] – FE-loss: -0.3620
Epoch [6/40] – FE-loss: -0.4129
Epoch [7/40] – FE-loss: -0.5079
Epoch [8/40] – FE-loss: -0.5443
Epoch [9/40] – FE-loss: -0.6414
Epoch [10/40] – FE-loss: -0.6745
Epoch [11/40] – FE-loss: -0.7724
Epoch [12/40] – FE-loss: -0.8179
Epoch [13/40] – FE-loss: -0.7842
Epoch [14/40] – FE-loss: -0.8428
Epoch [15/40] – FE-loss: -0.8816
Epoch [16/40] – FE-loss: -1.0139
Epoch [17/40] – FE-loss: -1.0035
Epoch [18/40] – FE-loss: -0.9721
Epoch [19/40] – FE-loss: -1.0556
Epoch [20/40] – FE-loss: -1.1099
Epoch [21/40] – FE-loss: -1.0768
Epoch [22/40] – FE-loss: -1.1358
Epoch [23/40] – FE-loss: -1.1823
Epoch [24/40] – FE-loss: -1.2042
Epoch [25/40] – FE-loss: -1.2695
Epoch [26/40] – FE-loss: -1.2569
Epoch [27/40] – FE-loss: -1.2001
Epoch [28/40] – FE-loss: -1.2535
Epoch [29/40] – FE-loss: -1.2973
Epoch [30/40] – FE-loss

In [13]:
def save_state(state, subset_size: int):
    name = f"rbm_w_vanilla_20_{subset_size}.pkl"
    path = MODELS_DIR / name
    with open(path, "wb") as f:
        pickle.dump({"params": state.params,
                     "visible": N_VIS,
                     "hidden": N_HID}, f)
    print("✓ saved", path)

for ns, st in states.items():
    save_state(st, ns)

✓ saved models/rbm_w_vanilla_20_50.pkl
✓ saved models/rbm_w_vanilla_20_1000.pkl
✓ saved models/rbm_w_vanilla_20_20000.pkl
