Of course. Let's break down the model you've described and derive the complete conditional distributions needed for a Gibbs sampler.

This is a very interesting model that combines a Gamma-Poisson factorization (often used in topic modeling or recommender systems) with a Bayesian logistic regression model, linked by the latent factor `θ_i`.

### 1. Model Specification

First, let's formalize the model based on your description and the provided diagram.

**Data:**
*   `x_ij`: Count data for user `i` and item `j` (for `i=1..n`, `j=1..p`).
*   `y_ik`: Binary outcome for user `i` and task `k` (for `i=1..n`, `k=1..κ`).
*   `x_i^aux`: Auxiliary feature vector for user `i`.

**Latent Variables & Priors:**
*   **Poisson Rate Factors:**
    *   `θ_i`: User-specific factor. `θ_i ~ Gamma(α_θ, ξ_i)`
    *   `β_j`: Item-specific factor. `β_j ~ Gamma(α_β, η_j)`
*   **Hyperpriors:**
    *   `ξ_i ~ Gamma(a_ξ, b_ξ)`
    *   `η_j ~ Gamma(a_η, b_η)`
*   **Logistic Regression Coefficients:**
    *   `γ_k`: Coefficients for auxiliary features. `γ_k ~ N(μ_γ, Σ_γ)`
    *   `ν_k`: Coefficients for latent factors `θ_i`. This has a **Spike-and-Slab** prior. Let's define this as:
        *   `s_k ~ Bernoulli(π_ν)` (The "spike" selector)
        *   `w_k ~ N(μ_w, σ_w^2)` (The "slab" value)
        *   `ν_k = s_k * w_k`

**Likelihoods:**
1.  **Count Data:** `x_ij | θ_i, β_j ~ Poisson(θ_i * β_j)`
    *   *Note: You mentioned `θ^T β`. Given the indices, `θ_i * β_j` is the standard interpretation for this type of model, where both are scalars.*
2.  **Binary Outcome:** `y_ik | θ_i, γ_k, ν_k, x_i^aux ~ Bernoulli(σ(z_{ik}))`
    *   where `z_{ik} = γ_k^T x_i^aux + ν_k θ_i`
    *   and `σ(z) = 1 / (1 + exp(-z))` is the sigmoid function.

### 2. Gibbs Sampling: Deriving the Complete Conditionals

The core idea of Gibbs sampling is to iteratively sample each latent variable from its distribution conditioned on the current values of all other variables and the observed data. This conditional distribution is proportional to all terms in the full joint probability that involve the variable in question.

---

#### Conditional for `β_j`

*   **Depends on:** Its parent `η_j` and its children `x_ij` (for all `i=1..n`).
*   **Derivation:**
    `p(β_j | ...)` ∝ `p(β_j | η_j) * ∏_i p(x_ij | θ_i, β_j)`
    `∝ [β_j^(α_β-1) * exp(-η_j β_j)] * ∏_i [(θ_i β_j)^(x_ij) * exp(-θ_i β_j)]`
    `∝ β_j^(α_β-1) * exp(-η_j β_j) * (β_j)^(∑_i x_ij) * exp(-β_j ∑_i θ_i)`
    `∝ β_j^((α_β + ∑_i x_ij) - 1) * exp(-(η_j + ∑_i θ_i) β_j)`
    This is the kernel of a Gamma distribution.
*   **Result:**
    **`β_j | ... ~ Gamma(shape = α_β + ∑_{i=1}^n x_ij, rate = η_j + ∑_{i=1}^n θ_i)`**

---

#### Conditional for `η_j`

*   **Depends on:** Its child `β_j` and its own hyperparameters `a_η, b_η`.
*   **Derivation:**
    `p(η_j | ...)` ∝ `p(η_j) * p(β_j | η_j)`
    `∝ [η_j^(a_η-1) * exp(-b_η η_j)] * [η_j^(α_β) * exp(-η_j β_j)]`
    `∝ η_j^((a_η + α_β) - 1) * exp(-(b_η + β_j) η_j)`
    This is also a Gamma kernel.
*   **Result:**
    **`η_j | ... ~ Gamma(shape = a_η + α_β, rate = b_η + β_j)`**

---

#### Conditional for `ξ_i`

*   This is perfectly symmetric to the derivation for `η_j`.
*   **Depends on:** Its child `θ_i` and its own hyperparameters `a_ξ, b_ξ`.
*   **Result:**
    **`ξ_i | ... ~ Gamma(shape = a_ξ + α_θ, rate = b_ξ + θ_i)`**

---

#### Conditional for `θ_i`

*   This is the most complex variable as it links the two likelihoods.
*   **Depends on:** Its parent `ξ_i`, its children `x_ij` (for all `j`), and its children `y_ik` (for all `k`), plus the co-parents of `y_ik` (`γ_k`, `ν_k`).
*   **Derivation:**
    `p(θ_i | ...)` ∝ `p(θ_i | ξ_i) * [∏_j p(x_ij | θ_i, β_j)] * [∏_k p(y_ik | θ_i, γ_k, ν_k)]`
    `∝ [θ_i^(α_θ-1) exp(-ξ_i θ_i)] * [∏_j (θ_i β_j)^(x_ij) exp(-θ_i β_j)] * [∏_k σ(γ_k^T x_i^aux + ν_k θ_i)^(y_ik) * (1-σ(...))^(1-y_ik)]`
    `∝ θ_i^((α_θ + ∑_j x_ij) - 1) * exp(-(ξ_i + ∑_j β_j) θ_i) * ∏_k p(y_ik | θ_i, ...)`
*   **Result:** This distribution is **not a standard form** because `θ_i` is inside the sigmoid function in the logistic regression term. You cannot sample from it directly.
    **`p(θ_i | ...) ∝ θ_i^((α_θ + ∑_j x_ij) - 1) * exp(-(ξ_i + ∑_j β_j) θ_i) * ∏_{k=1}^κ [σ(γ_k^T x_i^aux + ν_k θ_i)]^(y_ik) [1 - σ(γ_k^T x_i^aux + ν_k θ_i)]^(1-y_ik)`**
    *   **Action:** You must use a sampling method like a **Metropolis-Hastings** step or a **Slice Sampler** to draw a new value for `θ_i`.

---

#### Conditional for `γ_k`

*   **Depends on:** Its prior `(μ_γ, Σ_γ)` and its children `y_ik` (for all `i`).
*   **Derivation:**
    `p(γ_k | ...)` ∝ `p(γ_k) * ∏_i p(y_ik | γ_k, ...)`
    `∝ N(γ_k | μ_γ, Σ_γ) * ∏_i σ(γ_k^T x_i^aux + ν_k θ_i)^(y_ik) * (1-σ(...))^(1-y_ik)`
*   **Result:** This is the standard posterior for Bayesian logistic regression. It is also **not a standard form**.
    **`p(γ_k | ...) ∝ exp(-1/2 (γ_k - μ_γ)^T Σ_γ⁻¹ (γ_k - μ_γ)) * ∏_{i=1}^n [σ(z_{ik})^(y_ik) * (1-σ(z_{ik}))^(1-y_ik)]`**
    *   **Action:** Use a **Metropolis-Hastings** step (e.g., with a multivariate Normal proposal distribution) or see the note on Polya-Gamma augmentation below.

---

#### Conditionals for Spike-and-Slab `ν_k` (sampling `s_k` and `w_k`)

We sample the binary indicator `s_k` and the slab value `w_k` sequentially.

**1. Conditional for `s_k` (the "spike")**
*   This is a Bernoulli variable. We need to find the probability of it being 1 vs 0.
*   `p(s_k=1 | ...)` ∝ `p(s_k=1) * p(Y | s_k=1, ...)` = `π_ν * ∏_i p(y_ik | ν_k = w_k, ...)`
*   `p(s_k=0 | ...)` ∝ `p(s_k=0) * p(Y | s_k=0, ...)` = `(1-π_ν) * ∏_i p(y_ik | ν_k = 0, ...)`
*   **Derivation:**
    Let `L_1 = π_ν * ∏_i σ(γ_k^T x_i^aux + w_k θ_i)^(y_ik) * (1-σ(...))^(1-y_ik)`
    Let `L_0 = (1-π_ν) * ∏_i σ(γ_k^T x_i^aux)^(y_ik) * (1-σ(...))^(1-y_ik)`
    The probability of `s_k` being 1 is `P(s_k=1) = L_1 / (L_1 + L_0)`.
*   **Result:**
    **`s_k | ... ~ Bernoulli(p = L_1 / (L_1 + L_0))`**

**2. Conditional for `w_k` (the "slab")**
*   This is conditioned on the value of `s_k` we just sampled.
*   **If `s_k = 0`:** `w_k` does not appear in the likelihood. We just sample it from its prior.
    *   **Result:** **`w_k | s_k=0, ... ~ N(μ_w, σ_w^2)`**
*   **If `s_k = 1`:** `ν_k = w_k`, so `w_k` is in the likelihood.
    *   `p(w_k | s_k=1, ...)` ∝ `p(w_k) * ∏_i p(y_ik | ν_k=w_k, ...)`
    *   `∝ N(w_k | μ_w, σ_w^2) * ∏_i σ(γ_k^T x_i^aux + w_k θ_i)^(y_ik) * (1-σ(...))^(1-y_ik)`
    *   **Result:** Again, this is **not a standard form**.
        **`p(w_k | s_k=1, ...) ∝ exp(-(w_k - μ_w)² / (2σ_w²)) * ∏_{i=1}^n [σ(z_{ik})^(y_ik) * (1-σ(z_{ik}))^(1-y_ik)]`** where `z_{ik}` now includes `w_k`.
    *   **Action:** Use a **Metropolis-Hastings** step to sample `w_k`.

### 3. A More Advanced (and Efficient) Alternative: Polya-Gamma Augmentation

The non-conjugacy in the logistic regression parts (`θ_i`, `γ_k`, `w_k`) makes standard Gibbs sampling slow and difficult to tune. A very powerful technique to solve this is **Polya-Gamma data augmentation**.

The key idea is that any logistic likelihood can be written as a mixture of Gaussians with respect to a Polya-Gamma distribution.
`σ(z)^a * (1-σ(z))^b ∝ exp(cz) * ∫ exp(-ωz²/2) * p(ω|a+b, 0) dω` where `c= (a-b)/2`.

By introducing an auxiliary Polya-Gamma variable `ω_{ik}` for each observation `y_{ik}`, the posteriors for `θ_i`, `γ_k`, and `w_k` become **conditionally Gaussian**, which allows for direct (Gibbs) sampling instead of Metropolis-Hastings.

If you implement this:
1.  **Augment:** For each `i, k`, you sample `ω_{ik} ~ PG(1, z_{ik})`.
2.  **Sample `γ_k`:** The conditional for `γ_k` becomes a Multivariate Normal.
3.  **Sample `ν_k` (or `w_k`):** The conditional for `w_k` (when `s_k=1`) becomes a Normal distribution.
4.  **Sample `θ_i`:** The conditional for `θ_i` becomes a product of a Gamma-like term and a Gaussian term. This is still not a standard form, but it's much easier to handle (e.g., via a custom slice sampler or a very efficient Metropolis step) than the original form.

This is a more advanced implementation but is the standard for high-performance inference in such models. If you are serious about implementing this, I highly recommend reading the original paper by Polson, Scott, and Windle (2013), "Bayesian inference for logistic models using Polya-Gamma latent variables."

In [None]:
import numpy as np
from numpy.random import Generator, default_rng
from scipy.special import expit, gammaln, logsumexp
from scipy.stats import gamma as gamma_dist
from scipy.stats import norm as norm_dist

class BayesianHybridModelSampler:
    """
    Gibbs sampler for a hybrid Poisson Factorization and Logistic Regression model.

    This model links latent factors from count data (Poisson factorization) to a
    binary outcome (Logistic Regression).

    The factorization part `x_ij ~ Poisson(theta_i @ beta_j.T)` is handled via
    data augmentation, introducing latent counts `z_ijl` to restore conjugacy.

    The logistic regression part `y_ik ~ Bernoulli(sigmoid(..))` involves non-conjugate
    posteriors, which are handled with Metropolis-Hastings steps.
    """

    def __init__(
        self,
        X: np.ndarray,
        Y: np.ndarray,
        X_aux: np.ndarray,
        n_latent_dims: int,
        *,
        # Gamma shape hyperparameters
        alpha_theta: float = 0.3,
        alpha_beta: float = 0.3,
        alpha_xi: float = 0.3,
        alpha_eta: float = 0.3,
        # Gamma rate hyperparameters
        lambda_xi: float = 0.3,
        lambda_eta: float = 0.3,
        # Spike-and-Slab hyperparameters for nu (upsilon)
        pi_nu: float = 0.5,           # Prior probability of inclusion
        sigma_slab_sq: float = 1.0,   # Variance of the "slab"
        # Prior variance for gamma
        sigma_gamma_sq: float = 1.0,
        seed: int | None = None,
    ) -> None:
        """
        Initializes the sampler with data and hyperparameters.

        Args:
            X: Count data matrix of shape (n_users, p_items).
            Y: Binary outcome matrix of shape (n_users, k_tasks).
            X_aux: Auxiliary features matrix of shape (n_users, q_features).
            n_latent_dims: The number of latent dimensions, d.
            alpha_*: Shape parameters for Gamma priors.
            lambda_*: Rate parameters for Gamma hyperpriors.
            pi_nu: Prior probability for the slab component in the spike-and-slab prior.
            sigma_slab_sq: Variance for the slab component.
            sigma_gamma_sq: Prior variance for the gamma coefficients.
            seed: Random seed for reproducibility.
        """
        self.rng: Generator = default_rng(seed)

        # -- Data Dimensions --
        self.X = np.asarray(X, dtype=np.int32)
        self.Y = np.asarray(Y, dtype=np.int32)
        self.X_aux = np.asarray(X_aux)

        self.n, self.p = self.X.shape
        self.k = self.Y.shape[1]
        self.d = n_latent_dims
        self.q = self.X_aux.shape[1]

        # -- Hyperparameters --
        self.alpha_theta = alpha_theta
        self.alpha_beta = alpha_beta
        self.alpha_xi = alpha_xi
        self.alpha_eta = alpha_eta
        self.lambda_xi = lambda_xi
        self.lambda_eta = lambda_eta
        self.pi_nu = pi_nu
        self.sigma_slab_sq = sigma_slab_sq
        self.sigma_gamma_sq = sigma_gamma_sq

        # -- Initialize Parameters --
        self._init_params()

    def _init_params(self) -> None:
        """Initialise latent variables and parameters."""
        self.theta = self.rng.gamma(1.0, 1.0, size=(self.n, self.d))
        self.beta = self.rng.gamma(1.0, 1.0, size=(self.p, self.d))
        self.xi = self.rng.gamma(self.alpha_xi, scale=1.0/self.lambda_xi, size=self.n)
        self.eta = self.rng.gamma(self.alpha_eta, scale=1.0/self.lambda_eta, size=self.p)

        # Latent counts for Poisson data augmentation
        self.z = np.zeros((self.n, self.p, self.d), dtype=np.int32)
        # We need an initial update to populate z based on initial theta and beta
        self._update_latent_counts_z()

        # Logistic Regression Parameters
        self.gamma = self.rng.normal(0.0, np.sqrt(self.sigma_gamma_sq), size=(self.k, self.q))
        
        # Spike-and-Slab parameters (nu is upsilon)
        # s_k is the binary "spike" selector, w_k is the "slab" value
        self.s_nu = self.rng.binomial(1, self.pi_nu, size=(self.k, self.d))
        self.w_nu = self.rng.normal(0.0, np.sqrt(self.sigma_slab_sq), size=(self.k, self.d))
        self.nu = self.s_nu * self.w_nu

    # ------------------------------------------------------------------
    # --- Poisson Factorization Updates (with Data Augmentation) ---
    # ------------------------------------------------------------------
    def _update_latent_counts_z(self) -> None:
        """Sample the latent counts z_ijl using a Multinomial distribution."""
        # Calculate rates for each latent dimension component
        rates = self.theta[:, np.newaxis, :] * self.beta[np.newaxis, :, :] # Shape (n, p, d)
        total_rates = np.sum(rates, axis=2)
        
        # Avoid division by zero for cases where total_rate is 0
        total_rates[total_rates == 0] = 1.0
        
        # Calculate multinomial probabilities
        probs = rates / total_rates[..., np.newaxis]

        # Sample from Multinomial for each (i, j)
        for i in range(self.n):
            for j in range(self.p):
                if self.X[i, j] > 0:
                    self.z[i, j, :] = self.rng.multinomial(n=self.X[i, j], pvals=probs[i, j, :])
                else:
                    self.z[i, j, :] = 0

    def _update_beta(self) -> None:
        """Update beta_jl using its conjugate Gamma posterior."""
        # Sum over the 'n' dimension of the latent counts z
        z_sum_over_i = np.sum(self.z, axis=0)  # Shape (p, d)
        
        shape = self.alpha_beta + z_sum_over_i
        rate = self.eta[:, np.newaxis] + np.sum(self.theta, axis=0)
        
        # Correct usage of scipy.stats.gamma with rate parameter
        self.beta = self.rng.gamma(shape, scale=1.0 / rate)

    def _update_eta(self) -> None:
        """Update eta_j (hyper-priors for beta)."""
        shape = self.alpha_eta + self.d * self.alpha_beta
        rate = self.lambda_eta + np.sum(self.beta, axis=1)
        self.eta = self.rng.gamma(shape, scale=1.0 / rate)

    def _update_xi(self) -> None:
        """Update xi_i (hyper-priors for theta)."""
        shape = self.alpha_xi + self.d * self.alpha_theta
        rate = self.lambda_xi + np.sum(self.theta, axis=1)
        self.xi = self.rng.gamma(shape, scale=1.0 / rate)

    # ------------------------------------------------------------------
    # --- Non-Conjugate Updates (Metropolis-Hastings) ---
    # ------------------------------------------------------------------
    @staticmethod
    def _log_likelihood_bernoulli(y: np.ndarray, logits: np.ndarray) -> float:
        """Numerically stable log-likelihood for a Bernoulli model."""
        return np.sum(y * logits - np.log1p(np.exp(logits)))

    def _log_posterior_theta_i(self, i: int, theta_i: np.ndarray) -> float:
        """Calculate log posterior for a single theta_i vector."""
        # Log-prior from Gamma distributions
        log_prior = np.sum(gamma_dist.logpdf(theta_i, a=self.alpha_theta, scale=1.0/self.xi[i]))
        
        # Log-likelihood from Poisson factorization part (using latent z)
        log_lik_poisson = np.sum(
            self.z[i, :, :] * np.log(self.beta) - self.beta * theta_i[:, np.newaxis].T
        )
        
        # Log-likelihood from Logistic Regression part
        logits = self.X_aux[i] @ self.gamma.T + theta_i @ self.nu.T # Shape (k,)
        log_lik_logistic = self._log_likelihood_bernoulli(self.Y[i, :], logits)
        
        return log_prior + log_lik_poisson + log_lik_logistic

    def _update_theta(self, step_size: float = 0.05) -> None:
        """Update theta using Metropolis-Hastings due to non-conjugacy."""
        for i in range(self.n):
            current_theta_i = self.theta[i]
            # Use a log-normal proposal to keep theta positive
            proposal_theta_i = self.rng.lognormal(np.log(current_theta_i), step_size)
            
            log_p_curr = self._log_posterior_theta_i(i, current_theta_i)
            log_p_prop = self._log_posterior_theta_i(i, proposal_theta_i)
            
            # Jacobian correction for log-normal proposal
            log_p_curr += np.sum(np.log(current_theta_i))
            log_p_prop += np.sum(np.log(proposal_theta_i))
            
            if self.rng.random() < np.exp(log_p_prop - log_p_curr):
                self.theta[i] = proposal_theta_i
                
    def _log_posterior_gamma_k(self, k: int, gamma_k: np.ndarray) -> float:
        """Calculate log posterior for a single gamma_k vector."""
        log_prior = norm_dist.logpdf(gamma_k, 0, np.sqrt(self.sigma_gamma_sq)).sum()
        logits = self.X_aux @ gamma_k + self.theta @ self.nu[k]
        log_lik = self._log_likelihood_bernoulli(self.Y[:, k], logits)
        return log_prior + log_lik

    def _update_gamma(self, step_size: float = 0.05) -> None:
        """Update gamma using Metropolis-Hastings."""
        for k in range(self.k):
            current_gamma_k = self.gamma[k]
            proposal_gamma_k = self.rng.normal(current_gamma_k, step_size)
            
            log_p_curr = self._log_posterior_gamma_k(k, current_gamma_k)
            log_p_prop = self._log_posterior_gamma_k(k, proposal_gamma_k)
            
            if self.rng.random() < np.exp(log_p_prop - log_p_curr):
                self.gamma[k] = proposal_gamma_k

    # ------------------------------------------------------------------
    # --- Spike-and-Slab Updates for nu (upsilon) ---
    # ------------------------------------------------------------------
    def _update_s_nu(self) -> None:
        """Update the spike selectors s_k for nu_k."""
        for k in range(self.k):
            # Calculate logits with nu_k=0 (spike)
            logits_0 = self.X_aux @ self.gamma[k]
            # Calculate logits with nu_k=w_k (slab)
            logits_1 = logits_0 + self.theta @ self.w_nu[k]
            
            # Log-posterior probability for s_k=1 (slab)
            log_post_1 = np.log(self.pi_nu) + self._log_likelihood_bernoulli(self.Y[:, k], logits_1)
            
            # Log-posterior probability for s_k=0 (spike)
            log_post_0 = np.log(1 - self.pi_nu) + self._log_likelihood_bernoulli(self.Y[:, k], logits_0)
            
            # Numerically stable calculation of probability
            prob_s1 = 1 / (1 + np.exp(log_post_0 - log_post_1))
            self.s_nu[k] = self.rng.binomial(1, prob_s1, size=self.d)

    def _log_posterior_w_nu_k(self, k: int, w_nu_k: np.ndarray) -> float:
        """Calculate log posterior for a single w_nu_k vector."""
        log_prior = norm_dist.logpdf(w_nu_k, 0, np.sqrt(self.sigma_slab_sq)).sum()
        
        # The likelihood term only exists if the slab is active (s_nu=1)
        nu_k = self.s_nu[k] * w_nu_k
        logits = self.X_aux @ self.gamma[k] + self.theta @ nu_k
        log_lik = self._log_likelihood_bernoulli(self.Y[:, k], logits)
        
        return log_prior + log_lik

    def _update_w_nu(self, step_size: float = 0.05) -> None:
        """Update the slab values w_k for nu_k using Metropolis-Hastings."""
        for k in range(self.k):
            current_w_k = self.w_nu[k]
            proposal_w_k = self.rng.normal(current_w_k, step_size)
            
            log_p_curr = self._log_posterior_w_nu_k(k, current_w_k)
            log_p_prop = self._log_posterior_w_nu_k(k, proposal_w_k)
            
            if self.rng.random() < np.exp(log_p_prop - log_p_curr):
                self.w_nu[k] = proposal_w_k

    # ------------------------------------------------------------------
    def step(self) -> None:
        """Run a single full Gibbs iteration."""
        # 1. Update latent counts for Poisson part
        self._update_latent_counts_z()
        
        # 2. Update conjugate parameters
        self._update_beta()
        self._update_eta()
        self._update_xi()
        
        # 3. Update non-conjugate parameters via M-H
        self._update_theta()
        self._update_gamma()

        # 4. Update Spike-and-Slab parameters
        self._update_s_nu()
        self._update_w_nu()
        self.nu = self.s_nu * self.w_nu

    def run(self, n_iter: int, n_burnin: int = 100) -> dict:
        """Run the sampler and return parameter traces."""
        # Burn-in phase
        print(f"Running burn-in for {n_burnin} iterations...")
        for _ in range(n_burnin):
            self.step()

        # Sampling phase
        print(f"Running sampling for {n_iter} iterations...")
        traces = {
            "theta": np.zeros((n_iter, *self.theta.shape)),
            "beta": np.zeros((n_iter, *self.beta.shape)),
            "gamma": np.zeros((n_iter, *self.gamma.shape)),
            "nu": np.zeros((n_iter, *self.nu.shape)),
            "s_nu": np.zeros((n_iter, *self.s_nu.shape)),
        }

        for t in range(n_iter):
            self.step()
            traces["theta"][t] = self.theta
            traces["beta"][t] = self.beta
            traces["gamma"][t] = self.gamma
            traces["nu"][t] = self.nu
            traces["s_nu"][t] = self.s_nu
            if (t + 1) % 100 == 0:
                print(f"  ...iteration {t+1}/{n_iter}")

        return traces

In [None]:
import numpy as np
from numpy.random import Generator, default_rng
from scipy.special import expit, gammaln, logsumexp, log_expit
from scipy.stats import gamma as gamma_dist
from scipy.stats import norm as norm_dist

class LogSpaceBayesianHybridModelSampler:
    """
    Gibbs sampler for a hybrid model, implemented for numerical stability in log-space.

    This implementation prioritizes numerical stability by performing key calculations
    in the log domain to prevent underflow from multiplying small probabilities.
    """

    def __init__(
        self,
        X: np.ndarray,
        Y: np.ndarray,
        X_aux: np.ndarray,
        n_latent_dims: int,
        *,
        alpha_theta: float = 0.3,
        alpha_beta: float = 0.3,
        alpha_xi: float = 0.3,
        alpha_eta: float = 0.3,
        lambda_xi: float = 0.3,
        lambda_eta: float = 0.3,
        pi_nu: float = 0.5,
        sigma_slab_sq: float = 1.0,
        sigma_gamma_sq: float = 1.0,
        seed: int | None = None,
    ) -> None:
        self.rng: Generator = default_rng(seed)
        self.X = np.asarray(X, dtype=np.int32)
        self.Y = np.asarray(Y, dtype=np.int32)
        self.X_aux = np.asarray(X_aux)
        self.n, self.p = self.X.shape
        self.k = self.Y.shape[1]
        self.d = n_latent_dims
        self.q = self.X_aux.shape[1]
        self.alpha_theta = alpha_theta
        self.alpha_beta = alpha_beta
        self.alpha_xi = alpha_xi
        self.alpha_eta = alpha_eta
        self.lambda_xi = lambda_xi
        self.lambda_eta = lambda_eta
        self.pi_nu = pi_nu
        self.sigma_slab_sq = sigma_slab_sq
        self.sigma_gamma_sq = sigma_gamma_sq
        self._init_params()

    # def _init_params(self) -> None:
    #     """Initialise latent variables and parameters."""
    #     self.theta = self.rng.gamma(1.0, 1.0, size=(self.n, self.d))
    #     self.beta = self.rng.gamma(1.0, 1.0, size=(self.p, self.d))
    #     self.xi = self.rng.gamma(self.alpha_xi, scale=1.0/self.lambda_xi, size=self.n)
    #     self.eta = self.rng.gamma(self.alpha_eta, scale=1.0/self.lambda_eta, size=self.p)
    #     self.z = np.zeros((self.n, self.p, self.d), dtype=np.int32)
    #     self._update_latent_counts_z() # Initial population of z
    #     self.gamma = self.rng.normal(0.0, np.sqrt(self.sigma_gamma_sq), size=(self.k, self.q))
    #     self.s_nu = self.rng.binomial(1, self.pi_nu, size=(self.k, self.d))
    #     self.w_nu = self.rng.normal(0.0, np.sqrt(self.sigma_slab_sq), size=(self.k, self.d))
    #     self.nu = self.s_nu * self.w_nu

        def _init_params(self) -> None:
            """
            Initialise latent variables and parameters by sampling from their priors,
            following the model's generative process.
            """
            # 1. Initialize the top-level parent variables first.
            #    xi ~ Gamma(alpha_xi, lambda_xi)
            #    eta ~ Gamma(alpha_eta, lambda_eta)
            self.xi = self.rng.gamma(self.alpha_xi, scale=1.0 / self.lambda_xi, size=self.n)
            self.eta = self.rng.gamma(self.alpha_eta, scale=1.0 / self.lambda_eta, size=self.p)

            # 2. Now initialize the children using the sampled parents as their parameters.
            #    theta_i ~ Gamma(alpha_theta, xi_i)
            #    beta_j ~ Gamma(alpha_beta, eta_j)
            #    Note: We sample each row with its corresponding rate from xi/eta.
            self.theta = np.zeros((self.n, self.d))
            for i in range(self.n):
                self.theta[i, :] = self.rng.gamma(self.alpha_theta, scale=1.0 / self.xi[i], size=self.d)

            self.beta = np.zeros((self.p, self.d))
            for j in range(self.p):
                self.beta[j, :] = self.rng.gamma(self.alpha_beta, scale=1.0 / self.eta[j], size=self.d)

            # 3. Initialize the logistic regression parameters from their priors.
            #    gamma_k ~ Normal(0, sigma_gamma_sq)
            self.gamma = self.rng.normal(0.0, np.sqrt(self.sigma_gamma_sq), size=(self.k, self.q))

            #    Initialize spike-and-slab components from their priors.
            #    s_nu_k ~ Bernoulli(pi_nu)
            #    w_nu_k ~ Normal(0, sigma_slab_sq)
            self.s_nu = self.rng.binomial(1, self.pi_nu, size=(self.k, self.d))
            self.w_nu = self.rng.normal(0.0, np.sqrt(self.sigma_slab_sq), size=(self.k, self.d))
            #    Combine them to get the initial nu (upsilon).
            self.nu = self.s_nu * self.w_nu

            # 4. Finally, populate the latent counts `z` based on the initial `theta` and `beta`.
            self.z = np.zeros((self.n, self.p, self.d), dtype=np.int32)
            self._update_latent_counts_z()

    # ------------------------------------------------------------------
    # --- Log-Space and Conjugate Updates ---
    # ------------------------------------------------------------------

    def _update_latent_counts_z(self) -> None:
        """Sample latent counts z_ijl in log-space for numerical stability."""
        with np.errstate(divide='ignore'): # Ignore log(0) warnings
            log_theta = np.log(self.theta)
            log_beta = np.log(self.beta)

        # Calculate log-rates for each latent dimension component
        log_rates = log_theta[:, np.newaxis, :] + log_beta[np.newaxis, :, :] # Shape (n, p, d)
        
        # Normalize in log-space using logsumexp
        log_total_rates = logsumexp(log_rates, axis=2)
        log_probs = log_rates - log_total_rates[..., np.newaxis]

        # Convert back to probabilities only for the final sampling step
        probs = np.exp(log_probs)

        for i in range(self.n):
            for j in range(self.p):
                if self.X[i, j] > 0:
                    # Renormalize to handle any minor floating point inaccuracies
                    pvals = probs[i, j, :] / np.sum(probs[i, j, :])
                    self.z[i, j, :] = self.rng.multinomial(n=self.X[i, j], pvals=pvals)
                else:
                    self.z[i, j, :] = 0

    def _update_beta(self) -> None:
        """Update beta_jl. This calculation does not require log-space."""
        z_sum_over_i = np.sum(self.z, axis=0)
        shape = self.alpha_beta + z_sum_over_i
        rate = self.eta[:, np.newaxis] + np.sum(self.theta, axis=0)
        self.beta = self.rng.gamma(shape, scale=1.0 / rate)

    def _update_eta(self) -> None:
        """Update eta_j. This calculation does not require log-space."""
        shape = self.alpha_eta + self.d * self.alpha_beta
        rate = self.lambda_eta + np.sum(self.beta, axis=1)
        self.eta = self.rng.gamma(shape, scale=1.0 / rate)

    def _update_xi(self) -> None:
        """Update xi_i. This calculation does not require log-space."""
        shape = self.alpha_xi + self.d * self.alpha_theta
        rate = self.lambda_xi + np.sum(self.theta, axis=1)
        self.xi = self.rng.gamma(shape, scale=1.0 / rate)

    @staticmethod
    def _log_likelihood_bernoulli(y: np.ndarray, logits: np.ndarray) -> float:
        """Numerically stable Bernoulli log-likelihood using log_expit."""
        log_p1 = log_expit(logits)  # log(sigmoid(logits))
        log_p0 = log_expit(-logits) # log(1 - sigmoid(logits))
        return np.sum(y * log_p1 + (1 - y) * log_p0)

    # ------------------------------------------------------------------
    # --- Non-Conjugate Metropolis-Hastings Updates (already in log-space) ---
    # ------------------------------------------------------------------

    # def _log_posterior_theta_i(self, i: int, theta_i: np.ndarray) -> float:
    #     """Calculate log posterior for a single theta_i vector."""
    #     log_prior_gamma = np.sum(gamma_dist.logpdf(theta_i, a=self.alpha_theta, scale=1.0/self.xi[i]))
    #     with np.errstate(divide='ignore'):
    #         log_lik_poisson = np.sum(self.z[i, :, :] * np.log(self.beta) - self.beta * theta_i[:, np.newaxis].T)
    #     logits = self.X_aux[i] @ self.gamma.T + theta_i @ self.nu.T
    #     log_lik_logistic = self._log_likelihood_bernoulli(self.Y[i, :], logits)
    #     return log_prior_gamma + log_lik_poisson + log_lik_logistic

    # def _update_theta(self, step_size: float = 0.05) -> None:
    #     """Update theta using Metropolis-Hastings. Calculation is in log-space."""
    #     for i in range(self.n):
    #         current_theta_i = self.theta[i]
    #         proposal_theta_i = self.rng.lognormal(np.log(current_theta_i), step_size)
    #         log_p_curr = self._log_posterior_theta_i(i, current_theta_i) + np.sum(np.log(current_theta_i))
    #         log_p_prop = self._log_posterior_theta_i(i, proposal_theta_i) + np.sum(np.log(proposal_theta_i))
    #         if self.rng.random() < np.exp(log_p_prop - log_p_curr):
    #             self.theta[i] = proposal_theta_i

    # def _log_posterior_gamma_k(self, k: int, gamma_k: np.ndarray) -> float:
    #     """Calculate log posterior for a single gamma_k vector."""
    #     log_prior = norm_dist.logpdf(gamma_k, 0, np.sqrt(self.sigma_gamma_sq)).sum()
    #     logits = self.X_aux @ gamma_k + self.theta @ self.nu[k]
    #     log_lik = self._log_likelihood_bernoulli(self.Y[:, k], logits)
    #     return log_prior + log_lik

    # def _update_gamma(self, step_size: float = 0.05) -> None:
    #     """Update gamma using Metropolis-Hastings. Calculation is in log-space."""
    #     for k in range(self.k):
    #         current_gamma_k = self.gamma[k]
    #         proposal_gamma_k = self.rng.normal(current_gamma_k, step_size)
    #         log_p_curr = self._log_posterior_gamma_k(k, current_gamma_k)
    #         log_p_prop = self._log_posterior_gamma_k(k, proposal_gamma_k)
    #         if self.rng.random() < np.exp(log_p_prop - log_p_curr):
    #             self.gamma[k] = proposal_gamma_k

    # # ------------------------------------------------------------------
    # # --- Spike-and-Slab Updates for nu (upsilon) (already in log-space) ---
    # # ------------------------------------------------------------------

    # def _update_s_nu(self) -> None:
    #     """Update the spike selectors s_k. Calculation is in log-space."""
    #     for k in range(self.k):
    #         logits_0 = self.X_aux @ self.gamma[k] # Logits when nu_k=0
    #         logits_1 = logits_0 + self.theta @ self.w_nu[k] # Logits when nu_k=w_k
    #         log_post_1 = np.log(self.pi_nu) + self._log_likelihood_bernoulli(self.Y[:, k], logits_1)
    #         log_post_0 = np.log(1 - self.pi_nu) + self._log_likelihood_bernoulli(self.Y[:, k], logits_0)
            
    #         # Stable calculation of P(s=1) = sigmoid(log_post_1 - log_post_0)
    #         prob_s1 = expit(log_post_1 - log_post_0)
    #         self.s_nu[k] = self.rng.binomial(1, prob_s1, size=self.d)

    # def _log_posterior_w_nu_k(self, k: int, w_nu_k: np.ndarray) -> float:
    #     """Calculate log posterior for a single w_nu_k vector."""
    #     log_prior = norm_dist.logpdf(w_nu_k, 0, np.sqrt(self.sigma_slab_sq)).sum()
    #     nu_k = self.s_nu[k] * w_nu_k
    #     logits = self.X_aux @ self.gamma[k] + self.theta @ nu_k
    #     log_lik = self._log_likelihood_bernoulli(self.Y[:, k], logits)
    #     return log_prior + log_lik

    # def _update_w_nu(self, step_size: float = 0.05) -> None:
    #     """Update the slab values w_k. Calculation is in log-space."""
    #     for k in range(self.k):
    #         current_w_k = self.w_nu[k]
    #         proposal_w_k = self.rng.normal(current_w_k, step_size)
    #         log_p_curr = self._log_posterior_w_nu_k(k, current_w_k)
    #         log_p_prop = self._log_posterior_w_nu_k(k, proposal_w_k)
    #         if self.rng.random() < np.exp(log_p_prop - log_p_curr):
    #             self.w_nu[k] = proposal_w_k

    def _update_theta(self) -> None:
        """Update theta using Laplace Approximation."""
        for i in range(self.n):
            # The function to minimize is the *negative* log posterior
            def objective_func(theta_i):
                # Ensure positivity during optimization
                if np.any(theta_i <= 0):
                    return np.inf
                return -self._log_posterior_theta_i(i, theta_i)

            # Use the current value as the starting point for the optimization
            initial_guess = self.theta[i]

            # Find the mode of the posterior by minimizing the negative log posterior
            result = minimize(
                fun=objective_func,
                x0=initial_guess,
                method='Nelder-Mead', # A gradient-free method, good for stability
            )

            if result.success:
                self.theta[i] = result.x

    def _update_gamma(self) -> None:
        """Update gamma using Laplace Approximation."""
        for k in range(self.k):
            def objective_func(gamma_k):
                return -self._log_posterior_gamma_k(k, gamma_k)

            initial_guess = self.gamma[k]
            result = minimize(fun=objective_func, x0=initial_guess, method='BFGS') # A gradient-based method

            if result.success:
                self.gamma[k] = result.x
    
    # --- Spike-and-Slab Updates for nu (upsilon) ---

    # _update_s_nu remains the same, as it's already a closed-form Gibbs step.
    def _update_s_nu(self) -> None:
        # ... (no change here)
        for k in range(self.k):
            logits_0 = self.X_aux @ self.gamma[k]
            logits_1 = logits_0 + self.theta @ self.w_nu[k]
            log_post_1 = np.log(self.pi_nu) + self._log_likelihood_bernoulli(self.Y[:, k], logits_1)
            log_post_0 = np.log(1 - self.pi_nu) + self._log_likelihood_bernoulli(self.Y[:, k], logits_0)
            prob_s1 = expit(log_post_1 - log_post_0)
            self.s_nu[k] = self.rng.binomial(1, prob_s1, size=self.d)

    def _update_w_nu(self) -> None:
        """Update the slab values w_k using Laplace Approximation."""
        for k in range(self.k):
            # Only update w_nu_k if at least one of its components is active
            if np.any(self.s_nu[k] == 1):
                def objective_func(w_nu_k):
                    return -self._log_posterior_w_nu_k(k, w_nu_k)

                initial_guess = self.w_nu[k]
                result = minimize(fun=objective_func, x0=initial_guess, method='BFGS')

                if result.success:
                    self.w_nu[k] = result.x
            else:
                # If all s_nu are 0, just sample w_nu from its prior
                self.w_nu[k] = self.rng.normal(0.0, np.sqrt(self.sigma_slab_sq), size=self.d)
    # ------------------------------------------------------------------
    # --- Sampler Execution ---
    # ------------------------------------------------------------------
    def step(self) -> None:
        """Run a single full Gibbs iteration."""
        self._update_latent_counts_z()
        self._update_beta()
        self._update_eta()
        self._update_xi()
        self._update_theta()
        self._update_gamma()
        self._update_s_nu()
        self._update_w_nu()
        self.nu = self.s_nu * self.w_nu

    def run(self, n_iter: int, n_burnin: int = 100) -> dict:
        """Run the sampler and return parameter traces."""
        print(f"Running burn-in for {n_burnin} iterations...")
        for i in range(n_burnin):
            self.step()
            if (i + 1) % 100 == 0:
                print(f"  ...burn-in iteration {i+1}/{n_burnin}")

        print(f"Running sampling for {n_iter} iterations...")
        traces = {
            "theta": np.zeros((n_iter, *self.theta.shape)),
            "beta": np.zeros((n_iter, *self.beta.shape)),
            "gamma": np.zeros((n_iter, *self.gamma.shape)),
            "nu": np.zeros((n_iter, *self.nu.shape)),
            "s_nu": np.zeros((n_iter, *self.s_nu.shape)),
        }
        for t in range(n_iter):
            self.step()
            traces["theta"][t], traces["beta"][t], traces["gamma"][t], traces["nu"][t], traces["s_nu"][t] = \
                self.theta, self.beta, self.gamma, self.nu, self.s_nu
            if (t + 1) % 100 == 0:
                print(f"  ...sampling iteration {t+1}/{n_iter}")
        return traces