Skip to content

Commit 7c8219b

Browse files
committed
updated hmm_plots to include multi-hmm no facet plots
1 parent aa54037 commit 7c8219b

File tree

4 files changed

+105
-90
lines changed

4 files changed

+105
-90
lines changed

src/ethoscopy/behavpy_core.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -1408,7 +1408,7 @@ def remove_first_last_bout(self, variable: str) -> "behavpy_core":
14081408
"""
14091409
Remove the first and last bouts of a value per specimen.
14101410
1411-
Used for columns containing continuous runs of boolean values (like 'moving' or 'asleep')
1411+
Used for columns containing continuous runs of categorical integer values, such as bools,
14121412
to remove potentially incomplete bouts at the start and end of recordings. This is useful
14131413
when you are not sure if the starting and ending bouts were cut in two when filtering
14141414
or stopping the experiment.
@@ -1422,16 +1422,11 @@ def remove_first_last_bout(self, variable: str) -> "behavpy_core":
14221422
14231423
Raises:
14241424
KeyError: If variable column not found in data
1425-
TypeError: If variable column does not contain boolean values
1426-
ValueError: If data contains no state changes
14271425
14281426
Example:
14291427
# Remove first/last sleep bouts that may be incomplete
14301428
df = df.remove_first_last_bout('asleep')
14311429
"""
1432-
1433-
if not pd.api.types.is_bool_dtype(self[variable]):
1434-
raise TypeError(f"Column '{variable}' must contain boolean values")
14351430

14361431
def _wrapped_remove_first_last_bout(data: pd.DataFrame) -> pd.DataFrame:
14371432
if len(data) == 0:

src/ethoscopy/behavpy_draw.py

+78-65
Original file line numberDiff line numberDiff line change
@@ -166,45 +166,36 @@ def _check_lists_hmm(self, f_col, f_arg, f_lab, h, b):
166166
else:
167167
f_arg = [None] * len(h)
168168
f_lab = [f'HMM-{i+1}' for i in range(len(h))]
169-
b_list = [b] * len(h)
170-
return f_arg, f_lab, h, b_list
171-
172-
if f_col is not None:
173-
if f_arg is None:
174-
f_arg = list(set(self.meta[f_col].tolist()))
175-
if f_lab is None:
176-
string_args = []
177-
for i in f_arg:
178-
if i not in self.meta[f_col].tolist():
179-
raise KeyError(f'Argument "{i}" is not in the meta column {f_col}')
180-
string_args.append(str(i))
181-
f_lab = string_args
182-
elif len(f_arg) != len(f_lab):
183-
print("The facet labels don't match the length of the variables in the column. Using column variables instead")
184-
f_lab = f_arg
185-
else:
186-
if f_lab is None:
187-
string_args = []
188-
for i in f_arg:
189-
string_args.append(str(i))
190-
f_lab = string_args
191-
elif len(f_arg) != len(f_lab):
192-
print("The facet labels don't match the entered facet arguments in length. Using column variables instead")
193-
f_lab = f_arg
194-
else:
195-
f_arg = [None]
196-
if f_lab is None:
197-
f_lab = ['']
169+
return f_arg, f_lab, h, b
198170

199-
if isinstance(h, list) is False:
200-
h_list = [h]
201-
b_list = [b]
202-
if len(h_list) != len(f_arg):
203-
h_list = [h_list[0]] * len(f_arg)
204-
if len(b_list) != len(f_arg):
205-
b_list = [b_list[0]] * len(f_arg)
171+
else:
172+
h_list = h
173+
b_list = b
206174

207-
return f_arg, f_lab, h_list, b_list
175+
if f_col is None: # is no facet column, then return fake lists
176+
f_arg = [None]
177+
f_lab = ['']
178+
return f_arg, f_lab, h_list, b_list
179+
180+
if f_arg is not None: # check if all the facet args are in the meta column
181+
for i in f_arg:
182+
if i not in self.meta[f_col].tolist():
183+
raise KeyError(f'Argument "{i}" is not in the meta column {f_col}')
184+
185+
if f_col is not None and f_arg is not None and f_lab is not None: # is user provides all, just check for length match
186+
if len(f_arg) != len(f_lab):
187+
print("The facet labels don't match the length of the variables in the column. Using column variables names instead")
188+
f_lab = [str(arg) for arg in f_arg]
189+
return f_arg, f_lab, h_list, b_list
190+
191+
if f_col is not None and f_arg is not None and f_lab is None: # if user provides a facet column and args but no labels
192+
f_lab = [str(arg) for arg in f_arg]
193+
return f_arg, f_lab, h_list, b_list
194+
195+
if f_col is not None and f_arg is None and f_lab is None: # if user provides a facet column but no args or labels
196+
f_arg = list(set(self.meta[f_col].tolist()))
197+
f_lab = [str(arg) for arg in f_arg]
198+
return f_arg, f_lab, h_list, b_list
208199

209200
@staticmethod
210201
def _zscore_bootstrap(array:np.array, z_score:bool = True, second_array:np.array = None, min_max:bool = False):
@@ -391,10 +382,11 @@ def _check_grey(name, col, response = False):
391382

392383
# GENERAL PLOT HELPERS
393384

394-
def facet_merge(self, data, facet_col, facet_arg, facet_labels, hmm_labels = None):
385+
@staticmethod
386+
def facet_merge(data, meta, facet_col, facet_arg, facet_labels, hmm_labels = None):
395387
""" A internal method for joining a metadata column with its data for plotting purposes """
396388
# merge the facet_col column and replace with the labels
397-
data = data.join(self.meta[[facet_col]])
389+
data = data.join(meta[[facet_col]])
398390
data[facet_col] = data[facet_col].astype('category')
399391
map_dict = {k : v for k, v in zip(facet_arg, facet_labels)}
400392
data[facet_col] = data[facet_col].map(map_dict)
@@ -559,7 +551,7 @@ def alter_merge(response, mov, tb):
559551
grouped_data['state'] = grouped_data['state'].map(hmm_dict)
560552
grouped_data[''] = grouped_data['has_interacted']
561553
else:
562-
grouped_data = self.facet_merge(grouped_data, facet_col, facet_arg, facet_labels, hmm_labels = labels)
554+
grouped_data = self.facet_merge(grouped_data, facet_col, facet_arg, facet_labels, hmm_labels = labels) # chnage to have meta given as arg
563555
grouped_data[facet_col] = grouped_data[facet_col].astype('str')
564556
grouped_data[facet_col] = grouped_data[facet_col] + " " + grouped_data['has_interacted']
565557

@@ -798,7 +790,7 @@ def get_response(int_data, ptype, time_window_length, resp_col, t_col):
798790
# map stim names and create column to facet by
799791
grouped_data[interaction_id_col] = grouped_data[interaction_id_col].map(map_dict)
800792
if facet_col:
801-
grouped_data = self.facet_merge(grouped_data, facet_col, facet_arg, facet_labels)
793+
grouped_data = self.facet_merge(grouped_data, self.meta, facet_col, facet_arg, facet_labels)
802794
grouped_data[facet_col] = grouped_data[facet_col].astype(str) + "-" + grouped_data[interaction_id_col]
803795
else:
804796
facet_col = 'stim_type'
@@ -841,7 +833,7 @@ def _internal_plot_quantify(self, variable, facet_col, facet_arg, facet_labels,
841833
if facet_col:
842834
palette = self._get_colours(facet_labels)
843835
palette_dict = {name : self._check_grey(name, palette[c])[1] for c, name in enumerate(facet_labels)} # change to grey if control
844-
grouped_data = self.facet_merge(grouped_data, facet_col, facet_arg, facet_labels)
836+
grouped_data = self.facet_merge(grouped_data, self.meta, facet_col, facet_arg, facet_labels)
845837
else:
846838
palette = self._get_colours(variable)
847839
palette_dict = {name : self._check_grey(name, palette[c])[1] for c, name in enumerate(variable)} # change to grey if control
@@ -910,7 +902,7 @@ def _internal_plot_day_night(self, variable, facet_col, facet_arg, facet_labels,
910902

911903
palette = self._get_colours(facet_labels)
912904
palette_dict = {name : self._check_grey(name, palette[c])[1] for c, name in enumerate(facet_labels)} # change to grey if control
913-
if facet_col: grouped_data = self.facet_merge(grouped_data, facet_col, facet_arg, facet_labels)
905+
if facet_col: grouped_data = self.facet_merge(grouped_data, self.meta,facet_col, facet_arg, facet_labels)
914906

915907
return grouped_data, palette_dict, facet_labels
916908

@@ -933,13 +925,14 @@ def _internal_plot_anticipation_score(self, variable, facet_col, facet_arg, face
933925
palette_dict = {name : self._check_grey(name, palette[c])[1] for c, name in enumerate(facet_labels)} # change to grey if control
934926

935927
if facet_col:
936-
grouped_data = self.facet_merge(grouped_data, facet_col, facet_arg, facet_labels)
928+
grouped_data = self.facet_merge(grouped_data, self.meta, facet_col, facet_arg, facet_labels)
937929

938930
return grouped_data, palette_dict, facet_labels
939931

940932
def _internal_plot_decoder(self, hmm, variable, labels, colours, facet_col, facet_arg, facet_labels,
941933
t_bin, func, t_column, rm=False):
942934
""" contains the first part of the internal plotters for HMM quant plots """
935+
943936
labels, colours = self._check_hmm_shape(hm = hmm, lab = labels, col = colours)
944937
facet_arg, facet_labels, h_list, b_list = self._check_lists_hmm(facet_col, facet_arg, facet_labels, hmm, t_bin)
945938

@@ -954,23 +947,40 @@ def _internal_plot_decoder(self, hmm, variable, labels, colours, facet_col, face
954947
# takes subselection of df that contains the specified facet columns
955948
data = self.xmv(facet_col, facet_arg)
956949

950+
def hmm_list_facet(data, meta, facet_label, ind):
951+
d = data.copy(deep=True)
952+
m = meta.copy(deep=True)
953+
954+
d.id = d.id + f'_{ind}'
955+
m.index = m.index + f'_{ind}'
956+
m['HMM'] = facet_label
957+
return self.__class__(d, m, check=True)
958+
957959
if facet_col is None: # decode the whole dataset
958-
decoded_data = self.__class__(self._hmm_decode(data, hmm, t_bin, variable, func, t_column, return_type='table'), data.meta, check=True)
959-
else:
960-
if isinstance(hmm, list) is False: # if only 1 hmm but is faceted, decode as whole for efficiency
961-
decoded_data = self.__class__(self._hmm_decode(data, hmm, t_bin, variable, func, t_column, return_type='table'), data.meta, check=True)
960+
if isinstance(hmm, list):
961+
decoded_data = concat(*[hmm_list_facet(self._hmm_decode(data, h, b, variable, func, t_column, return_type='table'), data.meta, f, c+1)
962+
for c, (h, b, f) in enumerate(zip(h_list, b_list, facet_labels))])
963+
facet_arg = facet_labels
964+
facet_col = 'HMM'
962965
else:
966+
decoded_data = self.__class__(self._hmm_decode(data, hmm, t_bin, variable, func, t_column, return_type='table'), data.meta, check=True)
967+
else:
968+
if isinstance(hmm, list):
963969
decoded_data = concat(*[self.__class__(self._hmm_decode(data.xmv(facet_col, arg), h, b, variable, func, t_column, return_type='table'),
964970
data.meta, check=True) for arg, h, b in zip(facet_arg, h_list, b_list)])
971+
else: # if only 1 hmm but is faceted, decode as whole for efficiency
972+
decoded_data = self.__class__(self._hmm_decode(data, hmm, t_bin, variable, func, t_column, return_type='table'), data.meta, check=True)
973+
965974

966-
return decoded_data, labels, colours, facet_arg, facet_labels
975+
return decoded_data, labels, colours, facet_col, facet_arg, facet_labels
967976

968977
def _internal_plot_hmm_quantify(self, hmm, variable, labels, colours, facet_col, facet_arg, facet_labels,
969978
t_bin, func, t_column):
970979
""" internal method to calculate the average amount of each state for use in plot_hmm_quantify, plotly and seaborn """
971980

972-
decoded_data, labels, colours, facet_arg, facet_labels = self._internal_plot_decoder(hmm, variable, labels, colours, facet_col, facet_arg, facet_labels,
973-
t_bin, func, t_column)
981+
decoded_data, labels, colours, facet_col, facet_arg, facet_labels = self._internal_plot_decoder(hmm, variable, labels, colours,
982+
facet_col, facet_arg, facet_labels,
983+
t_bin, func, t_column)
974984

975985
# Count each state and find its fraction
976986
grouped_data = decoded_data.groupby([decoded_data.index, 'state'], sort=False).agg({'bin' : 'count'})
@@ -981,21 +991,22 @@ def _internal_plot_hmm_quantify(self, hmm, variable, labels, colours, facet_col,
981991
if facet_col:
982992
palette = self._get_colours(facet_labels)
983993
palette_dict = {name : self._check_grey(name, palette[c])[1] for c, name in enumerate(facet_labels)} # change to grey if control
984-
grouped_data = self.facet_merge(grouped_data, facet_col, facet_arg, facet_labels, hmm_labels = labels)
994+
grouped_data = self.facet_merge(grouped_data, decoded_data.meta, facet_col, facet_arg, facet_labels, hmm_labels = labels)
985995
else:
986996
palette = colours
987997
palette_dict = {name : self._check_grey(name, palette[c])[1] for c, name in enumerate(labels)} # change to grey if control
988998
hmm_dict = {k : v for k, v in zip(range(len(labels)), labels)}
989999
grouped_data['state'] = grouped_data['state'].map(hmm_dict)
9901000

991-
return grouped_data, labels, colours, facet_labels, palette_dict
1001+
return grouped_data, labels, colours, facet_col, facet_labels, palette_dict
9921002

9931003
def _internal_plot_hmm_quantify_length(self, hmm, variable, labels, colours, facet_col, facet_arg, facet_labels,
9941004
t_bin, func, t_column):
9951005
""" internal method to calculate the average length of each state for use in plot_hmm_quantify_length, plotly and seaborn """
9961006

997-
decoded_data, labels, colours, facet_arg, facet_labels = self._internal_plot_decoder(hmm, variable, labels, colours, facet_col, facet_arg, facet_labels,
998-
t_bin, func, t_column)
1007+
decoded_data, labels, colours, facet_col, facet_arg, facet_labels = self._internal_plot_decoder(hmm, variable, labels, colours,
1008+
facet_col, facet_arg, facet_labels,
1009+
t_bin, func, t_column)
9991010

10001011
# get each specimens states time series to find lengths
10011012
states = decoded_data.groupby(decoded_data.index, sort=False)['state'].apply(list)
@@ -1012,21 +1023,22 @@ def _internal_plot_hmm_quantify_length(self, hmm, variable, labels, colours, fac
10121023
if facet_col:
10131024
palette = self._get_colours(facet_labels)
10141025
palette_dict = {name : self._check_grey(name, palette[c])[1] for c, name in enumerate(facet_labels)} # change to grey if control
1015-
grouped_data = self.facet_merge(grouped_data, facet_col, facet_arg, facet_labels, hmm_labels = labels)
1026+
grouped_data = self.facet_merge(grouped_data, decoded_data.meta, facet_col, facet_arg, facet_labels, hmm_labels = labels)
10161027
else:
10171028
palette = colours
10181029
palette_dict = {name : self._check_grey(name, palette[c])[1] for c, name in enumerate(labels)} # change to grey if control
10191030
hmm_dict = {k : v for k, v in zip(range(len(labels)), labels)}
10201031
grouped_data['state'] = grouped_data['state'].map(hmm_dict)
10211032

1022-
return grouped_data, labels, colours, facet_labels, palette_dict
1033+
return grouped_data, labels, colours, facet_col, facet_labels, palette_dict
10231034

10241035
def _internal_plot_hmm_quantify_length_min_max(self, hmm, variable, labels, colours, facet_col, facet_arg, facet_labels,
10251036
t_bin, func, t_column):
10261037
""" internal method to calculate the average length of each state for use in plot_hmm_quantify_length, plotly and seaborn """
10271038

1028-
decoded_data, labels, colours, facet_arg, facet_labels = self._internal_plot_decoder(hmm, variable, labels, colours, facet_col, facet_arg, facet_labels,
1029-
t_bin, func, t_column, rm = True)
1039+
decoded_data, labels, colours, facet_col, facet_arg, facet_labels = self._internal_plot_decoder(hmm, variable, labels, colours,
1040+
facet_col, facet_arg, facet_labels,
1041+
t_bin, func, t_column, rm = True)
10301042

10311043
# get each specimens states time series to find lengths
10321044
states = decoded_data.groupby(decoded_data.index, sort=False)['state'].apply(list)
@@ -1043,21 +1055,22 @@ def _internal_plot_hmm_quantify_length_min_max(self, hmm, variable, labels, colo
10431055
if facet_col:
10441056
palette = self._get_colours(facet_labels)
10451057
palette_dict = {name : self._check_grey(name, palette[c])[1] for c, name in enumerate(facet_labels)} # change to grey if control
1046-
grouped_data = self.facet_merge(grouped_data, facet_col, facet_arg, facet_labels, hmm_labels = labels)
1058+
grouped_data = self.facet_merge(grouped_data, decoded_data.meta, facet_col, facet_arg, facet_labels, hmm_labels = labels)
10471059
else:
10481060
palette = colours
10491061
palette_dict = {name : self._check_grey(name, palette[c])[1] for c, name in enumerate(labels)} # change to grey if control
10501062
hmm_dict = {k : v for k, v in zip(range(len(labels)), labels)}
10511063
grouped_data['state'] = grouped_data['state'].map(hmm_dict)
10521064

1053-
return grouped_data, labels, colours, facet_labels, palette_dict
1065+
return grouped_data, labels, colours, facet_col, facet_labels, palette_dict
10541066

10551067
def _internal_plot_hmm_quantify_transition(self, hmm, variable, labels, colours, facet_col, facet_arg, facet_labels,
10561068
t_bin, func, t_column):
10571069
""" An internal method to find the % of transtions into a state occur per state per individual """
10581070

1059-
decoded_data, labels, colours, facet_arg, facet_labels = self._internal_plot_decoder(hmm, variable, labels, colours, facet_col, facet_arg, facet_labels,
1060-
t_bin, func, t_column, rm = True)
1071+
decoded_data, labels, colours, facet_col, facet_arg, facet_labels = self._internal_plot_decoder(hmm, variable, labels, colours,
1072+
facet_col, facet_arg, facet_labels,
1073+
t_bin, func, t_column, rm = True)
10611074

10621075
# get each specimens states time series to find lengths
10631076
states = decoded_data.groupby(decoded_data.index, sort=False)['state'].apply(list)
@@ -1074,11 +1087,11 @@ def _internal_plot_hmm_quantify_transition(self, hmm, variable, labels, colours,
10741087
if facet_col:
10751088
palette = self._get_colours(facet_labels)
10761089
palette_dict = {name : self._check_grey(name, palette[c])[1] for c, name in enumerate(facet_labels)} # change to grey if control
1077-
grouped_data = self.facet_merge(grouped_data, facet_col, facet_arg, facet_labels, hmm_labels = labels)
1090+
grouped_data = self.facet_merge(grouped_data, decoded_data.meta, facet_col, facet_arg, facet_labels, hmm_labels = labels)
10781091
else:
10791092
palette = colours
10801093
palette_dict = {name : self._check_grey(name, palette[c])[1] for c, name in enumerate(labels)} # change to grey if control
10811094
hmm_dict = {k : v for k, v in zip(range(len(labels)), labels)}
10821095
grouped_data['state'] = grouped_data['state'].map(hmm_dict)
10831096

1084-
return grouped_data, labels, colours, facet_labels, palette_dict
1097+
return grouped_data, labels, colours, facet_col, facet_labels, palette_dict

0 commit comments

Comments
 (0)