Skip to content

Commit

Permalink
add small epsilon to eigenvalues to prevent nans
Browse files Browse the repository at this point in the history
  • Loading branch information
9q9q committed Dec 19, 2023
1 parent ef5eb03 commit 4412aca
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
8 changes: 6 additions & 2 deletions sparsecoding/data/transforms/whiten.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,17 @@ class Whitener(object):
Eigenvalues of `self.covariance`.
eigenvectors : Tensor, shape [D, D]
Eigenvectors of `self.covariance`.
epsilon : float
Prevents division by zero.
"""

def __init__(
self,
data,
epsilon=1e-10,
):
self.D = data.shape[1]
self.epsilon = epsilon

with torch.no_grad():
data = data.T # [D, N]
Expand Down Expand Up @@ -83,7 +87,7 @@ def whiten(

whitened_data = (
self.eigenvectors
@ torch.diag(1. / torch.sqrt(self.eigenvalues))
@ torch.diag(1. / torch.sqrt(self.eigenvalues + self.epsilon))
@ self.eigenvectors.T
@ centered_data
) # [D, N]
Expand Down Expand Up @@ -124,7 +128,7 @@ def unwhiten(

unwhitened_data = (
self.eigenvectors
@ torch.diag(torch.sqrt(self.eigenvalues))
@ torch.diag(torch.sqrt(self.eigenvalues + self.epsilon))
@ self.eigenvectors.T
@ whitened_data
) + self.mean.reshape(self.D, 1) # [D, N]
Expand Down
7 changes: 7 additions & 0 deletions tests/data/transforms/test_whiten.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ def test_whitener(self):
atol=1e-3
), "Unwhitened data should be equal to input data."

def test_zero_div(self):
data = torch.Tensor([[1, 0], [2, 0]])
whitener = Whitener(data)

assert not torch.any(torch.isnan(
whitener.whiten(data),
)), "If an eigenvalue is 0, should not get NaNs."

if __name__ == "__main__":
unittest.main()

0 comments on commit 4412aca

Please sign in to comment.