Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
sronilsson committed Mar 20, 2024
2 parents 1a804a9 + 1960bc4 commit 68bfeb9
Showing 1 changed file with 42 additions and 13 deletions.
55 changes: 42 additions & 13 deletions simba/mixins/geometry_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,12 @@ def buffer_shape(
)

@staticmethod
def compute_pct_shape_overlap(shapes: List[Union[Polygon, LineString]], denominator: Optional[Literal['difference', 'shape_1', 'shape_2']] = 'difference') -> int:
def compute_pct_shape_overlap(
shapes: List[Union[Polygon, LineString]],
denominator: Optional[
Literal["difference", "shape_1", "shape_2"]
] = "difference",
) -> int:
"""
Compute the percentage of overlap between two shapes.
Expand Down Expand Up @@ -345,12 +350,23 @@ def compute_pct_shape_overlap(shapes: List[Union[Polygon, LineString]], denomina
val=len(shapes),
exact_accepted_length=2,
)
check_str(name=GeometryMixin.compute_pct_shape_overlap.__name__, value=denominator, options=('difference', 'shape_1', 'shape_2'))
check_str(
name=GeometryMixin.compute_pct_shape_overlap.__name__,
value=denominator,
options=("difference", "shape_1", "shape_2"),
)
if shapes[0].intersects(shapes[1]):
intersection = shapes[0].intersection(shapes[1])
if denominator == 'difference':
return np.round((intersection.area / ((shapes[0].area + shapes[1].area) - intersection.area) * 100), 2)
elif denominator == 'shape_1':
if denominator == "difference":
return np.round(
(
intersection.area
/ ((shapes[0].area + shapes[1].area) - intersection.area)
* 100
),
2,
)
elif denominator == "shape_1":
return np.round((intersection.area / shapes[0].area) * 100, 2)
else:
return np.round((intersection.area / shapes[1].area) * 100, 2)
Expand Down Expand Up @@ -1424,8 +1440,8 @@ def multiframe_compute_pct_shape_overlap(
video_name: Optional[str] = None,
verbose: Optional[bool] = False,
animal_names: Optional[Tuple[str]] = None,
denominator: Optional[Literal['difference', 'shape_1', 'shape_2']] = difference) -> List[float]:

denominator: Optional[Literal["difference", "shape_1", "shape_2"]] = difference,
) -> List[float]:
"""
Compute the percentage overlap between corresponding Polygons in two lists.
Expand Down Expand Up @@ -1467,10 +1483,22 @@ def multiframe_compute_pct_shape_overlap(
msg=f"shape_1 and shape_2 contains more than 1 dtype {input_dtypes}",
source=GeometryMixin.multiframe_compute_pct_shape_overlap.__name__,
)
check_instance(source=GeometryMixin.multiframe_compute_pct_shape_overlap.__name__, instance=shape_1[0], accepted_types=(LineString, Polygon))
data, results, timer = (np.column_stack((shape_1, shape_2)), [], SimbaTimer(start=True))
with multiprocessing.Pool(core_cnt, maxtasksperchild=Defaults.LARGE_MAX_TASK_PER_CHILD.value) as pool:
constants = functools.partial(GeometryMixin.compute_pct_shape_overlap, denominator=denominator)
check_instance(
source=GeometryMixin.multiframe_compute_pct_shape_overlap.__name__,
instance=shape_1[0],
accepted_types=(LineString, Polygon),
)
data, results, timer = (
np.column_stack((shape_1, shape_2)),
[],
SimbaTimer(start=True),
)
with multiprocessing.Pool(
core_cnt, maxtasksperchild=Defaults.LARGE_MAX_TASK_PER_CHILD.value
) as pool:
constants = functools.partial(
GeometryMixin.compute_pct_shape_overlap, denominator=denominator
)
for cnt, result in enumerate(pool.imap(constants, data, chunksize=1)):
if verbose:
if not video_name and not animal_names:
Expand Down Expand Up @@ -1605,7 +1633,9 @@ def multiframe_shape_distance(
with multiprocessing.Pool(
core_cnt, maxtasksperchild=Defaults.LARGE_MAX_TASK_PER_CHILD.value
) as pool:
constants = functools.partial(GeometryMixin.shape_distance, pixels_per_mm=pixels_per_mm, unit=unit)
constants = functools.partial(
GeometryMixin.shape_distance, pixels_per_mm=pixels_per_mm, unit=unit
)
for cnt, result in enumerate(pool.imap(constants, data, chunksize=1)):
results.append(result)

Expand Down Expand Up @@ -3724,4 +3754,3 @@ def filter_low_p_bps_for_shapes(x: np.ndarray, p: np.ndarray, threshold: float):
# polygon_1 = GeometryMixin().bodyparts_to_polygon(np.array([[0, 100],[100, 100],[0, 0],[100, 0]]))
# polygon_2 = GeometryMixin().bodyparts_to_polygon(np.array([[25, 75],[75, 75],[25, 25],[75, 25]]))
# y = GeometryMixin().compute_pct_shape_overlap(shapes=[polygon_1, polygon_2], denominator='shape_2')

0 comments on commit 68bfeb9

Please sign in to comment.