Skip to content

Commit dac88c8

Browse files
committed
updates to all plot_hmm quantify methods
1 parent 4003a3f commit dac88c8

File tree

5 files changed

+635
-10616
lines changed

5 files changed

+635
-10616
lines changed

src/ethoscopy/behavpy_draw.py

+110-5
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from ethoscopy.behavpy_core import behavpy_core
1717
from ethoscopy.misc.bootstrap_CI import bootstrap
1818
from ethoscopy.misc.general_functions import concat
19+
from ethoscopy.misc.hmm_functions import hmm_pct_transition, hmm_mean_length, hmm_pct_state
1920

2021
class behavpy_draw(behavpy_core):
2122
"""
@@ -849,14 +850,17 @@ def _internal_plot_anticipation_score(self, variable, facet_col, facet_arg, face
849850

850851
return grouped_data, palette_dict, facet_labels
851852

852-
def _internal_plot_hmm_quantify(self, hmm, variable, labels, colours, facet_col, facet_arg, facet_labels,
853-
t_bin, func, t_column):
854-
""" internal method to calculate the average amount of each state for use in plot_hmm_quantify, plotly and seaborn """
855-
853+
def _internal_plot_decoder(self, hmm, variable, labels, colours, facet_col, facet_arg, facet_labels,
854+
t_bin, func, t_column, rm=False):
855+
""" contains the first part of the internal plotters for HMM quant plots """
856856
labels, colours = self._check_hmm_shape(hm = hmm, lab = labels, col = colours)
857857
facet_arg, facet_labels, h_list, b_list = self._check_lists_hmm(facet_col, facet_arg, facet_labels, hmm, t_bin)
858858

859-
data = self.copy(deep=True)
859+
if rm:
860+
# remove the first and last bout to reduce errors and also copy the data
861+
data = self.remove_first_last_bout(variable=variable)
862+
else:
863+
data = self.copy(deep=True)
860864

861865
# takes subset of data if requested
862866
if facet_col and facet_arg:
@@ -871,6 +875,15 @@ def _internal_plot_hmm_quantify(self, hmm, variable, labels, colours, facet_col,
871875
else:
872876
decoded_data = concat(*[self.__class__(self._hmm_decode(data.xmv(facet_col, arg), h, b, variable, func, t_column, return_type='table'), mdata.meta, check=True) for arg, h, b in zip(facet_arg, h_list, b_list)])
873877

878+
return decoded_data, labels, colours, facet_arg, facet_labels
879+
880+
def _internal_plot_hmm_quantify(self, hmm, variable, labels, colours, facet_col, facet_arg, facet_labels,
881+
t_bin, func, t_column):
882+
""" internal method to calculate the average amount of each state for use in plot_hmm_quantify, plotly and seaborn """
883+
884+
decoded_data, labels, colours, facet_arg, facet_labels = self._internal_plot_decoder(hmm, variable, labels, colours, facet_col, facet_arg, facet_labels,
885+
t_bin, func, t_column)
886+
874887
# Count each state and find its fraction
875888
grouped_data = decoded_data.groupby([decoded_data.index, 'state'], sort=False).agg({'bin' : 'count'})
876889
grouped_data = grouped_data.join(decoded_data.groupby('id', sort=False).agg({'previous_state':'count'}))
@@ -888,3 +901,95 @@ def _internal_plot_hmm_quantify(self, hmm, variable, labels, colours, facet_col,
888901
grouped_data['state'] = grouped_data['state'].map(hmm_dict)
889902

890903
return grouped_data, labels, colours, facet_labels, palette_dict
904+
905+
def _internal_plot_hmm_quantify_length(self, hmm, variable, labels, colours, facet_col, facet_arg, facet_labels,
906+
t_bin, func, t_column):
907+
""" internal method to calculate the average length of each state for use in plot_hmm_quantify_length, plotly and seaborn """
908+
909+
decoded_data, labels, colours, facet_arg, facet_labels = self._internal_plot_decoder(hmm, variable, labels, colours, facet_col, facet_arg, facet_labels,
910+
t_bin, func, t_column)
911+
912+
# get each specimens states time series to find lengths
913+
states = decoded_data.groupby(decoded_data.index, sort=False)['state'].apply(list)
914+
df_lengths = []
915+
for l, id in zip(states, states.index):
916+
length = hmm_mean_length(l, delta_t = t_bin)
917+
length['id'] = [id] * len(length)
918+
df_lengths.append(length)
919+
920+
grouped_data = pd.concat(df_lengths)
921+
grouped_data.rename(columns={'mean_length' : 'Length of state bout (mins)'}, inplace=True)
922+
grouped_data.set_index('id', inplace=True)
923+
924+
if facet_col:
925+
palette = self._get_colours(facet_labels)
926+
palette_dict = {name : self._check_grey(name, palette[c])[1] for c, name in enumerate(facet_labels)} # change to grey if control
927+
grouped_data = self.facet_merge(grouped_data, facet_col, facet_arg, facet_labels, hmm_labels = labels)
928+
else:
929+
palette = colours
930+
palette_dict = {name : self._check_grey(name, palette[c])[1] for c, name in enumerate(labels)} # change to grey if control
931+
hmm_dict = {k : v for k, v in zip(range(len(labels)), labels)}
932+
grouped_data['state'] = grouped_data['state'].map(hmm_dict)
933+
934+
return grouped_data, labels, colours, facet_labels, palette_dict
935+
936+
def _internal_plot_hmm_quantify_length_min_max(self, hmm, variable, labels, colours, facet_col, facet_arg, facet_labels,
937+
t_bin, func, t_column):
938+
""" internal method to calculate the average length of each state for use in plot_hmm_quantify_length, plotly and seaborn """
939+
940+
decoded_data, labels, colours, facet_arg, facet_labels = self._internal_plot_decoder(hmm, variable, labels, colours, facet_col, facet_arg, facet_labels,
941+
t_bin, func, t_column, rm = True)
942+
943+
# get each specimens states time series to find lengths
944+
states = decoded_data.groupby(decoded_data.index, sort=False)['state'].apply(list)
945+
df_lengths = []
946+
for l, id in zip(states, states.index):
947+
length = hmm_mean_length(l, delta_t = t_bin, raw=True)
948+
length['id'] = [id] * len(length)
949+
df_lengths.append(length)
950+
951+
grouped_data = pd.concat(df_lengths)
952+
grouped_data.rename(columns={'length_adjusted' : 'Length of state bout (mins)'}, inplace=True)
953+
grouped_data.set_index('id', inplace=True)
954+
955+
if facet_col:
956+
palette = self._get_colours(facet_labels)
957+
palette_dict = {name : self._check_grey(name, palette[c])[1] for c, name in enumerate(facet_labels)} # change to grey if control
958+
grouped_data = self.facet_merge(grouped_data, facet_col, facet_arg, facet_labels, hmm_labels = labels)
959+
else:
960+
palette = colours
961+
palette_dict = {name : self._check_grey(name, palette[c])[1] for c, name in enumerate(labels)} # change to grey if control
962+
hmm_dict = {k : v for k, v in zip(range(len(labels)), labels)}
963+
grouped_data['state'] = grouped_data['state'].map(hmm_dict)
964+
965+
return grouped_data, labels, colours, facet_labels, palette_dict
966+
967+
def _internal_plot_hmm_quantify_transition(self, hmm, variable, labels, colours, facet_col, facet_arg, facet_labels,
968+
t_bin, func, t_column):
969+
970+
decoded_data, labels, colours, facet_arg, facet_labels = self._internal_plot_decoder(hmm, variable, labels, colours, facet_col, facet_arg, facet_labels,
971+
t_bin, func, t_column, rm = True)
972+
973+
# get each specimens states time series to find lengths
974+
states = decoded_data.groupby(decoded_data.index, sort=False)['state'].apply(list)
975+
df_list = []
976+
for l, id in zip(states, states.index):
977+
length = hmm_pct_transition(l, total_states=list(range(len(labels))))
978+
length['id'] = [id] * len(length)
979+
df_list.append(length)
980+
981+
grouped_data = pd.concat(df_list)
982+
grouped_data = grouped_data.set_index('id').stack().reset_index().set_index('id')
983+
grouped_data.rename(columns={'level_1' : 'state', 0 : 'Fraction of transitions into each state'}, inplace=True)
984+
985+
if facet_col:
986+
palette = self._get_colours(facet_labels)
987+
palette_dict = {name : self._check_grey(name, palette[c])[1] for c, name in enumerate(facet_labels)} # change to grey if control
988+
grouped_data = self.facet_merge(grouped_data, facet_col, facet_arg, facet_labels, hmm_labels = labels)
989+
else:
990+
palette = colours
991+
palette_dict = {name : self._check_grey(name, palette[c])[1] for c, name in enumerate(labels)} # change to grey if control
992+
hmm_dict = {k : v for k, v in zip(range(len(labels)), labels)}
993+
grouped_data['state'] = grouped_data['state'].map(hmm_dict)
994+
995+
return grouped_data, labels, colours, facet_labels, palette_dict

0 commit comments

Comments
 (0)