Skip to content

Commit

Permalink
Nanfiltering in TDigestTask (#667)
Browse files Browse the repository at this point in the history
  • Loading branch information
meengel committed May 24, 2023
1 parent 96b4e7a commit 1ef8d73
Showing 1 changed file with 27 additions and 18 deletions.
45 changes: 27 additions & 18 deletions ml_tools/eolearn/ml_tools/tdigest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""
The module provides an EOTask for the computation of a T-Digest representation of an EOPatch.
Requires installation of `eolearn.ml_tools[TDIGEST]`.
Copyright (c) 2017- Sinergise and contributors
For the full list of contributors, see the CREDITS file in the root directory of this source tree.
Expand All @@ -9,7 +8,7 @@
"""
from functools import partial
from itertools import product
from typing import Any, Callable, Dict, Generator, Iterable, List, Literal, Tuple
from typing import Any, Callable, Dict, Generator, Iterable, List, Literal, Tuple, Union

import numpy as np
import tdigest as td
Expand All @@ -34,8 +33,9 @@ def __init__(
self,
in_feature: FeaturesSpecification,
out_feature: FeaturesSpecification,
mode: Literal["standard", "timewise", "monthly", "total"] = "standard",
mode: Union[Literal["standard", "timewise", "monthly", "total"], Callable] = "standard",
pixelwise: bool = False,
filternan: bool = False,
):
"""
:param in_feature: The input feature to compute the T-Digest representation for.
Expand All @@ -46,14 +46,19 @@ def __init__(
* `'monthly'` computes the T-Digest representation for each band accumulating the timestamps per month.
* | `'total'` computes the total T-Digest representation of the whole feature accumulating all timestamps,
| bands and pixels. Cannot be used with `pixelwise=True`.
* | Callable computes the T-Digest representation defined by the processing function given as mode. Receives
| the input_array of the feature, the timestamps, the shape and the pixelwise and filternan keywords as an input.
:param pixelwise: Decider whether to compute the T-Digest representation accumulating pixels or per pixel.
Cannot be used with `mode='total'`.
:param filternan: Decider whether to filter out nan-values before computing the T-Digest.
"""

self.mode = mode

self.pixelwise = pixelwise

self.filternan = filternan

if self.pixelwise and self.mode == "total":
raise ValueError("Total mode does not support pixelwise=True.")

Expand All @@ -78,8 +83,8 @@ def execute(self, eopatch: EOPatch) -> EOPatch:
for in_feature_, out_feature_, shape in _looper(
in_feature=self.in_feature, out_feature=self.out_feature, eopatch=eopatch
):
eopatch[out_feature_] = _processing_function[self.mode](
input_array=eopatch[in_feature_], timestamps=eopatch.timestamps, shape=shape, pixelwise=self.pixelwise
eopatch[out_feature_] = _processing_function.get(self.mode, self.mode)(
input_array=eopatch[in_feature_], timestamps=eopatch.timestamps, shape=shape, pixelwise=self.pixelwise, filternan=self.filternan
)

return eopatch
Expand All @@ -95,6 +100,9 @@ def _is_input_ftype(feature_type: FeatureType, mode: ModeTypes) -> bool:


def _is_output_ftype(feature_type: FeatureType, mode: ModeTypes, pixelwise: bool) -> bool:
if callable(mode):
return True

if mode == "standard":
return feature_type == (FeatureType.DATA_TIMELESS if pixelwise else FeatureType.SCALAR_TIMELESS)

Expand All @@ -112,36 +120,36 @@ def _looper(
yield in_feature_, out_feature_, shape


def _process_standard(input_array: np.ndarray, shape: np.ndarray, pixelwise: bool, **_: Any) -> np.ndarray:
def _process_standard(input_array: np.ndarray, shape: np.ndarray, pixelwise: bool, filternan: bool, **_: Any) -> np.ndarray:
if pixelwise:
array = np.empty(shape[-3:], dtype=object)
for i, j, k in product(range(shape[-3]), range(shape[-2]), range(shape[-1])):
array[i, j, k] = _get_tdigest(input_array[..., i, j, k])
array[i, j, k] = _get_tdigest(input_array[..., i, j, k], filternan)

else:
array = np.empty(shape[-1], dtype=object)
for k in range(shape[-1]):
array[k] = _get_tdigest(input_array[..., k])
array[k] = _get_tdigest(input_array[..., k], filternan)

return array


def _process_timewise(input_array: np.ndarray, shape: np.ndarray, pixelwise: bool, **_: Any) -> np.ndarray:
def _process_timewise(input_array: np.ndarray, shape: np.ndarray, pixelwise: bool, filternan: bool, **_: Any) -> np.ndarray:
if pixelwise:
array = np.empty(shape, dtype=object)
for time_, i, j, k in product(range(shape[0]), range(shape[1]), range(shape[2]), range(shape[3])):
array[time_, i, j, k] = _get_tdigest(input_array[time_, i, j, k])
array[time_, i, j, k] = _get_tdigest(input_array[time_, i, j, k], filternan)

else:
array = np.empty(shape[[0, -1]], dtype=object)
for time_, k in product(range(shape[0]), range(shape[-1])):
array[time_, k] = _get_tdigest(input_array[time_, ..., k])
array[time_, k] = _get_tdigest(input_array[time_, ..., k], filternan)

return array


def _process_monthly(
input_array: np.ndarray, timestamps: Iterable, shape: np.ndarray, pixelwise: bool, **_: Any
input_array: np.ndarray, timestamps: Iterable, shape: np.ndarray, pixelwise: bool, filternan: bool, **_: Any
) -> np.ndarray:
midx = []
for month_ in range(12):
Expand All @@ -150,18 +158,18 @@ def _process_monthly(
if pixelwise:
array = np.empty([12, *shape[1:]], dtype=object)
for month_, i, j, k in product(range(12), range(shape[1]), range(shape[2]), range(shape[3])):
array[month_, i, j, k] = _get_tdigest(input_array[midx[month_], i, j, k])
array[month_, i, j, k] = _get_tdigest(input_array[midx[month_], i, j, k], filternan)

else:
array = np.empty([12, shape[-1]], dtype=object)
for month_, k in product(range(12), range(shape[-1])):
array[month_, k] = _get_tdigest(input_array[midx[month_], ..., k])
array[month_, k] = _get_tdigest(input_array[midx[month_], ..., k], filternan)

return array


def _process_total(input_array: np.ndarray, **_: Any) -> np.ndarray:
return _get_tdigest(input_array)
def _process_total(input_array: np.ndarray, filternan: bool, **_: Any) -> np.ndarray:
return _get_tdigest(input_array, filternan)


_processing_function: Dict[str, Callable] = {
Expand All @@ -172,7 +180,8 @@ def _process_total(input_array: np.ndarray, **_: Any) -> np.ndarray:
}


def _get_tdigest(values: np.ndarray) -> td.TDigest:
def _get_tdigest(values: np.ndarray, filternan: bool) -> td.TDigest:
result = td.TDigest()
result.batch_update(values.flatten())
values_ = values.flatten()
result.batch_update(values_[~np.isnan(values_)] if filternan else values_)
return result

0 comments on commit 1ef8d73

Please sign in to comment.