Skip to content

Commit

Permalink
Add missing type annotations (#546)
Browse files Browse the repository at this point in the history
add type annotations to snow_mask
add type annotations to geometry_io
  • Loading branch information
jgersak committed Jan 20, 2023
1 parent dfada66 commit c75c93a
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
6 changes: 3 additions & 3 deletions io/eolearn/io/geometry_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(
self.clip = clip

@abc.abstractmethod
def _load_vector_data(self, bbox: Optional[BBox]):
def _load_vector_data(self, bbox: Optional[BBox]) -> gpd.GeoDataFrame:
"""Loads vector data given a bounding box"""

def _reproject_and_clip(self, vectors: gpd.GeoDataFrame, bbox: Optional[BBox]) -> gpd.GeoDataFrame:
Expand Down Expand Up @@ -122,7 +122,7 @@ def __init__(

self.fiona_kwargs = kwargs
self._aws_session = None
self._dataset_crs = None
self._dataset_crs: Optional[CRS] = None

super().__init__(feature=feature, reproject=reproject, clip=clip, config=config)

Expand Down Expand Up @@ -163,7 +163,7 @@ def dataset_crs(self) -> Optional[CRS]:

return self._dataset_crs

def _read_crs(self):
def _read_crs(self) -> None:
"""Reads information about CRS from a dataset"""
with fiona.open(self.full_path, **self.fiona_kwargs) as features:
self._dataset_crs = CRS(features.crs)
Expand Down
8 changes: 4 additions & 4 deletions mask/eolearn/mask/snow_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
self.undefined_value = undefined_value
self.mask_feature = (FeatureType.MASK, mask_name)

def _apply_dilation(self, snow_masks):
def _apply_dilation(self, snow_masks: np.ndarray) -> np.ndarray:
"""Apply binary dilation for each mask in the series"""
if self.dilation_size:
snow_masks = np.array([binary_dilation(mask, disk(self.dilation_size)) for mask in snow_masks])
Expand Down Expand Up @@ -148,7 +148,7 @@ def __init__(
red_params: Tuple[float, float, float, float, float] = (12, 0.3, 0.1, 0.2, 0.040),
ndsi_params: Tuple[float, float, float] = (0.4, 0.15, 0.001),
b10_index: Optional[int] = None,
**kwargs,
**kwargs: Any,
):
"""
:param data_feature: EOPatch feature represented by a tuple in the form of `(FeatureType, 'feature_name')`
Expand Down Expand Up @@ -187,7 +187,7 @@ def __init__(
self.b10_index = b10_index
self._validate_params()

def _validate_params(self):
def _validate_params(self) -> None:
"""Check length of parameters defining threshold values"""
for params, n_params in [(self.dem_params, 2), (self.red_params, 5), (self.ndsi_params, 3)]:
if not isinstance(params, (tuple, list)) or len(params) != n_params:
Expand Down Expand Up @@ -229,7 +229,7 @@ def _adjust_cloud_mask(
).astype(np.uint8)

def _apply_first_pass(
self, bands: np.ndarray, ndsi: np.ndarray, clm: np.ndarray, dem, clm_temp: np.ndarray
self, bands: np.ndarray, ndsi: np.ndarray, clm: np.ndarray, dem: np.ndarray, clm_temp: np.ndarray
) -> Tuple[np.ndarray, Optional[np.ndarray], np.ndarray]:
"""Apply first pass of snow detection"""
snow_mask_pass1 = np.where(
Expand Down

0 comments on commit c75c93a

Please sign in to comment.