Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
sronilsson committed Mar 19, 2024
2 parents 8cdcf2f + 478c6d1 commit f4116cc
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 45 deletions.
163 changes: 122 additions & 41 deletions simba/mixins/geometry_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
check_if_valid_input, check_if_valid_rgb_tuple,
check_instance, check_int,
check_iterable_length, check_str,
check_valid_array, check_valid_lst, check_that_column_exist)
check_that_column_exist, check_valid_array,
check_valid_lst)
from simba.utils.data import create_color_palette, create_color_palettes
from simba.utils.enums import Defaults, Formats, GeometryEnum, TextOptions
from simba.utils.errors import CountError, InvalidInputError
Expand Down Expand Up @@ -422,8 +423,9 @@ def crosses(shapes: List[LineString]) -> bool:
return shapes[0].crosses(shapes[1])

@staticmethod
def is_shape_covered(shapes: List[Union[LineString, Polygon, MultiPolygon, MultiPoint]]) -> bool:

def is_shape_covered(
shapes: List[Union[LineString, Polygon, MultiPolygon, MultiPoint]]
) -> bool:
"""
Check if one geometry fully covers another.
Expand All @@ -440,7 +442,12 @@ def is_shape_covered(shapes: List[Union[LineString, Polygon, MultiPolygon, Multi
>>> True
"""
check_valid_lst(data=shapes, source=GeometryMixin.is_shape_covered.__name__, valid_dtypes=(LineString, Polygon, MultiPolygon, MultiPoint), exact_len=2)
check_valid_lst(
data=shapes,
source=GeometryMixin.is_shape_covered.__name__,
valid_dtypes=(LineString, Polygon, MultiPolygon, MultiPoint),
exact_len=2,
)
return shapes[1].covers(shapes[0])

@staticmethod
Expand Down Expand Up @@ -2397,10 +2404,12 @@ def geometry_histocomparison(
img_1=imgs[0], img_2=imgs[1], method=method, absolute=absolute
)

def multiframe_is_shape_covered(self,
shape_1: List[Polygon],
shape_2: List[Polygon],
core_cnt: Optional[int] = -1) -> List[bool]:
def multiframe_is_shape_covered(
self,
shape_1: List[Polygon],
shape_2: List[Polygon],
core_cnt: Optional[int] = -1,
) -> List[bool]:
"""
For each shape in time-series of shapes, check if another shape in the same time-series fully covers the
first shape.
Expand All @@ -2415,16 +2424,46 @@ def multiframe_is_shape_covered(self,
>>> shape_2 = [Polygon([[0, 0], [20, 20], [20, 10], [10, 20]]) for x in range(len(shape_1))]
>>> GeometryMixin.multiframe_is_shape_covered(shape_1=shape_1, shape_2=shape_2, core_cnt=3)
"""
check_valid_lst(data=shape_1, source=GeometryMixin.multiframe_is_shape_covered.__name__, valid_dtypes=(LineString, Polygon, MultiPolygon,))
check_valid_lst(data=shape_2, source=GeometryMixin.multiframe_is_shape_covered.__name__, valid_dtypes=(LineString, Polygon, MultiPolygon,))
check_valid_lst(
data=shape_1,
source=GeometryMixin.multiframe_is_shape_covered.__name__,
valid_dtypes=(
LineString,
Polygon,
MultiPolygon,
),
)
check_valid_lst(
data=shape_2,
source=GeometryMixin.multiframe_is_shape_covered.__name__,
valid_dtypes=(
LineString,
Polygon,
MultiPolygon,
),
)
if len(shape_1) != len(shape_2):
raise InvalidInputError(msg=f'shape_1 ({len(shape_1)}) and shape_2 ({len(shape_2)}) are unequal length', source=GeometryMixin.multiframe_is_shape_covered.__name__)
check_int(name="CORE COUNT", value=core_cnt, min_value=-1, max_value=find_core_cnt()[0], raise_error=True)
if core_cnt == -1: core_cnt = find_core_cnt()[0]
raise InvalidInputError(
msg=f"shape_1 ({len(shape_1)}) and shape_2 ({len(shape_2)}) are unequal length",
source=GeometryMixin.multiframe_is_shape_covered.__name__,
)
check_int(
name="CORE COUNT",
value=core_cnt,
min_value=-1,
max_value=find_core_cnt()[0],
raise_error=True,
)
if core_cnt == -1:
core_cnt = find_core_cnt()[0]
shapes = [list(x) for x in zip(shape_1, shape_2)]
results = []
with multiprocessing.Pool(core_cnt, maxtasksperchild=Defaults.LARGE_MAX_TASK_PER_CHILD.value) as pool:
for cnt, mp_return in enumerate(pool.imap(GeometryMixin.is_shape_covered, shapes, chunksize=1)):
with multiprocessing.Pool(
core_cnt, maxtasksperchild=Defaults.LARGE_MAX_TASK_PER_CHILD.value
) as pool:
for cnt, mp_return in enumerate(
pool.imap(GeometryMixin.is_shape_covered, shapes, chunksize=1)
):
results.append(mp_return)
pool.join()
pool.terminate()
Expand Down Expand Up @@ -3562,8 +3601,9 @@ def linear_frechet_distance(
return ca[n_p - 1, n_q - 1]

@staticmethod
def simba_roi_to_geometries(rectangles_df: pd.DataFrame, circles_df: pd.DataFrame,
polygons_df: pd.DataFrame) -> dict:
def simba_roi_to_geometries(
rectangles_df: pd.DataFrame, circles_df: pd.DataFrame, polygons_df: pd.DataFrame
) -> dict:
"""
Convert SimBA dataframes holding ROI geometries to nested dictionary holding Shapley polygons.
Expand All @@ -3574,33 +3614,74 @@ def simba_roi_to_geometries(rectangles_df: pd.DataFrame, circles_df: pd.DataFram
>>> #GeometryMixin.simba_roi_to_geometries(rectangles_df=config.rectangles_df, circles_df=config.circles_df, polygons_df=config.polygon_df)
"""

check_instance(source=GeometryMixin.simba_roi_to_geometries.__name__, instance=rectangles_df, accepted_types=(pd.DataFrame,))
check_instance(source=GeometryMixin.simba_roi_to_geometries.__name__, instance=circles_df, accepted_types=(pd.DataFrame,))
check_instance(source=GeometryMixin.simba_roi_to_geometries.__name__, instance=polygons_df, accepted_types=(pd.DataFrame,))
for i in [rectangles_df, circles_df, polygons_df]: check_that_column_exist(df=i, column_name=['Video', 'Name','Tags'], file_name='')
check_instance(
source=GeometryMixin.simba_roi_to_geometries.__name__,
instance=rectangles_df,
accepted_types=(pd.DataFrame,),
)
check_instance(
source=GeometryMixin.simba_roi_to_geometries.__name__,
instance=circles_df,
accepted_types=(pd.DataFrame,),
)
check_instance(
source=GeometryMixin.simba_roi_to_geometries.__name__,
instance=polygons_df,
accepted_types=(pd.DataFrame,),
)
for i in [rectangles_df, circles_df, polygons_df]:
check_that_column_exist(
df=i, column_name=["Video", "Name", "Tags"], file_name=""
)
results = {}
for video_name in rectangles_df['Video'].unique():
if video_name not in results.keys(): results[video_name] = {}
video_shapes = rectangles_df[['Tags', 'Name']][rectangles_df['Video'] == video_name]
for shape_name in video_shapes['Name'].unique():
shape_data = video_shapes[video_shapes['Name'] == shape_name].reset_index(drop=True)
tags, name = list(shape_data['Tags'].values[0].values()), shape_data['Name'].values[0]
for video_name in rectangles_df["Video"].unique():
if video_name not in results.keys():
results[video_name] = {}
video_shapes = rectangles_df[["Tags", "Name"]][
rectangles_df["Video"] == video_name
]
for shape_name in video_shapes["Name"].unique():
shape_data = video_shapes[
video_shapes["Name"] == shape_name
].reset_index(drop=True)
tags, name = (
list(shape_data["Tags"].values[0].values()),
shape_data["Name"].values[0],
)
results[video_name][name] = Polygon(tags)
for video_name in polygons_df['Video'].unique():
if video_name not in results.keys(): results[video_name] = {}
video_shapes = polygons_df[['Tags', 'Name']][polygons_df['Video'] == video_name]
for shape_name in video_shapes['Name'].unique():
shape_data = video_shapes[video_shapes['Name'] == shape_name].reset_index(drop=True)
tags, name = list(shape_data['Tags'].values[0].values()), shape_data['Name'].values[0]
for video_name in polygons_df["Video"].unique():
if video_name not in results.keys():
results[video_name] = {}
video_shapes = polygons_df[["Tags", "Name"]][
polygons_df["Video"] == video_name
]
for shape_name in video_shapes["Name"].unique():
shape_data = video_shapes[
video_shapes["Name"] == shape_name
].reset_index(drop=True)
tags, name = (
list(shape_data["Tags"].values[0].values()),
shape_data["Name"].values[0],
)
results[video_name][name] = Polygon(tags)
for video_name in circles_df['Video'].unique():
if video_name not in results.keys(): results[video_name] = {}
video_shapes = circles_df[['Tags', 'Name']][circles_df['Video'] == video_name]
for shape_name in video_shapes['Name'].unique():
shape_data = video_shapes[video_shapes['Name'] == shape_name].reset_index(drop=True)
tags, name, radius = list(shape_data['Tags'].values[0].values()), shape_data['Name'].values[0], \
shape_data['radius'].values[0]
results[video_name][name] = Point(tags['Center tag']).buffer(distance=radius)
for video_name in circles_df["Video"].unique():
if video_name not in results.keys():
results[video_name] = {}
video_shapes = circles_df[["Tags", "Name"]][
circles_df["Video"] == video_name
]
for shape_name in video_shapes["Name"].unique():
shape_data = video_shapes[
video_shapes["Name"] == shape_name
].reset_index(drop=True)
tags, name, radius = (
list(shape_data["Tags"].values[0].values()),
shape_data["Name"].values[0],
shape_data["radius"].values[0],
)
results[video_name][name] = Point(tags["Center tag"]).buffer(
distance=radius
)
return results

@staticmethod
Expand Down
14 changes: 10 additions & 4 deletions simba/roi_tools/ROI_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@

from simba.mixins.config_reader import ConfigReader
from simba.mixins.feature_extraction_mixin import FeatureExtractionMixin
from simba.utils.checks import check_file_exist_and_readable
from simba.utils.enums import ConfigKey, Dtypes
from simba.utils.errors import (BodypartColumnNotFoundError, MissingColumnsError, NoFilesFoundError)
from simba.utils.errors import (BodypartColumnNotFoundError,
MissingColumnsError, NoFilesFoundError)
from simba.utils.printing import stdout_success
from simba.utils.read_write import get_fn_ext, read_config_entry, read_df
from simba.utils.warnings import NoDataFoundWarning
from simba.utils.checks import check_file_exist_and_readable


class ROIAnalyzer(ConfigReader, FeatureExtractionMixin):
Expand Down Expand Up @@ -60,14 +61,19 @@ def __init__(
self.input_folder = os.path.join(self.project_path, "csv", data_path)
self.files_found = glob.glob(self.input_folder + f"/*.{self.file_type}")
if len(self.files_found) == 0:
raise NoFilesFoundError(msg=f"No files in format {self.file_type} found in {self.input_folder}", source=self.__class__.__name__)
raise NoFilesFoundError(
msg=f"No files in format {self.file_type} found in {self.input_folder}",
source=self.__class__.__name__,
)
if file_path is not None:
check_file_exist_and_readable(file_path=file_path)
self.files_found = [file_path]
if self.settings is None:
self.roi_config = dict(self.config.items(ConfigKey.ROI_SETTINGS.value))
if "animal_1_bp" not in self.roi_config.keys():
raise BodypartColumnNotFoundError(msg="Could not find animal_1_bp settings in the project config. Please analyze ROI data FIRST.")
raise BodypartColumnNotFoundError(
msg="Could not find animal_1_bp settings in the project config. Please analyze ROI data FIRST."
)
self.settings = {}
self.settings["threshold"] = read_config_entry(
self.config,
Expand Down

0 comments on commit f4116cc

Please sign in to comment.