-
Notifications
You must be signed in to change notification settings - Fork 22
Open
Description
I've run into an issue where sampling a (large) straightforward model results in unexpectedly large memory usage compared to numpyro sampled via PyMC.
Base code:
import numpy as np
import pymc as pm
from scipy.special import expit
N_PERSONS = 60_000
N_ITEMS = 1_000
N_ITEMS_PER_PERSON = 20
RPG = np.random.default_rng(12_1_2025)
theta_known = RPG.normal(0, 1, N_PERSONS)
beta_known = RPG.normal(0, 1, N_ITEMS)
person_idx = np.repeat(np.arange(N_PERSONS), N_ITEMS_PER_PERSON)
item_idx = np.arange(N_ITEMS)
item_idx = np.array(
[
RPG.choice(item_idx, size=N_ITEMS_PER_PERSON, replace=False)
for i in np.arange(N_PERSONS)
]
).flatten()
eta = theta_known[person_idx] - beta_known[item_idx]
score = (RPG.random(N_PERSONS * N_ITEMS_PER_PERSON) < expit(eta)) + 0
# ==================================================================
# Baseline Rasch Model
# ==================================================================
def build_rasch(
person_id: np.ndarray,
item_id: np.ndarray,
Y: np.ndarray,
theta_fixed: np.ndarray,
coords: dict = None,
) -> pm.Model:
N_obs = len(Y)
N_persons = len(np.unique(person_id))
N_items = len(np.unique(item_id))
if coords is None:
coords = {
"obs": np.arange(N_obs),
"person": np.arange(N_persons),
"item": np.arange(N_items),
}
with pm.Model(coords=coords) as model:
person_idx = pm.Data("person_idx", person_id, dims="obs")
item_idx = pm.Data("item_idx", item_id, dims="obs")
Y_data = pm.Data("Y_data", Y, dims="obs")
theta = pm.Data("theta_fixed", theta_fixed, dims="person")
beta = pm.Normal("beta", mu=0, sigma=2, dims="item")
theta_obs = theta[person_idx]
beta_obs = beta[item_idx]
logit_p = pm.Deterministic("logit_p", theta_obs - beta_obs, dims="obs")
y_obs = pm.Bernoulli("y_obs", logit_p=logit_p, observed=Y_data, dims="obs")
return model
rasch_model = build_rasch(
person_id=person_idx,
item_id=item_idx,
Y=score,
theta_fixed=theta_known,
)Using numpyro to sample:
with rasch_model:
t_rasch = pm.sample(
nuts_sampler="numpyro",
chains=4,
draws=300,
tune=500,
progress_bar=True,
nuts_sampler_kwargs={"chain_method": "vectorized"},
)Yields the following stats:
Mem:
Before sampling: ~3.9GB
During sampling: ~5.7GB
After sampling: ~16gb
After restarting the notebook and clearing memory, running nutpie to sample:
with rasch_model:
t_rasch = pm.sample(
nuts_sampler="nutpie",
chains=4,
draws=300,
tune=500,
progress_bar=True,
nuts_sampler_kwargs=dict(backend="jax", gradient_backend="jax"),
)Yields the following stats:
Mem:
Before sampling: ~3.9GB
During sampling: ~63GB
After sampling: ~34gb
Notes:
Sampling is conducted on GPU (RTX 5000 Ada) via WSL. Running nutpie version 0.16.3
Let me know if you're able to replicate the issue or need additional info.
Thanks,
Metadata
Metadata
Assignees
Labels
No labels