In [1]:
import time
import numpy as np
from math import lgamma
import statsmodels.api as sm
from statsmodels.gam.api import GLMGam, BSplines
from statsmodels.genmod.families import Poisson, Binomial, NegativeBinomial
from sklearn.base import BaseEstimator, RegressorMixin
from tqdm import tqdm


class ZIGAMEstimator(BaseEstimator, RegressorMixin):
    def __init__(self, dist='nb', df=None, degree=3, alpha=None, max_iter=100, tol=1e-6):
        """
        Zero-Inflated Generalized Additive Model (ZI-GAM) Estimator.

        Parameters
        ----------
        dist : str
            Distribution for the count process: 'poisson' or 'nb' (negative binomial).
            Negative binomial ('nb') is the default.
        df : int or list, optional
            Degrees of freedom (number of spline basis functions) per feature.
            If None, a default is chosen based on sample size.
        degree : int or list, optional
            Degree(s) of the B-spline basis (default is 3 for cubic splines).
        alpha : float or list, optional
            Smoothing penalty for the spline terms. If None, a default value is used.
        max_iter : int
            Maximum number of EM iterations.
        tol : float
            Convergence tolerance on the change in log-likelihood.
        """
        self.dist = dist.lower()
        self.df = df
        self.degree = degree
        self.alpha = alpha
        self.max_iter = max_iter
        self.tol = tol

        # Attributes set during fitting
        self._bs_count = None      # Spline basis for count component
        self._bs_infl = None       # Spline basis for inflation component
        self.logistic_res_ = None  # Fitted GLMGam result for logistic (structural zero) model
        self.count_res_ = None     # Fitted GLMGam result for count model
        self.dispersion_ = None    # Fitted dispersion parameter (only for NB)
        self.n_iter_ = 0           # Number of EM iterations performed
        self.loglikelihood_ = None # Final log-likelihood value
        self._r = None             # Internal storage of responsibilities (P(structural zero))

    def _initialize_basis(self, X):
        """Create B-spline basis objects for both count and inflation components."""
        X = np.asarray(X)
        n_samples, n_features = X.shape

        # Determine degrees of freedom per feature
        if self.df is None:
            df = [min(10, max(4, n_samples // 5)) for _ in range(n_features)]
        elif isinstance(self.df, (int, float)):
            df = [int(self.df)] * n_features
        else:
            df = list(self.df)
            if len(df) != n_features:
                raise ValueError("Length of df list must equal number of features.")

        # Determine spline degree per feature
        if isinstance(self.degree, (int, float)):
            degree = [int(self.degree)] * n_features
        else:
            degree = list(self.degree)
            if len(degree) != n_features:
                raise ValueError("Length of degree list must equal number of features.")

        # Create BSplines objects (both components use the same basis)
        self._bs_count = BSplines(X, df=df, degree=degree, include_intercept=False)
        self._bs_infl = BSplines(X, df=df, degree=degree, include_intercept=False)

    def _initialize_parameters(self, X, y):
        """
        Perform an initial fit to obtain a starting value for the EM algorithm.

        This involves:
          - Fitting an initial Poisson GAM (ignoring zero inflation) to obtain a baseline μ.
          - Estimating the overall excess zero rate.
          - Initializing the responsibilities r (P(structural zero | y = 0)).
        """
        n_samples = len(y)
        init_exog = sm.add_constant(np.zeros((n_samples, 1)))

        try:
            init_model = GLMGam(
                y,
                exog=init_exog,
                smoother=self._bs_count,
                alpha=np.zeros(len(self._bs_count.penalty_matrices)),
                family=Poisson()
            )
        except Exception:
            init_model = GLMGam(
                y,
                exog=init_exog,
                smoother=self._bs_count,
                alpha=np.full(len(self._bs_count.penalty_matrices), 1e-6),
                family=Poisson()
            )

        init_res = init_model.fit()
        mu_init = init_res.predict()
        p_zero_pois = np.exp(-mu_init)
        p_zero_obs = np.mean(y == 0)
        avg_p_zero_pois = np.mean(p_zero_pois)

        if p_zero_obs > avg_p_zero_pois:
            pi_init = (p_zero_obs - avg_p_zero_pois) / (1 - avg_p_zero_pois)
        else:
            pi_init = 0.0

        pi_init = np.clip(pi_init, 0.0, 0.99)
        r = np.zeros(n_samples)
        r[y == 0] = pi_init
        r[y > 0] = 0.0
        self._r = r.copy()
        return mu_init

    def _e_step(self, X, y):
        """
        E-step: Update responsibilities for zero observations.

        For each y_i == 0, compute:
            r_i = π(X_i) / [π(X_i) + (1-π(X_i)) * P_count(0 | μ(X_i))].
        """
        n_samples = len(y)
        pi_pred = self.logistic_res_.predict()
        mu_pred = self.count_res_.predict()

        if self.dist.startswith('nb'):
            p0_count = np.power(1.0 + self.dispersion_ * mu_pred, -1.0 / self.dispersion_)
        else:
            p0_count = np.exp(-mu_pred)

        r_new = np.zeros(n_samples)

        for i in range(n_samples):
            if y[i] == 0:
                denom = pi_pred[i] + (1 - pi_pred[i]) * p0_count[i]
                r_new[i] = pi_pred[i] / denom if denom > 0 else 0.0
            else:
                r_new[i] = 0.0
        self._r = r_new.copy()

    def _fit_logistic(self, X, r):
        """
        M-step: Fit the logistic GAM for the structural zero (inflation) component.

        This uses a Binomial family with a logit link.
        """
        n_samples = len(r)
        Exog = sm.add_constant(np.zeros((n_samples, 1)))
        logistic_model = GLMGam(
            r,
            exog=Exog,
            smoother=self._bs_infl,
            alpha=np.ones(len(self._bs_infl.penalty_matrices)) * 1.0,
            family=Binomial()
        )
        logistic_res = logistic_model.fit()
        return logistic_res

    def _fit_count(self, X, y, r):
        """
        M-step: Fit the GAM for the count component using weights (1 - r).

        For negative binomial, the current dispersion parameter is used.
        """
        n_samples = len(y)
        Exog = sm.add_constant(np.zeros((n_samples, 1)))
        weights = 1 - r

        if self.dist.startswith('nb'):
            nb_family = NegativeBinomial(alpha=self.dispersion_ if self.dispersion_ is not None else 1.0)
            count_model = GLMGam(
                y,
                exog=Exog,
                smoother=self._bs_count,
                alpha=np.ones(len(self._bs_count.penalty_matrices)) * 1.0,
                family=nb_family
            )
        else:
            count_model = GLMGam(
                y,
                exog=Exog,
                smoother=self._bs_count,
                alpha=np.ones(len(self._bs_count.penalty_matrices)) * 1.0,
                family=Poisson()
            )

        count_res = count_model.fit(weights=weights)
        return count_res

    def _update_dispersion(self, y, mu_pred, weights):
        """
        Update the dispersion parameter for the negative binomial distribution using a Pearson estimator.
        """
        resid = y - mu_pred
        num = np.sum(weights * ((resid**2 / np.maximum(mu_pred, 1e-8)) - 1))
        den = np.sum(weights * mu_pred)
        new_dispersion = max(1e-8, num / np.maximum(den, 1e-8))
        return new_dispersion

    def _compute_loglik(self, y, logistic_res, count_res, dispersion):
        """
        Compute the log-likelihood for the mixture model.
        """
        n_samples = len(y)
        pi_pred = logistic_res.predict()
        mu_pred = count_res.predict()
        log_pmf = np.zeros(n_samples)

        if self.dist.startswith('nb'):
            size = 1.0 / dispersion
            for i in range(n_samples):
                y_i = int(y[i])
                if y_i == 0:
                    log_pmf[i] = size * np.log(size / (size + mu_pred[i]))
                else:
                    log_comb = lgamma(y_i + size) - lgamma(size) - lgamma(y_i + 1)
                    log_p = size * np.log(size / (size + mu_pred[i]))
                    log_q = y_i * np.log(mu_pred[i] / (size + mu_pred[i]))
                    log_pmf[i] = log_comb + log_p + log_q
        else:
            for i in range(n_samples):
                y_i = int(y[i])
                log_pmf[i] = -mu_pred[i] + y_i * (np.log(mu_pred[i]) if mu_pred[i] > 0 else -np.inf) - lgamma(y_i + 1)

        pmf = np.exp(log_pmf)
        lik = np.where(y == 0, pi_pred + (1 - pi_pred) * pmf, (1 - pi_pred) * pmf)
        loglik = np.sum(np.log(np.clip(lik, 1e-12, None)))
        return loglik

    def fit(self, X, y):
        """
        Fit the Zero-Inflated GAM model using an EM algorithm.

        Parameters
        ----------
        X : array-like, shape (n_samples, n_features)
            Predictor variables.
        y : array-like, shape (n_samples,)
            Count response variable.

        Returns
        -------
        self : object
            The fitted model.
        """
        X = np.asarray(X)
        y = np.asarray(y)
        n_samples, n_features = X.shape

        # Initialize spline bases
        self._initialize_basis(X)
        # Initialize parameters and responsibilities
        self._initialize_parameters(X, y)
        # Initialize dispersion if using Negative Binomial
        if self.dist.startswith('nb'):
            self.dispersion_ = 1.0
        else:
            self.dispersion_ = 0.0

        loglik_old = -np.inf
        start_time = time.time()
        print("Starting EM iterations...")
        pbar = tqdm(total=self.max_iter, desc="EM iterations", ncols=80)

        for iteration in range(self.max_iter):
            # M-step: fit logistic and count GAMs
            self.logistic_res_ = self._fit_logistic(X, self._r)
            self.count_res_ = self._fit_count(X, y, self._r)
            # E-step: update responsibilities
            self._e_step(X, y)
            # Update dispersion for NB
            if self.dist.startswith('nb'):
                mu_pred = self.count_res_.predict()
                weights = 1 - self._r
                self.dispersion_ = self._update_dispersion(y, mu_pred, weights)
            # Compute log-likelihood
            loglik = self._compute_loglik(y, self.logistic_res_, self.count_res_, self.dispersion_)
            elapsed = time.time() - start_time
            pbar.set_description(f"Iter {iteration+1:3d}: logL = {loglik:.4f}, {elapsed:.2f}s")
            pbar.update(1)
            if loglik - loglik_old < self.tol:
                pbar.set_description("Convergence reached.")
                break
            loglik_old = loglik

        pbar.close()
        self.n_iter_ = iteration + 1
        self.loglikelihood_ = loglik_old
        total_time = time.time() - start_time
        print(f"EM algorithm converged after {self.n_iter_} iterations in {total_time:.2f} seconds.")
        return self

    def predict(self, X):
        """
        Predict the expected count for each observation using
        E[Y] = (1 - π(X)) * μ(X).

        Parameters
        ----------
        X : array-like, shape (n_samples, n_features)

        Returns
        -------
        y_pred : array, shape (n_samples,)
            Predicted expected counts.
        """
        X = np.asarray(X)
        n_samples = X.shape[0]
        # For new data, construct the exogenous variable as in training
        new_exog = sm.add_constant(np.zeros((n_samples, 1)))
        # Use the smoother part with new data via exog_smooth
        pi_pred = self.logistic_res_.predict(exog=new_exog, exog_smooth=X)
        mu_pred = self.count_res_.predict(exog=new_exog, exog_smooth=X)
        return (1 - pi_pred) * mu_pred

    def score(self, X, y):
        """
        Compute the average log-likelihood per observation on the provided data.

        This likelihood-based score is more appropriate for count models,
        including zero-inflated models, because it measures the probability
        of observing the data given the fitted model parameters.

        Parameters
        ----------
        X : array-like, shape (n_samples, n_features)
            Predictor variables.
        y : array-like, shape (n_samples,)
            Observed counts.

        Returns
        -------
        score : float
            Average log-likelihood per observation. Higher values (i.e. less negative)
            indicate a better fit.
        """
        y = np.asarray(y)
        # Compute total log-likelihood on the provided data using the current parameters
        total_loglik = self._compute_loglik(y, self.logistic_res_, self.count_res_, self.dispersion_)
        # Return average log likelihood per sample
        return total_loglik / len(y)

    def get_params(self, deep=True):
        """
        Return the parameters of the estimator.

        Parameters
        ----------
        deep : bool, default=True
            If True, will return the parameters for this estimator and contained subobjects.

        Returns
        -------
        params : dict
            Parameter names mapped to their values.
        """
        return {
            'dist': self.dist,
            'df': self.df,
            'degree': self.degree,
            'alpha': self.alpha,
            'max_iter': self.max_iter,
            'tol': self.tol
        }

    def set_params(self, **parameters):
        """
        Set the parameters of the estimator.

        Returns
        -------
        self : object
            Returns self.
        """
        for parameter, value in parameters.items():
            setattr(self, parameter, value)
        return self

In [2]:
def generate_dataset(n_samples=500, seed=42):
    """
    Generate a synthetic zero-inflated count dataset with one predictor.

    The count component follows a negative binomial distribution with true mean
    μ(X) = exp(1.5 + 0.5*sin(X)) and the structural zero probability is given by
    π(X) = logistic(-1.0 + 0.5*cos(2*X)).

    The negative binomial sampling is performed using:
        r = 1/alpha (with alpha set to 0.5) and p = r / (r + μ).

    Parameters
    ----------
    n_samples : int
        Number of samples to generate.
    seed : int
        Random seed.

    Returns
    -------
    X : array-like, shape (n_samples, 1)
        The predictor values.
    y : array-like, shape (n_samples,)
        The generated counts.
    """
    np.random.seed(seed)
    X = np.linspace(0, 10, n_samples)[:, None]
    f_count = 1.5 + 0.5 * np.sin(X).ravel()     # affects log(μ)
    mu_true = np.exp(f_count)
    f_infl = -1.0 + 0.5 * np.cos(2 * X).ravel()   # affects logit(π)
    pi_true = 1 / (1 + np.exp(-f_infl))

    alpha_true = 0.5      # dispersion parameter
    r = 1.0 / alpha_true  # number of successes
    y = np.empty(n_samples, dtype=int)
    for i in range(n_samples):
        if np.random.rand() < pi_true[i]:
            y[i] = 0
        else:
            p = r / (r + mu_true[i])
            y[i] = np.random.negative_binomial(r, p)
    return X, y

In [3]:
# Generate the dataset
X, y = generate_dataset(n_samples=10000)

In [4]:
# Instantiate and fit the ZIGAMEstimator
model = ZIGAMEstimator(dist='nb', max_iter=100, tol=1e-4)
model.fit(X, y)
print(f"EM iterations: {model.n_iter_}")
print(f"Final log-likelihood: {model.loglikelihood_:.2f}")
print("Average log-likelihood:", model.score(X, y))

Starting EM iterations...


Convergence reached.:  15%|███                 | 15/100 [00:19<01:53,  1.33s/it]

EM algorithm converged after 15 iterations in 19.99 seconds.
EM iterations: 15
Final log-likelihood: -23288.97
Average log-likelihood: -2.328898932863436





In [5]:
# Predict on new data
X_new = np.linspace(0, 10, 5)[:, None]
y_pred = model.predict(X_new)

print("Predicted expected counts for new data:")
for X_i, y_i in zip(X_new.ravel(), y_pred):
    print(f"X = {X_i:.1f} --> E[Y] = {y_i:.2f}")

Predicted expected counts for new data:
X = 0.0 --> E[Y] = 2.76
X = 2.5 --> E[Y] = 4.27
X = 5.0 --> E[Y] = 2.10
X = 7.5 --> E[Y] = 5.30
X = 10.0 --> E[Y] = 2.16
