Skip to content

Commit

Permalink
fix: inverse_transform fails with single mode and normalized pcs (#151)
Browse files Browse the repository at this point in the history
  • Loading branch information
nicrie committed Feb 4, 2024
1 parent b5d04cc commit 9ea7547
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 20 deletions.
42 changes: 29 additions & 13 deletions tests/models/test_eof.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,34 +444,50 @@ def test_transform_nan_feature(dim, mock_data_array):


@pytest.mark.parametrize(
"dim",
"dim, normalized",
[
(("time",)),
(("lat", "lon")),
(("lon", "lat")),
(("time",), True),
(("lat", "lon"), True),
(("lon", "lat"), True),
(("time",), False),
(("lat", "lon"), False),
(("lon", "lat"), False),
],
)
def test_inverse_transform(dim, mock_data_array):
def test_inverse_transform(dim, mock_data_array, normalized):
"""Test inverse_transform method in EOF class."""

# instantiate the EOF class with necessary parameters
eof = EOF(n_modes=3, standardize=True)
eof = EOF(n_modes=20, standardize=True)

# fit the EOF model
eof.fit(mock_data_array, dim=dim)
scores = eof.scores(normalized=normalized)

# Test with single mode
scores = eof.data["scores"].sel(mode=1)
reconstructed_data = eof.inverse_transform(scores)
assert isinstance(reconstructed_data, xr.DataArray)
scores_selection = scores.sel(mode=1)
X_rec_1 = eof.inverse_transform(scores_selection)
assert isinstance(X_rec_1, xr.DataArray)

# Test with single mode as list
scores_selection = scores.sel(mode=[1])
X_rec_1_list = eof.inverse_transform(scores_selection)
assert isinstance(X_rec_1_list, xr.DataArray)

# Single mode and list should be equal
xr.testing.assert_allclose(X_rec_1, X_rec_1_list)

# Test with all modes
scores = eof.data["scores"]
reconstructed_data = eof.inverse_transform(scores)
assert isinstance(reconstructed_data, xr.DataArray)
X_rec = eof.inverse_transform(scores, normalized=normalized)
assert isinstance(X_rec, xr.DataArray)

# Check that the reconstructed data has the same dimensions as the original data
assert set(reconstructed_data.dims) == set(mock_data_array.dims)
assert set(X_rec.dims) == set(mock_data_array.dims)

# Reconstructed data should be close to the original data
orig_dim_order = mock_data_array.dims
X_rec = X_rec.transpose(*orig_dim_order)
xr.testing.assert_allclose(mock_data_array, X_rec)


@pytest.mark.parametrize(
Expand Down
3 changes: 2 additions & 1 deletion tests/models/test_opa.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,10 @@ def test_inverse_transform(dim, mock_data_array, opa_model):

# fit the EOF model
opa_model.fit(mock_data_array, dim=dim)
scores = opa_model.scores()

with pytest.raises(NotImplementedError):
opa_model.inverse_transform(1)
opa_model.inverse_transform(scores)

# # Test with scalar
# mode = 1
Expand Down
13 changes: 10 additions & 3 deletions xeofs/models/_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,13 +283,13 @@ def fit_transform(
return self.fit(data, dim, weights).transform(data, **kwargs)

def inverse_transform(
self, scores: DataObject, normalized: bool = True
self, scores: DataArray, normalized: bool = True
) -> DataObject:
"""Reconstruct the original data from transformed data.
Parameters
----------
scores: DataObject
scores: DataArray
Transformed data to be reconstructed. This could be a subset
of the `scores` data of a fitted model, or unseen data. Must
have a 'mode' dimension.
Expand All @@ -303,8 +303,15 @@ def inverse_transform(
"""
if normalized:
scores = scores * self.data["norms"]
norms = self.data["norms"].sel(mode=scores.mode)
scores = scores * norms
data_reconstructed = self._inverse_transform_algorithm(scores)

# Reconstructing the data using a single mode introduces a
# redundant "mode" coordinate
if "mode" in data_reconstructed.coords:
data_reconstructed = data_reconstructed.drop_vars("mode")

return self.preprocessor.inverse_transform_data(data_reconstructed)

@abstractmethod
Expand Down
6 changes: 3 additions & 3 deletions xeofs/models/eof.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,12 @@ def _transform_algorithm(self, data: DataObject) -> DataArray:

return projections

def _inverse_transform_algorithm(self, scores: DataObject) -> DataArray:
def _inverse_transform_algorithm(self, scores: DataArray) -> DataArray:
"""Reconstruct the original data from transformed data.
Parameters
----------
scores: DataObject
scores: DataArray
Transformed data to be reconstructed. This could be a subset
of the `scores` data of a fitted model, or unseen data. Must
have a 'mode' dimension.
Expand All @@ -147,7 +147,7 @@ def _inverse_transform_algorithm(self, scores: DataObject) -> DataArray:
# Reconstruct the data
comps = self.data["components"].sel(mode=scores.mode)

reconstructed_data = xr.dot(comps.conj(), scores)
reconstructed_data = xr.dot(comps.conj(), scores, dims="mode")
reconstructed_data.name = "reconstructed_data"

# Enforce real output
Expand Down

0 comments on commit 9ea7547

Please sign in to comment.