Skip to content

Commit

Permalink
add types to rest of easily typeable content (#485)
Browse files Browse the repository at this point in the history
  • Loading branch information
zigaLuksic committed Oct 4, 2022
1 parent 318e4d7 commit 2548e4e
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 50 deletions.
12 changes: 4 additions & 8 deletions core/eolearn/core/core_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
file in the root directory of this source tree.
"""
import copy
from abc import ABCMeta, abstractmethod
from abc import ABCMeta
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, cast

import fs
Expand Down Expand Up @@ -40,18 +40,18 @@ def __init__(self, features: FeaturesSpecification = ...):
"""
self.features = features

def execute(self, eopatch):
def execute(self, eopatch: EOPatch) -> EOPatch:
return eopatch.copy(features=self.features)


class DeepCopyTask(CopyTask):
"""Makes a deep copy of the given EOPatch."""

def execute(self, eopatch):
def execute(self, eopatch: EOPatch) -> EOPatch:
return eopatch.copy(features=self.features, deep=True)


class IOTask(EOTask, metaclass=ABCMeta):
class IOTask(EOTask, metaclass=ABCMeta): # noqa B024
"""An abstract Input/Output task that can handle a path and a filesystem object."""

def __init__(
Expand Down Expand Up @@ -82,10 +82,6 @@ def filesystem(self) -> FS:

return unpickle_fs(self._pickled_filesystem)

@abstractmethod
def execute(self, *eopatches, **kwargs):
"""Implement execute function"""


class SaveTask(IOTask):
"""Saves the given EOPatch to a filesystem."""
Expand Down
100 changes: 60 additions & 40 deletions core/eolearn/core/eodata_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,18 @@
"""
from __future__ import annotations

import datetime as dt
import functools
import itertools as it
import warnings
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Dict, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, cast

import numpy as np
import pandas as pd
from geopandas import GeoDataFrame

from sentinelhub import BBox

from .constants import FeatureType
from .exceptions import EORuntimeWarning
from .utils.parsing import FeatureParser, FeatureSpec, FeaturesSpecification
Expand All @@ -29,12 +31,14 @@
if TYPE_CHECKING:
from .eodata import EOPatch

OperationInputType = Union[Literal[None, "concatenate", "min", "max", "mean", "median"], Callable]


def merge_eopatches(
*eopatches: EOPatch,
features: FeaturesSpecification = ...,
time_dependent_op: Union[Literal[None, "concatenate", "min", "max", "mean", "median"], Callable] = None,
timeless_op: Union[Literal[None, "concatenate", "min", "max", "mean", "median"], Callable] = None,
time_dependent_op: OperationInputType = None,
timeless_op: OperationInputType = None,
) -> Dict[FeatureSpec, Any]:
"""Merge features of given EOPatches into a new EOPatch.
Expand Down Expand Up @@ -69,7 +73,7 @@ def merge_eopatches(

feature_parser = FeatureParser(features)
all_features = {feature for eopatch in eopatches for feature in feature_parser.get_features(eopatch)}
eopatch_content = {}
eopatch_content: Dict[FeatureSpec, object] = {}

timestamps, order_mask_per_eopatch = _merge_timestamps(eopatches, reduce_timestamps)
optimize_raster_temporal = _check_if_optimize(eopatches, time_dependent_op)
Expand All @@ -92,6 +96,7 @@ def merge_eopatches(
eopatch_content[feature] = timestamps

if feature_type is FeatureType.META_INFO:
feature_name = cast(str, feature_name) # parser makes sure of it
eopatch_content[feature] = _select_meta_info_feature(eopatches, feature_name)

if feature_type is FeatureType.BBOX:
Expand All @@ -100,34 +105,36 @@ def merge_eopatches(
return eopatch_content


def _parse_operation(operation_input, is_timeless):
def _parse_operation(operation_input: OperationInputType, is_timeless: bool) -> Callable:
"""Transforms operation's instruction (i.e. an input string) into a function that can be applied to a list of
arrays. If the input already is a function it returns it.
"""
if isinstance(operation_input, Callable):
return operation_input

try:
return {
None: _return_if_equal_operation,
"concatenate": functools.partial(np.concatenate, axis=-1 if is_timeless else 0),
"mean": functools.partial(np.nanmean, axis=0),
"median": functools.partial(np.nanmedian, axis=0),
"min": functools.partial(np.nanmin, axis=0),
"max": functools.partial(np.nanmax, axis=0),
}[operation_input]
except KeyError as exception:
raise ValueError(f"Merge operation {operation_input} is not supported") from exception


def _return_if_equal_operation(arrays):
defaults: Dict[Optional[str], Callable] = {
None: _return_if_equal_operation,
"concatenate": functools.partial(np.concatenate, axis=-1 if is_timeless else 0),
"mean": functools.partial(np.nanmean, axis=0),
"median": functools.partial(np.nanmedian, axis=0),
"min": functools.partial(np.nanmin, axis=0),
"max": functools.partial(np.nanmax, axis=0),
}
if operation_input in defaults:
return defaults[operation_input] # type: ignore[index]

if isinstance(operation_input, Callable): # type: ignore[arg-type] #mypy 0.981 has issues with callable
return cast(Callable, operation_input)
raise ValueError(f"Merge operation {operation_input} is not supported")


def _return_if_equal_operation(arrays: np.ndarray) -> bool:
"""Checks if arrays are all equal and returns first one of them. If they are not equal it raises an error."""
if _all_equal(arrays):
return arrays[0]
raise ValueError("Cannot merge given arrays because their values are not the same.")


def _merge_timestamps(eopatches, reduce_timestamps):
def _merge_timestamps(
eopatches: Sequence[EOPatch], reduce_timestamps: bool
) -> Tuple[List[dt.datetime], List[np.ndarray]]:
"""Merges together timestamps from EOPatches. It also prepares a list of masks, one for each EOPatch, how
timestamps should be ordered and joined together.
"""
Expand All @@ -138,11 +145,11 @@ def _merge_timestamps(eopatches, reduce_timestamps):
return [], [np.array([], dtype=np.int32) for _ in range(len(eopatches))]

if reduce_timestamps:
all_timestamps, order_mask = np.unique(all_timestamps, return_inverse=True)
all_timestamps = all_timestamps.tolist()
unique_timestamps, order_mask = np.unique(all_timestamps, return_inverse=True) # type: ignore[call-overload]
ordered_timestamps = unique_timestamps.tolist()
else:
order_mask = np.argsort(all_timestamps)
all_timestamps = sorted(all_timestamps)
order_mask = np.argsort(all_timestamps) # type: ignore[arg-type]
ordered_timestamps = sorted(all_timestamps)

order_mask = order_mask.tolist()

Expand All @@ -152,18 +159,24 @@ def _merge_timestamps(eopatches, reduce_timestamps):
for eopatch_timestamps in timestamps_per_eopatch
]

return all_timestamps, order_mask_per_eopatch
return ordered_timestamps, order_mask_per_eopatch


def _check_if_optimize(eopatches, operation_input):
def _check_if_optimize(eopatches: Sequence[EOPatch], operation_input: OperationInputType) -> bool:
"""Checks whether optimisation of `_merge_time_dependent_raster_feature` is possible"""
if operation_input not in [None, "mean", "median", "min", "max"]:
return False
timestamp_list = [eopatch.timestamp for eopatch in eopatches]
return _all_equal(timestamp_list)


def _merge_time_dependent_raster_feature(eopatches, feature, operation, order_mask_per_eopatch, optimize):
def _merge_time_dependent_raster_feature(
eopatches: Sequence[EOPatch],
feature: FeatureSpec,
operation: Callable,
order_mask_per_eopatch: Sequence[np.ndarray],
optimize: bool,
) -> np.ndarray:
"""Merges numpy arrays of a time-dependent raster feature with a given operation and masks on how to order and join
time raster's time slices.
"""
Expand Down Expand Up @@ -200,7 +213,12 @@ def _merge_time_dependent_raster_feature(eopatches, feature, operation, order_ma
return np.array(split_arrays)


def _extract_and_join_time_dependent_feature_values(eopatches, feature, order_mask_per_eopatch, optimize):
def _extract_and_join_time_dependent_feature_values(
eopatches: Sequence[EOPatch],
feature: FeatureSpec,
order_mask_per_eopatch: Sequence[np.ndarray],
optimize: bool,
) -> Tuple[np.ndarray, np.ndarray]:
"""Collects feature arrays from EOPatches that have them and joins them together. It also joins together
corresponding order masks.
"""
Expand All @@ -225,12 +243,14 @@ def _extract_and_join_time_dependent_feature_values(eopatches, feature, order_ma
return np.concatenate(arrays, axis=0), np.concatenate(order_masks)


def _is_strictly_increasing(array):
def _is_strictly_increasing(array: np.ndarray) -> bool:
"""Checks if a 1D array of values is strictly increasing."""
return (np.diff(array) > 0).all()
return (np.diff(array) > 0).all().astype(bool)


def _merge_timeless_raster_feature(eopatches, feature, operation):
def _merge_timeless_raster_feature(
eopatches: Sequence[EOPatch], feature: FeatureSpec, operation: Callable
) -> np.ndarray:
"""Merges numpy arrays of a timeless raster feature with a given operation."""
arrays = _extract_feature_values(eopatches, feature)

Expand All @@ -246,7 +266,7 @@ def _merge_timeless_raster_feature(eopatches, feature, operation):
) from exception


def _merge_vector_feature(eopatches, feature):
def _merge_vector_feature(eopatches: Sequence[EOPatch], feature: FeatureSpec) -> GeoDataFrame:
"""Merges GeoDataFrames of a vector feature."""
dataframes = _extract_feature_values(eopatches, feature)

Expand All @@ -266,7 +286,7 @@ def _merge_vector_feature(eopatches, feature):
return merged_dataframe


def _select_meta_info_feature(eopatches, feature_name):
def _select_meta_info_feature(eopatches: Sequence[EOPatch], feature_name: str) -> Any:
"""Selects a value for a meta info feature of a merged EOPatch. By default, the value is the first one."""
values = _extract_feature_values(eopatches, (FeatureType.META_INFO, feature_name))

Expand All @@ -280,7 +300,7 @@ def _select_meta_info_feature(eopatches, feature_name):
return values[0]


def _get_common_bbox(eopatches):
def _get_common_bbox(eopatches: Sequence[EOPatch]) -> Optional[BBox]:
"""Makes sure that all EOPatches, which define a bounding box and CRS, define the same ones."""
bboxes = [eopatch.bbox for eopatch in eopatches if eopatch.bbox is not None]

Expand All @@ -292,13 +312,13 @@ def _get_common_bbox(eopatches):
raise ValueError("Cannot merge EOPatches because they are defined for different bounding boxes.")


def _extract_feature_values(eopatches, feature):
def _extract_feature_values(eopatches: Sequence[EOPatch], feature: FeatureSpec) -> List[Any]:
"""A helper function that extracts a feature values from those EOPatches where a feature exists."""
feature_type, feature_name = feature
return [eopatch[feature] for eopatch in eopatches if feature_name in eopatch[feature_type]]


def _all_equal(values):
def _all_equal(values: Union[Sequence[Any], np.ndarray]) -> bool:
"""A helper function that checks if all values in a given list are equal to each other."""
first_value = values[0]

Expand Down
11 changes: 9 additions & 2 deletions core/eolearn/core/eotask.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,14 @@
from typing import Any, Dict, Iterable, Type, TypeVar, Union

from .constants import FeatureType
from .utils.parsing import FeatureParser, parse_feature, parse_features, parse_renamed_feature, parse_renamed_features
from .utils.parsing import (
FeatureParser,
FeaturesSpecification,
parse_feature,
parse_features,
parse_renamed_feature,
parse_renamed_features,
)
from .utils.types import EllipsisType

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -70,7 +77,7 @@ def execute(self, *eopatches, **kwargs):

@staticmethod
def get_feature_parser(
features, allowed_feature_types: Union[Iterable[FeatureType], EllipsisType] = ...
features: FeaturesSpecification, allowed_feature_types: Union[Iterable[FeatureType], EllipsisType] = ...
) -> FeatureParser:
"""See :class:`FeatureParser<eolearn.core.utilities.FeatureParser>`."""
return FeatureParser(features, allowed_feature_types=allowed_feature_types)
Expand Down

0 comments on commit 2548e4e

Please sign in to comment.