# MED-PC Data Processing Notebook

## Importing the Python Libraries

In [1]:
import sys
import glob
from collections import defaultdict
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [2]:
from medpc2excel.medpc_read import medpc_read

In [3]:
# setting path
sys.path.append('../../src')

In [4]:
import extract.dataframe
import processing.tone
import extract.metadata

In [5]:
# Increase size of plot in jupyter

plt.rcParams["figure.figsize"] = (10,6)

## Getting the Metadata from all the files

In [6]:
all_med_pc_file = glob.glob("./data/timestamp_dataframes/*.txt")

In [7]:
all_med_pc_file[:10]

['./data/timestamp_dataframes/2022-05-06_12h59m_Subject 3.4 (2).txt',
 './data/timestamp_dataframes/2022-05-06_08h37m_Subject 2.3.txt',
 './data/timestamp_dataframes/2022-05-10_14h40m_Subject 4.3 (3).txt',
 './data/timestamp_dataframes/2022-05-06_12h59m_Subject 4.3 (3).txt',
 './data/timestamp_dataframes/2022-05-04_08h43m_Subject 2.3.txt',
 './data/timestamp_dataframes/2022-05-03_12h52m_Subject 2.1.txt',
 './data/timestamp_dataframes/2022-05-04_10h11m_Subject 1.2.txt',
 './data/timestamp_dataframes/2022-05-06_08h37m_Subject 1.1.txt',
 './data/timestamp_dataframes/2022-05-03_13h19m_Subject 1.2.txt',
 './data/timestamp_dataframes/2022-05-03_12h52m_Subject 2.4.txt']

In [8]:
file_path_to_meta_data = extract.metadata.get_all_med_pc_meta_data_from_files(list_of_files=all_med_pc_file)

## Making a Dataframe out of the Metadata

In [9]:
metadata_df = pd.DataFrame.from_dict(file_path_to_meta_data, orient="index")
metadata_df = metadata_df.reset_index()

In [10]:
metadata_df.head()

Unnamed: 0,index,File,Start Date,End Date,Subject,Experiment,Group,Box,Start Time,End Time,MSN
0,./data/timestamp_dataframes/2022-05-06_12h59m_...,C:\MED-PC\Data\2022-05-06_12h59m_Subject 3.4 (...,05/06/22,05/06/22,3.4 (2),Pilot of Pilot,Cage 4,1,12:59:58,14:02:38,levelNP_CS_reward_laserepochON1st_noshock
1,./data/timestamp_dataframes/2022-05-06_08h37m_...,C:\MED-PC\Data\2022-05-06_08h37m_Subject 2.3.txt,05/06/22,05/06/22,2.3,Pilot of Pilot,Cage 1,1,08:37:09,09:53:25,levelNP_CS_reward_laserepochON1st_noshock
2,./data/timestamp_dataframes/2022-05-10_14h40m_...,C:\MED-PC\Data\2022-05-10_14h40m_Subject 4.3 (...,05/10/22,05/10/22,4.3 (3),Pilot of Pilot,Cage 4,2,14:40:24,15:43:18,levelNP_CS_reward_laserepochON1st_noshock
3,./data/timestamp_dataframes/2022-05-06_12h59m_...,C:\MED-PC\Data\2022-05-06_12h59m_Subject 4.3 (...,05/06/22,05/06/22,4.3 (3),Pilot of Pilot,Cage 4,2,12:59:58,14:02:38,levelNP_CS_reward_laserepochON1st_noshock
4,./data/timestamp_dataframes/2022-05-04_08h43m_...,C:\MED-PC\Data\2022-05-04_08h43m_Subject 2.3.txt,05/04/22,05/04/22,2.3,Pilot of Pilot,Cage 1,3,08:43:11,09:54:22,levelNP_CS_reward_laserepochON1st_noshock


In [11]:
metadata_df.groupby("Subject").count()

Unnamed: 0_level_0,index,File,Start Date,End Date,Experiment,Group,Box,Start Time,End Time,MSN
Subject,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
1.1,9,9,9,9,9,9,9,9,9,9
1.2,9,9,9,9,9,9,9,9,9,9
1.3,9,9,9,9,9,9,9,9,9,9
1.4,9,9,9,9,9,9,9,9,9,9
2.1,10,10,10,10,10,10,10,10,10,10
2.2,9,9,9,9,9,9,9,9,9,9
2.3,9,9,9,9,9,9,9,9,9,9
2.4,9,9,9,9,9,9,9,9,9,9
3.1 (1),8,8,8,8,8,8,8,8,8,8
3.2 (2),8,8,8,8,8,8,8,8,8,8


In [12]:
metadata_df["Subject"].unique()

array(['3.4 (2)', '2.3', '4.3 (3)', '2.1', '1.2', '1.1', '2.4', '4.1 (1)',
       '4.4 (4)', '3.3 (4)', '3.2 (2)', '1.4', '1.3', '3.1 (1)', '2.2',
       '4.2 (3)'], dtype=object)

## Inputting all the MED-PC log files

- **Please make sure that the corresponding `.mpc` file (aka the MED-PC script) that was ran to create the log file, is also in the same folder**

In [13]:
concatted_medpc_df = extract.dataframe.get_medpc_dataframe_from_list_of_files(medpc_files=all_med_pc_file)

Traceback (most recent call last):
  File "/home/riwata/Projects/med_pc_repo/results/2022_05_02_log_processing/../../src/extract/dataframe.py", line 71, in get_medpc_dataframe_from_list_of_files
    ts_df, medpc_log = medpc_read(file=file_path, override=True, replace=False)
  File "/home/riwata/Projects/med_pc_repo/bin/conda_environments/env/med_pc_env/lib/python3.9/site-packages/medpc2excel/medpc_read.py", line 114, in medpc_read
    temp += re.split('\s+',d.split(':')[1])
IndexError: list index out of range

Invalid Formatting for file: ./data/timestamp_dataframes/2022-05-03_13h19m_Subject 1.2.txt
Traceback (most recent call last):
  File "/home/riwata/Projects/med_pc_repo/results/2022_05_02_log_processing/../../src/extract/dataframe.py", line 71, in get_medpc_dataframe_from_list_of_files
    ts_df, medpc_log = medpc_read(file=file_path, override=True, replace=False)
  File "/home/riwata/Projects/med_pc_repo/bin/conda_environments/env/med_pc_env/lib/python3.9/site-packages/medpc2exce

In [14]:
concatted_medpc_df

Unnamed: 0,(P)Portentry,(Q)USdelivery,(R)UStime,(W)ITIvalues,(S)CSpresentation,(N)Portexit,(K)CStype,(B)shockintensity,date_key,subject_key,file_path
0,12.34,64.0,399.0,0.0,60.01,12.39,1.0,0.0,20220506,3.4 (2),./data/timestamp_dataframes/2022-05-06_12h59m_...
1,14.60,144.0,399.0,0.0,140.01,14.79,1.0,0.0,20220506,3.4 (2),./data/timestamp_dataframes/2022-05-06_12h59m_...
2,23.95,234.0,399.0,0.0,230.01,24.88,1.0,0.0,20220506,3.4 (2),./data/timestamp_dataframes/2022-05-06_12h59m_...
3,31.83,314.0,399.0,0.0,310.01,31.90,1.0,0.0,20220506,3.4 (2),./data/timestamp_dataframes/2022-05-06_12h59m_...
4,31.99,389.0,399.0,0.0,385.01,32.09,1.0,0.0,20220506,3.4 (2),./data/timestamp_dataframes/2022-05-06_12h59m_...
...,...,...,...,...,...,...,...,...,...,...,...
2536,,,,,,,1.0,,20220508,1.3,./data/timestamp_dataframes/2022-05-08_11h14m_...
2537,,,,,,,1.0,,20220508,1.3,./data/timestamp_dataframes/2022-05-08_11h14m_...
2538,,,,,,,1.0,,20220508,1.3,./data/timestamp_dataframes/2022-05-08_11h14m_...
2539,,,,,,,1.0,,20220508,1.3,./data/timestamp_dataframes/2022-05-08_11h14m_...


## Getting the Latency for Each Test Session

In [15]:
len(1)

TypeError: object of type 'int' has no len()

In [17]:
concatted_first_porty_entry_dataframe = processing.tone.get_concatted_first_porty_entry_dataframe(concatted_medpc_df=concatted_medpc_df)

In [32]:
concatted_first_porty_entry_dataframe.reset_index(drop="True")

Unnamed: 0,current_tone_time,first_port_entry_after_tone,file_path,date_key,subject_key,latency,latency_adjusted,latency_less_than_10_seconds
0,60.01,69.00,./data/timestamp_dataframes/2022-05-06_12h59m_...,20220506,3.4 (2),8.99,8.99,1
1,140.01,148.27,./data/timestamp_dataframes/2022-05-06_12h59m_...,20220506,3.4 (2),8.26,8.26,1
2,230.01,231.91,./data/timestamp_dataframes/2022-05-06_12h59m_...,20220506,3.4 (2),1.90,1.90,1
3,310.01,320.97,./data/timestamp_dataframes/2022-05-06_12h59m_...,20220506,3.4 (2),10.96,10.96,0
4,385.01,394.75,./data/timestamp_dataframes/2022-05-06_12h59m_...,20220506,3.4 (2),9.74,9.74,1
...,...,...,...,...,...,...,...,...
4944,3160.01,3161.89,./data/timestamp_dataframes/2022-05-08_11h14m_...,20220508,1.3,1.88,1.88,1
4945,3255.01,3255.09,./data/timestamp_dataframes/2022-05-08_11h14m_...,20220508,1.3,0.08,0.08,1
4946,3345.01,3351.67,./data/timestamp_dataframes/2022-05-08_11h14m_...,20220508,1.3,6.66,6.66,1
4947,3425.01,3431.15,./data/timestamp_dataframes/2022-05-08_11h14m_...,20220508,1.3,6.14,6.14,1


In [24]:
concatted_first_porty_entry_dataframe["latency"] = concatted_first_porty_entry_dataframe["first_port_entry_after_tone"] - concatted_first_porty_entry_dataframe["current_tone_time"]

In [25]:
concatted_first_porty_entry_dataframe["latency_adjusted"] = concatted_first_porty_entry_dataframe["latency"].apply(lambda x: 30 if x >= 30 else x)

In [26]:
concatted_first_porty_entry_dataframe[concatted_first_porty_entry_dataframe["latency"] >= 25]

Unnamed: 0,current_tone_time,first_port_entry_after_tone,file_path,date_key,subject_key,latency,latency_adjusted
8,750.01,777.83,./data/timestamp_dataframes/2022-05-06_12h59m_...,20220506,3.4 (2),27.82,27.82
12,1150.01,1224.44,./data/timestamp_dataframes/2022-05-06_12h59m_...,20220506,3.4 (2),74.43,30.00
23,2130.01,2192.23,./data/timestamp_dataframes/2022-05-06_12h59m_...,20220506,3.4 (2),62.22,30.00
28,2585.01,2615.82,./data/timestamp_dataframes/2022-05-10_14h40m_...,20220510,4.3 (3),30.81,30.00
10,940.01,982.59,./data/timestamp_dataframes/2022-05-06_12h59m_...,20220506,4.3 (3),42.58,30.00
...,...,...,...,...,...,...,...
23,2130.01,2158.80,./data/timestamp_dataframes/2022-05-09_09h48m_...,20220509,1.1,28.79,28.79
3,310.01,351.06,./data/timestamp_dataframes/2022-05-08_11h14m_...,20220508,1.3,41.05,30.00
5,485.01,517.08,./data/timestamp_dataframes/2022-05-08_11h14m_...,20220508,1.3,32.07,30.00
12,1150.01,1182.36,./data/timestamp_dataframes/2022-05-08_11h14m_...,20220508,1.3,32.35,30.00


In [27]:
concatted_first_porty_entry_dataframe["latency_less_than_10_seconds"] = concatted_first_porty_entry_dataframe["latency"].apply(lambda x: 1 if x <= 10 else 0)

In [30]:
concatted_first_porty_entry_dataframe[concatted_first_porty_entry_dataframe["latency"] >= 30]

Unnamed: 0,current_tone_time,first_port_entry_after_tone,file_path,date_key,subject_key,latency,latency_adjusted,latency_less_than_10_seconds
12,1150.01,1224.44,./data/timestamp_dataframes/2022-05-06_12h59m_...,20220506,3.4 (2),74.43,30.0,0
23,2130.01,2192.23,./data/timestamp_dataframes/2022-05-06_12h59m_...,20220506,3.4 (2),62.22,30.0,0
28,2585.01,2615.82,./data/timestamp_dataframes/2022-05-10_14h40m_...,20220510,4.3 (3),30.81,30.0,0
10,940.01,982.59,./data/timestamp_dataframes/2022-05-06_12h59m_...,20220506,4.3 (3),42.58,30.0,0
32,2985.01,3023.64,./data/timestamp_dataframes/2022-05-06_12h59m_...,20220506,4.3 (3),38.63,30.0,0
...,...,...,...,...,...,...,...,...
35,3255.01,3306.80,./data/timestamp_dataframes/2022-05-07_13h54m_...,20220507,4.3 (3),51.79,30.0,0
3,310.01,351.06,./data/timestamp_dataframes/2022-05-08_11h14m_...,20220508,1.3,41.05,30.0,0
5,485.01,517.08,./data/timestamp_dataframes/2022-05-08_11h14m_...,20220508,1.3,32.07,30.0,0
12,1150.01,1182.36,./data/timestamp_dataframes/2022-05-08_11h14m_...,20220508,1.3,32.35,30.0,0


In [None]:
combined_latency_df.to_csv("./data/latency_dataframes/all_latencies.csv")
combined_latency_df.to_excel("./data/latency_dataframes/all_latencies.xlsx")


In [None]:
metadata_df["Group"]

In [None]:
metadata_df["Group_processed"] = metadata_df["Group"].apply(lambda x: x.strip("Cage").strip())

In [None]:
subject_to_cage_dict = dict(zip(metadata_df["Subject"], metadata_df["Group_processed"]))

In [None]:
subject_to_cage_dict

In [None]:
grouped_averaged_latency_df = combined_latency_df.groupby(["subject_key", "date_key"]).mean()

In [None]:
grouped_averaged_latency_df = grouped_averaged_latency_df.reset_index()

In [None]:
grouped_averaged_latency_df["cage"] = grouped_averaged_latency_df["subject_key"].map(subject_to_cage_dict)

In [None]:
grouped_averaged_latency_df["date_int"] = grouped_averaged_latency_df["date_key"].astype(int)

In [None]:
grouped_averaged_latency_df

In [None]:
for cage in grouped_averaged_latency_df["cage"].unique():
    fig, ax = plt.subplots()

    cage_df = grouped_averaged_latency_df[grouped_averaged_latency_df["cage"] == cage]
    for subject in cage_df["subject_key"].unique():
        subject_df = cage_df[cage_df["subject_key"] == subject]
        
        
        
        ax.plot(subject_df["date_int"] - subject_df["date_int"].min() + 1, subject_df["latency_adjusted"], '-o', label=subject)
    
    ax.set_xlabel("The Days After the First Session")
    ax.set_ylabel("Adjusted Average Latency of First Entry to Tone Onset")
    ax.set_title("Latency of Port Entry to Tone: Cage {}".format(cage))

    ax.set_ylim(0, 30)
    ax.legend()
    plt.savefig("./data/plots/average_latency_plots/average_port_entry_latency_cage_{}_date_20220503_20220510.png".format(cage))
    

In [None]:
for cage in grouped_averaged_latency_df["cage"].unique():
    fig, ax = plt.subplots()

    cage_df = grouped_averaged_latency_df[grouped_averaged_latency_df["cage"] == cage]
    for subject in cage_df["subject_key"].unique():
        subject_df = cage_df[cage_df["subject_key"] == subject]
        
        
        
        ax.plot(subject_df["date_int"] - subject_df["date_int"].min() + 1, subject_df["latency_less_than_10_seconds"], '-o', label=subject)
    
    ax.set_xlabel("The Days After the First Session")
    ax.set_ylabel("Proportion of Latencies")
    ax.set_title("Less Than 10 Seconds Latencies from Tone Onset: Cage {}".format(cage))

    ax.set_ylim(0, 1)
    ax.legend()
    plt.savefig("./data/plots/proportion_of_latencies_less_than_10_seconds/less_than_10_seconds_latency_proportion_cage_{}_date_20220503_20220510.png".format(cage))
    

In [None]:
grouped_averaged_latency_df.merge(metadata_df, left_on="subject_key", right_on='Subject')

In [None]:
metadata_df["Subject"]

In [None]:
latency_pivot_plot = combined_latency_df.pivot_table(
        values='latency_adjusted', 
        index=['subject_key'], 
        columns='date_key', 
        aggfunc=np.mean)

In [None]:
latency_pivot_plot

In [None]:
latency_pivot_plot.to_csv("./data/latency_dataframes/adjusted_latency_pivot_table.csv")
latency_pivot_plot.to_excel("./data/latency_dataframes/adjusted_latency_pivot_table.xlsx")

In [None]:
latency_pivot_plot.plot(y=["20220503", "20220504", "20220505", "20220506"], kind="bar", ylabel="Latency", xlabel="Subjects", title="Adjusted Latency to Port from Tone Onset")

In [None]:
less_than_10_latency_pivot_plot = combined_latency_df.pivot_table(
        values='latency_less_than_10_seconds', 
        index=['subject_key'], 
        columns='date_key', 
        aggfunc=np.mean)

In [None]:
less_than_10_latency_pivot_plot

In [None]:
less_than_10_latency_pivot_plot.to_csv("./data/latency_dataframes/less_than_10_latency_pivot_table.csv")
less_than_10_latency_pivot_plot.to_excel("./data/latency_dataframes/less_than_10_latency_pivot_table.xlsx")

In [None]:
less_than_10_latency_pivot_plot.plot(y=["20220503", "20220504", "20220505", "20220506", "20220507", "20220508", "20220510"], kind="bar", ylabel="Ratio of Latencies less than 10 Seconds", xlabel="Subjects", title="Ratio of Latencies Less than 10 Seconds Over Time")

# Combining the Plots

In [None]:
latency_and_metadata_df = latency_pivot_plot.join(other=metadata_df.set_index("Subject"))

In [None]:
latency_and_metadata_df["Group"].unique()

In [None]:
latency_and_metadata_df["Group_processed"] = latency_and_metadata_df["Group"].apply(lambda x: x.strip("Cage").strip())

In [None]:
latency_and_metadata_df

In [None]:
for group in latency_and_metadata_df["Group_processed"].unique():
    print(latency_and_metadata_df[latency_and_metadata_df["Group_processed"] == group])
    break

In [None]:
latency_pivot_plot_resetted_index = latency_pivot_plot.reset_index()

In [None]:
metadata_df

# Getting the port entry precision

## 1. Get all the numbers that are within the duration

### 1.1 Processing the Dataframe to remove all rows with NaNs

In [None]:
example_med_pc_df = file_path_to_med_pc_data["./data/timestamp_dataframes/2022-05-03_14h49m_Subject 3.2 (2).txt"]["med_pc_df"]
example_med_pc_df = example_med_pc_df.dropna(subset=("(P)Portentry", "(N)Portexit"))

In [None]:
example_med_pc_df

### 1.2 Making All the Times Into Whole Numbers

In [None]:
def scale_time_to_whole_number(time, multiplier=100):
    """
    Function used to convert times that are floats into whole numbers by scaling it. i.e. from 71.36 to 7136
    This is used with pandas.DataFrame.apply/pandas.Series.apply to convert a column of float times to integer times.

    Args:
        time: float
            - The time in seconds that something is happening
    Returns: 
        int:
            - Converted whole number time
    """
    try:
        if np.isnan(time):
            return 0
        else:
            return int(time * multiplier)
    except:
        return 0

In [None]:
example_med_pc_df["port_entry_scaled"] = example_med_pc_df["(P)Portentry"].apply(lambda x: scale_time_to_whole_number(x))
example_med_pc_df["port_exit_scaled"] = example_med_pc_df["(N)Portexit"].apply(lambda x: scale_time_to_whole_number(x))
example_med_pc_df["tone_start_scaled"] = example_med_pc_df["(S)CSpresentation"].apply(lambda x: scale_time_to_whole_number(x))

In [None]:
example_med_pc_df.head(n=25)

In [None]:
def get_all_port_entry_increments(port_entry_scaled, port_exit_scaled):
    """
    Gets all the numbers that are in the duration of the port entry and port exit times. 
    i.e. If the port entry was 7136 and port exit was 7142, we'd get [7136, 7137, 7138, 7139, 7140, 7141, 7142]
    This is done for all port entry and port exit times pairs between two Pandas Series

    Args:
        port_entry_scaled: Pandas Series
            - A column from a MED-PC Dataframe that has all the port entry times scaled
            (usually with the scale_time_to_whole_number function)
        port_exit_scaled: Pandas Series
            - A column from a MED-PC Dataframe that has all the port exit times scaled
            (usually with the scale_time_to_whole_number function)
    Returns: 
        Numpy array:
            - 1D Numpy Array of all the numbers that are in the duration of all the port entry and port exit times
    """
    all_port_entry_ranges = [np.arange(port_entry, port_exit+1) for port_entry, port_exit in zip(port_entry_scaled, port_exit_scaled)]
    return np.concatenate(all_port_entry_ranges)

In [None]:
example_port_entry_times = get_all_port_entry_increments(port_entry_scaled=example_med_pc_df["port_entry_scaled"], port_exit_scaled=example_med_pc_df["port_exit_scaled"])

In [None]:
example_port_entry_times[:10]

## 2. Make a set and see which numbers are in that set

### 2.1 Getting all the numbers from 0 to the time of the last tone plus 2000(or 20 seconds)

In [None]:
example_valid_tone_times = processing.tone.get_valid_tones(tone_pd_series=example_med_pc_df["tone_start_scaled"]).astype(int)

In [None]:
# Using the last tone and adding 2000(or 20 seconds to it)
example_experiment_interval = np.arange(example_valid_tone_times.max() + 2001)

In [None]:
example_experiment_interval

### 2.2 Getting a mask of all the numbers that are within a port entry and port exit time

In [None]:
example_port_entry_mask = np.isin(example_experiment_interval, example_port_entry_times)

In [None]:
example_port_entry_mask

In [None]:
example_experiment_interval[np.isin(example_experiment_interval, example_port_entry_times)]

### 2.3 Or just using a function to do all of this for us

In [None]:
def get_inside_port_mask(max_time, inside_port_numbers):
    """
    Gets a mask of all the times that the subject is inside the port. 
    First a range of number from 1 to the number for the max time is created.
    Then, a mask is created by seeing which numbers are within the inside port duration

    Args:
        max_time: int
            - The number that represents the largest number for the time. 
                - Usually this will be the number for the last tone played.  
            - We recommend adding 2001 if you are just using the number for the last tone played
                - This is because we are looking 20 seconds before and after. 
                - And 20 seconds becomes 2000 when scaled with our method.
        inside_port_numbers: Numpy Array
            - All the increments of of the duration that the subject is within the port
    Returns: 
        session_time_increments: Numpy Array
            - Range of number from 1 to max time 
        inside_port_mask: Numpy Array
            - The mask of True or False if the subject is in the port during the time of that index
    """
    session_time_increments = np.arange(1, max_time+1)
    inside_port_mask = np.isin(session_time_increments, inside_port_numbers)
    return session_time_increments, inside_port_mask

In [None]:
max_time = example_valid_tone_times.max() + 2001

In [None]:
example_experiment_interval, example_port_entry_mask = get_inside_port_mask(max_time=max_time, inside_port_numbers=example_port_entry_times)

In [None]:
example_port_entry_mask

In [None]:
example_experiment_interval[example_port_entry_mask]

## 3. Find the Overlap between the Tone Times and the Port Entries

### 3.1 Calculating the probability that the subject is in the port for each time increment between sessions

In [None]:
tone_time_to_mask = defaultdict(dict)
example_all_tone_time_masks = []
for index, tone_start in example_valid_tone_times.iteritems():
    tone_start_int = int(tone_start)
#     print(tone_start_int)
#     print(example_port_entry_mask[tone_start_int - 2000: tone_start_int + 2000])  
    example_all_tone_time_masks.append(example_port_entry_mask[tone_start_int - 2000: tone_start_int + 2000])
    tone_time_to_mask[tone_start_int] = example_port_entry_mask[tone_start_int - 2000: tone_start_int + 2000]
np.stack(example_all_tone_time_masks)

In [None]:
tone_time_to_mask

In [None]:
stacked_example_all_tone_time_masks = np.stack(example_all_tone_time_masks)

In [None]:
mean_example_all_tone_time_masks = stacked_example_all_tone_time_masks.mean(axis=0)

In [None]:
mean_example_all_tone_time_masks

### 3.2 Doing it with a function

In [None]:
def get_inside_port_probability_averages_for_all_increments(tone_times, inside_port_mask, before_tone_duration=2000, after_tone_duration=2000):
    """
    Calculates the average probability that a subject is in the port between sessions. 
    This is calculated by seeing the ratio that a subject is in the port at a given time increment 
    that's the same time difference to the tone with all the other sessions. 
    i.e. The time increment of 10.01 seconds after the tone for all sessions.
    
    Args:
        tone_times: list or Pandas Series
            - An array of the times that the tone has played
        inside_port_mask: Numpy Array
            - The mask where the subject is in the port based on the index being the time increment
        before_tone_duration: int
            - The number of increments before the tone to be analyzed
        after_tone_duration: int
            - The number of increments after the tone to be analyzed
    Returns: 
        Numpy Array
            - The averages of the probabilities that the subject is inside the port for all increments
    """
    result = []
    for tone_start in tone_times:
        tone_start_int = int(tone_start)
        result.append(inside_port_mask[tone_start_int - before_tone_duration: tone_start_int + after_tone_duration])
    return np.stack(result).mean(axis=0)

In [None]:
get_inside_port_probability_averages_for_all_increments(tone_times=example_valid_tone_times, inside_port_mask=example_port_entry_mask)

### 3.3 Plotting all the probailities

In [None]:
plt.plot(np.linspace(-20,20,4000), mean_example_all_tone_time_masks)
plt.xlabel("Seconds from the start of the tone")
plt.ylabel("Probability Inside Port")
plt.title("Probability Inside Port for 10ms Increments 20 Seconds Before and After Tone")

# 4. Plotting for Multiple Training Sessions

### 4.1 Combining Dataframes

In [None]:
file_path_to_med_pc_data_for_probability = defaultdict(dict)

for key, value in file_path_to_med_pc_data.items():
    
    valid_tones = processing.tone.get_valid_tones(tone_pd_series= value["med_pc_df"]["(S)CSpresentation"])
    if not valid_tones.empty:
        file_path_to_med_pc_data_for_probability[key]["med_pc_df"] = value["med_pc_df"]

        file_path_to_med_pc_data_for_probability[key]["med_pc_df"]["date_key"] = value["date_key"]
        file_path_to_med_pc_data_for_probability[key]["med_pc_df"]["subject_key"] = value["subject_key"]
        file_path_to_med_pc_data_for_probability[key]["med_pc_df"]["file_path"] = key 
    else:
        print("Skipped {}".format(key))

In [None]:
all_med_pc_df = []
for key, value in file_path_to_med_pc_data_for_probability.items():
    all_med_pc_df.append(value["med_pc_df"])

In [None]:
combined_med_pc_df = pd.concat(all_med_pc_df)

In [None]:
combined_med_pc_df

### 4.2 Get the port probability for one mice

In [None]:
combined_med_pc_df["subject_key"].unique()

In [None]:
example_one_subject_all_days = combined_med_pc_df[combined_med_pc_df["subject_key"] == "4.1 (1)"]

In [None]:
example_one_subject_all_days

In [None]:
subject_to_date_to_average_probability = defaultdict(dict)
for subject in combined_med_pc_df["subject_key"].unique():
#     subject_to_date_to_average_probability["subject"] = subject
    one_subject_all_days = combined_med_pc_df[combined_med_pc_df["subject_key"] == subject]


    for date in one_subject_all_days["date_key"].unique():
        
        one_day_df = one_subject_all_days[one_subject_all_days["date_key"] == date].copy()
        ### Scaling all the dataframes
        one_day_df["port_entry_scaled"] = one_day_df["(P)Portentry"].apply(lambda x: scale_time_to_whole_number(x))
        one_day_df["port_exit_scaled"] = one_day_df["(N)Portexit"].apply(lambda x: scale_time_to_whole_number(x))
        one_day_df["tone_start_scaled"] = one_day_df["(S)CSpresentation"].apply(lambda x: scale_time_to_whole_number(x))
        ### All the numbers of times inside port
        one_day_entry_times = get_all_port_entry_increments(port_entry_scaled=one_day_df["port_entry_scaled"], port_exit_scaled=one_day_df["port_exit_scaled"])
        ### Getting all the valid tone times and the max tone time
        one_day_valid_tone_times = processing.tone.get_valid_tones(tone_pd_series=one_day_df["tone_start_scaled"]).astype(int)
        one_day_max_time = one_day_valid_tone_times.max() + 2001
        ### Getting a mask of all the times in the port
        one_day_experiment_interval, one_day_entry_mask = get_inside_port_mask(max_time=one_day_max_time, inside_port_numbers=one_day_entry_times)
        ### Getting the average probility
        one_day_average_all_tone_time_masks = get_inside_port_probability_averages_for_all_increments(tone_times=one_day_valid_tone_times, inside_port_mask=one_day_entry_mask)
        subject_to_date_to_average_probability[subject][date] = one_day_average_all_tone_time_masks
        ### Plotting
#         plt.plot(np.linspace(-20,20,4000), one_day_average_all_tone_time_masks)
#         plt.xlabel("Seconds from the start of the tone")
#         plt.ylabel("Probability Inside Port")
#         plt.title("Probability Inside Port for 10ms Increments 20 Seconds Before and After Tone")
#         break
#     break

In [None]:
subject_to_date_to_average_probability["3.4 (2)"]

In [None]:
subject_to_date_to_average_probability["3.3 (4)"]

In [None]:
"b", "g", "r", "c", "m", "y", "k"

In [None]:
subject_to_cage_dict = dict(zip(metadata_df["Subject"], metadata_df["Group"]))


In [None]:
subject_to_cage_dict

In [None]:
combined_med_pc_df

In [None]:

all_colors = ["k", "b", "m", "c","g", "y", "darkorange", "r"]
for subject_id in combined_med_pc_df["subject_key"].unique():
    fig, ax = plt.subplots()
    cage = subject_to_cage_dict[subject_id].strip("Cage").strip()
    ax.set_xlabel("Seconds from the start of the tone")
    ax.set_ylabel("Probability Inside Port")
    ax.set_title("Probability Inside Port Before/After Tone for Subject: {} in Cage {}".format(subject_id, cage))
    counter = 0

    for key in sorted(subject_to_date_to_average_probability[subject_id].keys()):
        ax.plot(np.linspace(-20,20,4000), subject_to_date_to_average_probability[subject_id][key], label=key, color=all_colors[counter])
        counter += 1
    handles, labels = ax.get_legend_handles_labels()
    labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: t[0]))
    ax.set_ylim(0, 1)
    ax.legend(handles, labels)
    fig.savefig("./data/plots/probability_inside_port/probability_inside_port_cage_{}_subject_{}_date_20220503_20220510.png".format(cage, subject_id))

# TODO: 4. Licking Specifity Average

In [None]:
subject_to_date_to_licking_specificty = defaultdict(lambda: defaultdict(dict))
# subject_to_date_to_licking_specificty = defaultdict(dict)
for subject, all_dates in subject_to_date_to_average_probability.items():
    for date, averages in all_dates.items():
#         subject_to_date_to_licking_specificty[subject][date] = averages[0:2000].mean()
        subject_to_date_to_licking_specificty[subject][date]["-20_to_0"] = averages[0:2000].mean()
        subject_to_date_to_licking_specificty[subject][date]["0_to_10"] = averages[2000:3000].mean()

In [None]:
licking_specifity_df = pd.DataFrame.from_dict({(i,j): subject_to_date_to_licking_specificty[i][j] 
                           for i in subject_to_date_to_licking_specificty.keys() 
                           for j in subject_to_date_to_licking_specificty[i].keys()},
                       orient='index')


In [None]:
licking_specifity_df = licking_specifity_df.reset_index()
licking_specifity_df = licking_specifity_df.rename(columns={"level_0": "subject", "level_1": "date"})

In [None]:
licking_specifity_df["date_int"] = licking_specifity_df["date"].astype(int)
licking_specifity_df["cage"] = licking_specifity_df["subject"].map(subject_to_cage_dict)

In [None]:
licking_specifity_df

In [None]:
licking_specifity_df = licking_specifity_df.sort_values(by=["subject", "date_int"])

In [None]:
all_colors = ["b", "g", "y", "r"]

for cage in licking_specifity_df["cage"].unique():
    fig, ax = plt.subplots()

    cage_df = licking_specifity_df[licking_specifity_df["cage"] == cage]
    counter = 0
    for subject in cage_df["subject"].unique():
        subject_df = cage_df[cage_df["subject"] == subject]
        
        
        ax.plot(subject_df["date_int"] - subject_df["date_int"].min() + 1, subject_df["-20_to_0"], '--', color=all_colors[counter], label="{} at -20s to 0s".format(subject))
        ax.plot(subject_df["date_int"] - subject_df["date_int"].min() + 1, subject_df["0_to_10"], '-', color=all_colors[counter], label="{} at 0s to 10s".format(subject))
#         break
        counter += 1
        
    ax.set_xlabel("The Days After the First Session")
    ax.set_ylabel("Average Licking Specificty Probaility")
    ax.set_title("Licking Specifity Probability Before and After Tone Onset: Cage {}".format(cage))

    ax.set_ylim(0, 1)
    ax.legend()
#     break
    plt.savefig("./data/plots/licking_specifity/licking_specifity_cage_{}_date_20220503_20220510.png".format(cage))
    

average(probability(t=0-10)) -average(probability(t=-20:0))