Skip to content

Commit

Permalink
Numerically stabilize autocorrelation() (#3114)
Browse files Browse the repository at this point in the history
* Allow autocorrelation() to run without mkl

* Numerically stabilize

* Clarify definition, add test

* Remove xfail_if_not_implemented
  • Loading branch information
fritzo committed Jul 9, 2022
1 parent 36d29ee commit 66defe8
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 4 deletions.
10 changes: 8 additions & 2 deletions pyro/ops/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,18 @@ def autocorrelation(input, dim=0):
# inverse Fourier transform
autocorr = irfft(freqvec_gram, n=M2)

# truncate and normalize the result, then transpose back to original shape
# truncate and normalize the result, setting autocorrelation to 1 for all
# constant channels
autocorr = autocorr[..., :N]
autocorr = autocorr / torch.tensor(
range(N, 0, -1), dtype=input.dtype, device=input.device
)
autocorr = autocorr / autocorr[..., :1]
variance = autocorr[..., :1]
constant = (variance == 0).expand_as(autocorr)
autocorr = autocorr / variance.clamp(min=torch.finfo(variance.dtype).tiny)
autocorr[constant] = 1

# transpose back to original shape
return autocorr.transpose(dim, -1)


Expand Down
23 changes: 21 additions & 2 deletions tests/ops/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,34 @@ def test_statistics_A_ok_with_sample_shape(statistics, sample_shape):

def test_autocorrelation():
x = torch.arange(10.0)
with xfail_if_not_implemented():
actual = autocorrelation(x)
actual = autocorrelation(x)
assert_equal(
actual,
torch.tensor([1, 0.78, 0.52, 0.21, -0.13, -0.52, -0.94, -1.4, -1.91, -2.45]),
prec=0.01,
)


def test_autocorrelation_trivial():
x = torch.zeros(10)
actual = autocorrelation(x)
assert_equal(actual, torch.ones(10), prec=0.01)


def test_autocorrelation_vectorized():
# make a mostly noisy x with a couple constant series
x = torch.randn(3, 4, 5)
x[1, 2] = 0
x[2, 3] = 1

actual = autocorrelation(x, dim=-1)
expected = torch.tensor([[autocorrelation(xij).tolist() for xij in xi] for xi in x])
assert_equal(actual, expected)

assert (actual[1, 2] == 1).all()
assert (actual[2, 3] == 1).all()


def test_autocovariance():
x = torch.arange(10.0)
with xfail_if_not_implemented():
Expand Down

0 comments on commit 66defe8

Please sign in to comment.