In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import os
import numpy as np
import scipy
import seaborn as sns
import pandas as pd

from loading_data import get_data
from utils_maze import get_xyedges, get_trial_idx, get_zones
from utils_plotting import plot_proportions, plot_bydurations, plot_bytrial

import vdmlab as vdm

sns.set_style('white')
sns.set_style('ticks')

In [None]:
pickle_filepath = 'E:/code/emi_shortcut/cache/pickled'
output_filepath = 'E:/code/emi_shortcut/plots/intermediate'

In [None]:
from run import days1234_infos, days5678_infos

In [None]:
def combine_behavior(infos):
    durations = dict(u=[], shortcut=[], novel=[])
    n_sessions = 0

    trials = []

    for info in infos:
        print(info.session_id)

        events, position, spikes, lfp, lfp_theta = get_data(info)

        t_start = info.task_times['phase3'].start
        t_stop = info.task_times['phase3'].stop

        sliced_pos = position.time_slice(t_start, t_stop)

        feeder1_times = []
        for feeder1 in events['feeder1']:
            if t_start < feeder1 < t_stop:
                feeder1_times.append(feeder1)

        feeder2_times = []
        for feeder2 in events['feeder2']:
            if t_start < feeder2 < t_stop:
                feeder2_times.append(feeder2)

        path_pos = get_zones(info, sliced_pos)

        trials_idx = get_trial_idx(path_pos['u'].time, path_pos['shortcut'].time, path_pos['novel'].time,
                                   feeder1_times, feeder2_times, t_stop)

        trials.append(trials_idx)

        n_sessions += 1

        for key in durations:
            for trial in trials_idx[key]:
                durations[key].append(trials_idx['stop_trials'][trial[0]] - trials_idx['start_trials'][trial[0]])
    
    return durations, trials, n_sessions

In [None]:
total_n_sessions = 0

durations_together = dict(trajectory=[], value=[], time=[])
trials_together = dict(trajectory=[], value=[], time=[])

durations, trials, n_sessions = combine_behavior(days1234_infos)
total_n_sessions += n_sessions

for key in durations:
    for val in durations[key]:
        durations_together['trajectory'].append(key)
        durations_together['value'].append(val)
        durations_together['time'].append('early 1-4')
        
for key in numbers:
    for trial in trials:
        trials_together['trajectory'].append(key)
        trials_together['value'].append(len(trial[key])/float(len(trial['start_trials'])))
        trials_together['time'].append('early 1-4')

durations, trials, n_sessions = combine_behavior(days5678_infos)
total_n_sessions += n_sessions

for key in durations:
    for val in durations[key]:
        durations_together['trajectory'].append(key)
        durations_together['value'].append(val)
        durations_together['time'].append('later 5-8')
        
for key in numbers:
    for trial in trials:
        trials_together['trajectory'].append(key)
        trials_together['value'].append(len(trial[key])/float(len(trial['start_trials'])))
        trials_together['time'].append('later 5-8')


df_durations = pd.DataFrame(data=durations_together)
df_trials = pd.DataFrame(data=trials_together)

In [None]:
fliersize = 3
flierprops = dict(marker='o', markersize=fliersize, linestyle='none')
plt.figure(figsize=(8, 5))

colour = ['#bf812d', '#35978f']

ax = sns.boxplot(x="trajectory", y="value", hue="time", data=df_durations, palette=colour,
                 flierprops=flierprops)
ax.set(xticklabels=['U', 'Shortcut', 'Novel'])
plt.ylabel('Duration of trial (s)')
plt.xlabel('(sessions=' + str(total_n_sessions) + ')')
plt.ylim(0, 120)
sns.despine(left=False)

# plt.show()
plt.savefig(os.path.join(output_filepath, 'early-late_durations.png'), transparent=True)
plt.close()

In [None]:
plt.figure(figsize=(8, 5))

colour = ['#bf812d', '#35978f']

ax = sns.barplot(x="trajectory", y="value", hue="time", data=df_trials, palette=colour)
ax.set(xticklabels=['U', 'Shortcut', 'Novel'])
plt.ylabel('Proportion of trials')
plt.xlabel('(sessions=' + str(total_n_sessions) + ')')
sns.despine(left=False)

# plt.show()
plt.savefig(os.path.join(output_filepath, 'early-late_proportions.png'), transparent=True)
plt.close()