Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master'
Browse files Browse the repository at this point in the history
# Conflicts:
#	simba/mixins/statistics_mixin.py
#	simba/unsupervised/dbcv_calculator.py
  • Loading branch information
sronilsson committed Mar 26, 2024
2 parents 06e3d47 + c4c5a26 commit 58a3fb1
Show file tree
Hide file tree
Showing 15 changed files with 807 additions and 274 deletions.
7 changes: 5 additions & 2 deletions simba/mixins/geometry_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,10 @@ def geometry_video(
accepted_types=(tuple,),
)
if len(size) != 2:
raise InvalidInputError(msg=f"Size has to be 2 values, got {len(size)}", source=GeometryMixin.geometry_video.__name__,)
raise InvalidInputError(
msg=f"Size has to be 2 values, got {len(size)}",
source=GeometryMixin.geometry_video.__name__,
)
for i in size:
check_instance(
source=GeometryMixin.geometry_video.__name__,
Expand Down Expand Up @@ -1056,7 +1059,7 @@ def geometry_video(

video_writer.write(frm_img.astype(np.uint8))
if verbose:
print(f'Geometry frame complete ({frm_cnt+1} / {len(shapes)})')
print(f"Geometry frame complete ({frm_cnt+1} / {len(shapes)})")
video_writer.release()
timer.stop_timer()
stdout_success(
Expand Down
214 changes: 148 additions & 66 deletions simba/mixins/plotting_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,15 @@
from typing_extensions import Literal

import simba
from simba.utils.checks import check_file_exist_and_readable, check_instance, check_that_column_exist, check_valid_array, check_if_dir_exists, check_str
from simba.utils.checks import (check_file_exist_and_readable,
check_if_dir_exists, check_instance, check_str,
check_that_column_exist, check_valid_array)
from simba.utils.enums import Formats, Options, TextOptions
from simba.utils.lookups import get_color_dict, get_named_colors, get_categorical_palettes
from simba.utils.errors import InvalidInputError
from simba.utils.lookups import (get_categorical_palettes, get_color_dict,
get_named_colors)
from simba.utils.printing import SimbaTimer, stdout_success
from simba.utils.read_write import get_fn_ext, read_frm_of_video
from simba.utils.errors import InvalidInputError


class PlottingMixin(object):
Expand Down Expand Up @@ -1963,33 +1966,56 @@ def rotate_img(img: np.ndarray):
return np.ascontiguousarray(np.fliplr(rotated_image).astype(np.uint8))

@staticmethod
def continuous_scatter(data: Union[np.ndarray, pd.DataFrame],
columns: Optional[List[str]] = ('X', 'Y', 'Cluster'),
palette: Optional[str] = 'magma',
show_box: Optional[bool] = False,
size: Optional[int] = 10,
title: Optional[str] = None,
save_path: Optional[Union[str, os.PathLike]] = None):

""" Create a 2D scatterplot with a continuous legend """
def continuous_scatter(
data: Union[np.ndarray, pd.DataFrame],
columns: Optional[List[str]] = ("X", "Y", "Cluster"),
palette: Optional[str] = "magma",
show_box: Optional[bool] = False,
size: Optional[int] = 10,
title: Optional[str] = None,
save_path: Optional[Union[str, os.PathLike]] = None,
):
"""Create a 2D scatterplot with a continuous legend"""

check_instance(source=f'{PlottingMixin.continuous_scatter.__name__} data', instance=data, accepted_types=(np.ndarray, pd.DataFrame))
check_instance(
source=f"{PlottingMixin.continuous_scatter.__name__} data",
instance=data,
accepted_types=(np.ndarray, pd.DataFrame),
)
if isinstance(data, pd.DataFrame):
check_that_column_exist(df=data, column_name=columns, file_name=PlottingMixin.continuous_scatter.__name__)
check_that_column_exist(
df=data,
column_name=columns,
file_name=PlottingMixin.continuous_scatter.__name__,
)
data = data[list(columns)]
else:
check_valid_array(data=data, source=PlottingMixin.continuous_scatter.__name__, accepted_ndims=(2,), max_axis_1=len(columns), min_axis_1=len(columns))
check_valid_array(
data=data,
source=PlottingMixin.continuous_scatter.__name__,
accepted_ndims=(2,),
max_axis_1=len(columns),
min_axis_1=len(columns),
)
data = pd.DataFrame(data, columns=list(columns))

fig, ax = plt.subplots()
if not show_box: plt.axis('off')
if not show_box:
plt.axis("off")
plt.xlabel(columns[0])
plt.ylabel(columns[1])
plot = ax.scatter(data[columns[0]], data[columns[1]], c=data[columns[2]], s=size, cmap=palette)
plot = ax.scatter(
data[columns[0]], data[columns[1]], c=data[columns[2]], s=size, cmap=palette
)
cbar = fig.colorbar(plot)
cbar.set_label(columns[2], loc="center")
if title is not None:
plt.title(title, ha="center", fontsize=15, bbox={"facecolor": "orange", "alpha": 0.5, "pad": 0})
plt.title(
title,
ha="center",
fontsize=15,
bbox={"facecolor": "orange", "alpha": 0.5, "pad": 0},
)
if save_path is not None:
check_if_dir_exists(in_dir=os.path.dirname(save_path))
fig.savefig(save_path)
Expand All @@ -1998,39 +2024,66 @@ def continuous_scatter(data: Union[np.ndarray, pd.DataFrame],
return plot

@staticmethod
def categorical_scatter(data: Union[np.ndarray, pd.DataFrame],
columns: Optional[List[str]] = ('X', 'Y', 'Cluster'),
palette: Optional[str] = 'Set1',
show_box: Optional[bool] = False,
size: Optional[int] = 10,
title: Optional[str] = None,
save_path: Optional[Union[str, os.PathLike]] = None):

""" Create a 2D scatterplot with a categorical legend """
def categorical_scatter(
data: Union[np.ndarray, pd.DataFrame],
columns: Optional[List[str]] = ("X", "Y", "Cluster"),
palette: Optional[str] = "Set1",
show_box: Optional[bool] = False,
size: Optional[int] = 10,
title: Optional[str] = None,
save_path: Optional[Union[str, os.PathLike]] = None,
):
"""Create a 2D scatterplot with a categorical legend"""
cmaps = get_categorical_palettes()
if palette not in cmaps: raise InvalidInputError(msg=f'{palette} is not a valid palette. Accepted options: {cmaps}.', source=PlottingMixin.categorical_scatter.__name__)
check_instance(source=f'{PlottingMixin.categorical_scatter.__name__} data', instance=data, accepted_types=(np.ndarray, pd.DataFrame))
if palette not in cmaps:
raise InvalidInputError(
msg=f"{palette} is not a valid palette. Accepted options: {cmaps}.",
source=PlottingMixin.categorical_scatter.__name__,
)
check_instance(
source=f"{PlottingMixin.categorical_scatter.__name__} data",
instance=data,
accepted_types=(np.ndarray, pd.DataFrame),
)
if isinstance(data, pd.DataFrame):
check_that_column_exist(df=data, column_name=columns, file_name=PlottingMixin.categorical_scatter.__name__)
check_that_column_exist(
df=data,
column_name=columns,
file_name=PlottingMixin.categorical_scatter.__name__,
)
data = data[list(columns)]
else:
check_valid_array(data=data, source=PlottingMixin.categorical_scatter.__name__, accepted_ndims=(2,), max_axis_1=len(columns), min_axis_1=len(columns))
check_valid_array(
data=data,
source=PlottingMixin.categorical_scatter.__name__,
accepted_ndims=(2,),
max_axis_1=len(columns),
min_axis_1=len(columns),
)
data = pd.DataFrame(data, columns=list(columns))

if not show_box: plt.axis('off')
if not show_box:
plt.axis("off")
pct_x = np.percentile(data[columns[0]].values, 25)
pct_y = np.percentile(data[columns[1]].values, 25)
plt.xlim(data[columns[0]].min() - pct_x, data[columns[0]].max() + pct_x)
plt.ylim(data[columns[1]].min() - pct_y, data[columns[1]].max() + pct_y)

plot = sns.scatterplot(data=data,
x=columns[0],
y=columns[1],
hue=columns[2],
palette=sns.color_palette(palette, len(data[columns[2]].unique())),
s=size)
plot = sns.scatterplot(
data=data,
x=columns[0],
y=columns[1],
hue=columns[2],
palette=sns.color_palette(palette, len(data[columns[2]].unique())),
s=size,
)
if title is not None:
plt.title(title, ha="center", fontsize=15, bbox={"facecolor": "orange", "alpha": 0.5, "pad": 0})
plt.title(
title,
ha="center",
fontsize=15,
bbox={"facecolor": "orange", "alpha": 0.5, "pad": 0},
)
if save_path is not None:
check_if_dir_exists(in_dir=os.path.dirname(save_path))
plt.savefig(save_path)
Expand All @@ -2039,14 +2092,15 @@ def categorical_scatter(data: Union[np.ndarray, pd.DataFrame],
return plot

@staticmethod
def joint_plot(data: Union[np.ndarray, pd.DataFrame],
columns: Optional[List[str]] = ('X', 'Y', 'Cluster'),
palette: Optional[str] = 'Set1',
kind: Optional[str] = 'scatter',
size: Optional[int] = 10,
title: Optional[str] = None,
save_path: Optional[Union[str, os.PathLike]] = None):

def joint_plot(
data: Union[np.ndarray, pd.DataFrame],
columns: Optional[List[str]] = ("X", "Y", "Cluster"),
palette: Optional[str] = "Set1",
kind: Optional[str] = "scatter",
size: Optional[int] = 10,
title: Optional[str] = None,
save_path: Optional[Union[str, os.PathLike]] = None,
):
"""
:example:
>>> x = np.hstack([np.random.normal(loc=10, scale=4, size=(100, 2)), np.random.randint(0, 1, size=(100, 1))])
Expand All @@ -2055,36 +2109,64 @@ def joint_plot(data: Union[np.ndarray, pd.DataFrame],
"""

cmaps = get_categorical_palettes()
if palette not in cmaps: raise InvalidInputError(
msg=f'{palette} is not a valid palette. Accepted options: {cmaps}', source=PlottingMixin.joint_plot.__name__)
check_instance(source=f'{PlottingMixin.joint_plot.__name__} data', instance=data, accepted_types=(np.ndarray, pd.DataFrame))
check_str(name=f'{PlottingMixin.joint_plot.__name__} kind', value=kind, options=('kde', 'reg', 'hist', 'scatter'))
if palette not in cmaps:
raise InvalidInputError(
msg=f"{palette} is not a valid palette. Accepted options: {cmaps}",
source=PlottingMixin.joint_plot.__name__,
)
check_instance(
source=f"{PlottingMixin.joint_plot.__name__} data",
instance=data,
accepted_types=(np.ndarray, pd.DataFrame),
)
check_str(
name=f"{PlottingMixin.joint_plot.__name__} kind",
value=kind,
options=("kde", "reg", "hist", "scatter"),
)
if isinstance(data, pd.DataFrame):
check_that_column_exist(df=data, column_name=columns, file_name=PlottingMixin.joint_plot.__name__)
check_that_column_exist(
df=data,
column_name=columns,
file_name=PlottingMixin.joint_plot.__name__,
)
data = data[list(columns)]
else:
check_valid_array(data=data, source=PlottingMixin.joint_plot.__name__, accepted_ndims=(2,), max_axis_1=len(columns),min_axis_1=len(columns))
check_valid_array(
data=data,
source=PlottingMixin.joint_plot.__name__,
accepted_ndims=(2,),
max_axis_1=len(columns),
min_axis_1=len(columns),
)
data = pd.DataFrame(data, columns=list(columns))

pct_x = np.percentile(data[columns[0]].values, 10)
pct_y = np.percentile(data[columns[1]].values, 10)
plot = sns.jointplot(data=data,
x=columns[0],
y=columns[1],
hue=columns[2],
xlim=(data[columns[0]].min() - pct_x, data[columns[0]].max() + pct_x),
ylim=(data[columns[1]].min() - pct_y, data[columns[1]].max() + pct_y),
palette=sns.color_palette(palette, len(data[columns[2]].unique())),
kind=kind,
marginal_ticks=False,
s=size)
plot = sns.jointplot(
data=data,
x=columns[0],
y=columns[1],
hue=columns[2],
xlim=(data[columns[0]].min() - pct_x, data[columns[0]].max() + pct_x),
ylim=(data[columns[1]].min() - pct_y, data[columns[1]].max() + pct_y),
palette=sns.color_palette(palette, len(data[columns[2]].unique())),
kind=kind,
marginal_ticks=False,
s=size,
)

if title is not None:
plot.fig.suptitle(title, va='baseline', ha='center', fontsize=15,
bbox={"facecolor": "orange", "alpha": 0.5, "pad": 0})
plot.fig.suptitle(
title,
va="baseline",
ha="center",
fontsize=15,
bbox={"facecolor": "orange", "alpha": 0.5, "pad": 0},
)
if save_path is not None:
check_if_dir_exists(in_dir=os.path.dirname(save_path))
plot.savefig(save_path)
plt.close("all")
else:
return plot
return plot

0 comments on commit 58a3fb1

Please sign in to comment.