Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update _get_weights() method for SpatialAccessor and TemporalAccessor #252

Merged
merged 3 commits into from
Jun 15, 2022

Conversation

tomvothecoder
Copy link
Collaborator

@tomvothecoder tomvothecoder commented Jun 7, 2022

Description

Summary of Changes

Checklist

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • My changes generate no new warnings
  • Any dependent changes have been merged and published in downstream modules

If applicable:

  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass with my changes (locally and CI/CD build)
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • I have noted that this is a breaking change for a major release (fix or feature that would cause existing functionality to not work as expected)

@tomvothecoder tomvothecoder self-assigned this Jun 7, 2022
@tomvothecoder tomvothecoder added this to In progress in v0.3.0 via automation Jun 7, 2022
elif isinstance(weights, xr.DataArray):
dv_weights = weights

self._validate_weights(dv, axis, dv_weights)
dataset[dv.name] = self._averager(dv, axis, dv_weights)
return dataset

def get_weights(
Copy link
Collaborator

@pochedls pochedls Jun 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this is the only line that changed in this file (aside from calls to _get_weights)? Github is highlighting this whole function...probably because it got moved?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this method was moved and renamed. Nothing else was changed in this file.

Comment on lines 306 to 307
# FIXME: ValueError when domain bounds contains lower bound larger than
# upper bound
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should throw an error? I'm kind of confused why it wasn't before based on this check. Should this test have region bounds such that lon_bounds=np.array([350, 20]) or something like that?

I think region bounds should be able to accept a larger right hand bound, but I don't think this should be true for domain bounds.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ended up deleting this test because it is faulty and we already have a test for raising a ValueError if the domain bounds has a larger lower bound vs. upper bound.

Comment on lines -769 to -880
),
"lon": xr.DataArray(
name="lon_wts",
data=np.array([1, 2, 3, 4]),
coords={"lon": self.ds.lon},
dims=["lon"],
),
}

def test_weights_for_single_axis_are_identical(self):
axis_weights = self.axis_weights
del axis_weights["lon"]

result = self.ds.spatial._combine_weights(axis_weights=self.axis_weights)
expected = self.axis_weights["lat"]

assert result.identical(expected)

def test_weights_for_multiple_axis_is_the_product_of_matrix_multiplication(self):
result = self.ds.spatial._combine_weights(axis_weights=self.axis_weights)
expected = xr.DataArray(
data=np.array([[1, 2, 3, 4], [2, 4, 6, 8], [3, 6, 9, 12], [4, 8, 12, 16]]),
coords={"lat": self.ds.lat, "lon": self.ds.lon},
dims=["lat", "lon"],
)

assert result.identical(expected)


class TestAverager:
@pytest.fixture(autouse=True)
def setup(self):
self.ds = generate_dataset(cf_compliant=True, has_bounds=True)

@requires_dask
def test_chunked_weighted_avg_over_lat_and_lon_axes(self):
ds = self.ds.copy().chunk(2)

weights = xr.DataArray(
data=np.array([[1, 2, 3, 4], [2, 4, 6, 8], [3, 6, 9, 12], [4, 8, 12, 16]]),
coords={"lat": ds.lat, "lon": ds.lon},
dims=["lat", "lon"],
)

result = ds.spatial._averager(ds.ts, axis=["X", "Y"], weights=weights)
expected = xr.DataArray(
name="ts", data=np.ones(15), coords={"time": ds.time}, dims=["time"]
)

assert result.identical(expected)

def test_weighted_avg_over_lat_axis(self):
weights = xr.DataArray(
name="lat_wts",
data=np.array([1, 2, 3, 4]),
coords={"lat": self.ds.lat},
dims=["lat"],
)

result = self.ds.spatial._averager(self.ds.ts, axis=["Y"], weights=weights)
expected = xr.DataArray(
name="ts",
data=np.ones((15, 4)),
coords={"time": self.ds.time, "lon": self.ds.lon},
dims=["time", "lon"],
)

assert result.identical(expected)

def test_weighted_avg_over_lon_axis(self):
weights = xr.DataArray(
name="lon_wts",
data=np.array([1, 2, 3, 4]),
coords={"lon": self.ds.lon},
dims=["lon"],
)

result = self.ds.spatial._averager(self.ds.ts, axis=["X"], weights=weights)
expected = xr.DataArray(
name="ts",
data=np.ones((15, 4)),
coords={"time": self.ds.time, "lat": self.ds.lat},
dims=["time", "lat"],
)

assert result.identical(expected)

def test_weighted_avg_over_lat_and_lon_axis(self):
weights = xr.DataArray(
data=np.array([[1, 2, 3, 4], [2, 4, 6, 8], [3, 6, 9, 12], [4, 8, 12, 16]]),
coords={"lat": self.ds.lat, "lon": self.ds.lon},
dims=["lat", "lon"],
)

result = self.ds.spatial._averager(self.ds.ts, axis=["X", "Y"], weights=weights)
expected = xr.DataArray(
name="ts", data=np.ones(15), coords={"time": self.ds.time}, dims=["time"]
)

assert result.identical(expected)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kind of remember you commenting on this...is this all removed because they are private methods (and we don't need to test them)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for TestGetWeights?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, we should test the public method which will cover the private methods (implementation details).

No need to review any of these changes here so I am just porting some tests over from the private methods to the public methods.

Copy link
Collaborator

@pochedls pochedls left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no objection to making get_weights() public, though I didn't fully understand the the changes to the unit tests. Should I view the changes in tests/ as a reorganization or something that I should review more carefully?

I think the error for test_weights_for_region_in_lon_domain_with_both_spanning_p_meridian is correct and we should represent the problematic line as ds.lon_bnds.data[:] = np.array([[-1, 1], [1, 90], [90, 180], [180, 359]]).

@tomvothecoder
Copy link
Collaborator Author

I have no objection to making get_weights() public, though I didn't fully understand the the changes to the unit tests. Should I view the changes in tests/ as a reorganization or something that I should review more carefully?

I addressed this comment in one of the PR review comments.

I think the error for test_weights_for_region_in_lon_domain_with_both_spanning_p_meridian is correct and we should represent the problematic line as ds.lon_bnds.data[:] = np.array([[-1, 1], [1, 90], [90, 180], [180, 359]]).

I ended up removing tests where the domain bounds has a larger left hand value than right hand value because they are supposed to throw an error, rather than pass. There's already a test to check that an error is thrown for this case.

I kept test_weights_for_region_in_lon_domain_with_region_spanning_p_meridian since regions can have a larger lower bound vs. upper bound.

tests/test_spatial.py Show resolved Hide resolved
tests/test_spatial.py Outdated Show resolved Hide resolved
tests/test_spatial.py Outdated Show resolved Hide resolved
tests/test_spatial.py Outdated Show resolved Hide resolved
@codecov-commenter
Copy link

codecov-commenter commented Jun 7, 2022

Codecov Report

Merging #252 (e15f4bb) into main (1fdc8a9) will not change coverage.
The diff coverage is 100.00%.

❗ Current head e15f4bb differs from pull request most recent head 0beb116. Consider uploading reports for the commit 0beb116 to get more accurate results

@@            Coverage Diff            @@
##              main      #252   +/-   ##
=========================================
  Coverage   100.00%   100.00%           
=========================================
  Files            9         8    -1     
  Lines          742       736    -6     
=========================================
- Hits           742       736    -6     
Impacted Files Coverage Δ
xcdat/utils.py 100.00% <ø> (ø)
xcdat/__init__.py 100.00% <100.00%> (ø)
xcdat/axis.py 100.00% <100.00%> (ø)
xcdat/bounds.py 100.00% <100.00%> (ø)
xcdat/dataset.py 100.00% <100.00%> (ø)
xcdat/spatial.py 100.00% <100.00%> (ø)
xcdat/temporal.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update b2be983...0beb116. Read the comment docs.

@tomvothecoder tomvothecoder force-pushed the feature/251-public-weight-methods branch 2 times, most recently from 3c6c5a3 to 2d82d14 Compare June 9, 2022 19:30
@tomvothecoder
Copy link
Collaborator Author

tomvothecoder commented Jun 10, 2022

Hi @pochedls, @lee1043, and @chengzhuzhang, instead of making _get_weights() public for temporal averaging so the user can get weights, I implemented a keep_weights keyword argument to the temporal averaging APIs. This gives users the choice with whether or not they want to couple time_weights with its related output. Let me know what you think of this design choice.

Option Make _get_weights() public Option to keep time_weights in the output xr.Dataset
Pros * User can calculate weights independently of temporal averaging, which makes it faster to get weights instead of running temporal averaging first (do scientists normally calculate weights independently? does CDAT have a method for this?) * Makes sense to couple the weights to its related averaging output
* Implementation is straightforward, just add the time_weights DataArray to the final output Dataset
* User has the option to include the time_weights in the output (default keep_weights=False)
Cons * User needs to provide the same arguments as temporal averaging APIs (data_var, mode, freq, season_config) which makes the method less intuitive to implement and use
* Weights are decoupled from the output, which removes context on how they should be used for averaging
* We don't accept weights in our temporal averaging APIs right now (does it make sense to?)
* For group_average and climatology we must append "_original" to the name of time_weights time dimension/coordinates since datasets share dimensions and coordinates. Otherwise, it will use the reduced time
dimension/coordinates (not correct, will result in np.nan weights)

Example for ds.temporal.climatology():

xcdat/xcdat/temporal.py

Lines 332 to 339 in 44620c0

def climatology(
self,
data_var: str,
freq: Frequency,
weighted: bool = True,
keep_weights: bool = False,
season_config: SeasonConfigInput = DEFAULT_SEASON_CONFIG,
):

xcdat/xcdat/temporal.py

Lines 742 to 745 in 44620c0

if keep_weights:
ds = self._keep_weights(ds)
return ds

Example for ds.temporal.departures()

xcdat/xcdat/temporal.py

Lines 465 to 472 in 44620c0

def departures(
self,
data_var: str,
freq: Frequency,
weighted: bool = True,
keep_weights: bool = False,
season_config: SeasonConfigInput = DEFAULT_SEASON_CONFIG,
) -> xr.Dataset:

xcdat/xcdat/temporal.py

Lines 624 to 629 in 44620c0

if weighted and keep_weights:
self._weights = ds_climo.time_weights
ds_departs = self._keep_weights(ds_departs)
return ds_departs

_keep_weights() method:

xcdat/xcdat/temporal.py

Lines 1369 to 1402 in 44620c0

def _keep_weights(self, ds: xr.Dataset) -> xr.Dataset:
"""Keep the weights in the dataset.
Parameters
----------
ds : xr.Dataset
The dataset.
Returns
-------
xr.Dataset
The dataset with the weights used for averaging.
"""
# Append "_original" to the name of the weights` time coordinates to
# avoid conflict with the grouped time coordinates in the Dataset (can
# have a different shape).
if self._mode in ["group_average", "climatology"]:
self._weights = self._weights.rename(
{self._dim_name: f"{self._dim_name}_original"}
)
# Only keep the original time coordinates, not the ones labeled
# by group.
self._weights = self._weights.drop_vars(self._labeled_time.name)
# Strip "_original" from the name of the weights` time coordinates
# because the final departures Dataset has the original time coordinates
# restored after performing grouped subtraction.
elif self._mode == "departures":
self._weights = self._weights.rename(
{f"{self._dim_name}_original": self._dim_name}
)
ds[self._weights.name] = self._weights
return ds

If this looks good, we can implement keep_weights with the spatial averaging API as well if it makes sense to.

@chengzhuzhang
Copy link
Collaborator

chengzhuzhang commented Jun 10, 2022

Hi @tomvothecoder Thank you for the discussion. I don't recall that I have a use case that requires to get/keep the weights to pass down to another operation. With the exception that, they are useful to be examed during xcdat development/validation. So I don't have strong opinion on the implementation. I will defer to Steve and Jiwoo.

@pochedls
Copy link
Collaborator

I like keep_weights. I didn't totally understand the downside for keep_weights: why can't we just overwrite the original time_weights? Are both the full time series and the climatology time series saved in the dataset when calculating a climatology?

I think it would be good to keep .get_weights in spatial averaging, though I don't object to adding a keep_weights in that routine. One thing that might confuse people is that the weights we generate for spatial averaging are 2D (we never explicitly generate a time resolved weighting matrix).

@tomvothecoder
Copy link
Collaborator Author

I like keep_weights. I didn't totally understand the downside for keep_weights: why can't we just overwrite the original time_weights?

There is only one time_wts generated using the time_bounds, which is saved in the Dataset when keep_weights=True. I'm not sure what you mean by overwriting the original time_wts.

Keeping weights with Climatology

  • time_wts uses the original time coordinates, so it CANNOT share the reduced time dimension and coordinates from the climatology data variable.
    • "time_original" is the original time dimension (1919 coord pts)
    • "time" is the reduced time dimension (4 coord pts)
ds.temporal.climatology("tas", freq="season", weighted=True, keep_weights=True)

<xarray.Dataset>
Dimensions:        (lon: 320, nv: 2, lat: 160, bnds: 2, time: 4,
                    time_original: 1919)
Coordinates:
  * lon            (lon) float64 0.0 1.125 2.25 3.375 ... 356.6 357.8 358.9
  * lat            (lat) float64 -89.14 -88.03 -86.91 ... 86.91 88.03 89.14
  * time           (time) object 0001-01-01 00:00:00 ... 0001-10-01 00:00:00
  * time_original  (time_original) datetime64[ns] 1850-01-16T12:00:00 ... 200...
Dimensions without coordinates: nv, bnds
Data variables:
    lon_bounds     (lon, nv) float64 dask.array<chunksize=(320, 2), meta=np.ndarray>
    lat_bounds     (lat, nv) float64 dask.array<chunksize=(160, 2), meta=np.ndarray>
    lon_bnds       (lon, bnds) float64 -0.5625 0.5625 0.5625 ... 358.3 359.4
    lat_bnds       (lat, bnds) float64 -89.7 -88.59 -88.59 ... 88.59 88.59 89.7
    tas            (time, lat, lon) float64 dask.array<chunksize=(1, 160, 320), meta=np.ndarray>
    time_wts       (time_original) float64 0.002038 0.002038 ... 0.002083

Keeping weights with Departures

  • time_wts uses the original "time" dimension and coordinates, which are maintained from the departures data variable so they can be shared.
  • Notice there is no "time_original" dim or coords
ds.temporal.departures("tas", freq="season", weighted=True, keep_weights=True)

<xarray.Dataset>
Dimensions:     (lon: 320, nv: 2, lat: 160, time: 1919, bnds: 2)
Coordinates:
  * lon         (lon) float64 0.0 1.125 2.25 3.375 ... 355.5 356.6 357.8 358.9
  * lat         (lat) float64 -89.14 -88.03 -86.91 -85.79 ... 86.91 88.03 89.14
  * time        (time) datetime64[ns] 1850-01-16T12:00:00 ... 2009-11-16
Dimensions without coordinates: nv, bnds
Data variables:
    lon_bounds  (lon, nv) float64 dask.array<chunksize=(320, 2), meta=np.ndarray>
    lat_bounds  (lat, nv) float64 dask.array<chunksize=(160, 2), meta=np.ndarray>
    time_bnds   (time, bnds) datetime64[ns] 1850-01-01T18:00:00 ... 2009-12-0...
    tas         (time, lat, lon) float64 dask.array<chunksize=(2, 160, 320), meta=np.ndarray>
    lon_bnds    (lon, bnds) float64 -0.5625 0.5625 0.5625 ... 358.3 358.3 359.4
    lat_bnds    (lat, bnds) float64 -89.7 -88.59 -88.59 ... 88.59 88.59 89.7
    time_wts    (time) float64 0.002038 0.002038 0.002059 ... 0.002083 0.002083

Are both the full time series and the climatology time series saved in the dataset when calculating a climatology?
Only the climatology time series is stored in the climatology output.

I think it would be good to keep .get_weights in spatial averaging, though I don't object to adding a keep_weights in that routine.

Sounds good, I added keep_weights as a keyword argument to ds.spatial.average().

One thing that might confuse people is that the weights we generate for spatial averaging are 2D (we never explicitly generate a time resolved weighting matrix).

Hmmm, we can address this if users open up an issue about it, unless you think it is worth addressing now.

@pochedls
Copy link
Collaborator

Got it – thanks Tom.

- Add kwarg `keep_weights` to both methods
- Make `_get_weights()` a public method for `SpatialAccessor`
- Delete redundant tests for private methods
- Delete faulty tests with domain bounds:
@tomvothecoder tomvothecoder force-pushed the feature/251-public-weight-methods branch from 225df77 to 634dca4 Compare June 15, 2022 15:53
@tomvothecoder tomvothecoder changed the title Make SpatialAccessor._get_weights() and TemporalAccessor._get_weights() public Update _get_weights() method for SpatialAccessor and TemporalAccessor Jun 15, 2022
xcdat/temporal.py Outdated Show resolved Hide resolved
@tomvothecoder tomvothecoder merged commit 350ac8b into main Jun 15, 2022
v0.3.0 automation moved this from In progress to Done Jun 15, 2022
@tomvothecoder tomvothecoder deleted the feature/251-public-weight-methods branch June 15, 2022 16:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type: enhancement New enhancement request
Projects
No open projects
Development

Successfully merging this pull request may close these issues.

[Feature]: Update _get_weights() method for SpatialAccessor and TemporalAccessor
4 participants