### Place Cell Identification

#### Three inputs:
#### 1. input_animal:
- H0466
- H0422
- etc

#### 2. input_session: 
- all = run through all sessions and stages
- N01
- N02
- I01
- I02
- A01
- A02
- P01
- P02

#### 3. input_stage:
- PRE
- SAM
- CHO
- PRO

In [55]:
#Input desired session and stage
input_animal = 'H0466'
input_session = 'all'
input_stage = 'all'

In [1]:
import os
import time

import place_cell_functions
import multiprocessing as mp
from _thread import start_new_thread

import pandas as pd
import numpy as np
import math

import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import gridspec
import seaborn as sns

from scipy import stats
from scipy.spatial import distance
from scipy.ndimage import gaussian_filter

from sklearn.preprocessing import normalize
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis


class bcolors:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'
    
#plotting params
mpl.rcParams['axes.facecolor'] = 'white'
mpl.rcParams['axes.edgecolor'] = 'black'
mpl.rcParams['axes.linewidth'] = '0.5'
mpl.rcParams['axes.labelsize'] = '8'
mpl.rcParams['axes.labelcolor'] = 'black'

mpl.rcParams['xtick.color'] = 'black'
mpl.rcParams['xtick.labelsize'] = '4'
mpl.rcParams['ytick.labelsize'] = '4'
mpl.rcParams['ytick.color'] = 'black'

In [58]:
#Read iin LR events and trace files
all_events_dlc = pd.read_csv('/Users/rufusmitchell-heggs/Desktop/data_analysis/preprocessing/event_arena/'+input_animal+'/output_directory/'+input_animal+'_event_dlc_LR.csv')
all_traces_dlc = pd.read_csv('/Users/rufusmitchell-heggs/Desktop/data_analysis/preprocessing/event_arena/'+input_animal+'/output_directory/'+input_animal+'_trace_dlc_LR.csv')

if input_session == 'all':
    input_session = list(set(list(all_traces_dlc['Session'])))
if input_stage == 'all':
    input_stage = list(set(list(all_traces_dlc['stage'])))
    

In [64]:
for session in input_session:
    for stage in input_stage:
        print(session, stage)
        
        #Session and stage selection based on input
        traces = all_traces_dlc[all_traces_dlc['Session']==session][all_traces_dlc[all_traces_dlc['Session']==session]['stage']==stage]
        events = all_events_dlc[all_events_dlc['Session']==session][all_events_dlc[all_events_dlc['Session']==session]['stage']==stage]
        traces = traces.reset_index(drop=True)
        events = events.reset_index(drop=True)
        
        #Remove any cells that are not registered during that session/stage
        traces = traces.loc[:,(traces!=' nan').all()]

        #List of cell IDs and their respective traces
        cells = traces.columns
        events = events[events.columns.intersection(cells)]
        cell_ids = events.columns[8:]
        
        #Assign variable name to global identity
        cells = cell_ids
        trace_dlc = traces
        events_dlc = events
        
        if len(events_dlc) != 0:

            # ----------------------------------------------------------------------------------------
    #         # CRITERIA 1 - Cell sorting to manually remove bad traces
            # ----------------------------------------------------------------------------------------
    #         num_cells_analysed = 0
    #         for cell in cells:

    #             #Allow user to see how many cells they have looked through
    #             num_cells_analysed +=1
    #             num_cells_left = len(cells)-1
    #             print(cell, num_cells_analysed,'/',num_cells_left)

    #             #plot the raw traces and overlapping events
    #             plt.figure(figsize=(8, 1), dpi=400)
    #             plt.plot(trace_dlc[str(cell)].astype(float), linewidth=0.5)
    #             plt.plot(events_dlc[str(cell)].astype(float), linewidth=1)
    #             plt.show()

    #             #Option to remove any bad cells
    #             good_events = input("Are all events good - y/n?")
    #             if good_events != 'y':
    #                 print(bcolors.FAIL+cell,'Dropped'+bcolors.ENDC)
    #                 trace_dlc = trace_dlc.drop([cell], axis=1)
    #                 events_dlc = events_dlc.drop([cell], axis=1)

    #             else:
    #                 print(bcolors.OKGREEN+cell,'Accepted'+bcolors.ENDC)

    #         #Update cells being analysed list
    #         cells = events_dlc.columns[8:]

            # ----------------------------------------------------------------------------------------
            # CRITERIA 2 - all cells with events lower than 0.3 = 0
            # ----------------------------------------------------------------------------------------
            # for cell in cells:
            #     a = np.array(events_dlc[cell].values.tolist())
            #     events_dlc[str(cell)] = np.where(a <= 0.3, 0, a).tolist() # <--- Condition, ignore events lower than 0.3
            #     events_dlc[str(cell)] = np.where(a >= 0.3, 1, a).tolist() # <--- Binarize events
            #     if sum(events_dlc[cell])<3:
            #         events_dlc = events_dlc.drop([cell], axis=1)

            for cell in cells:
                a = np.array(events_dlc[cell].values.tolist())
                events_dlc[str(cell)] = np.where(a > 0, 1, a).tolist() # <--- Binarize events
                if sum(events_dlc[cell])<3:
                    events_dlc = events_dlc.drop([cell], axis=1)
                    trace_dlc = trace_dlc.drop([cell], axis=1)

            cells = events_dlc.columns[8:]
            for cell in trace_dlc.columns:
                if cell not in events.columns:
                    trace_dlc = trace_dlc.drop([cell], axis=1)

            # ----------------------------------------------------------------------------------------
            # OCCUPANCY VECTOR GENERATOR
            # ----------------------------------------------------------------------------------------

            #Defining the boundaries of the arena
            xedges = np.arange(0, 700, 720/33)
            yedges = np.arange(0, 600, 500/33)

            #Extract the DLC x,y coordinates
            x = events_dlc['x']
            y = events_dlc['y']

            #Create occupancy map vector
            occupancy_map = []
            for x_pos, y_pos in zip(x,y):
                for y_bin in range(len(yedges)):
                    if y_bin < len(yedges)-1:
                        if  yedges[y_bin] <= y_pos <= yedges[y_bin+1]:
                            for x_bin in range(len(xedges)):
                                if x_bin < len(xedges)-1:
                                    if xedges[x_bin] <= x_pos <= xedges[x_bin+1]:
                                        occupancy_map.append(int(str(y_bin)+str(x_bin)))

            # ----------------------------------------------------------------------------------------                         
            # SPATIAL MUTUAL INFORMATION for each cell, the percentile and the shuffled distribution
            # ----------------------------------------------------------------------------------------

            pool = mp.Pool(processes=4)
            results = [pool.apply_async(place_cell_functions.mi_perc_dis, args=(np.array(events_dlc[str(cell)]),occupancy_map)) for cell in cells]
            mi_all, perc_all, dist_all = np.array([p.get() for p in results]).transpose()
            perc_all = np.array([item for sublist in perc_all for item in sublist])
            end = time.time()

            # ----------------------------------------------------------------------------------------                         
            #Criteria 3 - Only cells in the 95th pecentile are considered place cells
            # ----------------------------------------------------------------------------------------                         
            percentile = 95
            place_cell_status = []
            for perc in perc_all:
                if perc > percentile:
                    place_cell_status.append('y')
                else:
                    place_cell_status.append('n')


            #Create dataframe containing all cells + mutual information distribution and percentile
            place_cell_data = {'Animal':list((input_animal,)*len(cells)),'Session':list((session,)*len(cells)),'Stage':list((stage,)*len(cells)),
                               'Neuron':list(cells),'Status':place_cell_status,'Mutual_Information':mi_all,
                               'Percentile':np.array(perc_all).flatten('F'),'Distribution':dist_all}

            place_cell_table = pd.DataFrame(place_cell_data) 

            #Dataframe of other cells that werent analysed
            cell_status = [x for x in all_traces_dlc.columns[8:] if x not in cells]
            removed_cell_data = {'Animal':list((input_animal,)*len(cell_status)),'Session':list((session,)*len(cell_status)),
                                 'Stage':list((stage,)*len(cell_status)),'Neuron':cell_status,'Status':list(('N/A',)*len(cell_status)),
                                 'Mutual_Information':list(('N/A',)*len(cell_status)),'Percentile':list(('N/A',)*len(cell_status)),
                                 'Distribution':list(('N/A',)*len(cell_status))}

            removed_cell_data = pd.DataFrame(removed_cell_data) 
            place_cell_table = place_cell_table.append(removed_cell_data)
            # ----------------------------------------------------------------------------------------                         
            # SAVE PLACE CELL TABLE
            # ----------------------------------------------------------------------------------------                         
            #Looks for place cell mutual info file - if it doesn't exist, it creates new one
            try:
                csv_place_cells = pd.read_csv('/Users/rufusmitchell-heggs/Desktop/data_analysis/projects/hippocampus/event_arena/secondary_analysis/'+input_animal+'/'+input_animal+'_place_cell_mutual_info.csv')
            except FileNotFoundError:
                #Creating new csv
                place_cell_table.to_csv('/Users/rufusmitchell-heggs/Desktop/data_analysis/projects/hippocampus/event_arena/secondary_analysis/'+input_animal+'/'+input_animal+'_place_cell_mutual_info.csv', index=False)
            else:
                #If there is an existing CSV, this checks the animal, session and stage is already present, 
                #and asks if you want to add it again
                csv_animal = csv_place_cells[csv_place_cells['Animal']==input_animal]
                csv_session = csv_animal[csv_animal['Session']==session]
                csv_stage = csv_session[csv_session['Stage']==stage]
                if len(csv_stage) > 0:
                    add_again = input('Animal '+input_animal+' Session '+ session+' Stage '+stage+' is already in csv table - are you sure you want to add it again? (y/n)')
                    if add_again == 'y':
                        place_cell_table.to_csv('/Users/rufusmitchell-heggs/Desktop/data_analysis/projects/hippocampus/event_arena/secondary_analysis/'+input_animal+'/'+input_animal+'_place_cell_mutual_info.csv', mode='a', header=False, index=False)
                        print('Animal '+input_animal+' Session '+ session+' Stage '+stage+' added to csv table)

P02 CHO


  result = method(y)


P02 PRO
P02 PRE
P02 SAM
N02 CHO
Animal H0466 Session N02 Stage CHO is already in csv table - are you sure you want to add it again? (y/n)y
N02 PRO
N02 PRE
N02 SAM
A01 CHO
Animal H0466 Session A01 Stage CHO is already in csv table - are you sure you want to add it again? (y/n)n
A01 PRO
A01 PRE
A01 SAM
I02 CHO
Animal H0466 Session I02 Stage CHO is already in csv table - are you sure you want to add it again? (y/n)y
I02 PRO
I02 PRE
I02 SAM
P01 CHO
P01 PRO
P01 PRE
P01 SAM
N01 CHO
Animal H0466 Session N01 Stage CHO is already in csv table - are you sure you want to add it again? (y/n)y
N01 PRO
N01 PRE
N01 SAM
A02 CHO
Animal H0466 Session A02 Stage CHO is already in csv table - are you sure you want to add it again? (y/n)y
A02 PRO
A02 PRE
A02 SAM
I01 CHO
Animal H0466 Session I01 Stage CHO is already in csv table - are you sure you want to add it again? (y/n)y
I01 PRO
I01 PRE
I01 SAM
