@@ -166,45 +166,36 @@ def _check_lists_hmm(self, f_col, f_arg, f_lab, h, b):
166
166
else :
167
167
f_arg = [None ] * len (h )
168
168
f_lab = [f'HMM-{ i + 1 } ' for i in range (len (h ))]
169
- b_list = [b ] * len (h )
170
- return f_arg , f_lab , h , b_list
171
-
172
- if f_col is not None :
173
- if f_arg is None :
174
- f_arg = list (set (self .meta [f_col ].tolist ()))
175
- if f_lab is None :
176
- string_args = []
177
- for i in f_arg :
178
- if i not in self .meta [f_col ].tolist ():
179
- raise KeyError (f'Argument "{ i } " is not in the meta column { f_col } ' )
180
- string_args .append (str (i ))
181
- f_lab = string_args
182
- elif len (f_arg ) != len (f_lab ):
183
- print ("The facet labels don't match the length of the variables in the column. Using column variables instead" )
184
- f_lab = f_arg
185
- else :
186
- if f_lab is None :
187
- string_args = []
188
- for i in f_arg :
189
- string_args .append (str (i ))
190
- f_lab = string_args
191
- elif len (f_arg ) != len (f_lab ):
192
- print ("The facet labels don't match the entered facet arguments in length. Using column variables instead" )
193
- f_lab = f_arg
194
- else :
195
- f_arg = [None ]
196
- if f_lab is None :
197
- f_lab = ['' ]
169
+ return f_arg , f_lab , h , b
198
170
199
- if isinstance (h , list ) is False :
200
- h_list = [h ]
201
- b_list = [b ]
202
- if len (h_list ) != len (f_arg ):
203
- h_list = [h_list [0 ]] * len (f_arg )
204
- if len (b_list ) != len (f_arg ):
205
- b_list = [b_list [0 ]] * len (f_arg )
171
+ else :
172
+ h_list = h
173
+ b_list = b
206
174
207
- return f_arg , f_lab , h_list , b_list
175
+ if f_col is None : # is no facet column, then return fake lists
176
+ f_arg = [None ]
177
+ f_lab = ['' ]
178
+ return f_arg , f_lab , h_list , b_list
179
+
180
+ if f_arg is not None : # check if all the facet args are in the meta column
181
+ for i in f_arg :
182
+ if i not in self .meta [f_col ].tolist ():
183
+ raise KeyError (f'Argument "{ i } " is not in the meta column { f_col } ' )
184
+
185
+ if f_col is not None and f_arg is not None and f_lab is not None : # is user provides all, just check for length match
186
+ if len (f_arg ) != len (f_lab ):
187
+ print ("The facet labels don't match the length of the variables in the column. Using column variables names instead" )
188
+ f_lab = [str (arg ) for arg in f_arg ]
189
+ return f_arg , f_lab , h_list , b_list
190
+
191
+ if f_col is not None and f_arg is not None and f_lab is None : # if user provides a facet column and args but no labels
192
+ f_lab = [str (arg ) for arg in f_arg ]
193
+ return f_arg , f_lab , h_list , b_list
194
+
195
+ if f_col is not None and f_arg is None and f_lab is None : # if user provides a facet column but no args or labels
196
+ f_arg = list (set (self .meta [f_col ].tolist ()))
197
+ f_lab = [str (arg ) for arg in f_arg ]
198
+ return f_arg , f_lab , h_list , b_list
208
199
209
200
@staticmethod
210
201
def _zscore_bootstrap (array :np .array , z_score :bool = True , second_array :np .array = None , min_max :bool = False ):
@@ -391,10 +382,11 @@ def _check_grey(name, col, response = False):
391
382
392
383
# GENERAL PLOT HELPERS
393
384
394
- def facet_merge (self , data , facet_col , facet_arg , facet_labels , hmm_labels = None ):
385
+ @staticmethod
386
+ def facet_merge (data , meta , facet_col , facet_arg , facet_labels , hmm_labels = None ):
395
387
""" A internal method for joining a metadata column with its data for plotting purposes """
396
388
# merge the facet_col column and replace with the labels
397
- data = data .join (self . meta [[facet_col ]])
389
+ data = data .join (meta [[facet_col ]])
398
390
data [facet_col ] = data [facet_col ].astype ('category' )
399
391
map_dict = {k : v for k , v in zip (facet_arg , facet_labels )}
400
392
data [facet_col ] = data [facet_col ].map (map_dict )
@@ -559,7 +551,7 @@ def alter_merge(response, mov, tb):
559
551
grouped_data ['state' ] = grouped_data ['state' ].map (hmm_dict )
560
552
grouped_data ['' ] = grouped_data ['has_interacted' ]
561
553
else :
562
- grouped_data = self .facet_merge (grouped_data , facet_col , facet_arg , facet_labels , hmm_labels = labels )
554
+ grouped_data = self .facet_merge (grouped_data , facet_col , facet_arg , facet_labels , hmm_labels = labels ) # chnage to have meta given as arg
563
555
grouped_data [facet_col ] = grouped_data [facet_col ].astype ('str' )
564
556
grouped_data [facet_col ] = grouped_data [facet_col ] + " " + grouped_data ['has_interacted' ]
565
557
@@ -798,7 +790,7 @@ def get_response(int_data, ptype, time_window_length, resp_col, t_col):
798
790
# map stim names and create column to facet by
799
791
grouped_data [interaction_id_col ] = grouped_data [interaction_id_col ].map (map_dict )
800
792
if facet_col :
801
- grouped_data = self .facet_merge (grouped_data , facet_col , facet_arg , facet_labels )
793
+ grouped_data = self .facet_merge (grouped_data , self . meta , facet_col , facet_arg , facet_labels )
802
794
grouped_data [facet_col ] = grouped_data [facet_col ].astype (str ) + "-" + grouped_data [interaction_id_col ]
803
795
else :
804
796
facet_col = 'stim_type'
@@ -841,7 +833,7 @@ def _internal_plot_quantify(self, variable, facet_col, facet_arg, facet_labels,
841
833
if facet_col :
842
834
palette = self ._get_colours (facet_labels )
843
835
palette_dict = {name : self ._check_grey (name , palette [c ])[1 ] for c , name in enumerate (facet_labels )} # change to grey if control
844
- grouped_data = self .facet_merge (grouped_data , facet_col , facet_arg , facet_labels )
836
+ grouped_data = self .facet_merge (grouped_data , self . meta , facet_col , facet_arg , facet_labels )
845
837
else :
846
838
palette = self ._get_colours (variable )
847
839
palette_dict = {name : self ._check_grey (name , palette [c ])[1 ] for c , name in enumerate (variable )} # change to grey if control
@@ -910,7 +902,7 @@ def _internal_plot_day_night(self, variable, facet_col, facet_arg, facet_labels,
910
902
911
903
palette = self ._get_colours (facet_labels )
912
904
palette_dict = {name : self ._check_grey (name , palette [c ])[1 ] for c , name in enumerate (facet_labels )} # change to grey if control
913
- if facet_col : grouped_data = self .facet_merge (grouped_data , facet_col , facet_arg , facet_labels )
905
+ if facet_col : grouped_data = self .facet_merge (grouped_data , self . meta , facet_col , facet_arg , facet_labels )
914
906
915
907
return grouped_data , palette_dict , facet_labels
916
908
@@ -933,13 +925,14 @@ def _internal_plot_anticipation_score(self, variable, facet_col, facet_arg, face
933
925
palette_dict = {name : self ._check_grey (name , palette [c ])[1 ] for c , name in enumerate (facet_labels )} # change to grey if control
934
926
935
927
if facet_col :
936
- grouped_data = self .facet_merge (grouped_data , facet_col , facet_arg , facet_labels )
928
+ grouped_data = self .facet_merge (grouped_data , self . meta , facet_col , facet_arg , facet_labels )
937
929
938
930
return grouped_data , palette_dict , facet_labels
939
931
940
932
def _internal_plot_decoder (self , hmm , variable , labels , colours , facet_col , facet_arg , facet_labels ,
941
933
t_bin , func , t_column , rm = False ):
942
934
""" contains the first part of the internal plotters for HMM quant plots """
935
+
943
936
labels , colours = self ._check_hmm_shape (hm = hmm , lab = labels , col = colours )
944
937
facet_arg , facet_labels , h_list , b_list = self ._check_lists_hmm (facet_col , facet_arg , facet_labels , hmm , t_bin )
945
938
@@ -954,23 +947,40 @@ def _internal_plot_decoder(self, hmm, variable, labels, colours, facet_col, face
954
947
# takes subselection of df that contains the specified facet columns
955
948
data = self .xmv (facet_col , facet_arg )
956
949
950
+ def hmm_list_facet (data , meta , facet_label , ind ):
951
+ d = data .copy (deep = True )
952
+ m = meta .copy (deep = True )
953
+
954
+ d .id = d .id + f'_{ ind } '
955
+ m .index = m .index + f'_{ ind } '
956
+ m ['HMM' ] = facet_label
957
+ return self .__class__ (d , m , check = True )
958
+
957
959
if facet_col is None : # decode the whole dataset
958
- decoded_data = self .__class__ (self ._hmm_decode (data , hmm , t_bin , variable , func , t_column , return_type = 'table' ), data .meta , check = True )
959
- else :
960
- if isinstance (hmm , list ) is False : # if only 1 hmm but is faceted, decode as whole for efficiency
961
- decoded_data = self .__class__ (self ._hmm_decode (data , hmm , t_bin , variable , func , t_column , return_type = 'table' ), data .meta , check = True )
960
+ if isinstance (hmm , list ):
961
+ decoded_data = concat (* [hmm_list_facet (self ._hmm_decode (data , h , b , variable , func , t_column , return_type = 'table' ), data .meta , f , c + 1 )
962
+ for c , (h , b , f ) in enumerate (zip (h_list , b_list , facet_labels ))])
963
+ facet_arg = facet_labels
964
+ facet_col = 'HMM'
962
965
else :
966
+ decoded_data = self .__class__ (self ._hmm_decode (data , hmm , t_bin , variable , func , t_column , return_type = 'table' ), data .meta , check = True )
967
+ else :
968
+ if isinstance (hmm , list ):
963
969
decoded_data = concat (* [self .__class__ (self ._hmm_decode (data .xmv (facet_col , arg ), h , b , variable , func , t_column , return_type = 'table' ),
964
970
data .meta , check = True ) for arg , h , b in zip (facet_arg , h_list , b_list )])
971
+ else : # if only 1 hmm but is faceted, decode as whole for efficiency
972
+ decoded_data = self .__class__ (self ._hmm_decode (data , hmm , t_bin , variable , func , t_column , return_type = 'table' ), data .meta , check = True )
973
+
965
974
966
- return decoded_data , labels , colours , facet_arg , facet_labels
975
+ return decoded_data , labels , colours , facet_col , facet_arg , facet_labels
967
976
968
977
def _internal_plot_hmm_quantify (self , hmm , variable , labels , colours , facet_col , facet_arg , facet_labels ,
969
978
t_bin , func , t_column ):
970
979
""" internal method to calculate the average amount of each state for use in plot_hmm_quantify, plotly and seaborn """
971
980
972
- decoded_data , labels , colours , facet_arg , facet_labels = self ._internal_plot_decoder (hmm , variable , labels , colours , facet_col , facet_arg , facet_labels ,
973
- t_bin , func , t_column )
981
+ decoded_data , labels , colours , facet_col , facet_arg , facet_labels = self ._internal_plot_decoder (hmm , variable , labels , colours ,
982
+ facet_col , facet_arg , facet_labels ,
983
+ t_bin , func , t_column )
974
984
975
985
# Count each state and find its fraction
976
986
grouped_data = decoded_data .groupby ([decoded_data .index , 'state' ], sort = False ).agg ({'bin' : 'count' })
@@ -981,21 +991,22 @@ def _internal_plot_hmm_quantify(self, hmm, variable, labels, colours, facet_col,
981
991
if facet_col :
982
992
palette = self ._get_colours (facet_labels )
983
993
palette_dict = {name : self ._check_grey (name , palette [c ])[1 ] for c , name in enumerate (facet_labels )} # change to grey if control
984
- grouped_data = self .facet_merge (grouped_data , facet_col , facet_arg , facet_labels , hmm_labels = labels )
994
+ grouped_data = self .facet_merge (grouped_data , decoded_data . meta , facet_col , facet_arg , facet_labels , hmm_labels = labels )
985
995
else :
986
996
palette = colours
987
997
palette_dict = {name : self ._check_grey (name , palette [c ])[1 ] for c , name in enumerate (labels )} # change to grey if control
988
998
hmm_dict = {k : v for k , v in zip (range (len (labels )), labels )}
989
999
grouped_data ['state' ] = grouped_data ['state' ].map (hmm_dict )
990
1000
991
- return grouped_data , labels , colours , facet_labels , palette_dict
1001
+ return grouped_data , labels , colours , facet_col , facet_labels , palette_dict
992
1002
993
1003
def _internal_plot_hmm_quantify_length (self , hmm , variable , labels , colours , facet_col , facet_arg , facet_labels ,
994
1004
t_bin , func , t_column ):
995
1005
""" internal method to calculate the average length of each state for use in plot_hmm_quantify_length, plotly and seaborn """
996
1006
997
- decoded_data , labels , colours , facet_arg , facet_labels = self ._internal_plot_decoder (hmm , variable , labels , colours , facet_col , facet_arg , facet_labels ,
998
- t_bin , func , t_column )
1007
+ decoded_data , labels , colours , facet_col , facet_arg , facet_labels = self ._internal_plot_decoder (hmm , variable , labels , colours ,
1008
+ facet_col , facet_arg , facet_labels ,
1009
+ t_bin , func , t_column )
999
1010
1000
1011
# get each specimens states time series to find lengths
1001
1012
states = decoded_data .groupby (decoded_data .index , sort = False )['state' ].apply (list )
@@ -1012,21 +1023,22 @@ def _internal_plot_hmm_quantify_length(self, hmm, variable, labels, colours, fac
1012
1023
if facet_col :
1013
1024
palette = self ._get_colours (facet_labels )
1014
1025
palette_dict = {name : self ._check_grey (name , palette [c ])[1 ] for c , name in enumerate (facet_labels )} # change to grey if control
1015
- grouped_data = self .facet_merge (grouped_data , facet_col , facet_arg , facet_labels , hmm_labels = labels )
1026
+ grouped_data = self .facet_merge (grouped_data , decoded_data . meta , facet_col , facet_arg , facet_labels , hmm_labels = labels )
1016
1027
else :
1017
1028
palette = colours
1018
1029
palette_dict = {name : self ._check_grey (name , palette [c ])[1 ] for c , name in enumerate (labels )} # change to grey if control
1019
1030
hmm_dict = {k : v for k , v in zip (range (len (labels )), labels )}
1020
1031
grouped_data ['state' ] = grouped_data ['state' ].map (hmm_dict )
1021
1032
1022
- return grouped_data , labels , colours , facet_labels , palette_dict
1033
+ return grouped_data , labels , colours , facet_col , facet_labels , palette_dict
1023
1034
1024
1035
def _internal_plot_hmm_quantify_length_min_max (self , hmm , variable , labels , colours , facet_col , facet_arg , facet_labels ,
1025
1036
t_bin , func , t_column ):
1026
1037
""" internal method to calculate the average length of each state for use in plot_hmm_quantify_length, plotly and seaborn """
1027
1038
1028
- decoded_data , labels , colours , facet_arg , facet_labels = self ._internal_plot_decoder (hmm , variable , labels , colours , facet_col , facet_arg , facet_labels ,
1029
- t_bin , func , t_column , rm = True )
1039
+ decoded_data , labels , colours , facet_col , facet_arg , facet_labels = self ._internal_plot_decoder (hmm , variable , labels , colours ,
1040
+ facet_col , facet_arg , facet_labels ,
1041
+ t_bin , func , t_column , rm = True )
1030
1042
1031
1043
# get each specimens states time series to find lengths
1032
1044
states = decoded_data .groupby (decoded_data .index , sort = False )['state' ].apply (list )
@@ -1043,21 +1055,22 @@ def _internal_plot_hmm_quantify_length_min_max(self, hmm, variable, labels, colo
1043
1055
if facet_col :
1044
1056
palette = self ._get_colours (facet_labels )
1045
1057
palette_dict = {name : self ._check_grey (name , palette [c ])[1 ] for c , name in enumerate (facet_labels )} # change to grey if control
1046
- grouped_data = self .facet_merge (grouped_data , facet_col , facet_arg , facet_labels , hmm_labels = labels )
1058
+ grouped_data = self .facet_merge (grouped_data , decoded_data . meta , facet_col , facet_arg , facet_labels , hmm_labels = labels )
1047
1059
else :
1048
1060
palette = colours
1049
1061
palette_dict = {name : self ._check_grey (name , palette [c ])[1 ] for c , name in enumerate (labels )} # change to grey if control
1050
1062
hmm_dict = {k : v for k , v in zip (range (len (labels )), labels )}
1051
1063
grouped_data ['state' ] = grouped_data ['state' ].map (hmm_dict )
1052
1064
1053
- return grouped_data , labels , colours , facet_labels , palette_dict
1065
+ return grouped_data , labels , colours , facet_col , facet_labels , palette_dict
1054
1066
1055
1067
def _internal_plot_hmm_quantify_transition (self , hmm , variable , labels , colours , facet_col , facet_arg , facet_labels ,
1056
1068
t_bin , func , t_column ):
1057
1069
""" An internal method to find the % of transtions into a state occur per state per individual """
1058
1070
1059
- decoded_data , labels , colours , facet_arg , facet_labels = self ._internal_plot_decoder (hmm , variable , labels , colours , facet_col , facet_arg , facet_labels ,
1060
- t_bin , func , t_column , rm = True )
1071
+ decoded_data , labels , colours , facet_col , facet_arg , facet_labels = self ._internal_plot_decoder (hmm , variable , labels , colours ,
1072
+ facet_col , facet_arg , facet_labels ,
1073
+ t_bin , func , t_column , rm = True )
1061
1074
1062
1075
# get each specimens states time series to find lengths
1063
1076
states = decoded_data .groupby (decoded_data .index , sort = False )['state' ].apply (list )
@@ -1074,11 +1087,11 @@ def _internal_plot_hmm_quantify_transition(self, hmm, variable, labels, colours,
1074
1087
if facet_col :
1075
1088
palette = self ._get_colours (facet_labels )
1076
1089
palette_dict = {name : self ._check_grey (name , palette [c ])[1 ] for c , name in enumerate (facet_labels )} # change to grey if control
1077
- grouped_data = self .facet_merge (grouped_data , facet_col , facet_arg , facet_labels , hmm_labels = labels )
1090
+ grouped_data = self .facet_merge (grouped_data , decoded_data . meta , facet_col , facet_arg , facet_labels , hmm_labels = labels )
1078
1091
else :
1079
1092
palette = colours
1080
1093
palette_dict = {name : self ._check_grey (name , palette [c ])[1 ] for c , name in enumerate (labels )} # change to grey if control
1081
1094
hmm_dict = {k : v for k , v in zip (range (len (labels )), labels )}
1082
1095
grouped_data ['state' ] = grouped_data ['state' ].map (hmm_dict )
1083
1096
1084
- return grouped_data , labels , colours , facet_labels , palette_dict
1097
+ return grouped_data , labels , colours , facet_col , facet_labels , palette_dict
0 commit comments