In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

import numpy as np
import pandas as pd

import os
import nept
from core import Rat, combine_rats
from load_data import assign_label
from plotting import plot_behavior, plot_duration

In [None]:
thisdir = 'E:\code\emi_biconditional'
data_filepath = os.path.join(thisdir, 'cache', 'data', 'spring2017')
output_filepath = os.path.join(thisdir, 'plots', 'spring2017')

In [None]:
magazine_session = '!2017-04-14'

sessions = []
for file in sorted(os.listdir(data_filepath)):
    if file != magazine_session and file[0] == '!':
        sessions.append(os.path.join(data_filepath, file))
        
rats = ['R141', 'R142', 'R143', 'R144', 'R145', 'R146', 'R147', 'R148']
groups = [1, 2, 2, 1, 2, 1, 1, 2]
males = ['R141', 'R143', 'R145', 'R147']
females = ['R142', 'R144', 'R146', 'R148']
group1 = ['R141', 'R144', 'R146', 'R147']
group2 = ['R142', 'R143', 'R145', 'R148']

In [None]:
cue_duration = 10

data = dict()
for rat in rats:
    data[rat] = Rat(rat, group1, group2)

for session in sessions:
    rats_data = nept.load_medpc(os.path.join(data_filepath, session), assign_label)
    
    for rat in rats:
        iti_starts = []
        iti_stops = []
        for trial in ['trial1', 'trial2', 'trial3', 'trial4']:
            iti_starts.extend(rats_data[rat][trial].starts - cue_duration)
            iti_stops.extend(rats_data[rat][trial].starts)
        rats_data[rat]['iti'] = nept.Epoch(np.vstack([iti_starts, iti_stops]))
            
    for rat, group in zip(rats, groups):
        data[rat].add_session(**rats_data[rat], group=group)

n_sessions = len(data[rats[0]].sessions)

df = combine_rats(data, rats, n_sessions)

In [None]:
plot_duration(df, ['R142'], filepath='E:/code/emi_biconditional/plots/zzz.png', by_outcome=True)

In [None]:
df.tail()

In [None]:
import measurements as m
from core import Experiment, Rat, TrialEpoch


expt = Experiment(
    name="201701",
    trial_epochs=[
        TrialEpoch("mags", start_idx=1, stop_idx=2),
        TrialEpoch("baseline", start_idx=4, duration=-10),
        TrialEpoch("baseline", start_idx=6, duration=-10),
        TrialEpoch("light1", start_idx=4, stop_idx=5),
        TrialEpoch("light2", start_idx=6, stop_idx=7),
        TrialEpoch("sound1", start_idx=8, stop_idx=9),
        TrialEpoch("sound2", start_idx=10, stop_idx=11),
        TrialEpoch("trial1", start_idx=12, stop_idx=13),
        TrialEpoch("trial2", start_idx=14, stop_idx=15),
        TrialEpoch("trial3", start_idx=16, stop_idx=17),
        TrialEpoch("trial4", start_idx=18, stop_idx=19),
    ],
    measurements=[m.Duration(), m.Count(), m.Latency(), m.AtLeastOne()],
    rats=[
        Rat('R114', group="1"),
        Rat('R116', group="1"),
        Rat('R117', group="2"),
        Rat('R118', group="1"),
        Rat('R119', group="2"),
        Rat('R120', group="1"),
        Rat('R121', group="2"),
    ],
    magazine_session='!2017-01-17',
)


def add_datapoints(session, data, rat):

    def add_data(cue, trial=None):
        if trial is not None:
            meta = {
                "cue_type": cue[:-1],
                "trial_type": trial[-1],
                "rewarded": "rewarded" if trial[-1] in ("2", "4") else "unrewarded",
                "cue": cue,
            }
            trial = data[trial]
            cue = data[cue]
            session.add_data(rat.rat_id, trial.intersect(cue), meta)
        else:
            meta = {
                "cue_type": cue,
                "trial_type": "",
                "rewarded": "",
                "cue": cue,
            }
            session.add_data(rat.rat_id, data[cue], meta)

    if rat.group == "1":
        add_data("light1", "trial1")
        add_data("sound2", "trial1")
        add_data("light1", "trial2")
        add_data("sound1", "trial2")
        add_data("light2", "trial3")
        add_data("sound1", "trial3")
        add_data("light2", "trial4")
        add_data("sound2", "trial4")
        add_data("baseline")

    elif rat.group == "2":
        add_data("light2", "trial1")
        add_data("sound2", "trial1")
        add_data("light2", "trial2")
        add_data("sound1", "trial2")
        add_data("light1", "trial3")
        add_data("sound1", "trial3")
        add_data("light1", "trial4")
        add_data("sound2", "trial4")
        add_data("baseline")


expt.add_datapoints = add_datapoints
expt.plot_all()

In [None]:
expt.trial_epochs['mags']

In [None]:
expt.trial_epochs[0].

In [None]:
expt.df.head()

In [None]:
b = [1,3,5,7,9,11,13,15,17,18]
len(b)

In [None]:
a = [5, 5, 6, 10, 10, 15, 15, 22, 22, 22, 22, 22]
len(a)

In [None]:
import nept
mag = nept.Epoch([[1, 5, 8], [3, 7, 10]])

In [None]:
mag.starts

In [None]:
mag.stops

In [None]:
def convert_binary(epoch, binsize, gaussian_std=None):
    start = epoch.start
    stop = epoch.stop
    timeline = np.arange(start, stop+binsize, binsize)
    print(timeline)
    binary = np.zeros(len(timeline))
    for i, val in enumerate(timeline):
        if epoch.contains(val):
            binary[i] = 1
    if gaussian_std is not None:
        binary = nept.gaussian_filter(binary, gaussian_std, dt=binsize)
    return binary

In [None]:
start = mag.start
stop = mag.stop
timeline = np.arange(start, stop, 0.5)


In [None]:
mags = convert_binary(mag, binsize=0.5, gaussian_std=None)

In [None]:
plt.plot(mags)
plt.show()

In [None]:
mags

In [None]:
yy = nept.Epoch([[1, 5, 8], [2, 7, 10]])

In [None]:
def overtime(binary_mags, epoch, binsize=1.):
    start = epoch.start
    stop = epoch.stop
    timeline = np.arange(start, stop, binsize)
    new = np.zeros(len(timeline))
    for i, time in enumerate(timeline):
        if mags[i] == 1:
            new[i] == 1
    return new

In [None]:
whatamidoing = overtime(mags, yy, binsize=1.)

In [None]:
whatamidoing

In [None]:
plt.plot(whatamidoing)
plt.show()

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

import os
import numpy as np
import pandas as pd
import seaborn as sns

import nept

from plotting import add_col

import measurements as m
from core import Experiment, Rat, TrialEpoch

In [None]:
binned_expt = Experiment(
    name="201709",
    cache_key="binned",
    trial_epochs=[
        TrialEpoch("mags", start_idx=1, stop_idx=2),
        TrialEpoch("light1", start_idx=4, stop_idx=5),
        TrialEpoch("light2", start_idx=6, stop_idx=7),
    ],
    measurements=[m.Duration()],
    rats=[
        Rat('R155', group="1"),
        Rat('R156', group="2"),
        Rat('R157', group="2"),
        Rat('R158', group="1"),
        Rat('R159', group="2"),
        Rat('R160', group="1"),
        Rat('R161', group="1"),
        Rat('R162', group="2"),
    ],
    magazine_session='!2017-09-20',
    sessionfiles=['!2017-09-21']
)


def add_datapoints(session, data, rat):
    session.add_binned_data(rat.rat_id, data["light1"], binsize=5, info={'cue': 'light1'})
    session.add_binned_data(rat.rat_id, data["light2"], binsize=5, info={'cue': 'light2'})

binned_expt.add_datapoints = add_datapoints
# df = binned_expt.analyze()

In [None]:
binned_df = binned_expt.analyze()
binned_df = binned_df.drop(binned_df.loc[(binned_df.duration == 155) & 
                                         (binned_df.time_start == 150)].index)
binned_df = binned_df.drop(binned_df.loc[(binned_df.duration == 245) & 
                                         (binned_df.time_start == 240)].index)
binned_df = binned_df.drop(binned_df.loc[(binned_df.duration == 275) & 
                                         (binned_df.time_start == 270)].index)

In [None]:
binned_df.loc[binned_df.duration == 155, 'duration'] = 150
binned_df.loc[binned_df.duration == 245, 'duration'] = 240
binned_df.loc[binned_df.duration == 275, 'duration'] = 270

In [None]:
def plot_overtime(df, rats, filepath=None):

    rat_idx = np.zeros(len(df), dtype=bool)
    for rat in rats:
        rat_idx = rat_idx | (df['rat'] == rat.rat_id)
    rats_df = df[rat_idx]

    df = add_col(df, "unit", "rat", "trial")
    g = sns.FacetGrid(data=df, col="duration", sharey=False, size=3, aspect=1.)
    g.map_dataframe(sns.tsplot, time="time_start", unit="unit", condition="cue", value="value",
                    err_style="ci_band", ci=68, color="deep")
    
    for ax in g.axes[0]:
        ax.set_ylabel("Duration in food cup (s)")

    plt.tight_layout()
    handles, labels = ax.get_legend_handles_labels()
    sortedhl = sorted(zip(handles, labels), key=lambda x: x[1])
    plt.legend(*zip(*sortedhl), bbox_to_anchor=(1., 1.))

    plt.tight_layout()
    if filepath is not None:
        mkdirs(os.path.dirname(filepath))
        plt.savefig(filepath, bbox_inches='tight')
        plt.close()
    else:
        plt.show()

In [None]:
plot_overtime(binned_df, rats=binned_expt.rats)

In [None]:
df.duration.unique()

In [None]:
df[(df.duration==155) & (df.time_start == 150)]

In [None]:
maps = list(sorted(m for m in plt.cm.datad if not m.endswith("_r")))

In [None]:
plt.get_cmap('tab20')

In [None]:
_tab20_data = (
	(0.12156862745098039, 0.4666666666666667,  0.7058823529411765  ),  # 1f77b4
	(0.6823529411764706,  0.7803921568627451,  0.9098039215686274  ),  # aec7e8
	(1.0,                 0.4980392156862745,  0.054901960784313725),  # ff7f0e
	(1.0,                 0.7333333333333333,  0.47058823529411764 ),  # ffbb78
	(0.17254901960784313, 0.6274509803921569,  0.17254901960784313 ),  # 2ca02c
	(0.596078431372549,   0.8745098039215686,  0.5411764705882353  ),  # 98df8a
	(0.8392156862745098,  0.15294117647058825, 0.1568627450980392  ),  # d62728
	(1.0,                 0.596078431372549,   0.5882352941176471  ),  # ff9896
	(0.5803921568627451,  0.403921568627451,   0.7411764705882353  ),  # 9467bd
	(0.7725490196078432,  0.6901960784313725,  0.8352941176470589  ),  # c5b0d5
	(0.5490196078431373,  0.33725490196078434, 0.29411764705882354 ),  # 8c564b
	(0.7686274509803922,  0.611764705882353,   0.5803921568627451  ),  # c49c94
	(0.8901960784313725,  0.4666666666666667,  0.7607843137254902  ),  # e377c2
	(0.9686274509803922,  0.7137254901960784,  0.8235294117647058  ),  # f7b6d2
	(0.4980392156862745,  0.4980392156862745,  0.4980392156862745  ),  # 7f7f7f
	(0.7803921568627451,  0.7803921568627451,  0.7803921568627451  ),  # c7c7c7
	(0.7372549019607844,  0.7411764705882353,  0.13333333333333333 ),  # bcbd22
	(0.8588235294117647,  0.8588235294117647,  0.5529411764705883  ),  # dbdb8d
	(0.09019607843137255, 0.7450980392156863,  0.8117647058823529  ),  # 17becf
	(0.6196078431372549,  0.8549019607843137,  0.8980392156862745),    # 9edae5
)

plt.cm.datad['tab20'] = {'listed': _tab20_data}