Skip to content

Numerical overflow in linfa-clustering GMM E-step: estimate_log_prob_resp uses unstable exp-sum-ln normalization #442

@schliffen

Description

@schliffen

Summary
The GaussianMixtureModel implementation in linfa-clustering still uses a numerically unstable normalization path in estimate_log_prob_resp in release 0.8.1 (and current master).
It computes ln(sum(exp(x))) directly, which overflows when weighted log-probabilities are large positive values.

Affected code
Release 0.8.1:
https://github.com/rust-ml/linfa/blob/0.8.1/algorithms/linfa-clustering/src/gaussian_mixture/algorithm.rs#L328-L350

Master:
https://github.com/rust-ml/linfa/blob/master/algorithms/linfa-clustering/src/gaussian_mixture/algorithm.rs#L328-L350

Current logic in estimate_log_prob_resp:

weighted_log_prob.mapv(|x| x.exp()).sum_axis(Axis(1)).mapv(|x| x.ln())
Why this is a bug
The above is mathematically correct but numerically unstable.
For large positive x, exp(x) overflows to inf.
Then log_prob_norm becomes inf and log_resp can become -inf or nan, which destabilizes EM updates.

Minimal reproducer of the exact numerical issue
This is the same normalization pattern used in estimate_log_prob_resp, isolated:

use ndarray::{array, Axis};

fn main() {
    let weighted_log_prob = array![[1000.0_f64, 999.0_f64]];
    let log_prob_norm = weighted_log_prob
        .mapv(|x| x.exp())
        .sum_axis(Axis(1))
        .mapv(|x| x.ln());
    let log_resp = &weighted_log_prob - &log_prob_norm.insert_axis(Axis(1));

    println!("log_prob_norm = {:?}", log_prob_norm); // inf
    println!("log_resp = {:?}", log_resp);           // -inf / nan
}

Expected behavior
Normalization in estimate_log_prob_resp should remain finite and stable for large magnitudes, using row-wise log-sum-exp.

Proposed fix
Replace direct exp-sum-ln with log-sum-exp:

let weighted_log_prob = self.estimate_weighted_log_prob(observations);
let log_max = weighted_log_prob.map_axis(Axis(1), |row| {
    *row.max().unwrap_or(&F::neg_infinity())
});
let shifted = &weighted_log_prob - &log_max.to_owned().insert_axis(Axis(1));
let log_prob_norm =
    shifted.mapv(|x| x.exp()).sum_axis(Axis(1)).mapv(|x| x.ln()) + &log_max;
let log_resp = weighted_log_prob - log_prob_norm.to_owned().insert_axis(Axis(1));

Additional context
I am currently using a local patch with exactly this log-sum-exp change, and it prevents overflow in high-dimensional GMM workloads.

Would maintainers accept a PR with:

  • this fix in estimate_log_prob_resp
  • a regression test that stresses large weighted log-probability values to ensure finite log_prob_norm and stable responsibilities?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions