Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add minimal pymc example #7281

Closed
wants to merge 5 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,140 @@ Features
* Simple extensibility
- Transparent support for missing value imputation


Linear Regression Example
==========================


Plant growth can be influenced by multiple factors, and understanding these relationships is crucial for optimizing agricultural practices.

Imagine we conduct an experiment to predict the growth of a plant based on different environmental variables.

.. code-block:: python

import pymc as pm

# Taking draws from a normal distribution
seed = 42
x_dist = pm.Normal.dist(shape=(100, 3))
x_data = pm.draw(x_dist, random_seed=seed)

# Independent Variables:
# Sunlight Hours: Number of hours the plant is exposed to sunlight daily.
# Water Amount: Daily water amount given to the plant (in milliliters).
# Soil Nitrogen Content: Percentage of nitrogen content in the soil.


# Dependent Variable:
# Plant Growth (y): Measured as the increase in plant height (in centimeters) over a certain period.


# Define coordinate values for all dimensions of the data
coords={
"trial": range(100),
"features": ["sunlight hours", "water amount", "soil nitrogen"],
}

# Define generative model
with pm.Model(coords=coords) as generative_model:
x = pm.Data("x", x_data, dims=["trial", "features"])

# Model parameters
betas = pm.Normal("betas", dims="features")
sigma = pm.HalfNormal("sigma")

# Linear model
mu = x @ betas

# Likelihood
# Assuming we measure deviation of each plant from baseline
plant_growth = pm.Normal("plant growth", mu, sigma, dims="trial")


# Generating data from model by fixing parameters
fixed_parameters = {
"betas": [5, 20, 2],
"sigma": 0.5,
}
with pm.do(generative_model, fixed_parameters) as synthetic_model:
idata = pm.sample_prior_predictive(random_seed=seed) # Sample from prior predictive distribution.
synthetic_y = idata.prior["plant growth (z-scored)"].sel(draw=0, chain=0)


# Infer parameters conditioned on observed data
with pm.observe(generative_model, {"plant growth (z-scored)": synthetic_y}) as inference_model:
idata = pm.sample(random_seed=seed)

summary = pm.stats.summary(idata, var_names=["betas", "sigma"]))
print(summary)


From the summary, we can see that the mean of the inferred parameters are very close to the fixed parameters

===================== ====== ===== ======== ========= =========== ========= ========== ========== =======
Params mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
===================== ====== ===== ======== ========= =========== ========= ========== ========== =======
betas[sunlight hours] 4.972 0.054 4.866 5.066 0.001 0.001 3003 1257 1
betas[water amount] 19.963 0.051 19.872 20.062 0.001 0.001 3112 1658 1
betas[soil nitrogen] 1.994 0.055 1.899 2.107 0.001 0.001 3221 1559 1
sigma 0.511 0.037 0.438 0.575 0.001 0 2945 1522 1
===================== ====== ===== ======== ========= =========== ========= ========== ========== =======

.. code-block:: python

# Simulate new data conditioned on inferred parameters
new_x_data = pm.draw(
pm.Normal.dist(shape=(3, 3)),
random_seed=seed,
)
new_coords = coords | {"trial": [0, 1, 2]}

with inference_model:
pm.set_data({"x": new_x_data}, coords=new_coords)
idata = pm.sample_posterior_predictive(
idata,
predictions=True,
extend_inferencedata=True,
random_seed=seed,
)

pm.stats.summary(idata.predictions, kind="stats")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we show all the summaries outputs? Why only the first?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for starters it's TMI and can scare people off. Convergence diagnostics is more advanced than what we want to demo here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I thought you meant more columns, but you meant more rows?

Copy link
Member

@ricardoV94 ricardoV94 May 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean't every time the have pm.stats.summary, we should show the output. I already removed the extra convergence columns with kind="stats". Right now it's only showing for the first usage


The new data conditioned on inferred parameters would look like:

========================== ====== ===== ======== =========
Output mean sd hdi_3% hdi_97%
========================== ====== ===== ======== =========
plant growth (z-scored)[0] 14.21 0.509 13.232 15.144
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is still the old name ("z-scored").

plant growth (z-scored)[1] 24.43 0.518 23.347 25.32
plant growth (z-scored)[2] -6.743 0.515 -7.778 -5.834
========================== ====== ===== ======== =========

.. code-block:: python

# Simulate new data, under a scenario where the first beta is zero
with pm.do(
inference_model,
{inference_model["betas"]: inference_model["betas"] * [0, 1, 1]},
) as plant_growth_model:
new_predictions = pm.sample_posterior_predictive(
idata,
predictions=True,
random_seed=seed,
)

pm.stats.summary(new_predictions, kind="stats")

The new data, under the above scenario would look like:

========================== ====== ===== ======== =========
Output mean sd hdi_3% hdi_97%
========================== ====== ===== ======== =========
plant growth (z-scored)[0] 14.153 0.509 13.181 15.096
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also needs updated name.

plant growth (z-scored)[1] 23.85 0.517 22.915 24.878
plant growth (z-scored)[2] -7.302 0.515 -8.315 -6.374
========================== ====== ===== ======== =========

Getting started
===============

Expand Down
Loading