16
16
from ethoscopy .behavpy_core import behavpy_core
17
17
from ethoscopy .misc .bootstrap_CI import bootstrap
18
18
from ethoscopy .misc .general_functions import concat
19
+ from ethoscopy .misc .hmm_functions import hmm_pct_transition , hmm_mean_length , hmm_pct_state
19
20
20
21
class behavpy_draw (behavpy_core ):
21
22
"""
@@ -849,14 +850,17 @@ def _internal_plot_anticipation_score(self, variable, facet_col, facet_arg, face
849
850
850
851
return grouped_data , palette_dict , facet_labels
851
852
852
- def _internal_plot_hmm_quantify (self , hmm , variable , labels , colours , facet_col , facet_arg , facet_labels ,
853
- t_bin , func , t_column ):
854
- """ internal method to calculate the average amount of each state for use in plot_hmm_quantify, plotly and seaborn """
855
-
853
+ def _internal_plot_decoder (self , hmm , variable , labels , colours , facet_col , facet_arg , facet_labels ,
854
+ t_bin , func , t_column , rm = False ):
855
+ """ contains the first part of the internal plotters for HMM quant plots """
856
856
labels , colours = self ._check_hmm_shape (hm = hmm , lab = labels , col = colours )
857
857
facet_arg , facet_labels , h_list , b_list = self ._check_lists_hmm (facet_col , facet_arg , facet_labels , hmm , t_bin )
858
858
859
- data = self .copy (deep = True )
859
+ if rm :
860
+ # remove the first and last bout to reduce errors and also copy the data
861
+ data = self .remove_first_last_bout (variable = variable )
862
+ else :
863
+ data = self .copy (deep = True )
860
864
861
865
# takes subset of data if requested
862
866
if facet_col and facet_arg :
@@ -871,6 +875,15 @@ def _internal_plot_hmm_quantify(self, hmm, variable, labels, colours, facet_col,
871
875
else :
872
876
decoded_data = concat (* [self .__class__ (self ._hmm_decode (data .xmv (facet_col , arg ), h , b , variable , func , t_column , return_type = 'table' ), mdata .meta , check = True ) for arg , h , b in zip (facet_arg , h_list , b_list )])
873
877
878
+ return decoded_data , labels , colours , facet_arg , facet_labels
879
+
880
+ def _internal_plot_hmm_quantify (self , hmm , variable , labels , colours , facet_col , facet_arg , facet_labels ,
881
+ t_bin , func , t_column ):
882
+ """ internal method to calculate the average amount of each state for use in plot_hmm_quantify, plotly and seaborn """
883
+
884
+ decoded_data , labels , colours , facet_arg , facet_labels = self ._internal_plot_decoder (hmm , variable , labels , colours , facet_col , facet_arg , facet_labels ,
885
+ t_bin , func , t_column )
886
+
874
887
# Count each state and find its fraction
875
888
grouped_data = decoded_data .groupby ([decoded_data .index , 'state' ], sort = False ).agg ({'bin' : 'count' })
876
889
grouped_data = grouped_data .join (decoded_data .groupby ('id' , sort = False ).agg ({'previous_state' :'count' }))
@@ -888,3 +901,95 @@ def _internal_plot_hmm_quantify(self, hmm, variable, labels, colours, facet_col,
888
901
grouped_data ['state' ] = grouped_data ['state' ].map (hmm_dict )
889
902
890
903
return grouped_data , labels , colours , facet_labels , palette_dict
904
+
905
+ def _internal_plot_hmm_quantify_length (self , hmm , variable , labels , colours , facet_col , facet_arg , facet_labels ,
906
+ t_bin , func , t_column ):
907
+ """ internal method to calculate the average length of each state for use in plot_hmm_quantify_length, plotly and seaborn """
908
+
909
+ decoded_data , labels , colours , facet_arg , facet_labels = self ._internal_plot_decoder (hmm , variable , labels , colours , facet_col , facet_arg , facet_labels ,
910
+ t_bin , func , t_column )
911
+
912
+ # get each specimens states time series to find lengths
913
+ states = decoded_data .groupby (decoded_data .index , sort = False )['state' ].apply (list )
914
+ df_lengths = []
915
+ for l , id in zip (states , states .index ):
916
+ length = hmm_mean_length (l , delta_t = t_bin )
917
+ length ['id' ] = [id ] * len (length )
918
+ df_lengths .append (length )
919
+
920
+ grouped_data = pd .concat (df_lengths )
921
+ grouped_data .rename (columns = {'mean_length' : 'Length of state bout (mins)' }, inplace = True )
922
+ grouped_data .set_index ('id' , inplace = True )
923
+
924
+ if facet_col :
925
+ palette = self ._get_colours (facet_labels )
926
+ palette_dict = {name : self ._check_grey (name , palette [c ])[1 ] for c , name in enumerate (facet_labels )} # change to grey if control
927
+ grouped_data = self .facet_merge (grouped_data , facet_col , facet_arg , facet_labels , hmm_labels = labels )
928
+ else :
929
+ palette = colours
930
+ palette_dict = {name : self ._check_grey (name , palette [c ])[1 ] for c , name in enumerate (labels )} # change to grey if control
931
+ hmm_dict = {k : v for k , v in zip (range (len (labels )), labels )}
932
+ grouped_data ['state' ] = grouped_data ['state' ].map (hmm_dict )
933
+
934
+ return grouped_data , labels , colours , facet_labels , palette_dict
935
+
936
+ def _internal_plot_hmm_quantify_length_min_max (self , hmm , variable , labels , colours , facet_col , facet_arg , facet_labels ,
937
+ t_bin , func , t_column ):
938
+ """ internal method to calculate the average length of each state for use in plot_hmm_quantify_length, plotly and seaborn """
939
+
940
+ decoded_data , labels , colours , facet_arg , facet_labels = self ._internal_plot_decoder (hmm , variable , labels , colours , facet_col , facet_arg , facet_labels ,
941
+ t_bin , func , t_column , rm = True )
942
+
943
+ # get each specimens states time series to find lengths
944
+ states = decoded_data .groupby (decoded_data .index , sort = False )['state' ].apply (list )
945
+ df_lengths = []
946
+ for l , id in zip (states , states .index ):
947
+ length = hmm_mean_length (l , delta_t = t_bin , raw = True )
948
+ length ['id' ] = [id ] * len (length )
949
+ df_lengths .append (length )
950
+
951
+ grouped_data = pd .concat (df_lengths )
952
+ grouped_data .rename (columns = {'length_adjusted' : 'Length of state bout (mins)' }, inplace = True )
953
+ grouped_data .set_index ('id' , inplace = True )
954
+
955
+ if facet_col :
956
+ palette = self ._get_colours (facet_labels )
957
+ palette_dict = {name : self ._check_grey (name , palette [c ])[1 ] for c , name in enumerate (facet_labels )} # change to grey if control
958
+ grouped_data = self .facet_merge (grouped_data , facet_col , facet_arg , facet_labels , hmm_labels = labels )
959
+ else :
960
+ palette = colours
961
+ palette_dict = {name : self ._check_grey (name , palette [c ])[1 ] for c , name in enumerate (labels )} # change to grey if control
962
+ hmm_dict = {k : v for k , v in zip (range (len (labels )), labels )}
963
+ grouped_data ['state' ] = grouped_data ['state' ].map (hmm_dict )
964
+
965
+ return grouped_data , labels , colours , facet_labels , palette_dict
966
+
967
+ def _internal_plot_hmm_quantify_transition (self , hmm , variable , labels , colours , facet_col , facet_arg , facet_labels ,
968
+ t_bin , func , t_column ):
969
+
970
+ decoded_data , labels , colours , facet_arg , facet_labels = self ._internal_plot_decoder (hmm , variable , labels , colours , facet_col , facet_arg , facet_labels ,
971
+ t_bin , func , t_column , rm = True )
972
+
973
+ # get each specimens states time series to find lengths
974
+ states = decoded_data .groupby (decoded_data .index , sort = False )['state' ].apply (list )
975
+ df_list = []
976
+ for l , id in zip (states , states .index ):
977
+ length = hmm_pct_transition (l , total_states = list (range (len (labels ))))
978
+ length ['id' ] = [id ] * len (length )
979
+ df_list .append (length )
980
+
981
+ grouped_data = pd .concat (df_list )
982
+ grouped_data = grouped_data .set_index ('id' ).stack ().reset_index ().set_index ('id' )
983
+ grouped_data .rename (columns = {'level_1' : 'state' , 0 : 'Fraction of transitions into each state' }, inplace = True )
984
+
985
+ if facet_col :
986
+ palette = self ._get_colours (facet_labels )
987
+ palette_dict = {name : self ._check_grey (name , palette [c ])[1 ] for c , name in enumerate (facet_labels )} # change to grey if control
988
+ grouped_data = self .facet_merge (grouped_data , facet_col , facet_arg , facet_labels , hmm_labels = labels )
989
+ else :
990
+ palette = colours
991
+ palette_dict = {name : self ._check_grey (name , palette [c ])[1 ] for c , name in enumerate (labels )} # change to grey if control
992
+ hmm_dict = {k : v for k , v in zip (range (len (labels )), labels )}
993
+ grouped_data ['state' ] = grouped_data ['state' ].map (hmm_dict )
994
+
995
+ return grouped_data , labels , colours , facet_labels , palette_dict
0 commit comments