diff --git a/io/eolearn/io/sentinelhub_process.py b/io/eolearn/io/sentinelhub_process.py index 185928443..97c08847e 100644 --- a/io/eolearn/io/sentinelhub_process.py +++ b/io/eolearn/io/sentinelhub_process.py @@ -38,7 +38,7 @@ from sentinelhub.type_utils import RawTimeIntervalType from eolearn.core import EOPatch, EOTask, FeatureType, FeatureTypeSet -from eolearn.core.utils.parsing import FeatureSpec, FeaturesSpecification +from eolearn.core.utils.parsing import FeatureRenameSpec, FeatureSpec, FeaturesSpecification from eolearn.core.utils.types import Literal LOGGER = logging.getLogger(__name__) @@ -97,26 +97,26 @@ def execute( eopatch = eopatch or EOPatch() - self._check_and_set_eopatch_bbox(bbox, eopatch) + eopatch.bbox = self._extract_bbox(bbox, eopatch) size_x, size_y = self._get_size(eopatch) if time_interval: time_interval = parse_time_interval(time_interval) timestamp = self._get_timestamp(time_interval, eopatch.bbox) + timestamp = [time_point.replace(tzinfo=None) for time_point in timestamp] elif self.data_collection.is_timeless: - timestamp = None + timestamp = None # should this be [] to match next branch in case of a fresh eopatch? else: timestamp = eopatch.timestamp if timestamp is not None: - eop_timestamp = [time_point.replace(tzinfo=None) for time_point in timestamp] - if eopatch.timestamp: - self.check_timestamp_difference(eop_timestamp, eopatch.timestamp) - else: - eopatch.timestamp = eop_timestamp + if not eopatch.timestamp: + eopatch.timestamp = timestamp + elif timestamp != eopatch.timestamp: + raise ValueError("Trying to write data to an existing EOPatch with a different timestamp.") - requests = self._build_requests(eopatch.bbox, size_x, size_y, timestamp, time_interval, geometry) - requests = [request.download_list[0] for request in requests] + sh_requests = self._build_requests(eopatch.bbox, size_x, size_y, timestamp, time_interval, geometry) + requests = [request.download_list[0] for request in sh_requests] LOGGER.debug("Downloading %d requests of type %s", len(requests), str(self.data_collection)) session = None if self.session_loader is None else self.session_loader() @@ -158,45 +158,33 @@ def _add_meta_info(self, eopatch): eopatch.meta_info["time_difference"] = self.time_difference.total_seconds() @staticmethod - def _check_and_set_eopatch_bbox(bbox, eopatch): + def _extract_bbox(bbox: Optional[BBox], eopatch: EOPatch) -> BBox: if eopatch.bbox is None: if bbox is None: raise ValueError("Either the eopatch or the task must provide valid bbox.") - eopatch.bbox = bbox - return + return bbox if bbox is None or eopatch.bbox == bbox: - return + return eopatch.bbox raise ValueError("Either the eopatch or the task must provide bbox, or they must be the same.") - @staticmethod - def check_timestamp_difference(timestamp1, timestamp2): - """Raises an error if the two timestamps are not the same""" - error_msg = "Trying to write data to an existing EOPatch with a different timestamp." - if len(timestamp1) != len(timestamp2): - raise ValueError(error_msg) - - for ts1, ts2 in zip(timestamp1, timestamp2): - if ts1 != ts2: - raise ValueError(error_msg) - def _extract_data(self, eopatch, images, shape): """Extract data from the received images and assign them to eopatch features""" raise NotImplementedError("The _extract_data method should be implemented by the subclass.") def _build_requests( self, - bbox: BBox, + bbox: Optional[BBox], size_x: int, size_y: int, timestamp: Optional[List[dt.datetime]], - time_interval: Optional[Tuple[dt.datetime, dt.datetime]], - geometry: Geometry, + time_interval: Optional[RawTimeIntervalType], + geometry: Optional[Geometry], ) -> List[SentinelHubRequest]: """Build requests""" raise NotImplementedError("The _build_requests method should be implemented by the subclass.") - def _get_timestamp(self, time_interval: Optional[Tuple[dt.datetime, dt.datetime]], bbox: BBox) -> List[dt.datetime]: + def _get_timestamp(self, time_interval: Optional[RawTimeIntervalType], bbox: BBox) -> List[dt.datetime]: """Get the timestamp array needed as a parameter for downloading the images""" raise NotImplementedError("The _get_timestamp method should be implemented by the subclass.") @@ -207,9 +195,9 @@ class SentinelHubEvalscriptTask(SentinelHubInputBaseTask): # pylint: disable=too-many-arguments def __init__( self, - features: Optional[FeaturesSpecification] = None, - evalscript: Optional[str] = None, - data_collection: Optional[DataCollection] = None, + features: FeaturesSpecification, + evalscript: str, + data_collection: DataCollection, size: Optional[Tuple[int, int]] = None, resolution: Optional[Union[float, Tuple[float, float]]] = None, maxcc: Optional[float] = None, @@ -260,9 +248,6 @@ def __init__( self.features = self._parse_and_validate_features(features) self.responses = self._create_response_objects() - - if not evalscript: - raise ValueError("evalscript parameter must not be missing/empty") self.evalscript = evalscript if maxcc and isinstance(maxcc, (int, float)) and (maxcc < 0 or maxcc > 1): @@ -274,10 +259,7 @@ def __init__( self.mosaicking_order = None if mosaicking_order is None else MosaickingOrder(mosaicking_order) self.aux_request_args = aux_request_args - def _parse_and_validate_features(self, features): - if not features: - raise ValueError("features must be defined") - + def _parse_and_validate_features(self, features: FeaturesSpecification) -> List[FeatureRenameSpec]: allowed_features = FeatureTypeSet.RASTER_TYPES.union({FeatureType.META_INFO}) _features = self.parse_renamed_features(features, allowed_feature_types=allowed_features) @@ -301,7 +283,7 @@ def _create_response_objects(self): return responses - def _get_timestamp(self, time_interval, bbox): + def _get_timestamp(self, time_interval: Optional[RawTimeIntervalType], bbox: BBox) -> List[dt.datetime]: """Get the timestamp array needed as a parameter for downloading the images""" if any(feat_type.is_timeless() for feat_type, _, _ in self.features if feat_type.is_raster()): return [] @@ -318,12 +300,12 @@ def _get_timestamp(self, time_interval, bbox): def _build_requests( self, - bbox: BBox, + bbox: Optional[BBox], size_x: int, size_y: int, - timestamp: List[dt.datetime], - time_interval: Optional[Tuple[dt.datetime, dt.datetime]], - geometry: Geometry, + timestamp: Optional[List[dt.datetime]], + time_interval: Optional[RawTimeIntervalType], + geometry: Optional[Geometry], ): """Defines request timestamps and builds requests. In case `timestamp` is either `None` or an empty list it still has to create at least one request in order to obtain back number of bands of responses.""" @@ -399,7 +381,7 @@ class SentinelHubInputTask(SentinelHubInputBaseTask): # pylint: disable=too-many-locals def __init__( self, - data_collection: Optional[DataCollection] = None, + data_collection: DataCollection, size: Optional[Tuple[int, int]] = None, resolution: Optional[Union[float, Tuple[float, float]]] = None, bands_feature: Optional[Tuple[FeatureType, str]] = None, @@ -411,7 +393,7 @@ def __init__( cache_folder: Optional[str] = None, config: Optional[SHConfig] = None, max_threads: Optional[int] = None, - bands_dtype: Union[np.dtype, type] = None, + bands_dtype: Union[None, np.dtype, type] = None, single_scene: bool = False, mosaicking_order: Optional[Union[str, MosaickingOrder]] = None, upsampling: Optional[ResamplingType] = None, @@ -548,7 +530,7 @@ def generate_evalscript(self) -> str: return evalscript - def _get_timestamp(self, time_interval, bbox): + def _get_timestamp(self, time_interval: Optional[RawTimeIntervalType], bbox: BBox) -> List[dt.datetime]: """Get the timestamp array needed as a parameter for downloading the images""" if self.single_scene: return [time_interval[0]] @@ -565,24 +547,24 @@ def _get_timestamp(self, time_interval, bbox): def _build_requests( self, - bbox: BBox, + bbox: Optional[BBox], size_x: int, size_y: int, timestamp: Optional[List[dt.datetime]], - time_interval: Optional[Tuple[dt.datetime, dt.datetime]], - geometry: Geometry, + time_interval: Optional[RawTimeIntervalType], + geometry: Optional[Geometry], ) -> List[SentinelHubRequest]: """Build requests""" if timestamp is None: - dates = [None] + intervals: List[Optional[RawTimeIntervalType]] = [None] elif self.single_scene: - dates = [parse_time_interval(time_interval)] + intervals = [parse_time_interval(time_interval)] else: - dates = [(date - self.time_difference, date + self.time_difference) for date in timestamp] + intervals = [(date - self.time_difference, date + self.time_difference) for date in timestamp] - return [self._create_sh_request(date1, date2, bbox, size_x, size_y, geometry) for date1, date2 in dates] + return [self._create_sh_request(time_interval, bbox, size_x, size_y, geometry) for time_interval in intervals] - def _create_sh_request(self, date_from, date_to, bbox, size_x, size_y, geometry): + def _create_sh_request(self, time_interval, bbox, size_x, size_y, geometry): """Create an instance of SentinelHubRequest""" responses = [ SentinelHubRequest.output_response(band.name, MimeType.TIFF) @@ -594,7 +576,7 @@ def _create_sh_request(self, date_from, date_to, bbox, size_x, size_y, geometry) input_data=[ SentinelHubRequest.input_data( data_collection=self.data_collection, - time_interval=(date_from, date_to), + time_interval=time_interval, mosaicking_order=self.mosaicking_order, maxcc=self.maxcc, upsampling=self.upsampling, @@ -750,7 +732,7 @@ def get_available_timestamps( bbox: BBox, data_collection: DataCollection, *, - time_interval: Optional[Tuple[dt.datetime, dt.datetime]] = None, + time_interval: Optional[RawTimeIntervalType] = None, time_difference: dt.timedelta = dt.timedelta(seconds=-1), # noqa: B008 timestamp_filter: Callable[[List[dt.datetime], dt.timedelta], List[dt.datetime]] = filter_times, maxcc: Optional[float] = None,