7
7
from colour import Color
8
8
from math import sqrt , floor , ceil
9
9
from scipy .stats import zscore
10
+ from functools import partial
10
11
11
12
#fig to img
12
13
import io
@@ -560,9 +561,9 @@ def _bouts_response(self, mov_df, hmm, variable, response_col, labels, colours,
560
561
561
562
hmm_dict = {k : v for k , v in zip (range (len (labels )), labels )}
562
563
grouped_data ['state' ] = grouped_data ['state' ].map (hmm_dict )
563
- grouped_data ['label_col' ] = grouped_data ['state' ] + " " + grouped_data ['has_interacted' ]
564
+ grouped_data ['label_col' ] = grouped_data ['state' ] + "- " + grouped_data ['has_interacted' ]
564
565
# create the order of plotting and double the colours to assign grey to false stimuli
565
- h_order = [f'{ lab } { ty } ' for lab in labels for ty in ["Spon. Mov." , "True Stimulus" ]]
566
+ h_order = [f'{ lab } - { ty } ' for lab in labels for ty in ["Spon. Mov." , "True Stimulus" ]]
566
567
palette = [x for xs in [[col , col ] for col in colours ] for x in xs ]
567
568
palette_dict = {name : self ._check_grey (name , palette [c ], response = True )[1 ] for c , name in enumerate (h_order )} # change to grey if control
568
569
@@ -619,24 +620,25 @@ def _internal_plot_response_overtime(self, t_bin_hours, response_col, interactio
619
620
# takes subset of data if requested
620
621
if facet_col and facet_arg :
621
622
data = self .xmv (facet_col , facet_arg )
622
- h_order = [f'{ lab } -{ ty } ' for lab in facet_labels for ty in ["Spon. Mov." , "True Stimulus" ]]
623
623
else :
624
624
data = self .copy (deep = True )
625
625
626
626
if len (set (data [interaction_id_col ])) == 1 : # if only stimulus type in the dataset
627
627
# get colours
628
628
palette = self ._get_colours (facet_labels )
629
+ h_order = [f'{ lab } -{ ty } ' for lab in facet_labels for ty in ["True Stimulus" ]]
630
+
629
631
# find the average response per hour per specimen
630
632
data = data .bin_time (response_col , (60 * 60 ) * t_bin_hours , function = 'mean' , t_column = t_column )
631
633
if facet_col and facet_arg :
632
634
data .meta ['new_facet' ] = data .meta [facet_col ] + '-' + 'True Stimulus'
633
635
else :
634
636
data .meta ['new_facet' ] = '-True Stimulus'
635
- h_order = [f'{ lab } -{ ty } ' for lab in facet_labels for ty in ["True Stimulus" ]]
636
637
637
638
else :
638
639
# get colours and double them to change to grey later
639
640
palette = [x for xs in [[col , col ] for col in self ._get_colours (facet_labels )] for x in xs ]
641
+ h_order = [f'{ lab } -{ ty } ' for lab in facet_labels for ty in ["Spon. Mov." , "True Stimulus" ]]
640
642
641
643
# filter into two stimulus and find average per hour per specimen
642
644
data1 = self .__class__ (data [data [interaction_id_col ]== 1 ].bin_time (response_col , (60 * 60 ) * t_bin_hours , function = func , t_column = t_column ), data .meta )
@@ -656,7 +658,6 @@ def _internal_plot_response_overtime(self, t_bin_hours, response_col, interactio
656
658
else :
657
659
data1 .meta ['new_facet' ] = '-True Stimulus'
658
660
meta2 ['new_facet' ] = '-Spon. Mov.'
659
- h_order = [f'{ lab } -{ ty } ' for lab in facet_labels for ty in ["Spon. Mov." , "True Stimulus" ]]
660
661
661
662
data = concat (data1 , self .__class__ (data2 , meta2 ))
662
663
@@ -666,4 +667,77 @@ def _internal_plot_response_overtime(self, t_bin_hours, response_col, interactio
666
667
df = self .__class__ (grouped_data , data .meta )
667
668
df .rename (columns = {'mean' : 'Response Rate' }, inplace = True )
668
669
669
- return df , h_order , palette
670
+ return df , h_order , palette
671
+
672
+ def _internal_plot_habituation (self , plot_type , t_bin_hours , response_col , interaction_id_col , facet_col , facet_arg , facet_labels , x_limit , t_column ):
673
+ """ An internal method to curate and analyse the data for both plotly and seaborn versions of plot_habituation """
674
+
675
+ facet_arg , facet_labels = self ._check_lists (facet_col , facet_arg , facet_labels )
676
+
677
+ plot_choice = {'time' : f'Hours { t_bin_hours } post first stimulus' , 'number' : 'Stimulus number post first' }
678
+
679
+ if plot_type not in plot_choice .keys ():
680
+ raise KeyError (f'activity argument must be one of { * plot_choice .keys (),} ' )
681
+
682
+ data_summary = {
683
+ "mean" : (response_col , 'mean' ),
684
+ "count" : (response_col , 'count' ),
685
+ 'ci' : (response_col , bootstrap ),
686
+ "stim_count" : ('stim_count' , 'sum' )
687
+ }
688
+ map_dict = {1 : 'True Stimulus' , 2 : 'Spon. Mov.' }
689
+
690
+ # takes subset of data if requested
691
+ if facet_col and facet_arg :
692
+ data = self .xmv (facet_col , facet_arg )
693
+ else :
694
+ data = self .copy (deep = True )
695
+
696
+ def get_response (int_data , ptype , time_window_length , resp_col , t_col ):
697
+ # bin the responses per amount of hours given and find the mean per specimen
698
+ if ptype == 'time' :
699
+ hour_secs = time_window_length * 60 * 60
700
+ int_data [plot_choice [plot_type ]] = int_data [t_col ].map (lambda t : hour_secs * floor (t / hour_secs ))
701
+ min_hour = int_data [plot_choice [plot_type ]].min ()
702
+ int_data [plot_choice [plot_type ]] = (int_data [plot_choice [plot_type ]] - min_hour ) / hour_secs
703
+ gb = int_data .groupby (plot_choice [plot_type ]).agg (** {
704
+ 'has_responded' : (resp_col , 'mean' ),
705
+ 'stim_count' : (resp_col , 'count' )
706
+ })
707
+ return gb
708
+ # Sort the responses by time, assign int according to place in the list, return as dataframe
709
+ elif ptype == 'number' :
710
+ int_data = int_data .sort_values (t_col )
711
+ int_data ['n_stim' ] = list (range (1 , len (int_data )+ 1 ))
712
+ return pd .DataFrame (data = {'has_responded' : int_data ['has_responded' ].tolist (), plot_choice [plot_type ] : int_data ['n_stim' ].tolist (),
713
+ 'stim_count' : [1 ] * len (int_data )}).set_index (plot_choice [plot_type ])
714
+
715
+ grouped_data = data .groupby ([data .index , interaction_id_col ]).apply (partial (get_response , ptype = plot_type , time_window_length = t_bin_hours ,
716
+ resp_col = response_col , t_col = t_column ), include_groups = False )
717
+ grouped_data = self .__class__ (grouped_data .reset_index ().set_index ('id' ), data .meta , check = True )
718
+
719
+ # reduce dataset to the maximum value of the True stimulus (reduces computation time)
720
+ if x_limit is False :
721
+ x_max = np .nanmax (grouped_data [grouped_data [interaction_id_col ] == 1 ][plot_choice [plot_type ]])
722
+ else :
723
+ x_max = x_limit
724
+
725
+ grouped_data = grouped_data [grouped_data [plot_choice [plot_type ]] <= x_max ]
726
+ # map stim names and create column to facet by
727
+ grouped_data [interaction_id_col ] = grouped_data [interaction_id_col ].map (map_dict )
728
+ if facet_col :
729
+ grouped_data = self .facet_merge (grouped_data , facet_col , facet_arg , facet_labels )
730
+ grouped_data [facet_col ] = grouped_data [facet_col ].astype (str ) + "-" + grouped_data [interaction_id_col ]
731
+ else :
732
+ facet_col = 'stim_type'
733
+ grouped_data [facet_col ] = "-" + grouped_data [interaction_id_col ]
734
+
735
+ grouped_final = grouped_data .groupby ([facet_col , plot_choice [plot_type ]]).agg (** data_summary ).reset_index (level = 1 )
736
+ grouped_final [['y_max' , 'y_min' ]] = pd .DataFrame (grouped_final ['ci' ].tolist (), index = grouped_final .index )
737
+ grouped_final .drop ('ci' , axis = 1 , inplace = True )
738
+
739
+ palette = [x for xs in [[col , col ] for col in self ._get_colours (facet_labels )] for x in xs ]
740
+ h_order = [f'{ lab } -{ ty } ' for lab in facet_labels for ty in ["Spon. Mov." , "True Stimulus" ]]
741
+ palette_dict = {name : self ._check_grey (name , palette [c ], response = True )[1 ] for c , name in enumerate (h_order )} # change to grey if control
742
+
743
+ return grouped_final , h_order , palette_dict , x_max , plot_choice [plot_type ]
0 commit comments