Skip to content

High Memory Usage During Sampling #265

@Wmuntean

Description

@Wmuntean

I've run into an issue where sampling a (large) straightforward model results in unexpectedly large memory usage compared to numpyro sampled via PyMC.

Base code:

import numpy as np
import pymc as pm
from scipy.special import expit

N_PERSONS = 60_000
N_ITEMS = 1_000
N_ITEMS_PER_PERSON = 20
RPG = np.random.default_rng(12_1_2025)

theta_known = RPG.normal(0, 1, N_PERSONS)
beta_known = RPG.normal(0, 1, N_ITEMS)
person_idx = np.repeat(np.arange(N_PERSONS), N_ITEMS_PER_PERSON)

item_idx = np.arange(N_ITEMS)
item_idx = np.array(
    [
        RPG.choice(item_idx, size=N_ITEMS_PER_PERSON, replace=False)
        for i in np.arange(N_PERSONS)
    ]
).flatten()

eta = theta_known[person_idx] - beta_known[item_idx]
score = (RPG.random(N_PERSONS * N_ITEMS_PER_PERSON) < expit(eta)) + 0


# ==================================================================
# Baseline Rasch Model
# ==================================================================
def build_rasch(
    person_id: np.ndarray,
    item_id: np.ndarray,
    Y: np.ndarray,
    theta_fixed: np.ndarray,
    coords: dict = None,
) -> pm.Model:

    N_obs = len(Y)
    N_persons = len(np.unique(person_id))
    N_items = len(np.unique(item_id))

    if coords is None:
        coords = {
            "obs": np.arange(N_obs),
            "person": np.arange(N_persons),
            "item": np.arange(N_items),
        }

    with pm.Model(coords=coords) as model:
        person_idx = pm.Data("person_idx", person_id, dims="obs")
        item_idx = pm.Data("item_idx", item_id, dims="obs")
        Y_data = pm.Data("Y_data", Y, dims="obs")
        theta = pm.Data("theta_fixed", theta_fixed, dims="person")

        beta = pm.Normal("beta", mu=0, sigma=2, dims="item")

        theta_obs = theta[person_idx]
        beta_obs = beta[item_idx]
        logit_p = pm.Deterministic("logit_p", theta_obs - beta_obs, dims="obs")
        y_obs = pm.Bernoulli("y_obs", logit_p=logit_p, observed=Y_data, dims="obs")
    return model

rasch_model = build_rasch(
    person_id=person_idx,
    item_id=item_idx,
    Y=score,
    theta_fixed=theta_known,
)

Using numpyro to sample:

with rasch_model:
    t_rasch = pm.sample(
        nuts_sampler="numpyro",
        chains=4,
        draws=300,
        tune=500,
        progress_bar=True,
        nuts_sampler_kwargs={"chain_method": "vectorized"},
    )

Yields the following stats:
Mem:
Before sampling: ~3.9GB
During sampling: ~5.7GB
After sampling: ~16gb

After restarting the notebook and clearing memory, running nutpie to sample:

with rasch_model:
    t_rasch = pm.sample(
        nuts_sampler="nutpie",
        chains=4,
        draws=300,
        tune=500,
        progress_bar=True,
        nuts_sampler_kwargs=dict(backend="jax", gradient_backend="jax"),
    )

Yields the following stats:
Mem:
Before sampling: ~3.9GB
During sampling: ~63GB
After sampling: ~34gb

Notes:
Sampling is conducted on GPU (RTX 5000 Ada) via WSL. Running nutpie version 0.16.3

Let me know if you're able to replicate the issue or need additional info.

Thanks,

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions