Skip to content

Commit 7e1bb2b

Browse files
committed
refactored and added plot_response_overtime to plotly and seaborn
1 parent 3de2d9f commit 7e1bb2b

File tree

6 files changed

+318
-776
lines changed

6 files changed

+318
-776
lines changed

src/ethoscopy/behavpy_core.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def _check_lists(self, f_col, f_arg, f_lab):
125125
string_args = []
126126
for i in f_arg:
127127
if i not in self.meta[f_col].tolist():
128-
print(self.meta[f_col].tolist())
128+
# print(set(self.meta[f_col].tolist()))
129129
raise KeyError(f'Argument "{i}" is not in the meta column {f_col}')
130130
string_args.append(str(i))
131131
if f_lab is None:

src/ethoscopy/behavpy_draw.py

+64-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from ethoscopy.behavpy_core import behavpy_core
1616
from ethoscopy.misc.bootstrap_CI import bootstrap
17+
from ethoscopy.misc.static_functions import concat
1718

1819
class behavpy_draw(behavpy_core):
1920
"""
@@ -526,7 +527,6 @@ def _bouts_response(self, mov_df, hmm, variable, response_col, labels, colours,
526527
all_runs = []
527528
for m, t, ids in zip(st_gb, time_gb, st_gb.index):
528529
spec_run = self._find_runs(m, t, ids)
529-
530530
all_runs.append(spec_run)
531531
# take the arrays and make a dataframe for merging
532532
counted_df = pd.concat([pd.DataFrame(specimen) for specimen in all_runs])
@@ -604,4 +604,66 @@ def _internal_bout_activity(self, mov_df, activity, variable, response_col, face
604604
if isinstance(activity_choice[activity], int):
605605
grouped_data = grouped_data[grouped_data['previous_moving'] == activity_choice[activity]]
606606

607-
return grouped_data, h_order, palette_dict, activity_choice[activity]
607+
return grouped_data, h_order, palette_dict, activity_choice[activity]
608+
609+
def _internal_plot_response_overtime(self, t_bin_hours, response_col, interaction_id_col, facet_col, facet_arg, facet_labels, func, t_column):
610+
""" An internal method to curate and analyse the data for both plotly and seaborn versions of plot_response_overtime """
611+
612+
data_summary = {
613+
"mean" : (f'{response_col}_{func}', 'mean'),
614+
"count" : (f'{response_col}_{func}', 'count')
615+
}
616+
617+
facet_arg, facet_labels = self._check_lists(facet_col, facet_arg, facet_labels)
618+
619+
# takes subset of data if requested
620+
if facet_col and facet_arg:
621+
data = self.xmv(facet_col, facet_arg)
622+
h_order = [f'{lab}-{ty}' for lab in facet_labels for ty in ["Spon. Mov.", "True Stimulus"]]
623+
else:
624+
data = self.copy(deep=True)
625+
626+
if len(set(data[interaction_id_col])) == 1: # if only stimulus type in the dataset
627+
# get colours
628+
palette = self._get_colours(facet_labels)
629+
# find the average response per hour per specimen
630+
data = data.bin_time(response_col, (60*60) * t_bin_hours, function = 'mean', t_column = t_column)
631+
if facet_col and facet_arg:
632+
data.meta['new_facet'] = data.meta[facet_col] + '-' + 'True Stimulus'
633+
else:
634+
data.meta['new_facet'] = '-True Stimulus'
635+
h_order = [f'{lab}-{ty}' for lab in facet_labels for ty in ["True Stimulus"]]
636+
637+
else:
638+
# get colours and double them to change to grey later
639+
palette = [x for xs in [[col, col] for col in self._get_colours(facet_labels)] for x in xs]
640+
641+
# filter into two stimulus and find average per hour per specimen
642+
data1 = self.__class__(data[data[interaction_id_col]==1].bin_time(response_col, (60*60) * t_bin_hours, function = func, t_column = t_column), data.meta)
643+
data2 = data[data[interaction_id_col]==2].bin_time(response_col, (60*60) * t_bin_hours, function = func, t_column = t_column)
644+
645+
# change the id of the false stimuli
646+
meta2 = data.meta.copy(deep=True)
647+
meta2['ref_id'] = meta2.index + '_2'
648+
map_dict = meta2[['ref_id']].to_dict()['ref_id']
649+
meta2.rename(columns={'ref_id' : 'id'}, inplace=True)
650+
meta2.index = meta2['id']
651+
data2.index = data2.index.map(map_dict)
652+
653+
if facet_col and facet_arg:
654+
data1.meta['new_facet'] = data1.meta[facet_col] + '-' + 'True Stimulus'
655+
meta2['new_facet'] = meta2[facet_col] + '-' + 'Spon. Mov.'
656+
else:
657+
data1.meta['new_facet'] = '-True Stimulus'
658+
meta2['new_facet'] = '-Spon. Mov.'
659+
h_order = [f'{lab}-{ty}' for lab in facet_labels for ty in ["Spon. Mov.", "True Stimulus"]]
660+
661+
data = concat(data1, self.__class__(data2, meta2))
662+
663+
palette= [self._check_grey(name, palette[c], response = True)[1] for c, name in enumerate(h_order)] # change to grey if control
664+
665+
grouped_data = data.groupby([data.index, 't_bin']).agg(**data_summary).reset_index(level=1)
666+
df = self.__class__(grouped_data, data.meta)
667+
df.rename(columns={'mean' : 'Response Rate'}, inplace=True)
668+
669+
return df, h_order, palette

src/ethoscopy/behavpy_plotly.py

+46-123
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def plot_overtime(self, variable:str, wrapped:bool = False, facet_col:None|str =
315315
for data, name, col in zip(d_list, facet_labels, col_list):
316316
upper, trace, lower, t_min, t_max = self._generate_overtime_plot(data = data, name = name, col = col, var = variable,
317317
avg_win = int((avg_window * 60)/self[t_column].diff().median()), wrap = wrapped, day_len = day_length,
318-
light_off= lights_off, t_col = t_column, canvas = 'plotly')
318+
light_off= lights_off, t_col = t_column)
319319
if upper is None:
320320
continue
321321

@@ -1264,138 +1264,61 @@ def plot_response_over_activity(self, mov_df, activity, variable = 'moving', res
12641264

12651265
return fig
12661266

1267-
def plot_response_overtime(self, bin_time = 1, wrapped = False, response_col = 'has_responded', int_id_col = 'has_interacted', facet_col = None, facet_arg = None, facet_labels = None, title = '', day_length = 24, lights_off = 12, secondary = True, t_column = 't', grids = False):
1268-
"""
1269-
A plotting function for AGO or mAGO datasets that have been loaded with the analysing function stimulus_response.
1270-
A plot to view the response rate to a puff over the time of day. Interactions will be binned to a users input (default is 1 hour) and plotted over a ZT hours x-axis. The plot can be the full length of an experiment or wrapped to a singular day.
1267+
def plot_response_overtime(self, t_bin_hours = 1, wrapped = False, response_col = 'has_responded', interaction_id_col = 'has_interacted', facet_col = None, facet_arg = None, facet_labels = None, day_length = 24, lights_off = 12, func = 'mean', t_column = 't', title = '', grids = False):
1268+
""" A plotting function for AGO or mAGO datasets that have been loaded with the analysing function stimulus_response.
1269+
Generate a plot which shows how the response response rate changes over a day (wrapped) or the course of the experiment.
1270+
If false stimuli are given and represented in the interaction_id column, they will be plotted seperately.
12711271
12721272
Args:
1273-
bin_time (int, optional): The number of hours you want to bin the response rate to, default is 1 (hour).
1273+
t_bin_hours (int, optional): The number of hours you want to bin the response rate to per specimen. Default is 1 (hour).
12741274
wrapped (bool, optional): If true the data is augmented to represent one day, combining data of the same time on consequtive days.
1275-
num_dtick (int, optional): The dtick for the x-axis (the number spacing) for when plot_type 'number is chosen. Default is 10.
1276-
response_col (str, optional): The name of the column that contains the boolean response data.
1277-
int_id_col (str, optional): The name of the column conataining the id for the interaction type, which should be either 1 (true interaction) or 2 (false interaction). Default 'has_interacted'.
1275+
response_col (str, optional): The name of the coloumn that has the responses per interaction. Must be a column of bools. Default is 'has_responded'.
1276+
interaction_id_col (str, optional): The name of the column conataining the id for the interaction type, which should be either 1 (true interaction) or 2 (false interaction). Default 'has_interacted'.
12781277
facet_col (str, optional): The name of the column to use for faceting, must be from the metadata. Default is None.
12791278
facet_arg (list, optional): The arguments to use for faceting. If None then all distinct groups will be used. Default is None.
12801279
facet_labels (list, optional): The labels to use for faceting, these will be what appear on the plot. If None the labels will be those from the metadata. Default is None.
1281-
title (str, optional): The title of the plot. Default is an empty string.
12821280
day_length (int, optional): The lenght in hours the experimental day is. Default is 24.
12831281
lights_off (int, optional): The time point when the lights are turned off in an experimental day, assuming 0 is lights on. Must be number between 0 and day_lenght. Default is 12.
1284-
sceondary (bool, optional): If true then a secondary y-axis is added that contains either the puff cound for 'time' or percentage of flies recieving the puff in 'number'. Default is True
1285-
t_column (str, optional): The name of column containing the timing data (in seconds). Default is 't'
1282+
func (str, optional): When binning the time what function to apply the variable column. Default is 'max'.
1283+
t_column (str, optional): The name of column containing the timing data (in seconds). Default is 't'.
1284+
title (str, optional): The title of the plot. Default is an empty string.
12861285
grids (bool, optional): true/false whether the resulting figure should have grids. Default is False
1287-
1286+
12881287
Returns:
12891288
fig (plotly.figure.Figure): Figure object of the plot.
1290-
"""
1291-
1292-
facet_arg, facet_labels = self._check_lists(facet_col, facet_arg, facet_labels)
1293-
1294-
if facet_col is not None:
1295-
d_list = [self.xmv(facet_col, arg) for arg in facet_arg]
1296-
else:
1297-
d_list = [self.copy(deep = True)]
1298-
facet_labels = ['']
1299-
1300-
fig = make_subplots(specs=[[{ "secondary_y" : True}]])
1301-
1302-
max_var = []
1303-
y_range, dtick = self._check_boolean(list(self[response_col]))
1304-
if y_range is not False:
1305-
max_var.append(1)
1306-
1307-
if secondary is False:
1308-
fig = go.Figure()
1309-
else:
1310-
fig = make_subplots(specs=[[{ "secondary_y" : True}]])
1311-
self._plot_ylayout(fig, yrange = False, t0 = 0, dtick = False, ylabel = 'Puff Count', title = title, secondary = True, xdomain = 'x1', grid = grids)
1312-
1313-
self._plot_ylayout(fig, yrange = y_range, t0 = 0, dtick = dtick, ylabel = 'Response Rate', title = title, secondary = False, grid = grids)
1314-
1315-
col_list = self._get_colours(d_list)
1316-
self._plot_xlayout(fig, xrange = False, t0 = 0, dtick = day_length/4, xlabel = 'ZT (Hours)')
1317-
1318-
def get_hourly_response(data, time_window_length):
1319-
data['bin_time'] = data[t_column].map(lambda t: time_window_length * floor(t / time_window_length))
1320-
gb = data.groupby(['bin_time', 'has_interacted']).agg(**{
1321-
'response_rate' : (response_col, 'mean'),
1322-
'puff_count' : (response_col, 'count')
1323-
1324-
})
1325-
return gb
1326-
1327-
max_x = []
1328-
min_t = []
1329-
max_t = []
1330-
1331-
for data, name, col in zip(d_list, facet_labels, col_list):
1332-
1333-
if len(data) == 0:
1334-
print(f'Group {name} has no values and cannot be plotted')
1335-
continue
1336-
1337-
if wrapped is True:
1338-
data[t_column] = data[t_column] % (60*60*day_length)
1339-
data[t_column] = data[t_column] / (60*60)
1340-
1341-
min_t.append(int(lights_off * floor(data[t_column].min() / lights_off)))
1342-
max_t.append(int(12 * ceil(data[t_column].max() / 12)) )
1343-
1344-
if len(list(set(data.has_interacted))) == 1:
1345-
loop_itr = list(set(data.has_interacted))
1346-
else:
1347-
loop_itr = [2, 1]
1348-
1349-
for q in loop_itr:
1350-
1351-
if q == 1:
1352-
qcol = col
1353-
lab = name
1354-
elif q == 2:
1355-
qcol = 'grey'
1356-
lab = f'{name} Spon. Mov'
1357-
1358-
tdf = data[data[int_id_col] == q].reset_index()
1359-
rdf = tdf.groupby('id', group_keys = False).apply(partial(get_hourly_response, time_window_length = bin_time))
1360-
1361-
filt_gb = rdf.groupby('bin_time').agg(**{
1362-
'mean' : ('response_rate', 'mean'),
1363-
'count' : ('puff_count', 'sum'),
1364-
'ci' : ('response_rate', bootstrap)
1365-
})
1366-
filt_gb[['y_max', 'y_min']] = pd.DataFrame(filt_gb['ci'].tolist(), index = filt_gb.index)
1367-
filt_gb.drop('ci', axis = 1, inplace = True)
1368-
filt_gb.reset_index(inplace = True)
1369-
1370-
max_x.append(np.nanmax(filt_gb['bin_time']))
1371-
1372-
upper, trace, lower = self._plot_line(df = filt_gb, x_col = 'bin_time', name = lab, marker_col = qcol)
1373-
fig.add_trace(upper)
1374-
fig.add_trace(trace)
1375-
fig.add_trace(lower)
1376-
1377-
if secondary is True:
1378-
fig.add_trace(
1379-
go.Scatter(
1380-
legendgroup = lab,
1381-
x = filt_gb['bin_time'],
1382-
y = filt_gb['count'],
1383-
mode = 'lines',
1384-
name = f'{lab} count',
1385-
line = dict(
1386-
dash = 'longdashdot',
1387-
shape = 'spline',
1388-
color = qcol
1389-
),
1390-
),
1391-
secondary_y = True
1392-
)
1393-
# Light-Dark annotaion bars
1394-
bar_shapes, min_bar = circadian_bars(np.nanmin(min_t), np.nanmax(max_t), max_y = np.nanmax(max_var), day_length = day_length, lights_off = lights_off)
1395-
fig.update_layout(shapes=list(bar_shapes.values()))
1396-
fig['layout']['xaxis']['range'] = [1, np.nanmax(max_t)]
1397-
1398-
return fig
1289+
1290+
Notes:
1291+
This function must be called on a behavpy dataframe that is populated with data loaded with the stimulus_response
1292+
analysing function. Contain columns such as 'has_responded' and 'has_interacted'.
1293+
"""
1294+
df, h_order, palette = self._internal_plot_response_overtime(t_bin_hours=t_bin_hours, response_col=response_col, interaction_id_col=interaction_id_col,
1295+
facet_col=facet_col, facet_arg=facet_arg, facet_labels=facet_labels, func=func, t_column=t_column)
1296+
1297+
return df.plot_overtime(variable='Response Rate', wrapped=wrapped, facet_col='new_facet', facet_arg=h_order, facet_labels=h_order,
1298+
avg_window=5, day_length=day_length, lights_off=lights_off, title=title, grids=grids, t_column='t_bin',
1299+
col_list = palette)
1300+
1301+
# Possibly add the puff count on the secondary access in the future
1302+
1303+
# fig = make_subplots(specs=[[{ "secondary_y" : True}]])
1304+
# self._plot_ylayout(fig, yrange = False, t0 = 0, dtick = False, ylabel = 'Puff Count', title = title, secondary = True, xdomain = 'x1', grid = grids)
1305+
1306+
# if secondary is True:
1307+
# fig.add_trace(
1308+
# go.Scatter(
1309+
# legendgroup = lab,
1310+
# x = filt_gb['bin_time'],
1311+
# y = filt_gb['count'],
1312+
# mode = 'lines',
1313+
# name = f'{lab} count',
1314+
# line = dict(
1315+
# dash = 'longdashdot',
1316+
# shape = 'spline',
1317+
# color = qcol
1318+
# ),
1319+
# ),
1320+
# secondary_y = True
1321+
# )
13991322

14001323
# Ploty Periodograms
14011324

@@ -1579,7 +1502,7 @@ def plot_periodogram(self, facet_col = None, facet_arg = None, facet_labels = No
15791502
if 'baseline' in name.lower() or 'control' in name.lower() or 'ctrl' in name.lower():
15801503
col = 'grey'
15811504

1582-
upper, trace, lower, _, _ = self._generate_overtime_plot(data = data, name = name, col = col, var = power_var, avg_win = False, wrap = False, day_len = False, light_off = False, canvas = 'plotly', t_col = period_var)
1505+
upper, trace, lower, _, _ = self._generate_overtime_plot(data = data, name = name, col = col, var = power_var, avg_win = False, wrap = False, day_len = False, light_off = False, t_col = period_var)
15831506
fig.add_trace(upper)
15841507
fig.add_trace(trace)
15851508
fig.add_trace(lower)

0 commit comments

Comments
 (0)