In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import pyabf
import seaborn as sns
import shelve
import sys

In [6]:
%matplotlib inline

In [3]:
def get_dif_dff(frame: int, roi: int, dff: np.ndarray):
    try:
        output = dff[roi, frame + 1] - dff[roi, frame]
    except KeyError:
        output = np.nan()
    return output


def get_abs_df_dff(frame: int, roi: int, dff: np.ndarray):
    try:
        output = abs(dff[roi, frame + 1] - dff[roi, frame])
    except KeyError:
        output = np.nan()
    return output


# Get behavior params


def get_bout_number(frame: int, df_bout: pd.DataFrame, df_frame: pd.DataFrame, fps_ci: float, fps_bh: float):
    start, end = int((frame / fps_ci) * fps_bh), int(((frame + 1) / fps_ci) * fps_bh)
    bouts = []

    for i in set(df_frame.BoutNumber[start:end]):
        if np.isnan(i):
            pass
        else:
            bouts.append(i)

    if len(bouts) == 1:  # if a single bout happened during this bin of time
        bout = int(bouts[0])
        output = bout
    elif len(bouts) > 1:  # 2 different bouts happened during this bin of time
        ta_i = df_frame[start:end]
        if len(np.where(ta_i.BoutNumber == bouts[0])[0]) > len(np.where(ta_i.BoutNumber == bouts[1])[0]):
            bout = int(bouts[0])
        else:
            bout = int(bouts[1])
        output = bout
    else:
        output = np.nan

    return output


def get_bout_type(frame: int, df_bout: pd.DataFrame, df_frame: pd.DataFrame, fps_ci: float, fps_bh: float):
    bout = get_bout_number(frame, df_bout, df_frame, fps_ci, fps_bh)
    if isinstance(bout, int):
        output = df_bout.Cat.iloc[bout]
    else:
        output = np.nan

    return output


def get_bout_amp(frame: int, df_bout: pd.DataFrame, df_frame: pd.DataFrame, fps_ci: float, fps_bh: float):
    bout = get_bout_number(frame, df_bout, df_frame, fps_ci, fps_bh)
    if isinstance(bout, int):
        output = df_bout.Max_Bend_Amp.iloc[bout]
    else:
        output = np.nan

    return output


def get_bout_duration(frame: int, df_bout: pd.DataFrame, df_frame: pd.DataFrame, fps_ci: float, fps_bh: float):
    bout = get_bout_number(frame, df_bout, df_frame, fps_ci, fps_bh)
    if isinstance(bout, int):
        output = df_bout.Bout_Duration.iloc[bout]
    else:
        output = np.nan

    return output


def get_frame_state(frame: int, roi: int, dff: np.ndarray, noise: np.ndarray):
    """

    For a given roi, returns the "state" it was during a given frame. State is given by dff value at this frame given
    previous dff value.

    :param frame:
    :param roi:
    :param dff:
    :param noise:
    :return: state of the roi during the given frame

    """

    if dff[roi, frame] > noise[roi]:  # if cell was up

        if dff[roi, frame] - noise[roi] <= dff[roi, frame + 1] <= dff[roi, frame] + noise[roi]:  # if cell stays up
            output = 'up_and_stays_up'
        elif dff[roi, frame + 1] >= dff[roi, frame] + noise[roi]:  # if cell goes even upper
            output = 'up_and_goes_up'
        else:
            output = 'up_and_goes_down'

    else:  # if cell was down

        if dff[roi, frame] - noise[roi] <= dff[roi, frame + 1] <= dff[roi, frame] + noise[roi]:  # if cell stays down
            output = 'down_and_stays_down'
        elif dff[roi, frame + 1] >= dff[roi, frame] + noise[roi]:  # if cell goes up
            output = 'down_and_goes_up'
        else:
            output = 'down_and_goes_down'

    return output


def get_dif_start_end(syl: int, roi: int, df_syllabus: pd.DataFrame, dff: np.ndarray):
    """

    Returns, for a given behavior syllabus and a given ROI, the differential DF/F between start and end of a syllabus.
    If start and end frame (calcium signal so recording rates much lower than behavior) are the same for a givne syllabus,
    Takes differential between -1 frame before start and +1 frame after end.

    :param roi:
    :param syl:
    :param df_syllabus:
    :param dff:
    :return:

    """

    start, end = df_syllabus['start'].iloc[syl].item(), df_syllabus['end'].iloc[syl].item()

    if start == end:
        output = dff[roi, int(end)] - dff[roi, int(start) - 1]
    else:
        output = dff[roi, int(end)] - dff[roi, int(start)]

    return output


def get_max_dff(syl: int, roi: int, df_syllabus: pd.DataFrame, dff: np.ndarray):
    """

    Returns, for a given behavior syllabus and a given ROI, the max DF/F between start and end of a syllabus.
    If start and end frame (calcium signal so recording rates much lower than behavior) are the same for a givne syllabus,
    Takes differential between -1 frame before start and +1 frame after end.

    :param roi:
    :param syl:
    :param df_syllabus:
    :param dff:
    :return:

    """

    start, end = df_syllabus['start'].iloc[syl].item(), df_syllabus['end'].iloc[syl].item()

    output = np.nanmax(dff[roi, int(start)-1:int(end)+1])

    return output


def get_max_dff_norm(syl: int, roi: int, df_syllabus: pd.DataFrame, dff: np.ndarray):
    """

    Returns, for a given behavior syllabus and a given ROI, the max DF/F reached during a syllabus, normalised by
    baseline signal before syllabus.

    :param roi:
    :param syl:
    :param df_syllabus:
    :param dff:
    :return:

    """

    start, end = int(df_syllabus['start'].iloc[syl].item()), int(df_syllabus['end'].iloc[syl].item())

    updated_baseline = np.nanmedian(dff[roi, start-5:start-2])

    if start == end:
        output = dff[roi, end] - updated_baseline
    else:
        output = np.nanmax(dff[roi, start-1:end+1]) - updated_baseline

    return output


def get_recruitment(syl: int, roi: int, df_syllabus: pd.DataFrame, dff: np.ndarray, noise, cells):
    """

    Returns, for a given behavior syllabus and a given ROI, the max DF/F reached during a syllabus, normalised by
    baseline signal before syllabus.

    :param roi:
    :param syl:
    :param df_syllabus:
    :param dff:
    :return:

    """

    start, end = int(df_syllabus['start'].iloc[syl].item()), int(df_syllabus['end'].iloc[syl].item())

    updated_baseline = np.nanmedian(dff[roi, start-5:start-2])

    if start == end:
        max_dff = dff[roi, end] - updated_baseline
    else:
        max_dff = np.nanmax(dff[roi, start-1:end+1]) - updated_baseline

    roi_index = np.where(cells == roi)
    if max_dff >= 3*noise[roi_index]+updated_baseline:
        output = 1
    else:
        output = 0

    return output


def get_syl_amp(frame: int, df_frame: pd.DataFrame, fps_ci: float, fps_bh: float):
    start, end = int((frame / fps_ci) * fps_bh), int(((frame + 1) / fps_ci) * fps_bh)
    ta = df_frame.Tail_angle[start:end]
    output = max(np.nanmin(ta), np.nanmax(ta), key=abs)

    return output


def get_syl_type(frame: int, df_frame: pd.DataFrame, fps_ci: float, fps_bh: float):
    max_ta = get_syl_amp(frame, df_frame, fps_ci, fps_bh)

    if abs(max_ta) < 20:
        output = 'F'
    else:
        output = 'S'

    return output


def get_syl_side(syl, df_syllabus):
    if df_syllabus.loc[df_syllabus.syl == syl, 'type'].item() == 'F':
        output = 'F'

    else:

        max_ta = df_syllabus.loc[df_syllabus.syl == syl, 'max_ta'].item()

        if max_ta < 0:
            output = 'ipsi'
        else:
            output = 'contra'

    return output


def get_cells_group(df_summary, fishlabel, plane, F, cells, stat):
    group_plane = df_summary[(df_summary.fishlabel == fishlabel) & (df_summary.plane == plane)]

    # build empty list to fill with cell info
    cells_group = np.zeros(F.shape[0])
    cells_group[:] = np.nan
    cells_group = list(cells_group)
    colors = cells_group.copy()
    side = cells_group.copy()

    # read in the csv files the cells that were input by user as lateral bulbar cells
    try:
        bl_input = list(group_plane['lateral_bulbar'])[0].split(',')
        bulbar_lateral = list(map(int, bl_input))
    except AttributeError:
        print('No bulbar lateral cells found from user, or mistyped in the csv file.')
        bulbar_lateral = list()  # bulbar lateral will be empty

    bulbar_medial = []
    spinal_cord = []
    pontine = []

    for cell in cells:

        if group_plane['direction'].item() == 0:  #  if fish was positioned rostro-caudal direction

            # now define the side of the cell
            if get_pos_y(cell, stat) < int(group_plane['midline']):
                side[cell] = 'ipsi'
            else:
                side[cell] = 'contra'

            # get cell group by user input
            # if cell in bulbar lateral
            if cell in bulbar_lateral:
                if side == 'ipsi':
                    cells_group[cell] = 'bulbar_lateral_ipsi'
                else:
                    cells_group[cell] = 'bulbar_lateral_contra'
                colors[cell] = 2
            # else, get cell gorup by x position vis à vis of x limits defined by user
            else:
                if get_pos_x(cell, stat) > int(group_plane['sc_bulbar']):
                    spinal_cord.append(cell)
                    cells_group[cell] = 'spinal_cord'
                    colors[cell] = 1
                elif int(group_plane['bulbar_pontine']) < get_pos_x(cell, stat) <= int(group_plane['sc_bulbar']):
                    bulbar_medial.append(cell)
                    cells_group[cell] = 'bulbar_medial'
                    colors[cell] = 3
                else:
                    pontine.append(cell)
                    cells_group[cell] = 'pontine'
                    colors[cell] = 4


        else:  #  if fish was positioned left_right

            # now define the side of the cell
            if get_pos_x(cell, stat) < int(group_plane['midline']):
                side[cell] = 'ipsi'
            else:
                side[cell] = 'contra'

            if cell in bulbar_lateral:
                if side[cell] == 'ipsi':
                    cells_group[cell] = 'bulbar_lateral_ipsi'
                else:
                    cells_group[cell] = 'bulbar_lateral_contra'
                colors[cell] = 2
            # else, get cell gorup by x position vis à vis of x limits defined by user
            else:
                if get_pos_y(cell, stat) < int(group_plane['sc_bulbar']):
                    spinal_cord.append(cell)
                    cells_group[cell] = 'spinal_cord'
                    colors[cell] = 1
                elif int(group_plane['bulbar_pontine']) > get_pos_y(cell, stat) >= int(group_plane['sc_bulbar']):
                    bulbar_medial.append(cell)
                    cells_group[cell] = 'bulbar_medial'
                    colors[cell] = 3
                else:
                    pontine.append(cell)
                    cells_group[cell] = 'pontine'
                    colors[cell] = 4



    print(
        'Spinal cord: {} cells.\nBulbar medial: {} cells.'.format(str(len(spinal_cord)), str(len(bulbar_medial))))
    print(
        'Bulbar lateral: {} cells.\n Pontine: {} cells.'.format(str(len(bulbar_lateral)), str(len(pontine))))

    return np.array(cells_group), np.array(side), bulbar_lateral, bulbar_medial, pontine, spinal_cord


def get_pos_x(cell_number, stat):
    """Middle position on the short axis"""
    return stat[cell_number]['med'][0]


def get_pos_y(cell_number, stat):
    """Returns the middle position of the cell masks on the long axis"""
    return stat[cell_number]['med'][1]

In [5]:

df_summary = pd.read_csv('/network/lustre/iss01/wyart/analyses/2pehaviour/MLR_analyses/data_summary_BH.csv')
fishlabel, plane = '210121_F04', '70um_bh'
print(fishlabel, plane)
output_path = df_summary.loc[(df_summary.fishlabel == fishlabel) & (df_summary.plane == plane), 'output_path'].item()

shelve_out = shelve.open(output_path + '/shelve_calciumAnalysis.out')

cells = shelve_out['cells']
dff = shelve_out['dff_f_lp_inter']
df_bouts = pd.read_pickle(output_path + '/dataset/df_bout')
df_frame = pd.read_pickle(output_path + '/dataset/df_frame')
tail_angle = shelve_out['tail_angle']
noise = shelve_out['noise_f_lp']
stat = shelve_out['stat']
fps_ci = shelve_out['fps']
fps_bh = shelve_out['fps_beh']

shelve_out.close()


side_lim = df_summary.loc[(df_summary.fishlabel == fishlabel) & (df_summary.plane == plane), 'midline'].item()

# cells group

cells_group, side, bulbar_lateral, bulbar_medial, pontine, spinal_cord = get_cells_group(df_summary, fishlabel, plane,
                                                                                         dff, cells, stat)
cells_x_pos = np.array(pd.Series(cells).apply(get_pos_x, args=(stat,)))
cells_y_pos = np.array(pd.Series(cells).apply(get_pos_y, args=(stat,)))


210121_F04 70um_bh
Spinal cord: 0 cells.
Bulbar medial: 27 cells.
Bulbar lateral: 4 cells.
 Pontine: 2 cells.


  after removing the cwd from sys.path.


In [7]:
df = pd.DataFrame({'fishlabel': [fishlabel] * len(cells) * dff.shape[1],
                   'plane': [plane] * len(cells) * dff.shape[1],
                   'roi': np.repeat(cells, dff.shape[1]),
                   'frame': np.tile(np.arange(dff.shape[1]), len(cells)),
                   'dff': dff[cells, :].flatten(),
                   'dif_dff': [np.nan] * len(cells) * dff.shape[1],
                   'frame_state': [np.nan] * len(cells) * dff.shape[1],
                   'bout_type': [np.nan] * len(cells) * dff.shape[1],
                   'max_ta': [np.nan] * len(cells) * dff.shape[1],
                   'max_ta_bout': [np.nan] * len(cells) * dff.shape[1],
                   'duration_bout': [np.nan] * len(cells) * dff.shape[1],
                   'roi_group': np.repeat(cells_group[cells], dff.shape[1]),
                   'roi_side': np.repeat(side[cells],
                                         dff.shape[1]),
                   'roi_x_pos': np.repeat(cells_x_pos, dff.shape[1]),
                   'roi_y_pos': np.repeat(cells_y_pos, dff.shape[1])})

# Remove frames without behavior
#
binary_bh = np.load(output_path + '/dataset/tail_angle_binary.npy')
to_keep = np.where(binary_bh != 0)[0]
print(len(to_keep))
df = df[df.frame.isin(to_keep)]
print(df.shape)

#
# FIRST: easy path of looking only at bout types and bout duration

df['bout_type'] = pd.Series(df.frame).apply(get_bout_type, args=(df_bouts, df_frame, fps_ci, fps_bh))
df['max_ta_bout'] = pd.Series(df.frame).apply(get_bout_amp, args=(df_bouts, df_frame, fps_ci, fps_bh))
df['duration_bout'] = pd.Series(df.frame).apply(get_bout_duration, args=(df_bouts, df_frame, fps_ci, fps_bh))

for roi in cells:
    df.loc[df.roi == roi, 'dif_dff'] = df.frame.apply(get_dif_dff, args=(roi, dff))
    try:
        df.loc[df.roi == roi, 'frame_state'] = df.frame.apply(get_frame_state, args=(roi, dff, noise))
    except IndexError:
        noise_f = np.zeros(dff.shape[0])
        noise_f[cells] = noise
        df.loc[df.roi == roi, 'frame_state'] = df.frame.apply(get_frame_state, args=(roi, dff, noise_f))

fig, axs = plt.subplots(1, 2, sharey=True)
axs[0].set_title('Bulbar lateral')
axs[1].set_title('Bulbar medial')
sns.scatterplot(data=df[df.roi_group.isin(['bulbar_lateral_ipsi', 'bulbar_lateral_contra'])], x='duration_bout',
                y='dif_dff', hue='bout_type', ax=axs[0])
sns.scatterplot(data=df[df.roi_group == 'bulbar_medial'], x='duration_bout', y='dif_dff', hue='bout_type', ax=axs[1])



97
(3104, 15)


RecursionError: maximum recursion depth exceeded

In [9]:
i

NameError: name 'i' is not defined