Skip to content

Commit 3de2d9f

Browse files
committed
refactored plot_responseover_activity for plotly and added to seaborn
1 parent cdb2916 commit 3de2d9f

File tree

4 files changed

+919
-198
lines changed

4 files changed

+919
-198
lines changed

src/ethoscopy/behavpy_draw.py

+70-21
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ def _ap_score(total, small):
433433

434434
return ant_df
435435

436-
def hmm_response(self, mov_df, hmm, variable, response_col, labels, colours, facet_col, facet_arg, t_bin, facet_labels, func, t_column):
436+
def _hmm_response(self, mov_df, hmm, variable, response_col, labels, colours, facet_col, facet_arg, t_bin, facet_labels, func, t_column):
437437

438438
data_summary = {
439439
"%s_mean" % response_col : (response_col, 'mean'),
@@ -502,57 +502,106 @@ def alter_merge(response, mov, tb):
502502

503503
return grouped_data, palette_dict, h_order
504504

505-
def hmm_bouts_response(self, mov_df, hmm, variable, response_col, labels, colours, x_limit, t_bin, func, t_column):
505+
def _bouts_response(self, mov_df, hmm, variable, response_col, labels, colours, x_limit, t_bin, func, t_col):
506506

507507
data_summary = {
508508
"mean" : (response_col, 'mean'),
509509
"count" : (response_col, 'count'),
510510
"ci" : (response_col, bootstrap),
511511
}
512-
513-
# copy and decode the dataset
514512
data = self.copy(deep=True)
515-
mdata = mov_df
516-
mdata = self.__class__(self._hmm_decode(mdata, hmm, t_bin, variable, func, t_column, return_type='table'), mdata.meta, check=True)
513+
mdata = mov_df.copy(deep=True)
514+
515+
if hmm is not False:
516+
# copy and decode the dataset
517+
mdata = self.__class__(self._hmm_decode(mdata, hmm, t_bin, variable, func, t_col, return_type='table'), mdata.meta, check=True)
518+
var, newT, m_var_1, m_var_2 = 'state', 'bin', 'moving', 'previous_moving'
519+
else:
520+
mdata = mdata.bin_time(variable, t_bin, function = func, t_column = t_col)
521+
var, newT, m_var_1, m_var_2 = f'{variable}_{func}', f'{t_col}_bin', 'activity_count', 'previous_activity_count'
517522

518523
# take the states and time per specimen and find the runs of states
519-
st_gb = mdata.groupby('id')['state'].apply(np.array)
520-
time_gb = mdata.groupby('id')['bin'].apply(np.array)
524+
st_gb = mdata.groupby('id')[var].apply(np.array)
525+
time_gb = mdata.groupby('id')[newT].apply(np.array)
521526
all_runs = []
522527
for m, t, ids in zip(st_gb, time_gb, st_gb.index):
523528
spec_run = self._find_runs(m, t, ids)
529+
524530
all_runs.append(spec_run)
525531
# take the arrays and make a dataframe for merging
526532
counted_df = pd.concat([pd.DataFrame(specimen) for specimen in all_runs])
527-
# _find_runs returns the column of interest as 'moving', so changing them for better clarity
528-
counted_df.rename(columns = {'moving' : 'state', 'previous_moving' : 'previous_state'}, inplace = True)
529533

530534
# change the time column to reflect the timing of counted_df
531-
data['t'] = data['interaction_t'].map(lambda t: t_bin * floor(t / t_bin))
535+
data[t_col] = data['interaction_t'].map(lambda t: t_bin * floor(t / t_bin))
532536
data.reset_index(inplace = True)
533537

534538
# merge the two dataframes on the id and time column and check the response is in the same time bin or the next
535-
merged = pd.merge(counted_df, data, how = 'inner', on = ['id', 't'])
539+
merged = pd.merge(counted_df, data, how = 'inner', on = ['id', t_col])
536540
merged['t_check'] = merged.interaction_t + merged.t_rel
537541
merged['t_check'] = merged['t_check'].map(lambda t: t_bin * floor(t / t_bin))
538-
merged['previous_state'] = np.where(merged['t_check'] > merged['t'], merged['state'], merged['previous_state'])
542+
# change both previous if the interaction to stimulus happens in the next time bin
543+
merged['previous_activity_count'] = np.where(merged['t_check'] > merged[t_col], merged['activity_count'], merged['previous_activity_count'])
544+
merged['previous_moving'] = np.where(merged['t_check'] > merged[t_col], merged['moving'], merged['previous_moving'])
539545
merged = merged[merged['previous_activity_count'] <= x_limit]
540-
541-
grouped_data = merged.groupby(['previous_state', 'previous_activity_count', 'has_interacted']).agg(**data_summary)
546+
merged.dropna(subset = ['previous_moving', 'previous_activity_count'], inplace=True)
547+
merged['previous_activity_count'] = merged['previous_activity_count'].astype(int)
548+
# groupby the columns of interest, and find the mean and bootstrapped 95% CIs
549+
grouped_data = merged.groupby(['previous_moving', 'previous_activity_count', 'has_interacted']).agg(**data_summary)
542550
grouped_data = grouped_data.reset_index()
543551
grouped_data[['y_max', 'y_min']] = pd.DataFrame(grouped_data['ci'].tolist(), index = grouped_data.index)
544552
grouped_data.drop('ci', axis = 1, inplace = True)
545-
grouped_data['state'] = grouped_data['previous_state']
546-
553+
grouped_data['moving'] = grouped_data['previous_moving']
547554
map_dict = {1 : 'True Stimulus', 2 : 'Spon. Mov.'}
548555
grouped_data['has_interacted'] = grouped_data['has_interacted'].map(map_dict)
556+
557+
if hmm is False:
558+
grouped_data['facet_col'] = [labels] * len(grouped_data)
559+
return grouped_data
560+
549561
hmm_dict = {k : v for k, v in zip(range(len(labels)), labels)}
550562
grouped_data['state'] = grouped_data['state'].map(hmm_dict)
551563
grouped_data['label_col'] = grouped_data['state'] + " " + grouped_data['has_interacted']
552-
564+
# create the order of plotting and double the colours to assign grey to false stimuli
553565
h_order = [f'{lab} {ty}' for lab in labels for ty in ["Spon. Mov.", "True Stimulus"]]
554-
palette = colours
555-
palette = [x for xs in [[col, col] for col in palette] for x in xs]
566+
palette = [x for xs in [[col, col] for col in colours] for x in xs]
556567
palette_dict = {name : self._check_grey(name, palette[c], response = True)[1] for c, name in enumerate(h_order)} # change to grey if control
557568

558-
return grouped_data, palette_dict, h_order
569+
return grouped_data, palette_dict, h_order
570+
571+
def _internal_bout_activity(self, mov_df, activity, variable, response_col, facet_col, facet_arg, facet_labels, x_limit, t_bin, t_column):
572+
""" The beginning code for plot_response_over_activity for both plotly and seaborn """
573+
574+
facet_arg, facet_labels = self._check_lists(facet_col, facet_arg, facet_labels)
575+
576+
activity_choice = {'inactive' : 0, 'active' : 1, 'both' : (0, 1)}
577+
if activity not in activity_choice.keys():
578+
raise KeyError(f'activity argument must be one of {*activity_choice.keys(),}')
579+
if activity == 'both' and facet_col is not None:
580+
print('When plotting both inactive and active runs you can not use facet_col. Reverted to None')
581+
facet_col, facet_arg, facet_labels = None, [None], ['inactive', 'active']
582+
583+
if facet_col and facet_arg:
584+
rdata = self.xmv(facet_col, facet_arg)
585+
# iterate over the filters and call the analysing function
586+
dfs = [rdata._bouts_response(mov_df=mov_df.xmv(facet_col, arg), hmm = False,
587+
variable=variable, response_col=response_col, labels=lab, colours=[],
588+
x_limit=x_limit, t_bin=t_bin, func='max', t_col=t_column) for arg, lab in zip(facet_arg, facet_labels)]
589+
grouped_data = pd.concat(dfs)
590+
else:
591+
grouped_data = self._bouts_response(mov_df=mov_df, hmm = False,
592+
variable=variable, response_col=response_col, labels=[], colours=[],
593+
x_limit=x_limit, t_bin=t_bin, func='max', t_col=t_column)
594+
inverse_dict = {v: k for k, v in activity_choice.items()}
595+
grouped_data['facet_col'] = grouped_data['previous_moving'].map(inverse_dict)
596+
597+
# Get colours and labels, syncing them together and replacing False Stimuli with a grey colour
598+
grouped_data['label_col'] = grouped_data['facet_col'] + " " + grouped_data['has_interacted']
599+
palette = [x for xs in [[col, col] for col in self._get_colours(facet_labels)] for x in xs]
600+
h_order = [f'{lab} {ty}' for lab in facet_labels for ty in ["Spon. Mov.", "True Stimulus"]]
601+
palette_dict = {name : self._check_grey(name, palette[c], response = True)[1] for c, name in enumerate(h_order)} # change to grey if control
602+
603+
# If not both filter the dataset
604+
if isinstance(activity_choice[activity], int):
605+
grouped_data = grouped_data[grouped_data['previous_moving'] == activity_choice[activity]]
606+
607+
return grouped_data, h_order, palette_dict, activity_choice[activity]

0 commit comments

Comments
 (0)