In [None]:
def linear_mean_variance_hurdle_gamma_model(
    x: np.ndarray,
    y_pred: np.ndarray,
    y: np.ndarray
) -> pm.Model:
    
    x_s = (x - x.mean()) / x.std()
    y_pred_s = (y_pred - y_pred.mean()) / y_pred.std()

    with pm.Model() as model:
        x_data = pm.Data("x_data", x)
        y_data = pm.Data('y_data', y)
        y_pred_data = pm.Data("y_pred_data", y_pred)

        # P(Y > 0)
        intercept_p = pm.Normal("intercept_p", 0, 0.1)
        beta_p_y = pm.Normal("beta_p_y", 0, 0.1)
        beta_p_x = pm.Normal('beta_p_x', 0, 0.1)
        p_nonzero = pm.math.invlogit(intercept_p + beta_p_y * y_pred_data + beta_p_x * x_data)

        # Mu
        intercept_mu = pm.Normal("intercept_mu", 0, 0.1)
        beta_mu = pm.Normal("beta_mu", 0, 0.1)
        y_mu = 1e-6 + pm.math.exp(intercept_mu + beta_mu * y_pred_data)

        # Sigma
        intercept_sigma = pm.Normal("intercept_sigma", 0, 0.1)
        beta_sigma = pm.Normal("beta_sigma", 0, 0.1)
        y_sigma = 1e-6 + pm.math.exp(intercept_sigma + beta_sigma * x_data)

        y_obs = pm.HurdleGamma(
            "y_obs",
            psi=p_nonzero,
            mu=y_mu,
            sigma=y_sigma,
            observed=y_data,
        )

    return model

In [None]:
x = model_data['days_out'].to_numpy()
y_pred = model_data['precip_pred'].to_numpy()
y = model_data['precip_obs'].to_numpy()

weather_gamma_model = linear_mean_variance_hurdle_gamma_model(x, y_pred, y)

weather_gamma_model.debug()

# with weather_gamma_model:
#     idata_gamma = pm.sample()

In [None]:
with weather_gamma_model:
    idata_gamma = pm.sample()

In [None]:
import numpy as np
from scipy.stats import gamma
from scipy.special import expit
import arviz as az

class HurdleGammaDistribution:
    def __init__(self, psi, alpha, theta):
        # psi, alpha, theta are numpy arrays (broadcastable)
        self.psi = np.asarray(psi)
        self.alpha = np.asarray(alpha)
        self.theta = np.asarray(theta)

    def pdf(self, x):
        x = np.asarray(x)
        out = np.zeros_like(x, dtype=float)
        # mass at zero
        is_zero = x == 0
        out[is_zero] = (1 - self.psi) if np.ndim(self.psi) == 0 else (1 - self.psi)
        # continuous gamma part for x > 0
        pos = x > 0
        if np.any(pos):
            # broadcast parameters to x[pos].shape
            a = np.broadcast_to(self.alpha, x[pos].shape)
            scale = np.broadcast_to(self.theta, x[pos].shape)
            out[pos] = self.psi * gamma.pdf(x[pos], a=a, scale=scale)
        return out

    def rvs(self, size=None, random_state=None):
        rng = np.random.default_rng(random_state)
        # determine output shape
        if size is None:
            # draw one sample per parameter location
            shape = np.broadcast(self.psi, self.alpha, self.theta).shape
        else:
            shape = tuple(size) if np.isscalar(size) else tuple(size)
        # draw bernoulli mask
        psi_b = np.broadcast_to(self.psi, shape)
        mask = rng.random(size=shape) < psi_b
        samples = np.zeros(shape, dtype=float)
        if np.any(mask):
            a = np.broadcast_to(self.alpha, shape)[mask]
            scale = np.broadcast_to(self.theta, shape)[mask]
            samples[mask] = gamma.rvs(a=a, scale=scale, random_state=rng, size=a.shape)
        return samples

def linear_hurdle_gamma_distribution(idata, y_pred, x):
    """
    Build a hurdle-gamma predicted distribution using posterior means from idata.
    y_pred and x may be scalars or arrays (will be broadcasted).
    Returns a HurdleGammaDistribution instance with methods pdf(...) and rvs(...).
    """
    params = az.summary(
        idata,
        var_names=[
            "intercept_p", "beta_p_y", "beta_p_x",
            "intercept_mu", "beta_mu",
            "intercept_sigma", "beta_sigma",
        ],
    )["mean"]

    intercept_p = float(params["intercept_p"])
    beta_p_y = float(params["beta_p_y"])
    beta_p_x = float(params.get("beta_p_x", 0.0))

    intercept_mu = float(params["intercept_mu"])
    beta_mu = float(params["beta_mu"])

    intercept_sigma = float(params["intercept_sigma"])
    beta_sigma = float(params["beta_sigma"])

    y_pred = np.asarray(y_pred)
    x = np.asarray(x)

    lin_p = intercept_p + beta_p_y * y_pred + beta_p_x * x
    psi = expit(lin_p)

    eta_mu = intercept_mu + beta_mu * y_pred
    mu = 1e-6 + np.exp(eta_mu)

    eta_sigma = intercept_sigma + beta_sigma * x
    sigma = 1e-6 + np.exp(eta_sigma)

    alpha = (mu / sigma) ** 2
    theta = (sigma ** 2) / mu

    return HurdleGammaDistribution(psi=psi, alpha=alpha, theta=theta)

In [None]:
def predict_hurdle_gamma(
    idata,
    x_new: np.ndarray,
    y_pred_new: np.ndarray,
    x_train: np.ndarray,
    y_pred_train: np.ndarray,
    n_draws: int | None = None,
    random_seed: int | None = None,
):
    """
    Manual posterior predictive for the linear_mean_variance_hurdle_gamma_model.
    """
    rng = np.random.default_rng(random_seed)
    post = idata.posterior

    def _flatten(name):
        v = post[name].values  # (chain, draw, ...)
        return v.reshape(-1, *v.shape[2:])

    intercept_p = _flatten("intercept_p")       # (S,)
    beta_p_y = _flatten("beta_p_y")
    # beta_p_x = _flatten("beta_p_x")
    intercept_mu = _flatten("intercept_mu")
    beta_mu = _flatten("beta_mu")
    intercept_sigma = _flatten("intercept_sigma")
    beta_sigma = _flatten("beta_sigma")

    S_total = intercept_p.shape[0]
    if n_draws is not None and n_draws < S_total:
        idx = rng.choice(S_total, size=n_draws, replace=False)
        intercept_p = intercept_p[idx]
        beta_p_y = beta_p_y[idx]
        # beta_p_x = beta_p_x[idx]
        intercept_mu = intercept_mu[idx]
        beta_mu = beta_mu[idx]
        intercept_sigma = intercept_sigma[idx]
        beta_sigma = beta_sigma[idx]

    x_mean = x_train.mean()
    x_sd = x_train.std()
    y_pred_mean = y_pred_train.mean()
    y_pred_sd = y_pred_train.std()

    x_s = x_new#(x_new - x_mean) / x_sd
    y_pred_s = y_pred_new# (y_pred_new - y_pred_mean) / y_pred_sd

    x_s = np.atleast_1d(x_s)[None, :]         # (1, N_new)
    y_pred_s = np.atleast_1d(y_pred_s)[None, :]

    intercept_p = intercept_p[:, None]        # (S, 1)
    beta_p_y = beta_p_y[:, None]
    # beta_p_x = beta_p_x[:, None]
    intercept_mu = intercept_mu[:, None]
    beta_mu = beta_mu[:, None]
    intercept_sigma = intercept_sigma[:, None]
    beta_sigma = beta_sigma[:, None]

    lin_p = intercept_p + beta_p_y * y_pred_s # + beta_p_x * x_s
    psi = expit(lin_p)

    eta_mu = intercept_mu + beta_mu * y_pred_s
    mu = 1e-6 + np.exp(eta_mu)

    eta_sigma = intercept_sigma + beta_sigma * x_s
    sigma = 1e-6 + np.exp(eta_sigma)

    alpha = (mu / sigma) ** 2
    theta = (sigma**2) / mu

    bern = rng.binomial(1, psi)
    gamma_samples = rng.gamma(shape=alpha, scale=theta)

    y_ppc = bern * gamma_samples

    return y_ppc, psi, mu, sigma


In [None]:
max_pred_days_out = 3
y_pred_new = np.repeat(5.0, max_pred_days_out)
x_new = np.arange(max_pred_days_out)

fig, axes = plt.subplots(1, len(x_new), figsize=(12, 4))
for idx, (x_val, y_pred_val) in enumerate(zip(x_new, y_pred_new)):
    plt.sca(axes[idx])
    plot_observed_vs_predicted(
        idata_gamma,
        model_data,
        y_pred=float(y_pred_val),
        x=int(x_val),
        bw=2.5,
        title=f"Days Out: {x_val}, Predicted Precipitation: {y_pred_val} mm",
        dist_func=linear_hurdle_gamma_distribution
    )

# align y-axis across subplots
ymax = max(ax.get_ylim()[1] for ax in axes)
for ax in axes:
    ax.set_ylim(0, ymax)

plt.tight_layout()

In [None]:
max_pred_days_out = 3
y_pred_new = np.repeat(5.0, max_pred_days_out)
x_new = np.arange(max_pred_days_out)

y_ppc, psi, mu, sigma = predict_hurdle_gamma(
    idata_gamma,
    x_new=x_new,
    y_pred_new=y_pred_new,
    x_train=x,
    y_pred_train=y_pred,
)

In [None]:
full_data = build_pred_actual_precip_df(
    model_data=model_data,
    y_ppc=y_ppc,
    x_new=x_new,
    pred_min=3,
    pred_max=7
)

In [None]:
fig, axes = plt.subplots(1, len(x_new), figsize=(12, 5))
for days_out in x_new:
    plot_model_vs_actual_distributions(full_data, days_out, axes[days_out])
    axes[days_out].set_title(f'{days_out} Days Out')