Skip to content

Commit 7a3352c

Browse files
committed
refactored and added plot_habituation to seaborn and plotly
1 parent 7e1bb2b commit 7a3352c

File tree

4 files changed

+3685
-302
lines changed

4 files changed

+3685
-302
lines changed

src/ethoscopy/behavpy_draw.py

+80-6
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from colour import Color
88
from math import sqrt, floor, ceil
99
from scipy.stats import zscore
10+
from functools import partial
1011

1112
#fig to img
1213
import io
@@ -560,9 +561,9 @@ def _bouts_response(self, mov_df, hmm, variable, response_col, labels, colours,
560561

561562
hmm_dict = {k : v for k, v in zip(range(len(labels)), labels)}
562563
grouped_data['state'] = grouped_data['state'].map(hmm_dict)
563-
grouped_data['label_col'] = grouped_data['state'] + " " + grouped_data['has_interacted']
564+
grouped_data['label_col'] = grouped_data['state'] + "-" + grouped_data['has_interacted']
564565
# create the order of plotting and double the colours to assign grey to false stimuli
565-
h_order = [f'{lab} {ty}' for lab in labels for ty in ["Spon. Mov.", "True Stimulus"]]
566+
h_order = [f'{lab}-{ty}' for lab in labels for ty in ["Spon. Mov.", "True Stimulus"]]
566567
palette = [x for xs in [[col, col] for col in colours] for x in xs]
567568
palette_dict = {name : self._check_grey(name, palette[c], response = True)[1] for c, name in enumerate(h_order)} # change to grey if control
568569

@@ -619,24 +620,25 @@ def _internal_plot_response_overtime(self, t_bin_hours, response_col, interactio
619620
# takes subset of data if requested
620621
if facet_col and facet_arg:
621622
data = self.xmv(facet_col, facet_arg)
622-
h_order = [f'{lab}-{ty}' for lab in facet_labels for ty in ["Spon. Mov.", "True Stimulus"]]
623623
else:
624624
data = self.copy(deep=True)
625625

626626
if len(set(data[interaction_id_col])) == 1: # if only stimulus type in the dataset
627627
# get colours
628628
palette = self._get_colours(facet_labels)
629+
h_order = [f'{lab}-{ty}' for lab in facet_labels for ty in ["True Stimulus"]]
630+
629631
# find the average response per hour per specimen
630632
data = data.bin_time(response_col, (60*60) * t_bin_hours, function = 'mean', t_column = t_column)
631633
if facet_col and facet_arg:
632634
data.meta['new_facet'] = data.meta[facet_col] + '-' + 'True Stimulus'
633635
else:
634636
data.meta['new_facet'] = '-True Stimulus'
635-
h_order = [f'{lab}-{ty}' for lab in facet_labels for ty in ["True Stimulus"]]
636637

637638
else:
638639
# get colours and double them to change to grey later
639640
palette = [x for xs in [[col, col] for col in self._get_colours(facet_labels)] for x in xs]
641+
h_order = [f'{lab}-{ty}' for lab in facet_labels for ty in ["Spon. Mov.", "True Stimulus"]]
640642

641643
# filter into two stimulus and find average per hour per specimen
642644
data1 = self.__class__(data[data[interaction_id_col]==1].bin_time(response_col, (60*60) * t_bin_hours, function = func, t_column = t_column), data.meta)
@@ -656,7 +658,6 @@ def _internal_plot_response_overtime(self, t_bin_hours, response_col, interactio
656658
else:
657659
data1.meta['new_facet'] = '-True Stimulus'
658660
meta2['new_facet'] = '-Spon. Mov.'
659-
h_order = [f'{lab}-{ty}' for lab in facet_labels for ty in ["Spon. Mov.", "True Stimulus"]]
660661

661662
data = concat(data1, self.__class__(data2, meta2))
662663

@@ -666,4 +667,77 @@ def _internal_plot_response_overtime(self, t_bin_hours, response_col, interactio
666667
df = self.__class__(grouped_data, data.meta)
667668
df.rename(columns={'mean' : 'Response Rate'}, inplace=True)
668669

669-
return df, h_order, palette
670+
return df, h_order, palette
671+
672+
def _internal_plot_habituation(self, plot_type, t_bin_hours, response_col, interaction_id_col, facet_col, facet_arg, facet_labels, x_limit, t_column):
673+
""" An internal method to curate and analyse the data for both plotly and seaborn versions of plot_habituation """
674+
675+
facet_arg, facet_labels = self._check_lists(facet_col, facet_arg, facet_labels)
676+
677+
plot_choice = {'time' : f'Hours {t_bin_hours} post first stimulus', 'number' : 'Stimulus number post first'}
678+
679+
if plot_type not in plot_choice.keys():
680+
raise KeyError(f'activity argument must be one of {*plot_choice.keys(),}')
681+
682+
data_summary = {
683+
"mean" : (response_col, 'mean'),
684+
"count" : (response_col, 'count'),
685+
'ci' : (response_col, bootstrap),
686+
"stim_count" : ('stim_count', 'sum')
687+
}
688+
map_dict = {1 : 'True Stimulus', 2 : 'Spon. Mov.'}
689+
690+
# takes subset of data if requested
691+
if facet_col and facet_arg:
692+
data = self.xmv(facet_col, facet_arg)
693+
else:
694+
data = self.copy(deep=True)
695+
696+
def get_response(int_data, ptype, time_window_length, resp_col, t_col):
697+
# bin the responses per amount of hours given and find the mean per specimen
698+
if ptype == 'time':
699+
hour_secs = time_window_length * 60 * 60
700+
int_data[plot_choice[plot_type]] = int_data[t_col].map(lambda t: hour_secs * floor(t / hour_secs))
701+
min_hour = int_data[plot_choice[plot_type]].min()
702+
int_data[plot_choice[plot_type]] = (int_data[plot_choice[plot_type]] - min_hour) / hour_secs
703+
gb = int_data.groupby(plot_choice[plot_type]).agg(**{
704+
'has_responded' : (resp_col, 'mean'),
705+
'stim_count' : (resp_col, 'count')
706+
})
707+
return gb
708+
# Sort the responses by time, assign int according to place in the list, return as dataframe
709+
elif ptype == 'number':
710+
int_data = int_data.sort_values(t_col)
711+
int_data['n_stim'] = list(range(1, len(int_data)+1))
712+
return pd.DataFrame(data = {'has_responded' : int_data['has_responded'].tolist(), plot_choice[plot_type] : int_data['n_stim'].tolist(),
713+
'stim_count' : [1] * len(int_data)}).set_index(plot_choice[plot_type])
714+
715+
grouped_data = data.groupby([data.index, interaction_id_col]).apply(partial(get_response, ptype=plot_type, time_window_length=t_bin_hours,
716+
resp_col=response_col, t_col=t_column), include_groups=False)
717+
grouped_data = self.__class__(grouped_data.reset_index().set_index('id'), data.meta, check=True)
718+
719+
# reduce dataset to the maximum value of the True stimulus (reduces computation time)
720+
if x_limit is False:
721+
x_max = np.nanmax(grouped_data[grouped_data[interaction_id_col] == 1][plot_choice[plot_type]])
722+
else:
723+
x_max = x_limit
724+
725+
grouped_data = grouped_data[grouped_data[plot_choice[plot_type]] <= x_max]
726+
# map stim names and create column to facet by
727+
grouped_data[interaction_id_col] = grouped_data[interaction_id_col].map(map_dict)
728+
if facet_col:
729+
grouped_data = self.facet_merge(grouped_data, facet_col, facet_arg, facet_labels)
730+
grouped_data[facet_col] = grouped_data[facet_col].astype(str) + "-" + grouped_data[interaction_id_col]
731+
else:
732+
facet_col = 'stim_type'
733+
grouped_data[facet_col] = "-" + grouped_data[interaction_id_col]
734+
735+
grouped_final = grouped_data.groupby([facet_col, plot_choice[plot_type]]).agg(**data_summary).reset_index(level=1)
736+
grouped_final[['y_max', 'y_min']] = pd.DataFrame(grouped_final['ci'].tolist(), index = grouped_final.index)
737+
grouped_final.drop('ci', axis = 1, inplace = True)
738+
739+
palette = [x for xs in [[col, col] for col in self._get_colours(facet_labels)] for x in xs]
740+
h_order = [f'{lab}-{ty}' for lab in facet_labels for ty in ["Spon. Mov.", "True Stimulus"]]
741+
palette_dict = {name : self._check_grey(name, palette[c], response = True)[1] for c, name in enumerate(h_order)} # change to grey if control
742+
743+
return grouped_final, h_order, palette_dict, x_max, plot_choice[plot_type]

0 commit comments

Comments
 (0)