In [None]:
import os
import numpy as np

import nept
from load_data import assign_occset_label
from plotting import plot_behavior, plot_duration

In [None]:
thisdir = 'E:\\code\\emi_biconditional'
data_filepath = os.path.join(thisdir, 'cache', 'data', 'spring2017')
output_filepath = os.path.join(thisdir, 'plots', 'fall2017')

In [None]:
magazine_session = '!2017-04-14'

In [None]:
sessions = ['!2017-05-30', '!2017-05-31', '!2017-06-01']

rats = ['R141', 'R142', 'R143', 'R144', 'R145', 'R146', 'R147', 'R148']
groups = [1, 2, 2, 1, 2, 1, 1, 2]
males = ['R141', 'R143', 'R145', 'R147']
females = ['R142', 'R144', 'R146', 'R148']
group1 = ['R141', 'R144', 'R146', 'R147']
group2 = ['R142', 'R143', 'R145', 'R148']

In [None]:
cue_duration = 10.

data = dict()
for rat in rats:
    data[rat] = nept.Rat(rat, group1, group2)

for session in sessions:
    rats_data = nept.load_medpc(os.path.join(data_filepath, session), assign_occset_label)

    for rat in rats:
        iti_starts = []
        iti_stops = []
        for trial in ['trial1', 'trial2']:
            iti_starts.extend(rats_data[rat][trial].starts - cue_duration)
            iti_stops.extend(rats_data[rat][trial].starts)

        rats_data[rat]['pre_cs'] = nept.Epoch(np.vstack([iti_starts, iti_stops]))

        post_rewarded_starts = []
        post_rewarded_stops = []
        for trial in ['trial2']:
            post_rewarded_starts.extend(rats_data[rat][trial].stops)
            post_rewarded_stops.extend(rats_data[rat][trial].stops + cue_duration)
        rats_data[rat]['post_rewarded'] = nept.Epoch(np.vstack([post_rewarded_starts, post_rewarded_stops]))

        post_unrewarded_starts = []
        post_unrewarded_stops = []
        for trial in ['trial1']:
            post_unrewarded_starts.extend(rats_data[rat][trial].stops)
            post_unrewarded_stops.extend(rats_data[rat][trial].stops + cue_duration)
        rats_data[rat]['post_unrewarded'] = nept.Epoch(np.vstack([post_unrewarded_starts, post_unrewarded_stops]))

    for rat, group in zip(rats, groups):
        data[rat].add_long_feature_session(mags=rats_data[rat]['mags'],
                                           pellets=rats_data[rat]['pellets'],
                                           lights1=rats_data[rat]['lights1'],
                                           lights2=rats_data[rat]['lights2'],
                                           sounds1=rats_data[rat]['sounds1'],
                                           trial1=rats_data[rat]['trial1'],
                                           trial2=rats_data[rat]['trial2'],
                                           group=group)

In [None]:
n_sessions = len(data[rats[0]].sessions)
print('n_sessions:', n_sessions)

In [None]:
df = nept.combine_rats(data, rats, n_sessions)
df[:20]

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.ticker import FuncFormatter

sns.set_style("white")
sns.set_style("ticks")

def plot_behavior(df, rats, filepath=None, only_sound=False, by_outcome=False, change_sessions=None, xlim=None):
    if change_sessions is None:
        change_sessions = []

    rat_idx = np.zeros(len(df), dtype=bool)
    for rat in rats:
        rat_idx = rat_idx | (df['rat'] == rat)
    rats_df = df[rat_idx]

    if only_sound:
        colours = ["#4393c3", "#b2182b", "#d6604d", "#2166ac", 'k', '#fe9929', '#f768a1']
    else:
        colours = ["#9970ab", "#4393c3", "#762a83", "#b2182b", "#5aae61",
                   "#d6604d", "#1b7837", "#2166ac", 'k', '#fe9929', '#f768a1']

    g = sns.FacetGrid(data=rats_df, col="measure", sharey=False, size=3, aspect=1.)
    plt.gca().xaxis.set_major_formatter(FuncFormatter(lambda x, _: int(x)))
    if by_outcome:
        colours = ["#9970ab", "#d6604d", "#1b7837", "#2166ac", 'k', '#fe9929', '#f768a1']
        g.map_dataframe(sns.tsplot, time="session", unit="trial", condition="rewarded", value="value",
                        err_style="ci_band", ci=68, color=colours)
        legend_dist = 1.
    else:
        g.map_dataframe(sns.tsplot, time="session", unit="trial", condition="condition", value="value",
                        err_style="ci_band", ci=68, color=colours)
        legend_dist = 1.
    g.set_axis_labels("Session", "Value")
    for ax, label in zip(g.axes[0], ["Duration in food cup (s)",
                                     "# of entries",
                                     "Latency to first entry (s)",
                                     "Percent responses"]):
        ax.set_title("")
        ax.set_ylabel(label)

        if len(change_sessions) == 1:
            ax.axvspan(change_sessions[0], rats_df['session'].max(), color='#cccccc', alpha=0.3)
        elif len(change_sessions) == 2:
            ax.axvspan(change_sessions[0], change_sessions[1]-1, color='#cccccc', alpha=0.3)
        elif len(change_sessions) == 3:
            ax.axvspan(change_sessions[0], change_sessions[1]-1, color='#cccccc', alpha=0.3)
            ax.axvspan(change_sessions[2], rats_df['session'].max(), color='#cccccc', alpha=0.3)

        if xlim is not None:
            ax.set_xlim(xlim)

    plt.tight_layout()
    plt.legend(bbox_to_anchor=(legend_dist, 1.))
    if filepath is not None:
        plt.savefig(filepath, bbox_inches='tight')
        plt.close()
    else:
        plt.show()

In [None]:
for rat in ['R142']:
#     filename = rat + '_outcome_behavior.png'
#     filepath = os.path.join(output_filepath, filename)
    plot_behavior(df, [rat], filepath=None, by_outcome=True)

In [None]:
trial2 = rats_data['R142']['trial2']

In [None]:
lights1 = rats_data['R142']['lights1']

In [None]:
lights2 = rats_data['R142']['lights2']

In [None]:
sounds1 = rats_data['R142']['sounds1']

In [None]:
mags = rats_data['R142']['mags']

In [None]:
trial1light2 = mags.intersect(rats_data['R142']['trial1'].intersect(rats_data['R142']['lights2']))
trial1sound1 = mags.intersect(rats_data['R142']['trial1'].intersect(rats_data['R142']['sounds1']))
trial1light1 = mags.intersect(rats_data['R142']['trial1'].intersect(rats_data['R142']['lights1']))
trial1sound1 = mags.intersect(rats_data['R142']['trial1'].intersect(rats_data['R142']['sounds1']))
trial2light1 = mags.intersect(rats_data['R142']['trial2'].intersect(rats_data['R142']['lights1']))
trial2sound1 = mags.intersect(rats_data['R142']['trial2'].intersect(rats_data['R142']['sounds1']))
trial2light2 = mags.intersect(rats_data['R142']['trial2'].intersect(rats_data['R142']['lights2']))
trial2sound1 = mags.intersect(rats_data['R142']['trial2'].intersect(rats_data['R142']['sounds1']))

In [None]:
trialsound.durations

In [None]:
trialsound.starts

In [None]:
yy.durations

In [None]:
plt.plot(trial2.starts, np.ones(len(trial2.starts)), 'o', ms=10, color='b')
plt.plot(trial2.stops, np.ones(len(trial2.stops)), 'o', ms=10, color='b')

plt.plot(triallight.starts, np.ones(len(triallight.starts)), 'o', ms=6, color='m')
plt.plot(triallight.stops, np.ones(len(triallight.stops)), 'o', ms=10, color='m')

plt.plot(trialsound.starts, np.ones(len(trialsound.starts)), 'o', ms=5, color='c')
plt.plot(trialsound.stops, np.ones(len(trialsound.stops)), 'o', ms=5, color='c')

# plt.plot(mags.starts, np.zeros(len(mags.starts)), '.')

plt.xlim(2000, 2200)
plt.show()

In [None]:
for trial in yy:
    print(trial.durations)