In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib
import math
import numpy as np
import scipy
import os
import nept

In [None]:
thisdir = os.getcwd()
# data_filepath = os.path.join(thisdir, "cache", "data")
output_filepath = os.path.join(thisdir, "plots", "experience")
if not os.path.exists(output_filepath):
    os.makedirs(output_filepath)

In [None]:
import info.r191_exp05 as info

In [None]:
data_filepath = os.path.join("E:\\", "data", "data-experience", info.rat_id, "RR1", info.rat_id+"_"+info.date+"_recording")

In [None]:
def get_exp_trials(events):
    starts = np.sort(np.append(events["trial_start"], [info.task_times["run1"].start, info.task_times["run2"].start]))
    stops = events["trial_end"]

    return nept.Epoch(starts, stops)

In [None]:
event_filename = info.date+"_Events.nev"
events = nept.load_events(os.path.join(data_filepath, event_filename), info.event_labels)

In [None]:
trials = get_exp_trials(events)

In [None]:
trials.durations

In [None]:
txt_filepath = os.path.join(data_filepath, info.date+"_experience.txt")
with open(txt_filepath) as f:
    trial_types = f.read().splitlines() 
f.close()

In [None]:
trial_types = np.concatenate([trial_types[5:41], trial_types[42:]])

In [None]:
len(trial_types)

In [None]:
trials.n_epochs

In [None]:
letssee = []
for i in range(trials.n_epochs):
    letssee.append((trial_types[i], trials.durations[i]))

In [None]:
letssee

In [None]:
arms = ["North", "South", "East", "West"]

rewarded = dict()
unrewarded = dict()
for arm in ["North", "South", "East", "West"]:
    rewarded[arm] = []
    unrewarded[arm] = []
for trial in letssee:
    for arm in arms:
        if trial[0] == arm+" +":
            rewarded[arm].append(trial[1])
        elif trial[0] == arm+" -":
            unrewarded[arm].append(trial[1])

In [None]:
np.mean(rewarded["North"]), np.mean(unrewarded["North"])

In [None]:
north = rewarded["North"] + unrewarded["North"]
print(north)

In [None]:
latencies = dict()
latencies["North"] = np.mean(rewarded["North"] + unrewarded["North"])
latencies["South"] = np.mean(rewarded["South"] + unrewarded["South"])
latencies["East"] = np.mean(rewarded["East"] + unrewarded["East"]) 
latencies["West"] = np.mean(rewarded["West"] + unrewarded["West"])
print(latencies)

fig = plt.figure()
ax = fig.add_subplot(111)

arms = ["West", "East", "South", "North"]
colors = ["#0570b0ff", "#74a9cfff", "#3690c0ff", "#74c476ff"]

frequencies = [latencies[arm] for arm in arms]

x_coordinates = np.arange(len(arms))
probe_choices = ax.bar(x_coordinates, frequencies, align='center', color=colors)

arm_to_outcome = ["high", "low", "medium", "control"]
plt.xticks(x_coordinates, arm_to_outcome)
plt.ylabel("Latency (s)")

ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.yaxis.set_ticks_position('left')
ax.xaxis.set_ticks_position('bottom')

plt.title("Latency by arm\n(Example session from Rat1)")

# plt.show()
plt.savefig(os.path.join(output_filepath, "example_mean_latency.png"))
plt.close()

In [None]:
lfp = nept.load_lfp(os.path.join(data_filepath, info.lfp_swr_filename))

In [None]:
position_filename = info.date+"_VT1.nvt"
position = nept.load_position(os.path.join(data_filepath, position_filename), pxl_to_cm=[info.pxl_to_cm["x"], info.pxl_to_cm["y"]])

In [None]:
phase = 'run1'
sliced_position = position.time_slice(info.task_times[phase].start, info.task_times[phase].stop)

plt.plot(sliced_position.x, sliced_position.y, ".", ms=1)
plt.show()

In [None]:
phase = 'run2'
sliced_position = position.time_slice(info.task_times[phase].start, info.task_times[phase].stop)

plt.plot(sliced_position.x, sliced_position.y, ".", ms=1)
plt.show()

In [None]:
sliced_lfp = lfp.time_slice(info.task_times["rest2"].start, info.task_times["rest2"].stop)

# Find SWRs
z_thresh = 2.0
power_thresh = 3.0
merge_thresh = 0.02
min_length = 0.05
swrs = nept.detect_swr_hilbert(sliced_lfp, 
                               fs=info.fs, 
                               thresh=(140.0, 250.0), 
                               z_thresh=z_thresh,
                               merge_thresh=merge_thresh, 
                               min_length=min_length)

In [None]:
for i in range(10):
    ok = sliced_lfp.time_slice(swrs.starts[i], swrs.stops[i])

    buffer = 0.1
    plt.plot(sliced_lfp.time, sliced_lfp.data)
    plt.plot(ok.time, ok.data)
    plt.xlim(ok.time[0]-buffer, ok.time[-1]+buffer)
    plt.show()