In [None]:




class MultiEvents:
    """
    A class for an event triggered average object that needs
    an EphysRecording class instance 
    and an event array [[start (ms), stop(ms)]..]
    

    Attributes:
        event: str, name of the event 
        events: numpy array of [[start (ms), stop (ms)] x n events]
        smoothing_window: int, default=250, window length in ms used to calculate firing rates
        timebin: int, default=1, bin size (in ms) for spike train and firing rate arrays
        ingore_freq: int, default=0, frequency in Hz that a good unit needs to fire at to be included in analysis
        longest_event: int, length of longest event (ms)
        event_lengths: lst, length of all events (ms)
        spiketrain: numpy array, each element of the array 
            is the number of spikes per timebin throughout the whole recording
        unit_spiketrains: dict, keys are unit ids (int), values (numpy arrays) are each "good"
            units spiketrains in the specified timebins for the whole recording
        unit_firing_rates: dict, keys are unit ids (int), values (numpy array) are each "good"
            units firing rates calculated using smoothing_window in bins of size timebin

    Methods: 
        get_whole_spiketrain: 
        get_unit_spiketrains: 
        get_unit_firing_rates: 
        get_event_snippets:
        get_unit_event_firing_rates:
        wilcox_baseline_v_event_stats:
        wilcox_baseline_v_event_plots:
    """
    def __init__(self, event_dict, recording, smoothing_window=250, timebin=1, ignore_freq=0.01):
        
        self.recording = recording
        self.event_dict = event_dict
        self.events = event_dict.keys()
        self.events = [value for sublist in event_dict.values() for value in sublist]
        self.smoothing_window = smoothing_window
        self.timebin = timebin
        self.ignore_freq = ignore_freq
        self.longest_event, self.event_lengths, self.mean_event_length = get_event_lengths(self.events)
        self.get_whole_spiketrain()
        self.get_unit_spiketrains()
        self.get_unit_firing_rates()

    
    def get_whole_spiketrain(self):
        """
        creates a spiketrain of ms time bins 
        each array element is the number of spikes recorded per ms
        
        Args (1 total):
            timestamp_array: numpy array, spike timestamp array
            
        Returns (1):
            spiketrain_ms_timebins: a numpy array 
                array elements are number of spikes per ms 
        """
        self.spiketrain = get_spiketrain(self.recording.timestamps_var, self.recording.sampling_rate, self.timebin)

    
    def get_unit_spiketrains(self):  
        """
        Creates a dictionary and assigns it as self.unit_spiketrains
        where keys are 'good' unit ids (int) (not 'mua') that reach
        a threhold frequency, values are numpy arrays of 
        spiketrains in timebin sized bins
        
        Args:
            None
            
        Reutrns:
            None
            
        """
        unit_spiketrains = {}
        for unit in self.recording.unit_timestamps.keys():
            if self.recording.labels_dict[str(unit)] == 'good':
                no_spikes = len(self.recording.unit_timestamps[unit])
                unit_freq = no_spikes/self.recording.timestamps_var[-1]*self.recording.sampling_rate
                if unit_freq > self.ignore_freq:
                    unit_spiketrains[unit] = get_spiketrain(self.recording.unit_timestamps[unit], 
                                                            self.recording.sampling_rate, self.timebin)
        self.unit_spiketrains = unit_spiketrains    

    
    def get_unit_firing_rates(self):  
        """
        Calculates firing rates per unit,
        creates a dictionary and assigns it as self.unit_firing_rates
        the keys are unit ids (int) and values are firing rates for the
        unit (numpy array) in timebin sized bins 
        calculated using smoothing_window for averaging
        
        Args:
            none
            
        Returns:
            none
        """
        unit_firing_rates = {}
        for unit in self.unit_spiketrains.keys():
            unit_firing_rates[unit] = get_firing_rate(self.unit_spiketrains[unit], self.smoothing_window, self.timebin)
        self.unit_firing_rates = unit_firing_rates

    
    def get_event_snippets(self, event, whole_recording, equalize, pre_window=0, post_window=0, ):
        """
        takes snippets of spiketrains or firing rates for events
        optional pre-event and post-event windows (s) may be included
        all events can also be of equal length by extending 
        snippet lengths to the longest event
    
        Args (5 total, 1 required): 
            whole_recording: numpy array, spiketrain or firing rates 
                for the whole recording, for population or for a single unit
            pre_window: int, default=0, seconds prior to start of event returned
            post_window: int, default=0, seconds after end of event returned
            equalize: {user_defined, 'max', 'average'}, equalizes lengths of events
                by padding with post event time or trimming event
                user_defined: float, makes all events user_defined (s) long   
                'max': makes all events as long as the longest event 
                'average': makes all events as long as the average event length 
            events:numpy array of [[start (ms), stop (ms)] x n events], 
                default=None in which case self.events is used
    
        Returns (1):
            event_snippets: a list of lists, where each list is a list of firing rates
                or spiketrains during an event including pre_window&post_windows, 
                accounting for equalize and timebins
        """
        
        if event in self.event_dict.keys():
            events = self.event_dict[event]
        event_snippets = []
        pre_window = math.ceil(pre_window*1000)
        post_window = math.ceil(post_window*1000)
        for i in range(events.shape[0]):
            if equalize == 'max':
                event_diff = math.ceil(self.longest_event - self.event_lengths[i])
            if equalize == 'average':
                event_diff = math.ceil(self.mean_event_length - self.event_lengths[i])
            else:
                event_diff = math.ceil(equalize*1000 - self.event_lengths[i])
            pre_event = math.ceil((events[i][0] - pre_window)/self.timebin)
            post_event = math.ceil((events[i][1] + post_window + event_diff)/self.timebin)
            event_snippet = whole_recording[pre_event:post_event]
            event_snippets.append(event_snippet)
        return event_snippets
    
    def get_unit_event_firing_rates(self, event, equalize, pre_window = 0, post_window = 0):
        """
        returns firing rates for events per unit
    
        Args (6 total, 1 required):
            smoothing_window: int, default=250, smoothing average window (ms)
                min smoothing_window = 1 
            timebin: int, default 1, timebin in ms for firing rate array
            pre_window: int, default=0, seconds prior to start of event returned
            post_window: int, default=0, seconds after end of event returned
            equalize: {'max', average'}, default=False, equalizes lengths of events
                by padding with post event time or trimming event
                'max': makes all events as long as the longest event 
                'average': makes all events as long as the average event length 
            events:numpy array of [[start (ms), stop (ms)] x n events], 
                default=None in which case self.events is used
            
        Return (1):
            unit_event_firing_rates: dict, keys are unit ids (???),
            values are lsts of numpy arrays of firing rates per event
        """
        unit_event_firing_rates = {}
        for unit in self.unit_spiketrains.keys():
            unit_event_firing_rates[unit] = self.get_event_snippets(event, self.unit_firing_rates[unit], pre_window, post_window, equalize, events)
        return unit_event_firing_rates

    def wilcox_baseline_v_event_stats(self, event, baseline_window, equalize):
        #what if i wanted a random snippet from the first ten minutes instead of prior to the event?
        """
        calculates wilcoxon signed-rank test for average firing rates of two windows: event vs baseline
        baseline used is an amount of time immediately prior to the event
        wilcoxon signed-rank test is applied to two sets of measurements:
        average firing rate per event, average firing rate per baseline
        
        Args (3 total, 1 required):
            baseline_window: int, length of baseline firing rate (s)
            max_event: int, default=None, max length of an event (s)
            equalize: Boolean, default=False, if True, equalizes lengths of each event to longest event
    
        Return (1):
            wilcoxon_df: pandas dataframe, columns are unit ids, 
            row[0] are wilcoxon statistics and row[1] are p values 
        
        """
        preevent_baselines = np.array([pre_event_window(event, baseline_window) for event in self.event_dict[event]])
        unit_preevent_firing_rates = self.get_unit_event_firing_rates(preevent_baselines, baseline_window, 0, 0)
        unit_event_firing_rates = self.get_unit_event_firing_rates(equalize,0,0)
        if equalize == 'average':
            self.wilcox_xstop = self.mean_event_length
        if equalize == 'max':
            self.wilcox_xstop = self.longest_event
        else:
            self.wilcox_xstop = equalize*1000
        unit_averages = {}
        for unit in unit_event_firing_rates.keys():
            try:
                event_averages = [mean(event) for event in unit_event_firing_rates[unit]]
                preevent_averages = [mean(event) for event in unit_preevent_firing_rates[unit]]
                unit_averages[unit] = [event_averages, preevent_averages]
            except:
                print(f'Unit {unit} has {len(self.recording.unit_timestamps[unit])} spikes')
        wilcoxon_stats = {}
        for unit in unit_averages.keys(): 
            wilcoxon_stats[unit] = wilcoxon(unit_averages[unit][0], unit_averages[unit][1], method = 'approx')
        wilcoxon_df = pd.DataFrame.from_dict(wilcoxon_stats)
        wilcoxon_df.index = ['Wilcoxon Stat', 'p value']
        self.wilcox_baseline = baseline_window
        return wilcoxon_df

    def fishers_exact_wilcox(self, baseline_window, equalize):
        sig_units = {}
        for event in self.event_dict.keys():
            wilcox_df = self.wilcox_baseline_v_event_stats(event, baseline_window, equalize) 
            sig_units[event] = (len(wilcox_df[(wilcox_df[1]<=0.05)]), len(wilcox_df[(wilcox_df[1]>.05)])) 
        fishers_df = pd.DataFrame(sig_units.values(), index=sig_units.keys(), columns=['Significant', 'Not Significant'])

                        

    def wilcox_baseline_v_event_plots(self, title, p_value=None, units=None):
        """
        plots event triggered average firing rates for units
        all events need to be the same length

        Args(3 total, 1 required):
            title: str, title of figure
            p_value: int, default=None, all p values less than will be plotted
            units: lst, default=None, list of unit ids (ints) to be plotted

        Returns:
            none
        """ 
        units_to_plot = []
        if p_value is not None:
            for unit in self.wilcoxon_df.columns.tolist():
                if self.wilcoxon_df[unit][1] < p_value:
                      units_to_plot.append(unit)
        else:
            if units is None:
                units_to_plot = self.wilcoxon_df.columns.tolist()
            else:
                units_to_plot = units
        no_plots = len(units_to_plot)
        height_fig = math.ceil(no_plots/3)
        i = 1
        plt.figure(figsize=(20,4*height_fig))
        unit_event_firing_rates = self.get_unit_event_firing_rates(self.wilcox_baseline, 0, True)
        x_stop = self.wilcox_xstop
        for unit in units_to_plot:
            mean_arr = np.mean(unit_event_firing_rates[unit], axis=0)
            sem_arr = sem(unit_event_firing_rates[unit], axis=0)
            p_value = self.wilcoxon_df[unit][1]
            x = np.linspace(start=-self.wilcox_baseline,stop=x_stop,num=len(mean_arr))
            plt.subplot(height_fig,3,i)
            plt.plot(x, mean_arr, c= 'b')
            plt.axvline(x=0, color='r', linestyle='--')
            plt.fill_between(x, mean_arr-sem_arr, mean_arr+sem_arr, alpha=0.2)
            plt.title(f'Unit {unit} Average (p={p_value})')
            i+=1
        plt.suptitle(title)
        plt.show()

    def wilcoxon_average_firingrates(self, event1, event2, equalize):
        #what if i wanted a random snippet from the first ten minutes instead of prior to the event?
        """
        calculates wilcoxon signed-rank test for average firing rates of two windows: event vs baseline
        baseline used is an amount of time immediately prior to the event
        wilcoxon signed-rank test is applied to two sets of measurements:
        average firing rate per event, average firing rate per baseline
        
        Args (3 total, 1 required):
            baseline_window: int, length of baseline firing rate (s)
            max_event: int, default=None, max length of an event (s)
            equalize: {'max', average'}, default=False, equalizes lengths of events
                by padding with post event time or trimming event
                'max': makes all events as long as the longest event 
                'average': makes all events as long as the average event length 
    
        Return (1):
            wilcoxon_df: pandas dataframe, columns are unit ids, 
            row[0] are wilcoxon statistics and row[1] are p values 
        
        """
        unit_event1_firing_rates = self.get_unit_event_firing_rates(event1, equalize, 0, 0)
        unit_event2_firing_rates = self.get_unit_event_firing_rates(event2, equalize, 0, 0)
        unit_averages = {}
        for unit in unit_event1_firing_rates.keys():
            try:
                event1_averages = [mean(event) for event in unit_event1_firing_rates[unit]]
                event2_averages = [mean(event) for event in unit_event2_firing_rates[unit]]
                unit_averages[unit] = [event1_averages, event2_averages]
            except:
                print(f'Unit {unit} has {len(self.recording.unit_timestamps[unit])} spikes')
        wilcoxon_stats = {}
        for unit in unit_averages.keys(): 
            wilcoxon_stats[unit] = wilcoxon(unit_averages[unit][0], unit_averages[unit][1], method = 'approx')
        wilcoxon_df = pd.DataFrame.from_dict(wilcoxon_stats)
        self.wilcoxon_df = wilcoxon_df

    def get_zscore(self, event, baseline_window, equalize):
        #nancy had a matrix of (neuron, timebin, trial)
        event = self.event_dict[event]
        preevent_baselines = np.array([pre_event_window(event, baseline_window) for event in self.event])
        unit_event_firing_rates = self.get_unit_event_firing_rates( baseline_window, 0, equalize)
        unit_preevent_firing_rates = self.get_unit_event_firing_rates(0,0,False,preevent_baselines)
        zscored_events = {}
        for unit in unit_event_firing_rates:
            #calculate average event across all events per unit
            event_average = np.mean(unit_event_firing_rates[unit], axis = 0)
            #one average for all preevents 
            preevent_average = np.mean(unit_preevent_firing_rates[unit], axis = 0)
            mew = np.mean(preevent_average)
            sigma = np.std(preevent_average)
            zscored_event = [(event_bin - mew)/sigma for event_bin in event_average]
            zscored_events[unit] = zscored_event
        self.zscored_events = zscored_events
        self.zscore_baseline = baseline_window
        if equalize == 'average':
            self.zscore_xstop = self.mean_event_length
        if equalize == 'max':
            self.zscore_xstop = self.longest_event
        else:
            self.zscore_xstop = equalize*1000
        
    def get_zcore_plot(self, max_event, title):
        plt.figure(figsize=(20,6))
        baseline_window = self.zscore_baseline
        zscored_unit_event_firing_rates = self.zscored_events
        zscore_pop = np.array(list(zscored_unit_event_firing_rates.values()))
        mean_arr = np.mean(zscore_pop, axis=0)
        sem_arr = sem(zscore_pop, axis=0)
        x = np.linspace(start=-baseline_window,stop=self.zscore_xstop,num=len(mean_arr))
        plt.subplot(1,2,1)
        plt.plot(x, mean_arr, c= 'b')
        plt.axvline(x=0, color='r', linestyle='--')
        plt.fill_between(x, mean_arr-sem_arr, mean_arr+sem_arr, alpha=0.2)
        plt.title(f'Population z-score {self.event} event')
        plt.subplot(1,2,2)
        for unit in zscored_unit_event_firing_rates.keys():
            plt.plot(x, zscored_unit_event_firing_rates[unit], linewidth = .5)
            plt.axvline(x=0, color='r', linestyle='--')
            plt.title(f'Unit z-score {self.event} event')
        plt.suptitle(f'{title} Z-scored average {self.event} event')
        plt.show()        

    def PCA_trajectories(self, pre_window = 0, post_window = 0, equalize = 'average'):
        first_event = True
        for event in self.event_dict.keys():
            unit_event_firing_rates = self.get_unit_event_firing_rates(self, event, pre_window, post_window, equalize)
            unit_event_average = get_unit_average_events(unit_event_firing_rates) 
            if first_event:
                PCA_matrix = [value for sublist in unit_event_average.values() for value in sublist]
                PCA_key = [event] * len(PCA_matrix)
                first_event = False
            else:
                next_event = [value for sublist in unit_event_average.values() for value in sublist]
                PCA_matrix = np.concatenate([PCA_matrix, next_event], axis=0)
                next_event_key = [event] * len(next_event)
                PCA_key = PCA_key + next_event_key
        pca = PCA(n_components = 2)
        transformed_matrix = pca.fit_transform(PCA_matrix)
        self.PCA_trajectories = transformed_matrix
        self.PCA_key = PCA_key

        