In [1]:
import os
import numpy as np
import statistics
import math

import behavior_analysis_helper_functions as helper

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
sns.set_theme()
colors = ["#fd7f6f", "#7eb0d5", "#b2e061", "#bd7ebe", "#ffb55a", "#ffee65", "#beb9db", "#fdcce5", "#8bd3c7"]

## Session

In [None]:
def load_session_meta(data_folder, dir_name, file_name):
    path = os.path.join(data_folder, dir_name, file_name)
    session_meta = pd.read_csv(path, nrows=1)
    return session_meta

In [None]:
def load_session(data_folder, dir_name, file_name):
    path = os.path.join(data_folder, dir_name, file_name)
    session_df = pd.read_csv(path, skiprows=3)
    return session_df

In [None]:
def get_session_basic(session_df):
    total_blocks = session_df.block_num.max()
    total_trials = session_df.session_trial_num.max()
    total_reward = round(session_df.reward_size.sum(), 2)
    total_time = round((session_df.session_time.max() - session_df.session_time.min()), 2)
    session_basic = [total_blocks, total_trials, total_reward, total_time]
    return session_basic

In [3]:
# get a list of all folder names in data folder
cohort = 'cohort_1'
cohort_folder = f'/Users/rebekahzhang/Documents/shuler_lab/behavior_data/{cohort}'
data_folder = os.path.join(cohort_folder, 'behavior_data_full_clean')
print(data_folder)
dir_list = os.listdir(data_folder)
dir_list.remove('.DS_Store')

/Users/rebekahzhang/Documents/shuler_lab/behavior_data/cohort_1/behavior_data_full_clean


In [4]:
# make a table of session info with date, mouse, folder name, and file name
date_list = []
mouse_list = []
filename_list = []
for f in dir_list:
    date_list.append(f[0:10])
    mouse = f[-5:]
    mouse_list.append(mouse)
    filename_list.append(f'data_{mouse}_{f[0:19]}.txt')
session_log = pd.DataFrame({'date': date_list, 'mouse': mouse_list, 
                            'dir': dir_list, 'filename': filename_list})

In [5]:
# get the type of training from session meta data for each session
training_list = []
for dir_name, file_name in zip(session_log.dir, session_log.filename):
    session_meta = helper.load_session_meta(data_folder, dir_name, file_name)
    training = session_meta.training.tolist()[0]
    training_list.append(training)
session_log['training'] = training_list

AttributeError: module 'behavior_analysis_helper_functions' has no attribute 'load_session_meta'

In [None]:
# get a df of only regular sessions, and sort by date time
training_session_log = session_log.loc[session_log.training == 'regular'].sort_values('dir').reset_index()
training_session_log.head()

In [None]:
# add columns of basic info of number of blocks, trials, total rewards, and total time to each training session
session_basics_columns = ['num_blocks', 'num_trials', 'rewards', 'time']
column_names = training_session_log.columns.values.tolist() + session_basics_columns
training_session_log = training_session_log.reindex(columns=column_names)
# add basic info of num
for dir_name, file_name in zip(session_log.dir, session_log.filename):
    session = load_session(data_folder, dir_name, file_name)
    session_basic = get_session_basic(session)
    training_session_log.loc[training_session_log.dir == dir_name, session_basics_columns] = session_basic

In [None]:
# no background in keys prior to 2023-02-20, filtered out sessions prior to that date
# no need to run if dataset is clean
training_session_log = training_session_log.loc[training_session_log.date > '2023-02-20']
training_session_log.head()

In [None]:
# prints mouse names in data base, check for weird ones and delete from data base
mouse_list = session_log.mouse.unique().tolist()
mouse_list.sort()
print(mouse_list)

In [None]:
# session screening
# should print nothing if all trials are reg and long enough
#### can upgrade to deleting wrong ones from data folder

# prints short sessions, to be deleted from dataset folder
short_session = training_session_log.loc[(training_session_log['training'] == 'regular') & 
                                         (training_session_log['num_trials'] < 100)] 
print(short_session.dir)

# prints dates of the same mouse with multiple sessions
for d in session_log.date.unique().tolist():
    session_of_the_day = session_log.loc[session_log['date'] == d]
    for mouse in mouse_list:
        count = len(session_of_the_day.loc[session_of_the_day['mouse'] == mouse])
        if count > 1:
            print(d, mouse)   

## Per session analysis 

In [None]:
# makes a list of 0 to total trial number, used to loop in the session
def generate_total_trial_list(session_log, dir_name):
    current_session = session_log.loc[session_log.dir == dir_name]
    total_trial_list = range(int(current_session.num_trials.tolist()[0]) + 1)
    return total_trial_list

In [None]:
# makes an empty df with each row being a trial, and each column with trial info
# trial number is added to the df
all_trials_column_names = ['trial_num', 'block_num', 'start_time', 'end_time', 'bg_repeats', 'blk_bg_avg',
                'bg_length', 'reward_size', 'miss_trial', 'time_waited', 'num_consumption_lick']

def generate_all_trials_df(column_names, total_trial_list):
    all_trials = pd.DataFrame(columns=column_names)
    all_trials['trial_num'] = total_trial_list
    return all_trials

In [None]:
# gets 5 basic things about the trial
# takes raw data of each trial as input
def get_trial_basics(trial):
    block_num = trial.loc[(trial['key'] == 'trial') & (trial['value'] == 1), 'block_num'].iloc[0]
    start_time = trial.loc[(trial['key'] == 'trial') & (trial['value'] == 1), 'session_time'].iloc[0]
    end_time = trial.loc[(trial['key'] == 'trial') & (trial['value'] == 0), 'session_time'].iloc[0]
    bg_repeat = trial['key'].value_counts()['background']
    blk_bg_avg = float(trial.loc[(trial['key'] == 'trial') & (trial['value'] == 1), 'time_bg'].iloc[0])
    return [block_num, start_time, end_time, bg_repeat, blk_bg_avg]

In [None]:
# gets total time in background
# takes raw data of each trial as input
def get_trial_bg_length(trial):
    bg_start_idx = trial.index[(trial['key'] == 'trial') & (trial['value'] == 1)].tolist()
    bg_end_idx = trial.index[(trial['key'] == 'wait') & (trial['value'] == 1)].tolist()
    trial_bg = trial.loc[bg_start_idx[0] : bg_end_idx[0]]
    trial_bg_length = trial_bg.session_time.max() - trial_bg.session_time.min()
    return [trial_bg_length]

In [None]:
# gets 3 values about trial performance
# takes trial raw data as input
def get_trial_performance(trial):
    wait_start_time = trial.loc[(trial['key'] == 'wait') & (trial['value'] == 1), 'session_time'].iloc[0]
    if 'in_consumption' in trial.state.unique() :
        miss_trial = False
        reward = trial.loc[trial['key'] == 'reward', 'reward_size'].iloc[0]
        consumption_start_time = trial.loc[trial['state'] == 'in_consumption', 'session_time'].iloc[0]
        time_waited = consumption_start_time - wait_start_time
    else :
        miss_trial = True
        reward = math.nan
        time_waited = math.nan
    return [reward, miss_trial, time_waited]

In [None]:
# gets the number of consumption licks of each trial
# takes raw data as input 
def get_num_consumption_licks(trial):
    consumption = trial.loc[trial['state'] == 'in_consumption']
    num_consumption_lick = len(consumption.loc[(consumption['key'] == 'lick') & (trial['value'] == 1)])
    return [num_consumption_lick]

In [None]:
# runs individual functions and consolidate all info to one long list to be added to session log
def get_trial_data(trial):
    trial_basics = get_trial_basics(trial)
    trial_bg_length = get_trial_bg_length(trial)
    trial_performance = get_trial_performance(trial)
    num_consumption_lick = get_num_consumption_licks(trial)
    trial_data = [trial_basics + trial_bg_length + trial_performance + num_consumption_lick]
    return trial_data

In [None]:
%%time
# for each session, generates all trials df and saves it in the raw data folder
# skips if all trials already exists in folder
for dir_name, file_name in zip(training_session_log.dir, training_session_log.filename):
    filename = f'{dir_name}_all_trials.csv'
    path = os.path.join(data_folder, dir_name, filename)
    if os.path.isfile(path):
        continue
    
    session = load_session(data_folder, dir_name, file_name)
    total_trial_list = generate_total_trial_list(training_session_log, dir_name)
    all_trials = generate_all_trials_df(all_trials_column_names, total_trial_list)
    for i in total_trial_list:
        trial = session.loc[session['session_trial_num'] == i]
        trial_data = get_trial_data(trial)
        all_trials.loc[all_trials.trial_num == i, all_trials_column_names[1:]] = trial_data
        
    all_trials.to_csv(path)

## adding session info based on per trial performance to master log

In [None]:
def load_all_trials(data_folder, dir_name):
    filename = f'{dir_name}_all_trials.csv'
    path = os.path.join(data_folder, dir_name, filename)
    all_trials_df = pd.read_csv(path)
    return all_trials_df

In [None]:
def select_good_trials(all_trials):
    good_trials = all_trials.loc[(all_trials['miss_trial'] == False) & (all_trials['bg_repeats'] == 1)]
    return good_trials

In [None]:
def get_session_performance(all_trials):
    num_miss_trials = all_trials.miss_trial.values.sum()
    good_trials = all_trials.loc[(all_trials['miss_trial'] == False) & (all_trials['bg_repeats'] == 1)]
    num_good_trials = len(good_trials)
    return [num_miss_trials, num_good_trials]

In [None]:
def get_session_mistakes(all_trials):
    num_bg_repeats_mean = all_trials.bg_repeats.mean()
    num_bg_repeats_med = all_trials.bg_repeats.median()
    num_bg_repeats_std = all_trials.bg_repeats.std()
    return [num_bg_repeats_mean, num_bg_repeats_med, num_bg_repeats_std]

In [None]:
def get_session_time_waited(all_trials):
    tw_mean = all_trials.time_waited.mean()
    tw_med = all_trials.time_waited.median()
    tw_std = all_trials.time_waited.std()
    return [tw_mean, tw_med, tw_std]

### number of days in training

In [None]:
for mouse in mouse_list:
    total_days = sum(training_session_log.mouse == mouse)
    training_session_log.loc[training_session_log.mouse == mouse, 'days'] = list(range(total_days))

### engagement of each session
miss trials and good trials, and the proportions

In [None]:
training_session_log['miss_trials'] = ''
training_session_log['good_trials'] = ''

In [None]:
for dir_name in training_session_log.dir:
    all_trials = load_all_trials(data_folder, dir_name)
    session_performance = get_session_performance(all_trials)
    training_session_log.loc[training_session_log.dir == dir_name, 
                             ['miss_trials', 'good_trials']] = session_performance

In [None]:
training_session_log['p_miss'] = training_session_log.miss_trials/training_session_log.num_trials
training_session_log['p_good'] = training_session_log.good_trials/training_session_log.num_trials
training_session_log['p_rest'] = 1 - training_session_log.p_good - training_session_log.p_miss

In [None]:
# plot engagement in session. each mouse has a plot
for mouse in mouse_list:
    session_mouse = training_session_log.loc[training_session_log.mouse == mouse]
    to_plot = session_mouse.loc[:, ['days', 'p_good','p_miss','p_rest']]
    fig = plt.figure()
    ax = to_plot.plot.bar(x='days', stacked=True, color = colors[0:3])
    ax.set_title(mouse)
    ax.set_xlabel('Days in Training')
    ax.set_ylabel('Normalized Proportion')
    ax.legend(loc='upper right', bbox_to_anchor=(1.25, 1))
    plt.plot(to_plot.days, to_plot.p_good, color = 'k', linewidth=1.5)
    plt.savefig(f'engagement_{mouse}', bbox_inches='tight')

In [None]:
# plot percent good trials in session. each mouse is a line
fig = plt.figure()
for i, mouse in enumerate(mouse_list):
    session_mouse = training_session_log.loc[training_session_log.mouse == mouse]
    plt.plot(session_mouse.days, session_mouse.p_good, color=colors[i], label=mouse)
plt.title('Proportion of Good Trials ')
plt.xlabel('Days in Training')
plt.legend(loc="upper left")
plt.xlim([0, 11])
plt.ylim([0, 1])
fig.savefig('%_good_trials.png', bbox_inches='tight')

In [None]:
# plot percent missed trials in session. each mouse is a line
fig = plt.figure()
for i, mouse in enumerate(mouse_list):
    session_mouse = training_session_log.loc[training_session_log.mouse == mouse]
    plt.plot(session_mouse.days, session_mouse.p_miss, color=colors[i], label=mouse)
plt.title('Proportion of Missed Trials')
plt.xlabel('Days in Training')
plt.legend(loc="upper right")
plt.xlim([0, 11])
plt.ylim([0, 0.2])
fig.savefig('%_missed_trials.png', bbox_inches='tight')

### performance of each session
reward rate, bg_repeats

In [None]:
training_session_log['reward_rate'] = training_session_log.rewards/training_session_log.time

In [None]:
# plot the change in reward rate over days. each mouse is a line 
fig = plt.figure()
for i, mouse in enumerate(mouse_list):
    session_mouse = training_session_log.loc[training_session_log.mouse == mouse]
    plt.plot(session_mouse.days, session_mouse.reward_rate, color=colors[i], label=mouse)

plt.xlabel('Days in Training')
plt.ylabel('Reward Rate (ul/s)')
plt.legend(loc="upper left")
plt.xlim([0, 11])
plt.ylim([0, 0.4])
fig.savefig('reward_rate.png', bbox_inches='tight')

In [None]:
training_session_log['num_bg_repeats_mean'] = ''
training_session_log['num_bg_repeats_median'] = ''
training_session_log['num_bg_repeats_stdev'] = ''

In [None]:
for dir_name in training_session_log.dir:
    all_trials = load_all_trials(data_folder, dir_name)
    session_mistakes = get_session_mistakes(all_trials)
    training_session_log.loc[training_session_log.dir == dir_name, 
                             ['num_bg_repeats_mean', 'num_bg_repeats_median', 
                              'num_bg_repeats_stdev']] = session_mistakes

In [None]:
# plots the number of bg repeats across days for each mouse. one mouse per plot. 
for mouse in mouse_list:
    session_mouse = training_session_log.loc[training_session_log.mouse == mouse]
    
    fig = plt.figure()
    plt.errorbar(session_mouse.days, session_mouse.num_bg_repeats_mean, session_mouse.num_bg_repeats_stdev, 
                 linestyle='None', marker='o', color=colors[0], label='mean')
    plt.scatter(session_mouse.days, session_mouse.num_bg_repeats_median, color=colors[1], label='median')
    plt.title(mouse)
    plt.xlabel('Days in Training')
    plt.ylabel('Number of BG Repeats')
    plt.legend(loc='upper right')
    plt.savefig(f'bg_repeats_{mouse}', bbox_inches='tight')

In [None]:
# plot percent mean bg repeats in session for all mice. each mouse is a line
fig = plt.figure()
for i, mouse in enumerate(mouse_list):
    session_mouse = training_session_log.loc[training_session_log.mouse == mouse]
    plt.plot(session_mouse.days, session_mouse.num_bg_repeats_mean, color=colors[i], label=mouse)
plt.xlabel('Days in Training')
plt.ylabel('Number of Mean BG Repeats')
plt.legend(loc='upper right')
plt.savefig(f'bg_repeats_mean', bbox_inches='tight')

In [None]:
# plot percent stdev bg repeats in session for all mice. each mouse is a line
fig = plt.figure()
for i, mouse in enumerate(mouse_list):
    session_mouse = training_session_log.loc[training_session_log.mouse == mouse]
    plt.plot(session_mouse.days, session_mouse.num_bg_repeats_stdev, color=colors[i], label=mouse)
plt.xlabel('Days in Training')
plt.ylabel('Stdev of BG Repeats')
plt.legend(loc='upper right')
plt.savefig(f'bg_repeats_stdev', bbox_inches='tight')

### wait behavior of each session

In [None]:
training_session_log['tw_mean'] = ''
training_session_log['tw_median'] = ''
training_session_log['tw_stdev'] = ''

In [None]:
for dir_name in training_session_log.dir:
    all_trials = load_all_trials(data_folder, dir_name)
    session_time_waited = get_session_time_waited(all_trials)
    training_session_log.loc[training_session_log.dir == dir_name, 
                             ['tw_mean', 'tw_median', 'tw_stdev']] = session_time_waited

In [None]:
# plot time waited across training days for each mouse. one plot per mouse. 
for mouse in mouse_list:
    session_mouse = training_session_log.loc[training_session_log.mouse == mouse]
    
    fig = plt.figure()
    plt.errorbar(session_mouse.days, session_mouse.tw_mean, session_mouse.tw_stdev, 
                 linestyle='None', marker='o', color=colors[0], label='mean')
    plt.scatter(session_mouse.days, session_mouse.tw_median, color=colors[1], label='median')
    plt.title(mouse)
    plt.xlabel('Days in Training')
    plt.ylabel('Time Waited (s)')
    plt.legend(loc='upper right')
    plt.savefig(f'tw_{mouse}', bbox_inches='tight')

In [None]:
# plot percent mean time waited in session. each mouse is a line.
fig = plt.figure()
for i, mouse in enumerate(mouse_list):
    session_mouse = training_session_log.loc[training_session_log.mouse == mouse]
    plt.plot(session_mouse.days, session_mouse.tw_mean, color=colors[i], label=mouse)

plt.xlabel('Days in Training')
plt.ylabel('Mean Time Waited (s)')
plt.legend(loc='upper right')
plt.savefig(f'tw', bbox_inches='tight')

In [None]:
# time waited for only good trials 
training_session_log['tw_good_mean'] = ''
training_session_log['tw_good_median'] = ''
training_session_log['tw_good_stdev'] = ''
for dir_name in training_session_log.dir:
    all_trials = load_all_trials(data_folder, dir_name)
    good_trials = select_good_trials(all_trials)
    session_time_waited = get_session_time_waited(good_trials)
    training_session_log.loc[training_session_log.dir == dir_name, 
                             ['tw_good_mean', 'tw_good_median', 'tw_good_stdev']] = session_time_waited

In [None]:
# plot time waited for only good trials in session. one plot per mouse
for mouse in mouse_list:
    session_mouse = training_session_log.loc[training_session_log.mouse == mouse]
    
    fig = plt.figure()
    plt.errorbar(session_mouse.days, session_mouse.tw_good_mean, session_mouse.tw_good_stdev, 
                 linestyle='None', marker='o', color=colors[0], label='mean')
    plt.scatter(session_mouse.days, session_mouse.tw_good_median, color=colors[1], label='median')
    plt.title(mouse)
    plt.xlabel('Days in Training')
    plt.ylabel('Time Waited (s)')
    plt.legend(loc='upper right')
    plt.savefig(f'tw_good_{mouse}', bbox_inches='tight')

In [None]:
# plot time waited for only good trials in session. each mouse is a line
fig = plt.figure()
for i, mouse in enumerate(mouse_list):
    session_mouse = training_session_log.loc[training_session_log.mouse == mouse]
    plt.plot(session_mouse.days, session_mouse.tw_good_mean, color=colors[i], label=mouse)

plt.xlabel('Days in Training')
plt.ylabel('Mean Time Waited (s)')
plt.legend(loc='upper right')
plt.savefig(f'tw_good', bbox_inches='tight')

In [None]:
# plots all trials of time bg vs time wait of each session
for dir_name in training_session_log.dir:
    all_trials = load_all_trials(data_folder, dir_name)
    fig = plt.figure()
    plt.scatter(all_trials.bg_length, all_trials.time_waited, color=colors[0])
    days_trained = training_session_log.loc[training_session_log.dir == dir_name, 'days'].tolist()[0]
    plt.title (f'Day {days_trained}')
    plt.xlabel('Time in Background (s)')
    plt.ylabel('Time Waited (s)')
    plt.savefig(f'tw_scatter_{dir_name}', bbox_inches='tight')

In [None]:
# plots only good trials of time bg vs time wait of each session
for dir_name in training_session_log.dir:
    all_trials = load_all_trials(data_folder, dir_name)
    good_trials = select_good_trials(all_trials)
    fig = plt.figure()
    plt.scatter(good_trials.bg_length, good_trials.time_waited, color=colors[0])
    days_trained = training_session_log.loc[training_session_log.dir == dir_name, 'days'].tolist()[0]
    plt.title (f'Day {days_trained}')
    plt.xlabel('Time in Background (s)')
    plt.ylabel('Time Waited (s)')
    plt.savefig(f'tw_good_scatter_{dir_name}', bbox_inches='tight')

In [None]:
sns.scatterplot(x="bg_length",
                    y="time_waited",
                    hue="block_num",
                    data=example_all_trials)

In [None]:
example_good_trials = select_good_trials(all_trials)
sns.scatterplot(x="bg_length",
                    y="time_waited",
                    hue="block_num",
                    data=example_good_trials)

### saves the master log with analyzed info

In [None]:
# save master log with analyzed data
filename = 'all_sessions.csv'
path = os.path.join(cohort_folder, filename)
training_session_log.to_csv(path)

## block based analysis

In [None]:
example_all_trials = load_all_trials(data_folder, training_session_log.dir[0])
example_all_trials.head()

In [None]:
# makes a list of 0 to total block number, used to loop in the session
def generate_total_block_list(session_log, dir_name):
    current_session = session_log.loc[session_log.dir == dir_name]
    total_block_list = range(int(current_session.num_blocks.tolist()[0]) + 1)
    return total_block_list

In [None]:
# makes an empty df with each row being a trial, and each column with trial info
# trial number is added to the df
all_trials_column_names = ['block_num', 'block_type', 'start_time', 'end_time', 'bg_repeats', 'blk_bg_avg',
                'bg_length', 'reward_size', 'miss_trial', 'time_waited', 'num_consumption_lick']

def generate_all_blocks_df(column_names, total_block_list):
    all_blocks = pd.DataFrame(columns=column_names)
    all_blocks['block_num'] = total_block_list
    return all_blocks

In [None]:
for dir_name in training_session_log.dir:
    all_trials = load_all_trials(data_folder, dir_name)

## lick analysis across sessions

In [None]:
lick_start = session.loc[(session['key'] == 'lick') & (session['value'] == 1)]
lick_start_times = lick_start['session_time'].tolist()
lick_end = session.loc[(session['key'] == 'lick') & (session['value'] == 0)]
lick_end_times = lick_end['session_time'].tolist()
lick_times = [end - start for end, start in zip(lick_end_times, lick_start_times)]
lick_time_min = min(lick_times)
lick_time_max = max(lick_times)
lick_time_med = statistics.median(lick_times)
lick_time_avg = statistics.mean(lick_times)
lick_time_std = statistics.stdev(lick_times)

In [None]:
for dir_name, file_name in zip(training_session_log.dir, training_session_log.filename):
    session = load_session(data_folder, dir_name, file_name)

## Furture implementation

In [None]:
# delete short sessions. doesnt work right now because of permission issue
short_session = session_log.loc[(session_log['training'] == 'regular') & (session_log['total_trials'] < 100)] 
print(short_session)
# for dir_name in short_session.dir:
#     path = os.path.join(data_folder, dir_name)
#     os.remove(path)

In [None]:
# print out sessions with the same mouse running more than once
mouse_list = session_log.mouse.unique().tolist()
print(mouse_list)
for d in session_log.date.unique().tolist():
    session_of_the_day = session_log.loc[session_log['date'] == d]
    for mouse in mouse_list:
        count = len(session_of_the_day.loc[session_of_the_day['mouse'] == mouse])
        if count > 1:
            print(d, mouse)   

In [None]:
# search for background in key
for dir_name, file_name in zip(session_log.dir, session_log.filename):
    session = load_session(data_folder, dir_name, file_name)
    keys = session['key'].unique().tolist()
    if "background" not in keys:
        print(dir_name)