In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib
import math
import numpy as np
import scipy
from collections import Counter
import os
import nept

In [None]:
thisdir = os.getcwd()
# data_filepath = os.path.join(thisdir, "cache", "data")
output_filepath = os.path.join(thisdir, "plots", "experience")
if not os.path.exists(output_filepath):
    os.makedirs(output_filepath)

In [None]:
import info.r191_exp01 as r191d1
import info.r191_exp02 as r191d2
import info.r191_exp03 as r191d3
import info.r191_exp04 as r191d4
import info.r191_exp05 as r191d5

import info.r192_exp01 as r192d1
import info.r192_exp02 as r192d2
import info.r192_exp03 as r192d3
import info.r192_exp04 as r192d4
import info.r192_exp05 as r192d5

infos = [r191d1, r191d2, r191d3, r191d4, r191d5,
         r192d1, r192d2, r192d3, r192d4, r192d5]

In [None]:
def get_exp_trials(info, events):
    starts = np.sort(np.append(events["trial_start"], [info.task_times["run1"].start, info.task_times["run2"].start]))
    if len(starts) > len(events["trial_end"]):
        stops = np.sort(np.append(events["trial_end"], [info.task_times["run2"].stop]))
    else:
        stops = events["trial_end"]
    return nept.Epoch(starts, stops)

In [None]:
trial_duration = []
phase_duration = []
probe_choices = []
probe_arms = []
for info in infos:
    data_filepath = os.path.join("E:\\", "data", "data-experience", info.rat_id, "RR1", info.rat_id+"_"+info.date+"_recording")

    event_filename = info.date+"_Events.nev"
    events = nept.load_events(os.path.join(data_filepath, event_filename), info.event_labels)

    trials = get_exp_trials(info, events)
    trial_duration.append(np.mean(trials.durations))
    phase_duration.append(info.task_times['run1'].durations[0])
    phase_duration.append(info.task_times['run2'].durations[0])
    
    for choice in info.probe_choice:
        probe_choices.append(info.arm_to_outcome[choice])
    probe_arms.extend(info.probe_choice)
        
print("Mean trial duration:", str(np.mean(trial_duration)), " s")
print("Mean phase duration:", np.mean(phase_duration)/60, "min")

In [None]:
probe_counts = Counter(probe_choices)

print(probe_counts)

highlow_total = probe_counts['Low'] + probe_counts['High']
mediums_total = probe_counts['Medium'] + probe_counts['Control']

probe_proportions = dict()
probe_proportions["High"] = probe_counts['High'] / highlow_total
probe_proportions["Low"] = probe_counts['Low'] / highlow_total
probe_proportions["Medium"] = probe_counts['Medium'] / mediums_total
probe_proportions["Control"] = probe_counts['Control'] / mediums_total
print(probe_proportions)

fig = plt.figure()
ax = fig.add_subplot(111)

arms = ["High", "Low", "Medium", "Control"]
colors = ["#0570b0ff", "#74a9cfff", "#3690c0ff", "#74c476ff"]

frequencies = [probe_proportions[arm] for arm in arms]

x_coordinates = np.arange(len(arms))
ax.bar(x_coordinates, frequencies, align='center', color=colors)

plt.xticks(x_coordinates, arms)
plt.ylabel("Choice proportion")

ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.yaxis.set_ticks_position('left')
ax.xaxis.set_ticks_position('bottom')

plt.title("Probe choice by outcome (n=2)")

plt.show()
# plt.savefig(os.path.join(output_filepath, "probe_outcome_proportions.png"))
# plt.close()

In [None]:
probe_arm_counts = Counter(probe_arms)

print(probe_arm_counts)

northsouth_total = probe_arm_counts['north'] + probe_arm_counts['south']
eastwest_total = probe_arm_counts['east'] + probe_arm_counts['west']

probe_arm_proportions = dict()
probe_arm_proportions["north"] = probe_arm_counts['north'] / northsouth_total
probe_arm_proportions["south"] = probe_arm_counts['south'] / northsouth_total
probe_arm_proportions["east"] = probe_arm_counts['east'] / eastwest_total
probe_arm_proportions["west"] = probe_arm_counts['west'] / eastwest_total
print(probe_arm_proportions)

fig = plt.figure()
ax = fig.add_subplot(111)

arms = ["north", "south", "east", "west"]
colors = ["#fc4e2a", "#fc4e2a", "#fc8d59", "#fc8d59"]

frequencies = [probe_arm_proportions[arm] for arm in arms]

x_coordinates = np.arange(len(arms))
ax.bar(x_coordinates, frequencies, align='center', color=colors)

plt.xticks(x_coordinates, arms)
plt.ylabel("Choice proportion")

ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.yaxis.set_ticks_position('left')
ax.xaxis.set_ticks_position('bottom')

plt.title("Probe choice by arm (n=2)")

plt.show()
# plt.savefig(os.path.join(output_filepath, "probe_arm_proportions.png"))
# plt.close()