In [1]:
from contextlib import contextmanager
from matplotlib.patches import Ellipse, Circle
from sklearn.linear_model import LinearRegression
from scipy.stats import gaussian_kde
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import torch
from scipy.signal import savgol_filter
import neo
import os
import astropy.convolution.convolve

In [2]:
def cart2pol(x, y):
    rho = np.sqrt(x**2 + y**2)
    phi = np.arctan2(y, x)
    return(rho, phi)

def pol2cart(rho, phi):
    x = rho * np.cos(phi)
    y = rho * np.sin(phi)
    return(x, y)

def get_relative_ang(px, py, heading_angle,target_x,target_y):
    difference_vector = np.array([target_x,target_y]).reshape([2,-1]) - np.array([px,py]).reshape([2,-1])
    relative_ang = heading_angle - np.arctan2(difference_vector[1],difference_vector[0]) # [y,x]
    # make angles in [-pi,pi]
    relative_ang = np.remainder(relative_ang, 2*np.pi)
    for i, ang_value in enumerate(relative_ang) :
        if ang_value < np.pi:
            relative_ang[i] = ang_value
        else:
            relative_ang[i] = ang_value - 2*np.pi
             
    return relative_ang

def plt_config(title=None, xlim=None, ylim=None, xlabel=None, ylabel=None, colorbar=False, sci=False):
    for field in ['title', 'xlim', 'ylim', 'xlabel', 'ylabel']:
        if eval(field) != None: getattr(plt, field)(eval(field))
    if isinstance(sci, str): plt.ticklabel_format(style='sci', axis=sci, scilimits=(0,0))
    if isinstance(colorbar,str): plt.colorbar(label=colorbar)
    elif colorbar: plt.colorbar(label = '$Number\ of\ Entries$')

@contextmanager
def initiate_plot(dimx=24, dimy=9, dpi=100):
    plt.rcParams['figure.figsize'] = (dimx, dimy)
    global fig; fig = plt.figure(dpi=dpi)
    yield
    plt.show()

# For new monkey data

## Marker
4 is juice,
1 is file start,
2 is trial start,
3 is trial end,
8 is perturbation start.

In [3]:
class smr_extractor(object):   
    def __init__(self,folder_path, task='gain'):
        self.folder_path = os.path.join(folder_path,task)
        self.full_path_file_names = [os.path.join(self.folder_path, file) for file in os.listdir(self.folder_path) 
                                   if 'smr' in file]   
        self.full_path_file_names.sort()
        
    def extract_data(self):
        Channel_signal_output = []
        marker_list = []
        
        for index, file_name in enumerate(self.full_path_file_names):  # loop 2 files one by one
            seg_reader = neo.io.Spike2IO(filename=file_name).read_segment() # read file
            
            if index == 0: # get sampling rate, only need to get it once
                smr_sampling_rate = seg_reader.analogsignals[0].sampling_rate 
                
            analog_length = min([i.size for i in seg_reader.analogsignals]) # in case analog channels have different shape
            
            Channel_signal = np.ones((analog_length,seg_reader.size['analogsignals']+1)) # create a matrix to store analog data
            Channel_index = [] # create an empty list to store channel names
            for C_index, C_data in enumerate(seg_reader.analogsignals):
                Channel_signal[:,C_index] = C_data.as_array()[:analog_length].T # get channel data one by one and put in Channel_signal
                Channel_index.append(seg_reader.analogsignals[C_index].annotations['channel_names'][0]) # get channel name one by one and put in Channel_index
            
            Channel_signal[:,-1] = seg_reader.analogsignals[0].times[:analog_length] # get time stamps and put in Channel_signal
            Channel_index.append('Time') 
            
            Channel_signal_output.append(pd.DataFrame(Channel_signal,columns=Channel_index))
            
            marker_channel_index = [index for index,value in enumerate(seg_reader.events) if value.name == 'marker'][0] #find 'marker' channel
            marker_labels = seg_reader.events[marker_channel_index].get_labels().astype('int') # get 'marker' labels
            marker_values = seg_reader.events[marker_channel_index].as_array() # get 'marker' values
            marker = {'labels': marker_labels, 'values': marker_values} # arrange labels and values in a dict
            marker_list.append(marker)
            
        return Channel_signal_output, marker_list, smr_sampling_rate

In [4]:
class log_extractor(object):
    def __init__(self,files,marker_list):
        
        self.full_path_file_names = [file.replace('smr','log') for file in files]
        
        self.num_of_trials = [(value['labels']==3).sum() for value in marker_list] # usually need to remove the last trial
        
        self.gain_V = []
        self.gain_W = []
        self.isFullOn = []
        self.isRewarded = []
        self.FFX = []
        self.FFY = []
        
    def extract_data(self):
        for index, file_name in enumerate(self.full_path_file_names):
            trial_count = 0
            with open(file_name,'r',encoding='UTF-8') as content:
                log_content = content.readlines()
            for line_number, line in enumerate(log_content):
                if 'Joy Stick Max Velocity' in line:
                    gain_V = float(line.split(': ')[1])
                if 'Joy Stick Max Angular Velocity' in line:
                    gain_W = float(line.split(': ')[1])
        
                    
                if 'Firefly Full On' in line:
                    trial_count += 1
                    if trial_count > self.num_of_trials[index]:
                        continue
                    
                    content_temp = float(line.split(': ')[1])
                    self.isFullOn = np.hstack([self.isFullOn,content_temp])
                    self.gain_V = np.hstack([self.gain_V,gain_V])
                    self.gain_W = np.hstack([self.gain_W,gain_W])
                
                if 'Reward Duration' in line:
                    if trial_count > self.num_of_trials[index]:
                        continue
                    
                    content_temp = float(line.split(': ')[1])
                    if content_temp == 0:  
                        self.isRewarded = np.hstack([self.isRewarded,0])
                    else:
                        self.isRewarded = np.hstack([self.isRewarded,1])
                
                if 'Position x/y(cm)' in line:
                    if trial_count > self.num_of_trials[index]:
                        continue
                    
                    content_temp_x = float(line.split(': ')[1].split(' ')[0])
                    content_temp_y = float(line.split(': ')[1].split(' ')[1])
                    self.FFX = np.hstack([self.FFX,content_temp_x])
                    self.FFY = np.hstack([self.FFY,-content_temp_y])
        return self.gain_V, self.gain_W, self.isFullOn, self.isRewarded, self.FFX, self.FFY

In [None]:
# process smr
def segment_trials(Channel_signal_output,marker_list,isFullOn,isRewarded,FFX,FFY,gain_V,gain_W):
    COLUMNS = ['ep','gain_V','gain_W','px', 'py', 'p_heading', 'real_v','real_w','time',
              'FFX','FFY','real_relative_radius','real_relative_angle',
               'isFullOn','isRewarded','action_v','action_w']
    monkey_trajectory = pd.DataFrame(columns=COLUMNS)
    trial_counter = 0

    for session_index, session_data in enumerate(Channel_signal_output):
        # remove head
        head_value = marker_list[session_index]['values'][marker_list[session_index]['labels']==2][0]
        session_data = session_data[session_data.Time >= head_value]
        # remove tail
        tail_value = marker_list[session_index]['values'][marker_list[session_index]['labels']==3][-1]
        session_data = session_data[session_data.Time <= tail_value]
        # segment trials
        trial_index_start = \
        np.digitize(session_data.Time,marker_list[session_index]['values'][marker_list[session_index]['labels']==2])
        trial_index_stop = \
        np.digitize(session_data.Time,marker_list[session_index]['values'][marker_list[session_index]['labels']==3])

        for trial_number in np.unique(trial_index_start):  
            trial_index_temp = np.subtract((trial_index_start==trial_number),
                                             (trial_index_start==trial_number)&(trial_index_stop==trial_number),
                                                        dtype=np.float).astype(bool)
            session_data_temp = session_data[trial_index_temp]

             # skip skipping trials
            if session_data_temp['AngularV'].values.size < 301:
                trial_counter += 1
                continue

            monkey_real_w_temp = savgol_filter(-session_data_temp['AngularV'].values,301,1)
            monkey_real_v_temp = savgol_filter(session_data_temp['ForwardV'].values,301,1)

            if len(np.where(monkey_real_v_temp>10)[0])==0 or \
                  (session_data_temp['Time'].values[-1]-session_data_temp['Time'].values[0])>3: # skip not move trials
                trial_counter += 1
                continue
            else:
                monkey_px_temp = savgol_filter(session_data_temp['MonkeyX'].values,301,1)
                monkey_py_temp = savgol_filter(-session_data_temp['MonkeyY'].values,301,1)
                monkey_p_heading_temp = savgol_filter(-session_data_temp['MonkeyYa'].values,301,1) + 90

                # 2. set trial start time as 0
                monkey_trial_T_temp = session_data_temp['Time'].values - session_data_temp['Time'].values[0]

                # 3. calculate start and stop action index
                # 3.1 the start index of v
                move_T_idx_temp = np.nanmin([[np.where(abs(monkey_real_v_temp)>0.5)[0][0] if 
                                len(np.where(abs(monkey_real_v_temp)>0.5)[0])>0 else np.nan],
                                        [np.where(abs(monkey_real_w_temp)>0.5)[0][0] if 
                                len(np.where(abs(monkey_real_w_temp)>0.5)[0])>0 else np.nan]])
                start_v_idx = int(move_T_idx_temp+5)
                
                # 3.2 the start index of x, make sure trial start from origin (0,-32.5)
                start_x_idx = np.linalg.norm(np.vstack([monkey_px_temp[:400],
                                          monkey_py_temp[:400]+32.5]),axis=0).argmin()
                
                # 3.3 get start index
                start_idx = max(start_v_idx,start_x_idx+1)
                
                # 3.4 get stop index when the monkey intended to stop
                if len(np.where(monkey_real_v_temp[::-1]>1)[0])>0:
                    stop_idx = int(monkey_real_v_temp.size - np.where(monkey_real_v_temp[::-1]>0.5)[0][0]) - 1
                else:
                    stop_idx = int(monkey_real_v_temp.size) - 1

                if stop_idx <= start_idx:
                    trial_counter += 1
                    continue
                    
                    
                # 7. add log file stuff
                FFX_temp = FFX[trial_counter]
                FFY_temp = FFY[trial_counter]
                isFullOn_temp = isFullOn[trial_counter]
                isRewarded_temp = isRewarded[trial_counter]
                gain_V_temp = gain_V[trial_counter]
                gain_W_temp = gain_W[trial_counter]

                # 8. calculate monkey to target radius
                relative_radius_temp = np.linalg.norm(np.vstack([monkey_px_temp,monkey_py_temp]).reshape([2,-1]) - 
                       np.vstack([FFX_temp,FFY_temp]).reshape([2,-1]),axis=0)

                # 9. calculate monkey to target angle
                relative_angle_temp = \
                get_relative_ang(monkey_px_temp,monkey_py_temp,np.radians(monkey_p_heading_temp),FFX_temp,FFY_temp)                

                # 10. concat data
                data = np.array([[trial_counter,gain_V_temp,gain_W_temp,
                        monkey_px_temp[start_idx:stop_idx],monkey_py_temp[start_idx:stop_idx], 
                                  monkey_p_heading_temp[start_idx:stop_idx],monkey_real_v_temp[start_idx:stop_idx], 
                                  monkey_real_w_temp[start_idx:stop_idx],
                                  (monkey_trial_T_temp[start_idx:stop_idx]-monkey_trial_T_temp[start_idx:stop_idx][0]),
                                 FFX_temp,FFY_temp,relative_radius_temp[start_idx:stop_idx],
                                np.degrees(relative_angle_temp)[start_idx:stop_idx],
                                 isFullOn_temp,isRewarded_temp,
                                 monkey_real_v_temp[start_idx:stop_idx]/gain_V_temp, 
                                  monkey_real_w_temp[start_idx:stop_idx]/gain_W_temp]])
                monkey_trajectory = monkey_trajectory.append(pd.DataFrame(data,columns=COLUMNS))

                trial_counter += 1

    monkey_trajectory.reset_index(drop=True,inplace=True)
    monkey_trajectory['ep'] = monkey_trajectory.index.values
    
    return monkey_trajectory

# For old monkey data

## Load smr

In [3]:
class old_smr_extractor(object):   
    def __init__(self,folder_path, task='gain'):
        self.folder_path = os.path.join(folder_path,task)
        self.full_path_file_names = [os.path.join(self.folder_path, file) for file in os.listdir(self.folder_path) 
                                   if 'smr' in file]   
        self.full_path_file_names.sort()
        
    def extract_data(self):
        Channel_signal_output = []
        marker_list = []
        
        for index, file_name in enumerate(self.full_path_file_names):  # loop 2 files one by one
            seg_reader = neo.io.Spike2IO(filename=file_name).read_segment() # read file
            
            if index == 0: # get sampling rate, only need to get it once
                smr_sampling_rate = seg_reader.analogsignals[0].sampling_rate 
                
            analog_length = min([i.size for i in seg_reader.analogsignals]) # in case analog channels have different shape
            
            Channel_signal = np.ones((analog_length,seg_reader.size['analogsignals']-2+1)) # create a matrix to store analog data
            Channel_index = [] # create an empty list to store channel names
            for C_index, C_data in enumerate(seg_reader.analogsignals[:-2]):
                Channel_signal[:,C_index] = C_data.as_array()[:analog_length].T # get channel data one by one and put in Channel_signal
                Channel_index.append(seg_reader.analogsignals[C_index].annotations['channel_names'][0]) # get channel name one by one and put in Channel_index
            
            Channel_signal[:,-1] = seg_reader.analogsignals[0].times[:analog_length] # get time stamps and put in Channel_signal
            Channel_index.append('Time') 
            
            Channel_signal_output.append(pd.DataFrame(Channel_signal,columns=Channel_index))
            
            marker_channel_index = [index for index,value in enumerate(seg_reader.events) if value.name == 'marker'][0] #find 'marker' channel
            marker_labels = seg_reader.events[marker_channel_index].get_labels().astype('int') # get 'marker' labels
            marker_values = seg_reader.events[marker_channel_index].as_array() # get 'marker' values
            marker = {'labels': marker_labels, 'values': marker_values} # arrange labels and values in a dict
            marker_list.append(marker)
            
        return Channel_signal_output, marker_list, smr_sampling_rate

In [5]:
class old_log_extractor(object):
    def __init__(self,files,marker_list):
        
        self.full_path_file_names = [file.replace('smr','log') for file in files]
        
        self.num_of_trials = [(value['labels']==3).sum() for value in marker_list] # usually need to remove the last trial
        
        self.isFullOn = []
        self.gain_V = []
        self.gain_W = []
        
    def extract_data(self):
        for index, file_name in enumerate(self.full_path_file_names):
            trial_count = 0
            with open(file_name,'r',encoding='UTF-8') as content:
                log_content = content.readlines()
            for line_number, line in enumerate(log_content):
                if 'Joy Stick Max Velocity' in line:
                    gain_V = float(line.split(': ')[1])
                if 'Joy Stick Max Angular Velocity' in line:
                    gain_W = float(line.split(': ')[1])
                    
                if 'Firefly Full ON' in line:
                    trial_count += 1
                    if trial_count > self.num_of_trials[index]:
                        continue
                    
                    content_temp = float(line.split(': ')[1])
                    self.isFullOn = np.hstack([self.isFullOn,content_temp])
                    self.gain_V = np.hstack([self.gain_V,gain_V])
                    self.gain_W = np.hstack([self.gain_W,gain_W])
                    
                
        return self.gain_V, self.gain_W, self.isFullOn

In [1]:
# process smr
def old_segment_trials(Channel_signal_output,marker_list,isFullOn,gain_V,gain_W):
    COLUMNS = ['ep','gain_V','gain_W','px', 'py', 'p_heading', 'real_v','real_w','time',
              'FFX','FFY','real_relative_radius','real_relative_angle',
               'isFullOn','isRewarded','action_v','action_w']
    monkey_trajectory = pd.DataFrame(columns=COLUMNS)
    trial_counter = 0

    for session_index, session_data in enumerate(Channel_signal_output):
        # remove head
        head_value = marker_list[session_index]['values'][marker_list[session_index]['labels']==2][0]
        session_data = session_data[session_data.Time >= head_value]
        # remove tail
        tail_value = marker_list[session_index]['values'][marker_list[session_index]['labels']==3][-1]
        session_data = session_data[session_data.Time <= tail_value]
        # segment trials
        trial_index_start = \
        np.digitize(session_data.Time,marker_list[session_index]['values'][marker_list[session_index]['labels']==2])
        trial_index_stop = \
        np.digitize(session_data.Time,marker_list[session_index]['values'][marker_list[session_index]['labels']==3])

        for trial_number in np.unique(trial_index_start):  
            trial_index_temp = np.subtract((trial_index_start==trial_number),
                                             (trial_index_start==trial_number)&(trial_index_stop==trial_number),
                                                        dtype=np.float).astype(bool)
            session_data_temp = session_data[trial_index_temp]

            # skip skipping trials
            if session_data_temp['AngularV'].values.size < 301:
                trial_counter += 1
                continue

            monkey_real_w_temp = savgol_filter(-session_data_temp['AngularV'].values,301,1)
            monkey_real_v_temp = savgol_filter(session_data_temp['ForwardV'].values,301,1)

            if len(np.where(monkey_real_v_temp>10)[0])==0 or \
                    (session_data_temp['Time'].values[-1]-session_data_temp['Time'].values[0])>3: # skip not move trials
                trial_counter += 1
                continue
            else:
                monkey_px_temp = savgol_filter(session_data_temp['MonkeyX'].values,301,1)
                monkey_py_temp = savgol_filter(-session_data_temp['MonkeyY'].values,301,1)

                # 1. Check if this is a rewarded trial
                isRewarded_temp = False
                for juice_marker in marker_list[session_index]['values'][marker_list[session_index]['labels']==4]:

                    if (juice_marker > session_data_temp['Time'].values[0]) & \
                    (juice_marker < session_data_temp['Time'].values[-1]):
                        isRewarded_temp = True

                # 2. set trial start time as 0
                monkey_trial_T_temp = session_data_temp['Time'].values - session_data_temp['Time'].values[0]

                # 3. calculate start and stop action index
                # 3.1 the start index of v
                move_T_idx_temp = np.nanmin([[np.where(abs(monkey_real_v_temp)>0.5)[0][0] if 
                                len(np.where(abs(monkey_real_v_temp)>0.5)[0])>0 else np.nan],
                                        [np.where(abs(monkey_real_w_temp)>0.5)[0][0] if 
                                len(np.where(abs(monkey_real_w_temp)>0.5)[0])>0 else np.nan]])
                start_v_idx = int(move_T_idx_temp+5)
                
                # 3.2 the start index of x, make sure trial start from origin (0,-32.5)
                start_x_idx = np.linalg.norm(np.vstack([monkey_px_temp[:400],
                                          monkey_py_temp[:400]+32.5]),axis=0).argmin()
                
                # 3.3 get start index
                start_idx = max(start_v_idx,start_x_idx+1)

                # 3.4 get stop index when the monkey intended to stop
                if len(np.where(monkey_real_v_temp[::-1]>1)[0])>0:
                    stop_idx = int(monkey_real_v_temp.size - np.where(monkey_real_v_temp[::-1]>0.5)[0][0]) - 1
                else:
                    stop_idx = int(monkey_real_v_temp.size) - 1

                if stop_idx <= start_idx:
                    trial_counter += 1
                    continue
                    

                # 4. add log file stuff
                isFullOn_temp = isFullOn[trial_counter]
                gain_V_temp = gain_V[trial_counter]
                gain_W_temp = gain_W[trial_counter]

                # 8. calculate monkey to target radius           
                FFX_temp = session_data_temp['FireflyX'].values.mean()
                FFY_temp = -session_data_temp['FireflyY'].values.mean()

                relative_radius_temp = np.linalg.norm(np.vstack([monkey_px_temp,monkey_py_temp]).reshape([2,-1]) - 
                       np.vstack([FFX_temp,FFY_temp]).reshape([2,-1]),axis=0)

                # 9. calculate monkey to target angle
                monkey_p_heading_temp = (np.cumsum(monkey_real_w_temp[start_idx:stop_idx])/833.33) + 90 
                relative_angle_temp = \
                get_relative_ang(monkey_px_temp[start_idx:stop_idx],monkey_py_temp[start_idx:stop_idx],
                                 np.radians(monkey_p_heading_temp),FFX_temp,FFY_temp)

                # 10. concat data
                data = np.array([[trial_counter,gain_V_temp,gain_W_temp,
                        monkey_px_temp[start_idx:stop_idx],monkey_py_temp[start_idx:stop_idx], 
                                  monkey_p_heading_temp,monkey_real_v_temp[start_idx:stop_idx], 
                                  monkey_real_w_temp[start_idx:stop_idx], 
                                  (monkey_trial_T_temp[start_idx:stop_idx]-monkey_trial_T_temp[start_idx:stop_idx][0]),
                                 FFX_temp,FFY_temp,relative_radius_temp[start_idx:stop_idx],
                                np.degrees(relative_angle_temp),
                                 isFullOn_temp,isRewarded_temp,
                                  monkey_real_v_temp[start_idx:stop_idx]/gain_V_temp, 
                                  monkey_real_w_temp[start_idx:stop_idx]/gain_W_temp]])
                monkey_trajectory = monkey_trajectory.append(pd.DataFrame(data,columns=COLUMNS))

                trial_counter += 1

    monkey_trajectory.reset_index(drop=True,inplace=True)
    monkey_trajectory['ep'] = monkey_trajectory.index.values
    
    return monkey_trajectory

In [None]:
def find_target_matched_trials(monkey_trajectory, pro_gains_1x=0.5, spatial_scale=400):
    if monkey_trajectory.name.unique().size==1:
        gain_index = monkey_trajectory.gain_V>(pro_gains_1x*spatial_scale)
    elif monkey_trajectory.name.unique().size==2:
        gain_index = (monkey_trajectory.name=='Viktor')
        
    gain_targets = monkey_trajectory.loc[gain_index==1,['target_x','target_y']].astype(float)
    normal_targets = monkey_trajectory.loc[gain_index==0,['target_x','target_y']].astype(float)
    normal_targets_original = normal_targets.copy()

    df_index_all = []
    for _, gain_target in gain_targets.iterrows():
        normal_targets = normal_targets.assign(distance=np.linalg.norm(
            normal_targets.values[:,:2]-gain_target.values,axis=1))
        found_trial = normal_targets.iloc[normal_targets.distance.argmin()]
        if found_trial.distance < 15:
            df_index = found_trial.name
            normal_targets.drop(df_index,inplace=True)
        else:
            normal_targets_original = normal_targets_original.assign(distance=np.linalg.norm(
                normal_targets_original.values[:,:2]-gain_target.values,axis=1))
            df_index = normal_targets_original.iloc[normal_targets_original.distance.argmin()].name

        df_index_all.append(df_index)

    monkey_trajectory_gain = monkey_trajectory.loc[gain_index]
    monkey_trajectory_gain.reset_index(drop=True,inplace=True)
    monkey_trajectory_target_matched = monkey_trajectory.loc[df_index_all]
    monkey_trajectory_target_matched.reset_index(drop=True,inplace=True)
    
    return monkey_trajectory_gain, monkey_trajectory_target_matched