In [247]:
import pandas as pd
import numpy as np

#Input file locations
lr_traces = '/Users/rufusmitchell-heggs/Desktop/H0466_ALL_corrected.csv'
timestamps = '/Users/rufusmitchell-heggs/Desktop/timestamps.csv'

def lr_data_correction(traces_or_events, timestamps):
    """Labels and corrects timings for longitudinally registered csv file of multiple sessions/stages
    
    INPUT
    -----
    traces_or_events = .csv file location for longitudinally registered events or traces
    timestamps = .csv file for timestamps of manually identified stage start and endings
    
    timestamps table format:
    ---------------------------
    |session|pre  |sam  | cho |
    ---------------------------
    | N01   |12701|21496|30611|
    ---------------------------
    
    OUPUT:
    -----
    corrected_data = A datafame containing labelled sessions and stages with corrected timings
    """
    #Read in lr_trace file location and make minor corrections
    lr_traces = pd.read_csv(lr_traces)
    lr_traces = lr_traces.drop(lr_traces.index[0])
    lr_traces = lr_traces.reset_index(drop=True)
    lr_traces = lr_traces.rename(columns={" ": "Time"})

    #Read in timestamp info
    timestamps = pd.read_csv(timestamps)
    sessions = list(timestamps['session'])
    
    #Identify start and end frames for all sessions
    all_data = list(traces_or_events['Time'].astype(float))
    session_starts = [0]
    session_ends = []
    for i in range(len(all_data)):
        if i + 1 < len(all_data):
            if abs(all_data[i+1] - all_data[i]) > 1 :
                session_starts.append(i+1)
                session_ends.append(i)
    session_ends.append(len(traces_or_events))
        
    # Save each session and each stage as a list 
    indiv_sessions = []
    for sesh, start, end in zip(sessions, session_starts, session_ends):
        
        #Isolate individual sessions
        indiv_session = traces_or_events[start:end]
        indiv_session = indiv_session.reset_index(drop=True)

        #isolate indiviudal stages within a session
        pre, sam, cho = np.array(timestamps[timestamps['session'].str.contains(sesh)])[0][1:].astype(int)
        stages = list(('PRE',) * pre) + list(('SAM',) * (sam-pre)) + list(('CHO',) * (cho-sam))    
        
        #Correct timings and add column showing stage
        stage_timings = [np.arange(0, pre, 1)*0.05006,np.arange(0, (sam-pre), 1)*0.05006,np.arange(0, (cho-sam), 1)*0.05006]
        stage_timings = [item for sublist in stage_timings for item in sublist]
        indiv_session.insert(loc=0, column='stage', value=stages)
        indiv_session['Time'] = stage_timings
        
        # Insert column showing session
        indiv_session.insert(loc=0, column='Session', value=list((sesh,) * len(indiv_session)))
        
        indiv_sessions.append(indiv_session)
    
    #Concatenate all sessions into single table
    corrected_data = pd.concat(indiv_sessions)
    return corrected_data

lr_data_traces = lr_data_correction(lr_traces, timestamps)


  interactivity=interactivity, compiler=compiler, result=result)
