Skip to content

fix: use float64 accumulators in PearsonCorrelation to prevent catastrophic cancellation#3740

Open
tejasae-afk wants to merge 1 commit intopytorch:masterfrom
tejasae-afk:fix/pearson-correlation-float64-accumulators
Open

fix: use float64 accumulators in PearsonCorrelation to prevent catastrophic cancellation#3740
tejasae-afk wants to merge 1 commit intopytorch:masterfrom
tejasae-afk:fix/pearson-correlation-float64-accumulators

Conversation

@tejasae-afk
Copy link
Copy Markdown

Fixes #3662.

The naive E[X²] - (E[X])² formula for variance is mathematically correct but numerically unstable. When values have large magnitude relative to their variance — for example an offset of 1e8 with small inter-sample differences — both E[X²] and (E[X])² are around 1e16 while their difference (the actual variance) is O(1). In float32, the unit in the last place at that scale is roughly 10^9, so the variance is completely lost and the metric returns 0.0.

The fix is to accumulate in float64 on devices that support it. MPS does not support float64, so it falls back to float32 and retains the previous behaviour. The public return type (Python float) is unchanged.

# before (float32 accumulators)
offset = 1e8
y_pred = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.float64) + offset
y      = torch.tensor([2.0, 4.0, 6.0, 8.0, 10.0], dtype=torch.float64) + offset
m = PearsonCorrelation()
m.update((y_pred, y))
m.compute()  # 0.0  (wrong; expected 1.0)

# after (float64 accumulators)
m.compute()  # 1.0  ✓

A new test test_numerical_stability_large_offset covers this case. All existing non-distributed tests pass.

you know the codebase far better than I do — happy to adjust if a different approach (e.g., a shared Welford-based utility across PearsonCorrelation, R2Score, and FID as discussed in the issue) is preferred.

…rophic cancellation

The naive E[X²] - (E[X])² formula loses all precision when values have
large magnitude relative to their variance: both terms are ~μ² ≈ 1e16
while their difference (the variance) is ~O(1), which falls below
float32's unit in the last place at that scale.

Switch all five accumulators and the incoming batches to float64 on
non-MPS devices.  MPS does not support float64 and keeps the previous
float32 behaviour.  The final result is still returned as a Python
float, so the public API is unchanged.

Fixes pytorch#3662

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: metrics Metrics module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] Numerical instability in PearsonCorrelation due to naive variance formula

1 participant