In [2]:
import numpy as np
import os
import csv
import matplotlib.pyplot as plt
from scipy.signal import convolve
from scipy.signal import savgol_filter
import math
import boris_extraction as boris
import pandas as pd
from scipy.stats import sem
from statistics import mean
from scipy.stats import wilcoxon
from sklearn.decomposition import PCA


In [None]:
#to do

# test multirecording code
# specfically recording = name of object or object
# test PCA code
# write doc_strings for zscore and pca code  

# rearrange attributes of recordings vs event functions

# write code to add event dict to recordings 
# as well as subject per recording
# either as a batch process or individually (code for both) 

# add exporting functions for all data (collection) or some data (per recording) 
# add fishers exact test code somewhere 

# figure out this equalize parameter
# potentially get rid of all options except user_defined 




In [16]:

def get_spiketrain(timestamp_array, timebin =1, sampling_rate=20000):
    """
    creates a spiketrain of ms time bins 
    each array element is the number of spikes recorded per ms
    
    Args (3 total):
        timestamp_array: numpy array, spike timestamp array
        timebin: int, default=1, timebin (ms) of resulting spiketrain
        sampling_rate: int, default=20000, sampling rate in Hz of the ephys recording
        
    Returns (1):
        spiketrain: numpy array, array elements are number of spikes per timebin
    """
    
    hz_to_timebin = int(sampling_rate*.001*timebin)
    spiketrain = np.histogram(timestamp_array, bins=np.arange(0, timestamp_array[-1], hz_to_timebin))[0]
    
    return spiketrain

def get_firing_rate(spiketrain, smoothing_window = 250, timebin=1):
    """
    calculates firing rate (spikes/second)
    
    Args (3 total, 1 required):
        spiketrain: numpy array, in timebin (ms) bins
        smoothing_window: int, default=250, smoothing average window (ms)
            min smoothing_window = 1
        timebin: int, default = 1, timebin (ms) of spiketrain

    Return (1):
        firing_rate: numpy array of firing rates in timebin sized windows
        
    """ 
    weights = np.ones(smoothing_window) / smoothing_window * 1000 / timebin 
    firing_rate = np.convolve(spiketrain, weights, mode='same')

    return firing_rate


def get_event_lengths(events):
    """
    calculates event lengths and longest event length

    Args (1):
        events:numpy array of [[start (ms), stop (ms)] x n events]

    Returns (2):
        max event length: int, longest event length in ms
        event_lengths: lst of ints, event lengths in ms
    """
    event_lengths = []
    for i in range(events.shape[0]):
        event_length = int(events[i][1] - events[i][0])
        event_lengths.append(event_length)
    return max(event_lengths), event_lengths, mean(event_lengths)


def trim_event(event, max_event):
    """
    trims events to a given length
    Args (2 total):
        events:numpy array of [[start (ms), stop (ms)] x n events]
        max_event: int, max length (s) of event desired

    Returns (1):
        events:numpy array of [[start (ms), stop (ms)] x n events]
        with none longer than max_event
    """
    if event[1] - event[0] > (max_event*1000):
        event[1] = event[0]+(max_event*1000)
        event[0] = event[0]  
    return np.array(event)


def pre_event_window(event, baseline_window):
    """
    creates an event like object np.array[start(ms), stop(ms)] for
    baseline_window amount of time prior to an event

    Args (2 total):
        event: np.array[start(ms), stop(ms)]
        baseline_window: int, seconds prior to an event

    Returns (1):
        preevent: np.array, [start(ms),stop(ms)] baseline_window (s) before event
    """
    preevent = [event[0] - (baseline_window*1000)-1, event[0]-1]
    return np.array(preevent)


def max_events(unit_dict, max_event, pre_window, timebin = 1):
    """
    creates a dictionary with unit firing rates during events no longer
    than max_event (s) (all longer events will be trimmed) and start times
    adjusted to include pre_window time (s)

    Args (4 total):
        unit_dict: dict, unit id as keys, and values are spiketrains or firing rates 
        max_event: int, longest event length (s) returned (all longer events will be trimmed)
        pre_window: int, amount of preevent time (s) returned
        timebin: timebin (ms) of dict

    Returns (1):
        snippets_dict: dict, unit id as keys, values are spiketrains or firing rates during
        pre_window and up until max event 
    """
    
    snippets_dict = {}
    for unit in unit_dict.keys():
        events = unit_dict[unit]
        try:
            events = [event[0:int((pre_window + max_event)*1000/timebin)] for event in events]
        except IndexError:
            pass
        snippets_dict[unit] = events
    return snippets_dict


def get_unit_average_events(unit_event_snippets):
    unit_average_event = {}
    for unit in unit_event_snippets.keys():
        unit_average_event[unit] = np.mean(unit_event_snippets[unit], axis=0)
    return unit_average_event


class EphysRecording:
    """
    A class for an ephys recording after being spike sorted and manually curated using phy. 
    Ephys recording must have a phy folder. 

    Attributes:
        path: str, relative path to the phy folder
            formatted as: './folder/folder/phy'
        subject: str, subject id who was being recorded
        sampling_rate: int, sampling rate of the ephys device
            in Hz, standard in the PC lab is 20,000Hz
        timestamps_var: numpy array, all spike timestamps 
            of good and mua units (no noise unit-generated spikes)
        unit_array: numpy array, unit ids associated with each
            spike in the timestamps_var
        labels_dict: dict, keys are unit ids (str) and
            values are labels (str)
        unit_timestamps: dict, keys are unit ids (int), and
            values are numpy arrays of timestamps for all spikes 
            from "good" units only 
        spiketrain: np.array, spiketrain of number of spikes in a specified timebin
        unit_spiketrains: dict, spiketrains for each unit
            keys: str, unit ids
            values: np.array, number of spikes per specified timebin
        unit_firing_rates: dict, firing rates per unit
            keys: str, unit ids
            values: np.arrays, firing rate of unit in a specified timebin 
                    calculated with a specified smoothing window

    Methods: (all called in __init__)
        get_unit_labels: creates labels_dict
        get_spike_specs: creates timestamps_var and unit_array
        get_unit_timestamps: creates unit_timestamps dictionary
    """
    
    def __init__(self, path, sampling_rate=20000):
        """
        constructs all necessary attributes for the EphysRecording object
        including creating labels_dict, timestamps_var, and a unit_timstamps 
        dictionary 
        
        Arguments (2 total):
            path: str, relative path to the phy folder
                formatted as: './folder/folder/phy'
            sampling_rate: int, default=20000; sampling rate of 
                the ephys device in Hz
        Returns:
            None
        """
        self.path = path
        self.sampling_rate = sampling_rate
        self.zscored_events = {}
        self.wilcox_dfs = {}
        self.get_unit_labels()
        self.get_spike_specs()
        self.get_unit_timestamps()

    
    def get_unit_labels(self):
        """
        assigns self.labels_dicts as a dictionary 
        with unit id (str) as key and label as values (str)
        labels: 'good', 'mua', 'noise' 

        Arguments:
            None

        Returns:
            None
        """
        labels = 'cluster_group.tsv'
        with open(os.path.join(self.path, labels), 'r') as f:
            reader = csv.DictReader(f, delimiter='\t')
            self.labels_dict = {row['cluster_id']: row['group'] for row in reader}

    
    def get_spike_specs(self):
        """
        imports spike_time and spike_unit from phy folder
        deletes spikes from units labeled noise in unit and timestamp array
        and assigns self.timstamps_var (numpy array) as the remaining timestamps 
        and assigns self.unit_array (numpy array) as the unit ids associated
        with each spike
        
        Args:
            None
        
        Returns:
            None 
        """
        timestamps = 'spike_times.npy'
        unit = 'spike_clusters.npy'
        timestamps_var = np.load(os.path.join(self.path, timestamps))
        unit_array = np.load(os.path.join(self.path, unit))
        spikes_to_delete = []
        for spike in range(len(timestamps_var)): 
            if self.labels_dict[unit_array[spike].astype(str)] == 'noise':
                spikes_to_delete.append(spike)
        self.timestamps_var = np.delete(timestamps_var, spikes_to_delete)
        self.unit_array = np.delete(unit_array, spikes_to_delete)

    
    def get_unit_timestamps(self):
        """
        creates a dictionary of units to spike timestamps
        keys are unit ids (int) and values are spike timestamps for that unit (numpy arrays)
        and assigns dictionary to self.unit_timestamps
        
        Args:
            None
        
        Return:
            None
        """
        
        unit_timestamps = {}
        for spike in range(len(self.timestamps_var)): 
            if self.unit_array[spike] in unit_timestamps.keys():
                timestamp_list = unit_timestamps[self.unit_array[spike]] 
                timestamp_list = np.append(timestamp_list, self.timestamps_var[spike])
                unit_timestamps[self.unit_array[spike]] = timestamp_list
            else:
                unit_timestamps[self.unit_array[spike]] = self.timestamps_var[spike]
        
        self.unit_timestamps = unit_timestamps   
    

class MultiEvents:
    """
    A class for ephys statistics done on multiple event types
    an EphysRecording class instance 
    and an event dictionary
    where keys are event type names and values are arrays [[start (ms), stop(ms)]..]
    

    Attributes:
        event_dict: dict, dictionary of event names and event start and stop times
            key: str, name of the event 
            value: numpy array of [[start (ms), stop (ms)] x n events]
        event_types: lst of strs, from the keys of the event dict
        events: list of all start and stop times from event types 
        smoothing_window: int, default=250, window length in ms used to calculate firing rates
        timebin: int, default=1, bin size (in ms) for spike train and firing rate arrays
        ignore_freq: int, default=0, frequency in Hz that a good unit needs to fire at to be included in analysis
        longest_event: int, length of longest event (ms)
        event_lengths: lst, length of all events (ms)
        spiketrain: numpy array, each element of the array 
            is the number of spikes per timebin throughout the whole recording
        unit_spiketrains: dict, keys are unit ids (int), values (numpy arrays) are each "good"
            units spiketrains in the specified timebins for the whole recording
        unit_firing_rates: dict, keys are unit ids (int), values (numpy array) are each "good"
            units firing rates calculated using smoothing_window in bins of size timebin

    Methods: 
        get_whole_spiketrain: 
        get_unit_spiketrains: 
        get_unit_firing_rates: 
        get_event_snippets:
        get_unit_event_firing_rates:
        wilcox_baseline_v_event_stats:
        wilcox_baseline_v_event_plots:
    """
    def __init__(self, event_dict, recording, smoothing_window=250, timebin=1, ignore_freq=0.01):
        
        self.e_recording = recording
        self.event_dict = event_dict
        self.event_types = list(event_dict.keys())
        self.events = [value for sublist in event_dict.values() for value in sublist]
        self.smoothing_window = smoothing_window
        self.timebin = timebin
        self.ignore_freq = ignore_freq
        self.longest_event, self.event_lengths, self.mean_event_length = get_event_lengths(self.events)
        self.get_whole_spiketrain()
        self.get_unit_spiketrains()
        self.get_unit_firing_rates()

    
    def get_whole_spiketrain(self):
        """
        creates a spiketrain of ms time bins 
        each array element is the number of spikes recorded per ms
        
        Args (1 total):
            timestamp_array: numpy array, spike timestamp array
            
        Returns (1):
            spiketrain_ms_timebins: a numpy array 
                array elements are number of spikes per ms 
        """
        if isinstance(self.recording, EphysRecording): 
            self.spiketrain = get_spiketrain(self.recording.timestamps_var, self.recording.sampling_rate, self.timebin)
        else: 
            for recording in self.ephys_recording.collection.values():
                recording.spiketrain = get_spiketrain(recording.timestamps_var, recording.sampling_rate, self.timebin)
    
    def get_unit_spiketrains(self):  
        """
        Creates a dictionary and assigns it as self.unit_spiketrains
        where keys are 'good' unit ids (int) (not 'mua') that reach
        a threshold frequency, values are numpy arrays of 
        spiketrains in timebin sized bins
        
        Args:
            None
            
        Reutrns:
            None
            
        """
        if isinstance(self.recording, EphysRecording): 
            unit_spiketrains = {}
            for unit in self.recording.unit_timestamps.keys():
                if self.recording.labels_dict[str(unit)] == 'good':
                    no_spikes = len(self.recording.unit_timestamps[unit])
                    unit_freq = no_spikes/self.recording.timestamps_var[-1]*self.recording.sampling_rate
                    if unit_freq > self.ignore_freq:
                        unit_spiketrains[unit] = get_spiketrain(self.recording.unit_timestamps[unit], 
                                                                self.recording.sampling_rate, self.timebin)
            self.unit_spiketrains = unit_spiketrains
        else:
            for recording in self.ephys_recording.collection.values():
                unit_spiketrains = {}
                for unit in self.recording.unit_timestamps.keys():
                    if self.recording.labels_dict[str(unit)] == 'good':
                        no_spikes = len(self.recording.unit_timestamps[unit])
                        unit_freq = no_spikes/self.recording.timestamps_var[-1]*self.recording.sampling_rate
                        if unit_freq > self.ignore_freq:
                            unit_spiketrains[unit] = get_spiketrain(self.recording.unit_timestamps[unit], 
                                                                    self.recording.sampling_rate, self.timebin)
                recording.unit_spiketrains = unit_spiketrains    

    
    def get_unit_firing_rates(self):  
        """
        Calculates firing rates per unit,
        creates a dictionary and assigns it as self.unit_firing_rates
        the keys are unit ids (int) and values are firing rates for the
        unit (numpy array) in timebin sized bins 
        calculated using smoothing_window for averaging
        
        Args:
            none
            
        Returns:
            none
        """
        if isinstance(self.recording, EphysRecording): 
            unit_firing_rates = {}
            for unit in self.unit_spiketrains.keys():
                unit_firing_rates[unit] = get_firing_rate(self.unit_spiketrains[unit], self.smoothing_window, self.timebin)
            self.unit_firing_rates = unit_firing_rates

        else:
            for recording in self.ephys_recording.collection.values():
                unit_firing_rates = {}
                for unit in self.unit_spiketrains.keys():
                    unit_firing_rates[unit] = get_firing_rate(self.unit_spiketrains[unit], self.smoothing_window, self.timebin)
                recording.unit_firing_rates = unit_firing_rates
    
    def get_event_snippets(self, event, whole_recording, equalize, pre_window=0, post_window=0):
        """
        takes snippets of spiketrains or firing rates for events
        optional pre-event and post-event windows (s) may be included
        all events can also be of equal length by extending 
        snippet lengths to the longest event
    
        Args (5 total, 1 required): 
            whole_recording: numpy array, spiketrain or firing rates 
                for the whole recording, for population or for a single unit
            pre_window: int, default=0, seconds prior to start of event returned
            post_window: int, default=0, seconds after end of event returned
            equalize: {user_defined, 'max', 'average'}, equalizes lengths of events
                by padding with post event time or trimming event
                user_defined: float, makes all events user_defined (s) long   
                'max': makes all events as long as the longest event 
                'average': makes all events as long as the average event length 
            events:numpy array of [[start (ms), stop (ms)] x n events], 
                default=None in which case self.events is used
    
        Returns (1):
            event_snippets: a list of lists, where each list is a list of firing rates
                or spiketrains during an event including pre_window&post_windows, 
                accounting for equalize and timebins
        """
        
        if event in self.event_dict.keys():
            events = self.event_dict[event]
        event_snippets = []
        pre_window = math.ceil(pre_window*1000)
        post_window = math.ceil(post_window*1000)
        for i in range(events.shape[0]):
            if equalize == 'max':
                event_diff = math.ceil(self.longest_event - self.event_lengths[i])
            if equalize == 'average':
                event_diff = math.ceil(self.mean_event_length - self.event_lengths[i])
            else:
                event_diff = math.ceil(equalize*1000 - self.event_lengths[i])
            pre_event = math.ceil((events[i][0] - pre_window)/self.timebin)
            post_event = math.ceil((events[i][1] + post_window + event_diff)/self.timebin)
            event_snippet = whole_recording[pre_event:post_event]
            event_snippets.append(event_snippet)
        return event_snippets
    
    def get_unit_event_firing_rates(self, event, equalize, pre_window = 0, post_window = 0, e_collection = False):
        """
        returns firing rates for events per unit
    
        Args (6 total, 1 required):
            smoothing_window: int, default=250, smoothing average window (ms)
                min smoothing_window = 1 
            timebin: int, default 1, timebin in ms for firing rate array
            pre_window: int, default=0, seconds prior to start of event returned
            post_window: int, default=0, seconds after end of event returned
            equalize: {'max', average'}, default=False, equalizes lengths of events
                by padding with post event time or trimming event
                'max': makes all events as long as the longest event 
                'average': makes all events as long as the average event length 
            events:numpy array of [[start (ms), stop (ms)] x n events], 
                default=None in which case self.events is used
            
        Return (1):
            unit_event_firing_rates: dict, keys are unit ids (???),
            values are lsts of numpy arrays of firing rates per event
        """
        if not e_collection: 
            unit_event_firing_rates = {}
            for unit in self.unit_spiketrains.keys():
                unit_event_firing_rates[unit] = self.get_event_snippets(event, self.unit_firing_rates[unit], equalize, pre_window, post_window)
            return unit_event_firing_rates
        else:
            unit_event_firing_rates = {}
            for unit in self.unit_spiketrains.keys():
                unit_event_firing_rates[unit] = self.get_event_snippets(event, e_collection.unit_firing_rates[unit], equalize, pre_window, post_window)
            return unit_event_firing_rates
        
    def wilcox_baseline_v_event_stats(self, event, baseline_window, equalize):
        #what if i wanted a random snippet from the first ten minutes instead of prior to the event?
        """
        calculates wilcoxon signed-rank test for average firing rates of two windows: event vs baseline
        baseline used is an amount of time immediately prior to the event
        wilcoxon signed-rank test is applied to two sets of measurements:
        average firing rate per event, average firing rate per baseline
        
        Args (3 total, 1 required):
            baseline_window: int, length of baseline firing rate (s)
            max_event: int, default=None, max length of an event (s)
            equalize: Boolean, default=False, if True, equalizes lengths of each event to longest event
    
        Return (1):
            wilcoxon_df: pandas dataframe, columns are unit ids, 
            row[0] are wilcoxon statistics and row[1] are p values 
        
        """
        preevent_baselines = np.array([pre_event_window(event, baseline_window) for event in self.event_dict[event]])
        unit_preevent_firing_rates = self.get_unit_event_firing_rates(preevent_baselines, baseline_window, 0, 0)
        unit_event_firing_rates = self.get_unit_event_firing_rates(equalize,0,0)
        if equalize == 'average':
            self.wilcox_xstop = self.mean_event_length
        if equalize == 'max':
            self.wilcox_xstop = self.longest_event
        else:
            self.wilcox_xstop = equalize*1000
        unit_averages = {}
        for unit in unit_event_firing_rates.keys():
            try:
                event_averages = [mean(event) for event in unit_event_firing_rates[unit]]
                preevent_averages = [mean(event) for event in unit_preevent_firing_rates[unit]]
                unit_averages[unit] = [event_averages, preevent_averages]
            except:
                print(f'Unit {unit} has {len(self.recording.unit_timestamps[unit])} spikes')
        wilcoxon_stats = {}
        for unit in unit_averages.keys(): 
            wilcoxon_stats[unit] = wilcoxon(unit_averages[unit][0], unit_averages[unit][1], method = 'approx')
        wilcoxon_df = pd.DataFrame.from_dict(wilcoxon_stats)
        wilcoxon_df.index = ['Wilcoxon Stat', 'p value']
        self.wilcox_baseline = baseline_window
        return wilcoxon_df

    def fishers_exact_wilcox(self, baseline_window, equalize):
        sig_units = {}
        for event in self.event_dict.keys():
            wilcox_df = self.wilcox_baseline_v_event_stats(event, baseline_window, equalize) 
            sig_units[event] = (len(wilcox_df[(wilcox_df[1]<=0.05)]), len(wilcox_df[(wilcox_df[1]>.05)])) 
        fishers_df = pd.DataFrame(sig_units.values(), index=sig_units.keys(), columns=['Significant', 'Not Significant'])

                        

    def wilcox_baseline_v_event_plots(self, title, p_value=None, units=None):
        """
        plots event triggered average firing rates for units
        all events need to be the same length

        Args(3 total, 1 required):
            title: str, title of figure
            p_value: int, default=None, all p values less than will be plotted
            units: lst, default=None, list of unit ids (ints) to be plotted

        Returns:
            none
        """ 
        units_to_plot = []
        if p_value is not None:
            for unit in self.wilcoxon_df.columns.tolist():
                if self.wilcoxon_df[unit][1] < p_value:
                      units_to_plot.append(unit)
        else:
            if units is None:
                units_to_plot = self.wilcoxon_df.columns.tolist()
            else:
                units_to_plot = units
        no_plots = len(units_to_plot)
        height_fig = math.ceil(no_plots/3)
        i = 1
        plt.figure(figsize=(20,4*height_fig))
        unit_event_firing_rates = self.get_unit_event_firing_rates(self.wilcox_baseline, 0, True)
        x_stop = self.wilcox_xstop
        for unit in units_to_plot:
            mean_arr = np.mean(unit_event_firing_rates[unit], axis=0)
            sem_arr = sem(unit_event_firing_rates[unit], axis=0)
            p_value = self.wilcoxon_df[unit][1]
            x = np.linspace(start=-self.wilcox_baseline,stop=x_stop,num=len(mean_arr))
            plt.subplot(height_fig,3,i)
            plt.plot(x, mean_arr, c= 'b')
            plt.axvline(x=0, color='r', linestyle='--')
            plt.fill_between(x, mean_arr-sem_arr, mean_arr+sem_arr, alpha=0.2)
            plt.title(f'Unit {unit} Average (p={p_value})')
            i+=1
        plt.suptitle(title)
        plt.show()

    def wilcoxon_event_v_event_stats(self, event1, event2, equalize):
        """
        calculates wilcoxon signed-rank test for average firing rates of two windows: event vs baseline
        baseline used is an amount of time immediately prior to the event
        wilcoxon signed-rank test is applied to two sets of measurements:
        average firing rate per event, average firing rate per baseline
        
        Args (3 total, 1 required):
            baseline_window: int, length of baseline firing rate (s)
            max_event: int, default=None, max length of an event (s)
            equalize: {'max', average'}, default=False, equalizes lengths of events
                by padding with post event time or trimming event
                'max': makes all events as long as the longest event 
                'average': makes all events as long as the average event length 
    
        Return (1):
            wilcoxon_df: pandas dataframe, columns are unit ids, 
            row[0] are wilcoxon statistics and row[1] are p values 
        
        """
        unit_event1_firing_rates = self.get_unit_event_firing_rates(event1, equalize, 0, 0)
        unit_event2_firing_rates = self.get_unit_event_firing_rates(event2, equalize, 0, 0)
        unit_averages = {}
        for unit in unit_event1_firing_rates.keys():
            try:
                event1_averages = [mean(event) for event in unit_event1_firing_rates[unit]]
                event2_averages = [mean(event) for event in unit_event2_firing_rates[unit]]
                unit_averages[unit] = [event1_averages, event2_averages]
            except:
                print(f'Unit {unit} has {len(self.recording.unit_timestamps[unit])} spikes')
        wilcoxon_stats = {}
        for unit in unit_averages.keys(): 
            wilcoxon_stats[unit] = wilcoxon(unit_averages[unit][0], unit_averages[unit][1], method = 'approx')
        wilcoxon_df = pd.DataFrame.from_dict(wilcoxon_stats)
        self.wilcoxon_df = wilcoxon_df

    def get_zscore(self, event, baseline_window, equalize):
        #nancy had a matrix of (neuron, timebin, trial)
        event = self.event_dict[event]
        preevent_baselines = np.array([pre_event_window(event, baseline_window) for event in self.event])
        unit_event_firing_rates = self.get_unit_event_firing_rates( baseline_window, 0, equalize)
        unit_preevent_firing_rates = self.get_unit_event_firing_rates(0,0,False,preevent_baselines)
        zscored_events = {}
        for unit in unit_event_firing_rates:
            #calculate average event across all events per unit
            event_average = np.mean(unit_event_firing_rates[unit], axis = 0)
            #one average for all preevents 
            preevent_average = np.mean(unit_preevent_firing_rates[unit], axis = 0)
            mew = np.mean(preevent_average)
            sigma = np.std(preevent_average)
            zscored_event = [(event_bin - mew)/sigma for event_bin in event_average]
            zscored_events[unit] = zscored_event
        self.zscored_events = zscored_events
        self.zscore_baseline = baseline_window
        if equalize == 'average':
            self.zscore_xstop = self.mean_event_length
        if equalize == 'max':
            self.zscore_xstop = self.longest_event
        else:
            self.zscore_xstop = equalize*1000
        
    def get_zcore_plot(self, max_event, title):
        plt.figure(figsize=(20,6))
        baseline_window = self.zscore_baseline
        zscored_unit_event_firing_rates = self.zscored_events
        zscore_pop = np.array(list(zscored_unit_event_firing_rates.values()))
        mean_arr = np.mean(zscore_pop, axis=0)
        sem_arr = sem(zscore_pop, axis=0)
        x = np.linspace(start=-baseline_window,stop=self.zscore_xstop,num=len(mean_arr))
        plt.subplot(1,2,1)
        plt.plot(x, mean_arr, c= 'b')
        plt.axvline(x=0, color='r', linestyle='--')
        plt.fill_between(x, mean_arr-sem_arr, mean_arr+sem_arr, alpha=0.2)
        plt.title(f'Population z-score {self.event} event')
        plt.subplot(1,2,2)
        for unit in zscored_unit_event_firing_rates.keys():
            plt.plot(x, zscored_unit_event_firing_rates[unit], linewidth = .5)
            plt.axvline(x=0, color='r', linestyle='--')
            plt.title(f'Unit z-score {self.event} event')
        plt.suptitle(f'{title} Z-scored average {self.event} event')
        plt.show()        

    def PCA_trajectories(self, pre_window = 0, post_window = 0, equalize = 'average'):
        first_event = True
        for event in self.event_dict.keys():
            unit_event_firing_rates = self.get_unit_event_firing_rates(self, event, pre_window, post_window, equalize)
            unit_event_average = get_unit_average_events(unit_event_firing_rates) 
            if first_event:
                PCA_matrix = [value for sublist in unit_event_average.values() for value in sublist]
                PCA_key = [event] * len(PCA_matrix)
                first_event = False
            else:
                next_event = [value for sublist in unit_event_average.values() for value in sublist]
                PCA_matrix = np.concatenate([PCA_matrix, next_event], axis=0)
                next_event_key = [event] * len(next_event)
                PCA_key = PCA_key + next_event_key
        pca = PCA(n_components = 2)
        transformed_matrix = pca.fit_transform(PCA_matrix)
        self.PCA_trajectories = transformed_matrix
        self.PCA_key = PCA_key

class EphysRecordingCollection:
    #this should have an event dictionary baked into it as well as subject per recording
    #and an initiated
    def __init__(self, path, sampling_rate=20000):

        self.sampling_rate = sampling_rate
        self.path = path 
        self.wilcox_dfs = {}
        self.zscored_events = {}
        self.make_collection()

    def make_collection(self):
        
        collection = {}
        for root, dirs, files in os.walk(self.path):
            for directory in dirs:
                if directory.endswith('merged.rec'):
                    tempobject = EphysRecording(os.path.join(self.path, directory, 'phy'), self.sampling_rate)
                    collection[directory] = tempobject
        self.collection = collection


    def get_by_name(self, name):
        return self.collection[name] 

    def assign_events(self):
        for root, dirs, files in os.walk(self.path):
            for directory in dirs:
                if directory.endswith('merged.rec'):
                    reader = csv.DictReader(f, delimter='\t')
                    self.collection[directory].event_dict = {row['event']: row['start&stop'] for row in reader}

    

class MultiEvent_MultiSession:
    """
    A class for ephys statistics done on multiple event types for multiple recordings

    where keys are event type names and values are arrays [[start (ms), stop(ms)]..]
    

    Attributes:
        event_dict: dict, dictionary of event names and event start and stop times
            key: str, name of the event 
            value: numpy array of [[start (ms), stop (ms)] x n events]
        event_types: lst of strs, from the keys of the event dict
        events: list of all start and stop times from event types 
        smoothing_window: int, default=250, window length in ms used to calculate firing rates
        timebin: int, default=1, bin size (in ms) for spike train and firing rate arrays
        ignore_freq: int, default=0, frequency in Hz that a good unit needs to fire at to be included in analysis
        longest_event: int, length of longest event (ms)
        event_lengths: lst, length of all events (ms)
        spiketrain: numpy array, each element of the array 
            is the number of spikes per timebin throughout the whole recording
        unit_spiketrains: dict, keys are unit ids (int), values (numpy arrays) are each "good"
            units spiketrains in the specified timebins for the whole recording
        unit_firing_rates: dict, keys are unit ids (int), values (numpy array) are each "good"
            units firing rates calculated using smoothing_window in bins of size timebin

    Methods: 
        get_whole_spiketrain: 
        get_unit_spiketrains: 
        get_unit_firing_rates: 
        get_event_snippets:
        get_unit_event_firing_rates:
        wilcox_baseline_v_event_stats:
        wilcox_baseline_v_event_plots:
    """
    def __init__(self, collection, smoothing_window=250, timebin=1, ignore_freq=0.01):
        #each recording within the collection should have its own event dictionary as an attribute
        
        self.collection = collection
        self.event_dict = event_dict
        self.event_types = list(event_dict.keys())
        self.events = [value for sublist in event_dict.values() for value in sublist]
        self.smoothing_window = smoothing_window
        self.timebin = timebin
        self.ignore_freq = ignore_freq
        self.longest_event, self.event_lengths, self.mean_event_length = get_event_lengths(self.events)
        self.get_whole_spiketrain()
        self.get_unit_spiketrains()
        self.get_unit_firing_rates()

    
    def get_whole_spiketrain(self):
        """
        creates a spiketrain with timebin length timebins 
        for each recording in the collection
        each array element is the number of spikes per timebin

        each spiketrian is assigned as an attribute for that recording
        
        Args:
            None

        Returns:
            None
         
        """
        for recording in self.collection.values():
            recording.spiketrain = get_spiketrain(recording.timestamps_var, recording.sampling_rate, self.timebin)
    
    def get_unit_spiketrains(self):  
        """
        Creates a dictionary and assigns it as recording.unit_spiketrains
        for each recording in the collection
        where keys are 'good' unit ids (int) (not 'mua') that reach
        a threshold frequency, values are numpy arrays of 
        spiketrains in timebin sized bins
        
        Args:
            None
            
        Reutrns:
            None
            
        """
        sampling_rate = self.collection.sampling_rate
        for recording in self.collection.values():
                unit_spiketrains = {}
                for unit in recording.unit_timestamps.keys():
                    if recording.labels_dict[str(unit)] == 'good':
                        no_spikes = len(recording.unit_timestamps[unit])
                        unit_freq = no_spikes/recording.timestamps_var[-1]*sampling_rate
                        if unit_freq > self.ignore_freq:
                            unit_spiketrains[unit] = get_spiketrain(recording.unit_timestamps[unit], 
                                                                    sampling_rate, 
                                                                    self.timebin)
                recording.unit_spiketrains = unit_spiketrains    
    
    def get_unit_firing_rates(self):  
        """
        Calculates firing rates per unit per recording in collection,
        creates a dictionary and assigns it as recording.unit_firing_rates
        the keys are unit ids (int) and values are firing rates for the
        unit (numpy array) in timebin sized bins 
        calculated using smoothing_window for averaging
        
        Args:
            none
            
        Returns:
            none
        """
        for recording in self.collection.values():
            unit_firing_rates = {}
            for unit in recording.unit_spiketrains.keys():
                unit_firing_rates[unit] = get_firing_rate(recording.unit_spiketrains[unit],
                                                        self.smoothing_window, 
                                                        self.timebin)
            recording.unit_firing_rates = unit_firing_rates
    
    def get_event_snippets(self, recording, event, whole_recording, equalize, pre_window=0, post_window=0):
        """
        takes snippets of spiketrains or firing rates for events
        optional pre-event and post-event windows (s) may be included
        all events can also be of equal length by extending 
        snippet lengths to the longest event
    
        Args (5 total, 1 required): 
            whole_recording: numpy array, spiketrain or firing rates 
                for the whole recording, for population or for a single unit
            pre_window: int, default=0, seconds prior to start of event returned
            post_window: int, default=0, seconds after end of event returned
            equalize: {user_defined(float), 'max', 'average'}, equalizes lengths of events
                by padding with post event time or trimming event
                user_defined: float, makes all events user_defined (s) long   
                'max': makes all events as long as the longest event 
                'average': makes all events as long as the average event length 
            events:numpy array of [[start (ms), stop (ms)] x n events], 
                default=None in which case self.events is used
    
        Returns (1):
            event_snippets: a list of lists, where each list is a list of firing rates
                or spiketrains during an event including pre_window&post_windows, 
                accounting for equalize and timebins
        """
        #need to figure out how event lengths are going to work for max and average etc. 
        #or honestly maybe just get rid of the options 
        events = recording.event_dict[event]
        event_snippets = []
        pre_window = math.ceil(pre_window*1000)
        post_window = math.ceil(post_window*1000)
        for i in range(events.shape[0]):
            if equalize == 'max':
                event_diff = math.ceil(self.longest_event - self.event_lengths[i])
            if equalize == 'average':
                event_diff = math.ceil(self.mean_event_length - self.event_lengths[i])
            else:
                event_diff = math.ceil(equalize*1000 - self.event_lengths[i])
            pre_event = math.ceil((events[i][0] - pre_window)/self.timebin)
            post_event = math.ceil((events[i][1] + post_window + event_diff)/self.timebin)
            event_snippet = whole_recording[pre_event:post_event]
            event_snippets.append(event_snippet)
        return event_snippets
    
    def get_unit_event_firing_rates(self, recording, event, equalize, pre_window = 0, post_window = 0):
        """
        returns firing rates for events per unit
    
        Args (6 total, 1 required):
            smoothing_window: int, default=250, smoothing average window (ms)
                min smoothing_window = 1 
            timebin: int, default 1, timebin in ms for firing rate array
            pre_window: int, default=0, seconds prior to start of event returned
            post_window: int, default=0, seconds after end of event returned
            equalize: {user_defined(float), 'max', 'average'}, equalizes lengths of events
                by padding with post event time or trimming event
                user_defined: float, makes all events user_defined (s) long   
                'max': makes all events as long as the longest event 
                'average': makes all events as long as the average event length  
            events:numpy array of [[start (ms), stop (ms)] x n events], 
                default=None in which case self.events is used
            
        Return (1):
            unit_event_firing_rates: dict, keys are unit ids (???),
            values are lsts of numpy arrays of firing rates per event
        """
        unit_event_firing_rates = {}
        for unit in recording.unit_firing_rates.keys():
            unit_event_firing_rates[unit] = self.get_event_snippets(recording, recording.event, recording.unit_firing_rates[unit], equalize, pre_window, post_window)
        return unit_event_firing_rates
    
    def wilcox_baseline_v_event_stats(self, recording, event, baseline_window, equalize):
        """
        calculates wilcoxon signed-rank test for average firing rates of two windows: event vs baseline
        baseline used is an amount of time immediately prior to the event
        wilcoxon signed-rank test is applied to two sets of measurements:
        average firing rate per event, average firing rate per baseline.
        the resulting dataframe of wilcoxon stats and p values for every unit 
        is added to a dictionary of dataframes for that recording. 

        Key for this dictionary item is '{event} vs {baselinewindow}second baseline' 
        and the value is the dataframe. 
        
        Args (3 total, 1 required):
            baseline_window: int, length of baseline firing rate (s)
            max_event: int, default=None, max length of an event (s)
            equalize: {user_defined(float), 'max', 'average'}, equalizes lengths of events
                by padding with post event time or trimming event
                user_defined: float, makes all events user_defined (s) long   
                'max': makes all events as long as the longest event 
                'average': makes all events as long as the average event length
    
        Return (1):
            wilcoxon_df: pandas dataframe, columns are unit ids, 
            row[0] are wilcoxon statistics and row[1] are p values 
        
        """
        #this is another one where i gotta figure out/edit the equalize fxn
        preevent_baselines = np.array([pre_event_window(event, baseline_window) for event in recording.event_dict[event]])
        unit_preevent_firing_rates = self.get_unit_event_firing_rates(recording, preevent_baselines, baseline_window, 0, 0)
        unit_event_firing_rates = self.get_unit_event_firing_rates(recording, event, equalize, 0, 0)
        if equalize == 'average':
            recording.wilcox_xstop = self.mean_event_length
        if equalize == 'max':
            recording.wilcox_xstop = self.longest_event
        else:
            recording.wilcox_xstop = equalize*1000
        unit_averages = {}
        for unit in unit_event_firing_rates.keys():
            try:
                event_averages = [mean(event) for event in unit_event_firing_rates[unit]]
                preevent_averages = [mean(event) for event in unit_preevent_firing_rates[unit]]
                unit_averages[unit] = [event_averages, preevent_averages]
            except:
                print(f'Unit {unit} has {len(recording.unit_timestamps[unit])} spikes')
        wilcoxon_stats = {}
        for unit in unit_averages.keys(): 
            wilcoxon_stats[unit] = wilcoxon(unit_averages[unit][0], unit_averages[unit][1], method = 'approx')
        wilcoxon_df = pd.DataFrame.from_dict(wilcoxon_stats)
        wilcoxon_df.index = ['Wilcoxon Stat', 'p value']
        wilcox_key = f'{event} vs {baseline_window}second baseline'
        recording.wilcox_dfs[wilcox_key] = wilcoxon_df
        
        return wilcoxon_df

    # def fishers_exact_wilcox(self, baseline_window, equalize):
    #     sig_units = {}
    #     for event in self.event_dict.keys():
    #         wilcox_df = self.wilcox_baseline_v_event_stats(event, baseline_window, equalize) 
    #         sig_units[event] = (len(wilcox_df[(wilcox_df[1]<=0.05)]), len(wilcox_df[(wilcox_df[1]>.05)])) 
    #     fishers_df = pd.DataFrame(sig_units.values(), index=sig_units.keys(), columns=['Significant', 'Not Significant'])

    def wilcox_baseline_v_event_collection(self, event, baseline_window, equalize):  
        #and another for equalize 
        """
        Runs a wilcoxon signed rank test on all good units of 
        all recordings in the collection on the 
        given event's firing rate versus the given baseline window.
        Baseline window is the amount of time immediately prior to the event
        whose firing rate is being compared. 

        Creates a dataframe with rows for each unit and columns representing 
        Wilcoxon stats, p values, orginal unit ids, recording,
        subject and the event + baselien given. Dataframe is saved in the collections
        wilcox_dfs dictionary, key is '{event} vs {baseline_window}second baseline'

        Args(3 total):
            event: str, event firing rates for stats to be run on 
            baseline_window: float, time (s) prior to event for stats to be run on
            equalize: {user_defined(float), 'max', 'average'}, equalizes lengths of events
                by padding with post event time or trimming event
                user_defined: float, makes all events user_defined (s) long   
                'max': makes all events as long as the longest event 
                'average': makes all events as long as the average event length
        """
        is_first = True
        for recording_name, recording in self.collection.item():
            recording_df = self.wilcox_baseline_v_event_stats(self, recording, event, baseline_window, equalize)
            recording_df = recording_df.transpose().reset_index()
            recording_df = recording_df.rename(column={'index': 'original unit id'})
            recording_df['Recording'] = recording_name
            recording_df['Subject'] = recording.subject
            recording_df['Event'] = [event, baseline_window] 
            if is_first:
                master_df = recording_df
                is_first = False
            else:
                master_df = pd.concat([master_df, recording_df], axis=0).reset_index(drop=True)
        wilcox_key = f'{event} vs {baseline_window}second baseline'
        self.collection.wilcox_dfs[wilcox_key] = master_df
        return master_df



    def wilcox_baseline_v_event_plots(self, recording, wilcoxon_df, event, equalize, baseline, title, p_value=None, units=None):
        #ugh these plot functions are getting ugly 
        """
        plots event triggered average firing rates for units of a given recording. 
        optional filtering for p value threshold and unit ids. 

        Args(4 total, 2 required):
            recording: an EphysRecording instance
            title: str, title of figure
            p_value: int, default=None, all p values less than will be plotted
            units: lst, default=None, list of unit ids (ints) to be plotted

        Returns:
            none
        """ 
        units_to_plot = []
        if p_value is not None:
            for unit in wilcoxon_df.columns.tolist():
                if wilcoxon_df[unit][1] < p_value:
                      units_to_plot.append(unit)
        else:
            if units is None:
                units_to_plot = wilcoxon_df.columns.tolist()
            else:
                units_to_plot = units
        no_plots = len(units_to_plot)
        height_fig = math.ceil(no_plots/3)
        i = 1
        plt.figure(figsize=(20,4*height_fig))
        unit_event_firing_rates = self.get_unit_event_firing_rates(
            recording,
            event,
            equalize,
            baseline,
            0
            )
        x_stop = recording.wilcox_xstop
        for unit in units_to_plot:
            mean_arr = np.mean(unit_event_firing_rates[unit], axis=0)
            sem_arr = sem(unit_event_firing_rates[unit], axis=0)
            p_value = wilcoxon_df[unit][1]
            x = np.linspace(start=-baseline,stop=x_stop,num=len(mean_arr))
            plt.subplot(height_fig,3,i)
            plt.plot(x, mean_arr, c= 'b')
            plt.axvline(x=0, color='r', linestyle='--')
            plt.fill_between(x, mean_arr-sem_arr, mean_arr+sem_arr, alpha=0.2)
            plt.title(f'Unit {unit} Average (p={p_value})')
            i+=1
        plt.suptitle(title)
        plt.show()

    def wilcoxon_event_v_event_stats(self, recording, event1, event2, equalize):
        #another function that uses the equalize parameter 
        """
        calculates wilcoxon signed-rank test for average firing rates between
        two events for a given recording. the resulting dataframe of wilcoxon stats
        and p values for every unit is added to a dictionary of dataframes for that
        recording. 

        Key for this dictionary item is '{event1} vs {event2}' 
        and the value is the dataframe. 
        
        Args (4 total):
            recording: EphysRecording instance
            event1: str, first event type firing rates for stats to be run on
            event2: str, second event type firing rates for stats to be run on
            equalize: {user_defined(float), 'max', 'average'}, equalizes lengths of events
                by padding with post event time or trimming event
                user_defined: float, makes all events user_defined (s) long   
                'max': makes all events as long as the longest event 
                'average': makes all events as long as the average event length
    
        Return (1):
            wilcoxon_df: pandas dataframe, columns are unit ids, 
            row[0] are wilcoxon statistics and row[1] are p values 
        
        """
        unit_event1_firing_rates = self.get_unit_event_firing_rates(recording, event1, equalize, 0, 0)
        unit_event2_firing_rates = self.get_unit_event_firing_rates(recording, event2, equalize, 0, 0)
        unit_averages = {}
        for unit in unit_event1_firing_rates.keys():
            try:
                event1_averages = [mean(event) for event in unit_event1_firing_rates[unit]]
                event2_averages = [mean(event) for event in unit_event2_firing_rates[unit]]
                unit_averages[unit] = [event1_averages, event2_averages]
            except:
                print(f'Unit {unit} has {len(recording.unit_timestamps[unit])} spikes')
        wilcoxon_stats = {}
        for unit in unit_averages.keys(): 
            wilcoxon_stats[unit] = wilcoxon(unit_averages[unit][0], unit_averages[unit][1], method = 'approx')
        wilcoxon_df = pd.DataFrame.from_dict(wilcoxon_stats)
        wilcox_key = f'{event1 } vs {event2}'
        recording.wilcox_dfs[wilcox_key] = wilcoxon_df
        return wilcoxon_df

    def wilcox_event_v_event_collection(self, event1, event2, equalize):  
        """ 
        Runs a wilcoxon signed rank test on all good units of 
        all recordings in the collection on the 
        given event's firing rate versus another given event's firing rate.
    
        Creates a dataframe with rows for each unit and columns representing 
        Wilcoxon stats, p values, orginal unit ids, recording,
        subject and the events given.  Dataframe is saved in the collections
        wilcox_dfs dictionary, key is '{event1} vs {event2}' 

        Args(3 total):
            event: str, event firing rates for stats to be run on 
            baseline_window: float, time (s) prior to event for stats to be run on
            equalize: {user_defined(float), 'max', 'average'}, equalizes lengths of events
                by padding with post event time or trimming event
                user_defined: float, makes all events user_defined (s) long   
                'max': makes all events as long as the longest event 
                'average': makes all events as long as the average event length
        """
        is_first = True
        for recording_name, recording in self.collection.item():
            recording_df = self.wilcox_event_v_event_stats(self, recording, event1, event2, equalize)
            recording_df = recording_df.transpose().reset_index()
            recording_df = recording_df.rename(column={'index':'original unit id'})
            recording_df['Recording'] = recording_name
            recording_df['Subject'] = recording.subject
            recording_df['Event'] = [event1, event2]
            master_df = pd.concat([master_df, recording_df], axis=0).reset_index(drop=True)
            if is_first:
                master_df = recording_df
                is_first = False
            else:
                master_df = pd.concat([master_df, recording_df], axis=0).reset_index(drop=True)
        wilcox_key = f'{event1} vs {event2}'
        self.wilcox_dfs[wilcox_key] = master_df
        return master_df

    def get_zscore(self, recording, event, baseline_window, equalize):
        #have to deal with the equalize parameter 
        events = recording.event_dict[event]
        preevent_baselines = np.array([pre_event_window(event, baseline_window) for event in events])
        unit_event_firing_rates = self.get_unit_event_firing_rates(recording, baseline_window, 0, equalize)
        unit_preevent_firing_rates = self.get_unit_event_firing_rates(recording, 0,0,False,preevent_baselines)
        zscored_events = {}
        for unit in unit_event_firing_rates:
            #calculate average event across all events per unit
            event_average = np.mean(unit_event_firing_rates[unit], axis = 0)
            #one average for all preevents 
            preevent_average = np.mean(unit_preevent_firing_rates[unit], axis = 0)
            mew = np.mean(preevent_average)
            sigma = np.std(preevent_average)
            zscored_event = [(event_bin - mew)/sigma for event_bin in event_average]
            zscored_events[unit] = zscored_event
        recording.zscored_events[event] = zscored_events
        if equalize == 'average':
            zscore_xstop = recording.mean_event_length
        if equalize == 'max':
            zscore_xstop = recording.longest_event
        else:
            zscore_xstop = equalize*1000
        return zscored_events
        
    def get_zscore_collection(self, event, baseline_window, equalize):
        is_first = True 
        for recording_name, recording in self.collection.values():
            zscored_events = self.get_zscore(recording, event, baseline_window, equalize)
            zscored_events_df = pd.DataFrame.from_dict(zscored_events, orient='index')
            zscored_events_df = zscored_events_df.reset_index().rename(column={'index': 'original unit id'})
            zscored_events_df['Recording'] = recording_name
            zscored_events_df['Subject'] = recording.subject
            zscored_events_df['Event'] = [event, baseline_window]
            if is_first:
                master_df = zscored_events_df
                is_first = False    
            else:
                master_df = pd.concat([master_df, zscored_events_df], axis=0).reset_index(drop=True)
        zscore_key = f'{event} vs {baseline_window}second baseline'
        self.collection.zscored_events[zscore_key] = master_df
        
    def get_zcore_plot(self, recording, event, zscore_xstop, baseline_window, max_event, title):
        plt.figure(figsize=(20,6))
        zscored_unit_event_firing_rates = recording.zscored_events[event]
        zscore_pop = np.array(list(zscored_unit_event_firing_rates.values()))
        mean_arr = np.mean(zscore_pop, axis=0)
        sem_arr = sem(zscore_pop, axis=0)
        x = np.linspace(start=-baseline_window,stop=zscore_xstop,num=len(mean_arr))
        plt.subplot(1,2,1)
        plt.plot(x, mean_arr, c= 'b')
        plt.axvline(x=0, color='r', linestyle='--')
        plt.fill_between(x, mean_arr-sem_arr, mean_arr+sem_arr, alpha=0.2)
        plt.title(f'Population z-score {event}')
        plt.subplot(1,2,2)
        for unit in zscored_unit_event_firing_rates.keys():
            plt.plot(x, zscored_unit_event_firing_rates[unit], linewidth = .5)
            plt.axvline(x=0, color='r', linestyle='--')
            plt.title(f'Unit z-score {event} event')
        plt.suptitle(f'{title} Z-scored average {event} event')
        plt.show()        

    def PCA_trajectories(self, events, equalize, n_components=2, pre_window = 0, post_window = 0):
        first_event = True
        first_recording = True
        for recording_name, recording in self.collection.items():
            no_units = len(list(recording.unit_firing_rates.keys()))
            unit_key = [recording_name] * len(no_units)
            for event in events:
                unit_event_firing_rates = self.get_unit_event_firing_rates(recording, event, equalize, pre_window, post_window)
                
                unit_event_average = get_unit_average_events(unit_event_firing_rates) 
                if first_event:
                    PCA_matrix = [value for sublist in unit_event_average.values() for value in sublist]
                    PCA_key = [event] * len(PCA_matrix)
                    first_event = False
                else:
                    next_event = [value for sublist in unit_event_average.values() for value in sublist]
                    PCA_matrix = np.concatenate([PCA_matrix, next_event], axis=0)
                    next_event_key = [event] * len(next_event)
                    PCA_key = PCA_key + next_event_key
            if first_recording:
                master_key = unit_key
                first_recording = False
            else:
                master_key = master_key + unit_key
        pca = PCA(n_components)
        transformed_matrix = pca.fit_transform(PCA_matrix)
        PCA_df = pd.DataFrame({'data': transformed_matrix, 'index': master_key, 'columns': PCA_key})
        self.collection.PCA_df = PCA_df 
            




In [9]:
pilot = EphysRecordingCollection("C://Users//megha//Documents//GitHub//diff_fam_social_memory_ephys//proc")

In [4]:
import pandas as pd

recording_name = 'recording'
recording2_name = 'recording2'
recording = {'1': [1.3, 2.3], '2': [3.4, 5]}
recording2 = {'1': [6, 0], '2': [1, 2]}
recording_df = pd.DataFrame(recording)
recording2_df = pd.DataFrame(recording2)
recording2_df

Unnamed: 0,1,2
0,6,1
1,0,2


In [5]:
recording2_df = recording2_df.transpose().reset_index()
recording_df = recording_df.transpose().reset_index()
recording_df['Recording'] = recording_name
recording2_df['Recording'] = recording2_name
recording_df

Unnamed: 0,index,0,1,Recording
0,1,1.3,2.3,recording
1,2,3.4,5.0,recording


In [6]:
recording2_df = recording2_df.rename(columns={'index': 'original unit id'})
recording_df = recording_df.rename(columns={'index': 'original unit id'})
recording2_df

Unnamed: 0,original unit id,0,1,Recording
0,1,6,0,recording2
1,2,1,2,recording2


In [7]:
master_df = pd.concat([recording_df, recording2_df], axis = 0).reset_index(drop=True)
master_df

wilcox_dict = {recording_name: recording_df, recording2_name: recording2_df}

In [12]:
recording_name + ' vs ' + recording2_name

'recording vs recording2'