Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
sronilsson committed Apr 11, 2024
2 parents c86dfa4 + 11fcc16 commit e1f6963
Show file tree
Hide file tree
Showing 14 changed files with 1,241 additions and 535 deletions.
9 changes: 7 additions & 2 deletions simba/mixins/config_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,9 @@ def get_all_clf_names(self) -> List[str]:
)
return model_names

def insert_column_headers_for_outlier_correction(self, data_df: pd.DataFrame, new_headers: List[str], filepath: str) -> pd.DataFrame:
def insert_column_headers_for_outlier_correction(
self, data_df: pd.DataFrame, new_headers: List[str], filepath: str
) -> pd.DataFrame:
"""
Helper to insert new column headers onto a dataframe.
Expand All @@ -350,7 +352,10 @@ def insert_column_headers_for_outlier_correction(self, data_df: pd.DataFrame, ne
difference = int(len(data_df.columns) - len(new_headers))
bp_missing = int(abs(difference) / 3)
if difference < 0:
raise DataHeaderError(msg=f"SIMBA ERROR: SimBA expects {len(new_headers)} columns of data inside the files within project_folder/csv/input_csv directory. However, within file {filepath} file, SimBA found {len(data_df.columns)} columns. Thus, there is {abs(difference)} missing data columns in the imported data, which may represent {int(bp_missing)} bodyparts if each body-part has an x, y and p value. Either revise the SimBA project pose-configuration with {bp_missing} less body-part, or include {bp_missing} more body-part in the imported data", source=self.__class__.__name__,)
raise DataHeaderError(
msg=f"SIMBA ERROR: SimBA expects {len(new_headers)} columns of data inside the files within project_folder/csv/input_csv directory. However, within file {filepath} file, SimBA found {len(data_df.columns)} columns. Thus, there is {abs(difference)} missing data columns in the imported data, which may represent {int(bp_missing)} bodyparts if each body-part has an x, y and p value. Either revise the SimBA project pose-configuration with {bp_missing} less body-part, or include {bp_missing} more body-part in the imported data",
source=self.__class__.__name__,
)
else:
raise DataHeaderError(
msg=f"SIMBA ERROR: SimBA expects {len(new_headers)} columns of data inside the files within project_folder/csv/input_csv directory. However, within file {filepath} file, SimBA found {len(data_df.columns)} columns. Thus, there is {abs(difference)} more data columns in the imported data than anticipated, which may represent {int(bp_missing)} bodyparts if each body-part has an x, y and p value. Either revise the SimBA project pose-configuration with {bp_missing} more body-part, or include {bp_missing} less body-part in the imported data",
Expand Down
23 changes: 15 additions & 8 deletions simba/mixins/feature_extraction_supplement_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
check_if_filepath_list_is_empty,
check_instance, check_str,
check_that_column_exist, check_valid_array,
check_valid_lst, check_valid_dataframe)
check_valid_dataframe, check_valid_lst)
from simba.utils.data import detect_bouts
from simba.utils.errors import CountError, InvalidInputError
from simba.utils.printing import SimbaTimer, stdout_success
Expand Down Expand Up @@ -626,7 +626,9 @@ def find_path_loops(data: np.ndarray) -> Dict[Tuple[int], List[int]]:
return {k: v for k, v in seen_dedup.items() if len(v) > 1}

@staticmethod
def sequential_lag_analysis(data: pd.DataFrame, criterion: str, target: str, time_window: float, fps: float):
def sequential_lag_analysis(
data: pd.DataFrame, criterion: str, target: str, time_window: float, fps: float
):
"""
Perform sequential lag analysis to determine the temporal relationship between two events.
Expand Down Expand Up @@ -668,13 +670,19 @@ def sequential_lag_analysis(data: pd.DataFrame, criterion: str, target: str, tim
value=time_window,
min_value=10e-6,
)
check_valid_dataframe(df=data, source=f'{FeatureExtractionSupplemental.sequential_lag_analysis.__name__} data',
valid_dtypes=(np.float32, np.float64, np.int64, np.int32, float, int),
required_fields=[criterion, target])
check_valid_dataframe(
df=data,
source=f"{FeatureExtractionSupplemental.sequential_lag_analysis.__name__} data",
valid_dtypes=(np.float32, np.float64, np.int64, np.int32, float, int),
required_fields=[criterion, target],
)

bouts = detect_bouts(data_df=data, target_lst=[criterion, target], fps=fps)
if len(bouts) == 0:
raise CountError(msg=f"No events of behaviors {criterion} and {target} detected in data.", source=FeatureExtractionSupplemental.sequential_lag_analysis)
raise CountError(
msg=f"No events of behaviors {criterion} and {target} detected in data.",
source=FeatureExtractionSupplemental.sequential_lag_analysis,
)
criterion_starts = bouts["Start_frame"][bouts["Event"] == criterion].values
target_starts = bouts["Start_frame"][bouts["Event"] == target].values
preceding_cnt, proceeding_cnt = 0, 0
Expand Down Expand Up @@ -760,8 +768,7 @@ def distance_and_velocity(
return movement, np.mean(v)



#df = read_df(file_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/csv/targets_inserted/Together_1.csv', file_type='csv')
# df = read_df(file_path='/Users/simon/Desktop/envs/simba/troubleshooting/two_black_animals_14bp/project_folder/csv/targets_inserted/Together_1.csv', file_type='csv')
#
# df = pd.DataFrame(np.random.randint(0, 2, (100, 2)), columns=['Attack', 'Sniffing'])
# FeatureExtractionSupplemental.sequential_lag_analysis(data=df, criterion='Attack', target='Sniffing', fps=5, time_window=2.0)
11 changes: 9 additions & 2 deletions simba/mixins/geometry_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2905,7 +2905,12 @@ def adjust_geometries(
:example:
>>> shapes = GeometryMixin().adjust_geometries(geometries=shapes, shift=(0, 333))
"""
check_valid_lst(data=geometries, source=f"{GeometryMixin().adjust_geometries.__name__} geometries", valid_dtypes=(Polygon,), min_len=1)
check_valid_lst(
data=geometries,
source=f"{GeometryMixin().adjust_geometries.__name__} geometries",
valid_dtypes=(Polygon,),
min_len=1,
)
results = []
for shape_cnt, shape in enumerate(geometries):
results.append(
Expand Down Expand Up @@ -3410,7 +3415,9 @@ def cumsum_animal_geometries_grid(
return np.cumsum(img_arr, axis=0) / fps

@staticmethod
def hausdorff_distance(geometries: List[List[Union[Polygon, LineString]]]) -> np.ndarray:
def hausdorff_distance(
geometries: List[List[Union[Polygon, LineString]]]
) -> np.ndarray:
"""
The Hausdorff distance measure of the similarity between time-series sequential geometries. It is defined as the maximum of the distances
from each point in one set to the nearest point in the other set.
Expand Down
162 changes: 103 additions & 59 deletions simba/mixins/plotting_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@
import os
import random
import shutil
import plotly.graph_objs as go
import plotly.io as pio
from PIL import Image
from typing import Any, Dict, List, Optional, Tuple, Union

import cv2
Expand All @@ -17,10 +14,13 @@
import numpy as np
import pandas as pd
import PIL
import plotly.graph_objs as go
import plotly.io as pio
import seaborn as sns
from matplotlib import cm
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from numba import njit
from PIL import Image

try:
from typing import Literal
Expand Down Expand Up @@ -1827,13 +1827,18 @@ def continuous_scatter(
fig, ax = plt.subplots()
if bg_clr is not None:
if bg_clr not in get_named_colors():
raise InvalidInputError(msg=f'bg_clr {bg_clr} is not a valid named color. Options: {get_named_colors()}', source=PlottingMixin.continuous_scatter.__name__)
raise InvalidInputError(
msg=f"bg_clr {bg_clr} is not a valid named color. Options: {get_named_colors()}",
source=PlottingMixin.continuous_scatter.__name__,
)
fig.set_facecolor(bg_clr)
if not show_box:
plt.axis("off")
plt.xlabel(columns[0])
plt.ylabel(columns[1])
plot = ax.scatter(data[columns[0]], data[columns[1]], c=data[columns[2]], s=size, cmap=palette)
plot = ax.scatter(
data[columns[0]], data[columns[1]], c=data[columns[2]], s=size, cmap=palette
)
cbar = fig.colorbar(plot)
cbar.set_label(columns[2], loc="center")
if title is not None:
Expand Down Expand Up @@ -2053,46 +2058,69 @@ def line_plot(
return plot

@staticmethod
def make_line_plot(data: List[np.ndarray],
colors: List[str],
show_box: Optional[bool] = True,
width: Optional[int] = 640,
height: Optional[int] = 480,
line_width: Optional[int] = 6,
font_size: Optional[int] = 8,
bg_clr: Optional[str] = None,
x_lbl_divisor: Optional[float] = None,
title: Optional[str] = None,
y_lbl: Optional[str] = None,
x_lbl: Optional[str] = None,
y_max: Optional[int] = -1,
line_opacity: Optional[int] = 0.0,
save_path: Optional[Union[str, os.PathLike]] = None):

check_valid_lst(data=data, source=PlottingMixin.make_line_plot.__name__, valid_dtypes=(np.ndarray, list,))
check_valid_lst(data=colors, source=PlottingMixin.make_line_plot.__name__, valid_dtypes=(str,), exact_len=len(data))
def make_line_plot(
data: List[np.ndarray],
colors: List[str],
show_box: Optional[bool] = True,
width: Optional[int] = 640,
height: Optional[int] = 480,
line_width: Optional[int] = 6,
font_size: Optional[int] = 8,
bg_clr: Optional[str] = None,
x_lbl_divisor: Optional[float] = None,
title: Optional[str] = None,
y_lbl: Optional[str] = None,
x_lbl: Optional[str] = None,
y_max: Optional[int] = -1,
line_opacity: Optional[int] = 0.0,
save_path: Optional[Union[str, os.PathLike]] = None,
):

check_valid_lst(
data=data,
source=PlottingMixin.make_line_plot.__name__,
valid_dtypes=(
np.ndarray,
list,
),
)
check_valid_lst(
data=colors,
source=PlottingMixin.make_line_plot.__name__,
valid_dtypes=(str,),
exact_len=len(data),
)
clr_dict = get_color_dict()
matplotlib.font_manager._get_font.cache_clear()
plt.close("all")
fig, ax = plt.subplots()
if bg_clr is not None: fig.set_facecolor(bg_clr)
if not show_box: plt.axis("off")
if bg_clr is not None:
fig.set_facecolor(bg_clr)
if not show_box:
plt.axis("off")
for i in range(len(data)):
line_clr = clr_dict[colors[i]][::-1]
line_clr = tuple(x / 255 for x in line_clr)
flat_data = data[i].flatten()
plt.plot(flat_data, color=line_clr, linewidth=line_width, alpha=line_opacity)
plt.plot(
flat_data, color=line_clr, linewidth=line_width, alpha=line_opacity
)
max_x = max([len(x) for x in data])
if y_max == -1: y_max = max([np.max(x) for x in data])
if y_max == -1:
y_max = max([np.max(x) for x in data])
y_ticks_locs = y_lbls = np.round(np.linspace(0, y_max, 10), 2)
x_ticks_locs = x_lbls = np.linspace(0, max_x, 5)
if x_lbl_divisor is not None: x_lbls = np.round((x_lbls / x_lbl_divisor), 1)
if y_lbl is not None: plt.ylabel(y_lbl)
if x_lbl is not None: plt.xlabel(x_lbl)
if x_lbl_divisor is not None:
x_lbls = np.round((x_lbls / x_lbl_divisor), 1)
if y_lbl is not None:
plt.ylabel(y_lbl)
if x_lbl is not None:
plt.xlabel(x_lbl)
plt.xticks(x_ticks_locs, x_lbls, rotation="horizontal", fontsize=font_size)
plt.yticks(y_ticks_locs, y_lbls, fontsize=font_size)
plt.ylim(0, y_max)
if title is not None: plt.suptitle(title, x=0.5, y=0.92, fontsize=font_size + 4)
if title is not None:
plt.suptitle(title, x=0.5, y=0.92, fontsize=font_size + 4)
buffer_ = io.BytesIO()
plt.savefig(buffer_, format="png")
buffer_.seek(0)
Expand All @@ -2103,28 +2131,29 @@ def make_line_plot(data: List[np.ndarray],
img = cv2.resize(img, (width, height))
if save_path is not None:
cv2.imwrite(save_path, img)
stdout_success(msg=f'Line plot saved at {save_path}')
stdout_success(msg=f"Line plot saved at {save_path}")
else:
return img

@staticmethod
def make_line_plot_plotly(data: List[np.ndarray],
colors: List[str],
show_box: Optional[bool] = True,
show_grid: Optional[bool] = False,
width: Optional[int] = 640,
height: Optional[int] = 480,
line_width: Optional[int] = 6,
font_size: Optional[int] = 8,
bg_clr: Optional[str] = 'white',
x_lbl_divisor: Optional[float] = None,
title: Optional[str] = None,
y_lbl: Optional[str] = None,
x_lbl: Optional[str] = None,
y_max: Optional[int] = -1,
line_opacity: Optional[int] = 0.5,
save_path: Optional[Union[str, os.PathLike]] = None):

def make_line_plot_plotly(
data: List[np.ndarray],
colors: List[str],
show_box: Optional[bool] = True,
show_grid: Optional[bool] = False,
width: Optional[int] = 640,
height: Optional[int] = 480,
line_width: Optional[int] = 6,
font_size: Optional[int] = 8,
bg_clr: Optional[str] = "white",
x_lbl_divisor: Optional[float] = None,
title: Optional[str] = None,
y_lbl: Optional[str] = None,
x_lbl: Optional[str] = None,
y_max: Optional[int] = -1,
line_opacity: Optional[int] = 0.5,
save_path: Optional[Union[str, os.PathLike]] = None,
):
"""
Create a line plot using Plotly.
Expand Down Expand Up @@ -2167,20 +2196,35 @@ def tick_formatter(x):

fig = go.Figure()
clr_dict = get_color_dict()
if y_max == -1: y_max = max([np.max(i) for i in data])
if y_max == -1:
y_max = max([np.max(i) for i in data])
for i in range(len(data)):
line_clr = clr_dict[colors[i]]
line_clr = f'rgba({line_clr[0]}, {line_clr[1]}, {line_clr[2]}, {line_opacity})'
fig.add_trace(go.Scatter(y=data[i].flatten(), mode='lines', line=dict(color=line_clr, width=line_width)))
line_clr = (
f"rgba({line_clr[0]}, {line_clr[1]}, {line_clr[2]}, {line_opacity})"
)
fig.add_trace(
go.Scatter(
y=data[i].flatten(),
mode="lines",
line=dict(color=line_clr, width=line_width),
)
)

if not show_box:
fig.update_layout(width=width, height=height, title=title, xaxis_visible=False, yaxis_visible=False,
showlegend=False)
fig.update_layout(
width=width,
height=height,
title=title,
xaxis_visible=False,
yaxis_visible=False,
showlegend=False,
)
else:
if fig['layout']['xaxis']['tickvals'] is None:
if fig["layout"]["xaxis"]["tickvals"] is None:
tickvals = [i for i in range(data[0].shape[0])]
else:
tickvals = fig['layout']['xaxis']['tickvals']
tickvals = fig["layout"]["xaxis"]["tickvals"]
if x_lbl_divisor is not None:
ticktext = [tick_formatter(x) for x in tickvals]
else:
Expand All @@ -2193,7 +2237,7 @@ def tick_formatter(x):
title=x_lbl,
tickvals=tickvals,
ticktext=ticktext,
tickmode='auto',
tickmode="auto",
tick0=0,
dtick=10,
tickfont=dict(size=font_size),
Expand All @@ -2205,14 +2249,14 @@ def tick_formatter(x):
range=[0, y_max],
showgrid=show_grid,
),
showlegend=False
showlegend=False,
)

if bg_clr is not None:
fig.update_layout(plot_bgcolor=bg_clr)
if save_path is not None:
pio.write_image(fig, save_path)
stdout_success(msg=f'Line plot saved at {save_path}')
stdout_success(msg=f"Line plot saved at {save_path}")
else:
img_bytes = fig.to_image(format="png")
img = PIL.Image.open(io.BytesIO(img_bytes))
Expand Down

0 comments on commit e1f6963

Please sign in to comment.