# PCA & Clustering

The goal of this code is to input activity time-series data from a neural recording and cluster the cells/rois (samples) based on the neural activity (features). Clustering is performed on trial-averaged event-related responses; data from different trial conditions are concatenated and fed into dimensionality reduction (PCA) and finally into multiple clustering algorithms. The optimal hyperparameters for PCA and clustering methods are automatically determined based on the best silhouette score. 

1) PCA to reduce dimensionality of trial-averaged event-related responses (rois x time), with respect to time dimension. Intuitive concept: PCA is performed on the time dimension (each time point is treated as a feature/variable. That means the resulting principal components (PCs) are linear combinations of the original time points. The first PCs represent each ROI's datapoint resides in n dimensional space where n is the number of samples in the event-related window. PCA finds new set of (orthogonal) axes that maximizes the variance in the activity. These new axes are linear combinations of the original axes
 

2) Clustering: The roi data are now characterized by a reduced set of optimized axes describing time. We now cluster using either kMeans clustering or spectral clustering.
    
    1. KMeans clustering: Assuming data clouds are gaussian. The three main steps of kMeans clustering are **A)** Initialize the K value, **B)** Calculate the distance between test input and K trained nearest neighbors, **C)** Return class category by taking the majority of votes
    
    2. Spectral clustering: Not assuming any particular shape of the cluster data points. The three main steps of spectral clustering are **A)** create graph theory similarity matrix for each ROI based on how close other ROIs are in the PCA space, **B)** perform eigendecomposition of the similarity matrix, **C)** Use kmeans clustering on the transformed data. 

Prerequisites
------------------------------------

All data should reside in a parent folder. This folder's name should be the name of the session and ideally be the same as the base name of the recording file.

Data need to be run through the NAPECA event_rel_analysis code in order to generate the event_data_dict.pkl file, which contains event-related activity across different behavioral conditions for all neurons/ROIs.


How to run this code
------------------------------------

In this jupyter notebook, just run all cells in order (shift + enter).

__You can indicate specific files and parameters to include in the second cell__

Required Packages
-----------------
Python 3.7, seaborn, matplotlib, pandas, scikit-learn, statsmodels, numpy, h5py

Custom code requirements: utils

Parameters 
----------

fname_signal : string
    
    Name of file that contains roi activity traces. Must include full file name with extension. Accepted file types: .npy, .csv. IMPORTANT: data dimensions should be rois (y) by samples/time (x)

fname_events : string

    Name of file that contains event occurrences. Must include full file name with extension. Accepted file types: .pkl, .csv. Pickle (pkl) files need to contain a dictionary where keys are the condition names and the values are lists containing samples/frames for each corresponding event. Csv's should have two columns (event condition, sample). The first row are the column names. Subsequent rows contain each trial's event condition and sample in tidy format. See example in sample_data folder for formatting, or this link: https://github.com/zhounapeuw/NAPE_imaging_postprocess/raw/main/docs/_images/napeca_post_event_csv_format.png

fdir : string 

    Root file directory containing the raw tif, tiff, h5 files. Note: leave off the last backslash. For example: ../napeca_post/sample_data if clone the repo directly

trial_start_end : list of two entries  

    Entries can be ints or floats. The first entry is the time in seconds relative to the event/ttl onset for the start of the event analysis window (negative if before the event/ttl onset. The second entry is the time in seconds for the end of the event analysis window. For example if the desired analysis window is 5.5 seconds before event onset and 8 seconds after, `trial_start_end` would be [-5.5, 8].  
    
baseline_end : int/float  

    Time in seconds for the end of the baseline epoch. By default, the baseline epoch start time will be the first entry ot `trial_start_end`. This baseline epoch is used for calculating baseline normalization metrics.

event_sort_analysis_win : list with two float entries

    Time window [a, b] in seconds during which some visualization calculations will apply to. For example, if the user sets flag_sort_rois to be True, ROIs in heatmaps will be sorted based on the mean activity in the time window between a and b. 

pca_num_pc_method : 0 or 1 (int)

    Method for calculating number of principal components to retain from PCA preprocessing. 0 for bend in scree plot, 1 for num PCs that account for 90% variance.
    User should try either method and observe which result fits the experiment. Sometimes may not impact the results.

flag_save_figs : boolean  

    Set as True to save figures as JPG and vectorized formats. 

selected_conditions : list of strings

    Specific conditions that the user wants to analyze; needs to be exactly the name of conditions in the events CSV or pickle file

flag_plot_reward_line: boolean  

    If set to True, plot a vertical line for secondary event. Time of vertical line is dictated by the variable second_event_seconds
    
second_event_seconds: int/float
    
    Time in seconds (relative to primary event onset) for plotting a vertical dotted line indicating an optional second event occurrence 

max_n_clusters : integer
    
    Maximum number of clusters expected for clustering models. As general rule, select this number based on maximum expected number of clusters in the data + ~5. Keep in mind that larger values will increase processing time
    
possible_n_nearest_neighbors : array of integers
    
    In spectral clustering, set n_neighbors to n from the range of possible_n_nearest_neighbors for each data point and create connectivity graph (affinity matrix).

In [None]:
import pickle
import math
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC, SVR, LinearSVC
from sklearn.metrics import accuracy_score, silhouette_score, adjusted_rand_score, silhouette_samples
from sklearn.cluster import AgglomerativeClustering, SpectralClustering, KMeans
from sklearn.model_selection import KFold, LeaveOneOut, train_test_split
from sklearn.model_selection import GridSearchCV
from sklearn.kernel_ridge import KernelRidge
from sklearn import linear_model
from sklearn.manifold import TSNE
import scipy.stats as stats
import statsmodels.api as sm
import statsmodels.formula.api as smf
from patsy import (ModelDesc, EvalEnvironment, Term, EvalFactor, LookupFactor, dmatrices, INTERCEPT)
from statsmodels.distributions.empirical_distribution import ECDF
import matplotlib.cm as cm
import matplotlib.colors as colors
import matplotlib.colorbar as colorbar
import sys
import os
import re
import glob
import seaborn as sns
import matplotlib.pyplot as plt
#important for text to be detected when importing saved figures into illustrator
import matplotlib
matplotlib.rcParams['pdf.fonttype']=42
matplotlib.rcParams['ps.fonttype']=42
plt.rcParams["font.family"] = "Arial"

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning) 

import numpy as np
import pandas as pd
import json
import utils

In [None]:
"""
USER-DEFINED VARIABLES
"""


def define_params(method = 'single'):
    
    fparams = {}
    
    if method == 'single':
        
        fparams['fname_signal'] = 'VJ_OFCVTA_7_260_D6_neuropil_corrected_signals_15_50_beta_0.8.npy' # name of your npy or csv file that contains activity signals
        fparams['fname_events'] = 'event_times_VJ_OFCVTA_7_260_D6_trained.csv' # name of your pickle or csv file that contains behavioral event times (in seconds)
        # fdir signifies to the root path of the data. Currently, the abspath phrase points to sample data from the repo.
        # To specify a path that is on your local computer, use this string format: r'your_root_path', where you should copy/paste
        # your path between the single quotes (important to keep the r to render as a complete raw string). See example below:
        # r'C:\Users\stuberadmin\Documents\GitHub\NAPE_imaging_postprocess\napeca_post\sample_data' 
        fparams['fdir'] = os.path.abspath("../napeca_post/sample_data/VJ_OFCVTA_7_260_D6") # for an explicit path, eg. r'C:\2pData\Vijay data\VJ_OFCVTA_7_D8_trained' 
        fparams['fname'] = os.path.split(fparams['fdir'])[1]
        fparams['flag_close_figs_after_save'] = True
        fparams['flag_save_figs'] = True

        fparams['fs'] = 5 # sampling rate of activity data

        # trial extraction info
        fparams['trial_start_end'] = [-2, 8] # trial [start, end] times (in seconds); centered on event onset
        fparams['baseline_end'] = -0.2 # baseline epoch end time (in seconds) for performing baseline normalization; I set this to -0.2 to be safe I'm not grabbing a sample that includes the event
        fparams['event_sort_analysis_win'] = [0, 5] # time window (in seconds)
        fparams['flag_normalization'] = 'zscore' # z-score highly recommended as machine learning algorithms prefer normalized data
        
        # set to 0 to determine number of PCs to where the scree plot bends, 1 for num PCs that account for 90% variance
        fparams['pca_num_pc_method'] = 0 

        # variables for clustering
        fparams['max_n_clusters'] = 10 # from Vijay: Maximum number of clusters expected. This should be based on the number of functional neuron groups you expect + ~3. In your data, 
        # might be worth increasing this, but it will take more time to run.
        '''In spectral clustering: get n nearest neighbors for each data point and create connectivity graph (affinity matrix)'''
        fparams['possible_n_nearest_neighbors'] = np.arange(1, 10) #np.array([3,5,10]) # This should be selected for each dataset
        # appropriately. When 4813 neurons are present, the above number of nearest neighbors ([30,40,30,50,60]) provides a good sweep of the
        # parameter space. But it will need to be changed for other data.

        # optional arguments
        fparams['selected_conditions'] = None # set to a list of strings if you want to filter specific conditions to analyze; eg. ['plus', 'minus']
        fparams['flag_plot_reward_line'] = False # if there's a second event that happens after the main event, it can be indicated if set to True; timing is dictated by the next variables below
        fparams['second_event_seconds'] = 1 # time in seconds
        fparams['flag_save_figs'] = False # set to true if you want to save plots
        fparams['load_savedpca_or_dopca'] = 'dopca' # Select 'dopca' for doing PCA on the data. Select 'savedpca' for loading my previous results

        fparams['group_data'] = False
        fparams['group_data_conditions'] = ['cs_plus', 'cs_minus']
                
    return fparams

fparams = define_params(method = 'single') # options are 'single', 'f2a', 'root_dir'




In [None]:
def tvec_to_dict(tvec):
    """
    Used for relabeling pandas df columns to make lineplot from df
    """
    dict_ = {}
    for i in range(len(tvec)):
        dict_[i] = tvec[i] 
    return dict_

In [None]:

def set_cluster_colors(dict_vars):
    dict_vars['colors_for_cluster'] = [[0.933, 0.250, 0.211],
                                              [0.941, 0.352, 0.156],
                                              [0.964, 0.572, 0.117],
                                              [0.980, 0.686, 0.250],
                                              [0.545, 0.772, 0.247],
                                              [0.215, 0.701, 0.290],
                                              [0, 0.576, 0.270],
                                              [0, 0.650, 0.611],
                                              [0.145, 0.662, 0.878]]
    return dict_vars

def normalization_vars(dict_vars, fparams):
    if 'zscore' in fparams['flag_normalization']:
        dict_vars['data_trial_resolved_key'] = 'zdata'
        dict_vars['data_trial_avg_key'] = 'ztrial_avg_data'
        dict_vars['cmap_'] = None
        dict_vars['ylabel'] = 'Z-score Activity'
    else:
        dict_vars['data_trial_resolved_key'] = 'data'
        dict_vars['data_trial_avg_key'] = 'trial_avg_data'
        dict_vars['cmap_'] = 'inferno'
        dict_vars['ylabel'] = 'Activity'
    
    return dict_vars

def define_paths(dict_vars, fparams):
    dict_vars['signals_fpath'] = os.path.join(fparams['fdir'], fparams['fname_signal'])
    dict_vars['events_file_path'] = os.path.join(fparams['fdir'], fparams['fname_events'])

    dict_vars['save_dir']= os.path.join(fparams['fdir'], 'event_rel_analysis')
    utils.check_exist_dir(dict_vars['save_dir']); # make the save directory

    return dict_vars

def make_timing_info(dict_vars, fparams):
    
    # convert times to samples 
    dict_vars['trial_start_end_sec'] = np.array(fparams['trial_start_end']) # trial windowing in seconds relative to ttl-onset/trial-onset, in seconds
    dict_vars['trial_begEnd_samp'] = dict_vars['trial_start_end_sec']*fparams['fs'] # turn trial start/end times to samples
    # get sample vector for the trial 
    trial_svec = np.arange(dict_vars['trial_begEnd_samp'][0], dict_vars['trial_begEnd_samp'][1])
    # calculate time vector for plot x axes
    num_samples_trial = len( trial_svec )
    dict_vars['tvec'] = np.round(np.linspace(dict_vars['trial_start_end_sec'][0], dict_vars['trial_start_end_sec'][1], num_samples_trial+1), 2)
    
    baseline_start_end_sec = np.array([dict_vars['trial_start_end_sec'][0], fparams['baseline_end']])
    dict_vars['baseline_begEnd_samp'] = baseline_start_end_sec*fparams['fs']
    dict_vars['baseline_svec'] = (np.arange(dict_vars['baseline_begEnd_samp'][0], dict_vars['baseline_begEnd_samp'][1] + 1, 1) -
                            dict_vars['baseline_begEnd_samp'][0]).astype('int')
    

    return dict_vars

def load_behav(dict_vars, fparams):
    ### load behavioral data and trial info
    # requires define_paths method to be run first 

    glob_event_files = glob.glob(dict_vars['events_file_path']) # look for a file in specified directory
    if not glob_event_files:
        print('{} not detected. Please check if path is correct.'.format(dict_vars['events_file_path']))
    if 'csv' in glob_event_files[0]:
        event_times = utils.df_to_dict(glob_event_files[0])
    elif 'pkl' in glob_event_files[0]:
        event_times = pickle.load( open( glob_event_files[0], "rb" ), fix_imports=True, encoding='latin1' ) # latin1 b/c original pickle made in python 2
    dict_vars['event_frames'] = utils.dict_samples_to_time(event_times, fparams['fs'])

    # identify conditions to analyze
    all_conditions = dict_vars['event_frames'].keys()
    conditions = [ condition for condition in all_conditions if len(dict_vars['event_frames'][condition]) > 0 ] # keep conditions that have events

    conditions.sort()
    if fparams['selected_conditions']:
        conditions = fparams['selected_conditions']
    dict_vars['conditions'] = conditions
    #dict_vars['cmap_lines'] = get_cmap(len(conditions)) # colors for plotting lines for each condition

    return dict_vars


def format_trial_data_for_ml(dict_vars, signals):
    """
    MAIN data processing function to extract event-centered data

    extract and save trial data, 
    saved data are in the event_rel_analysis subfolder, a pickle file that contains the extracted trial data
    """
    data_dict = utils.extract_trial_data(dict_vars, signals, dict_vars['event_frames'], 
                                         dict_vars['conditions'], save_dir=dict_vars['save_dir'])
    
    #### concatenate data across trial conditions

    # concatenates data across trials in the time axis; populationdata dimentionss are ROI by time (trials are appended)
    formatted_data = np.concatenate([data_dict[condition][dict_vars['data_trial_avg_key']] 
                                     for condition in dict_vars['conditions']], axis=1)

    # remove rows with nan values
    nan_rows = np.unique(np.where(np.isnan(formatted_data))[0])
    if nan_rows.size != 0:
        formatted_data = np.delete(formatted_data, obj=nan_rows, axis=0)
        print('Some ROIs contain nans in tseries!')
        
    return formatted_data


def sort_rois (dict_vars, fparams, data):
    # sortresponse corresponds to an ordering of the neurons based on their average response in the sortwindow
    dict_vars['window_size'] = int(data.shape[1]/len(dict_vars['conditions'])) # Total number of frames in a trial window; needed to split processed concatenated data
    dict_vars['sortwindow_frames'] = [utils.get_tvec_sample(dict_vars['tvec'], time) for time in fparams['event_sort_analysis_win']] # Sort responses between first lick and 10 seconds.
    dict_vars['sortresponse'] = np.argsort(np.mean(data[:,dict_vars['sortwindow_frames'][0]:dict_vars['sortwindow_frames'][1]], axis=1))[::-1]
    
    return dict_vars


In [None]:
def make_silhouette_plot(dict_vars, X, cluster_labels):

    """
    Create silhouette plot for the clusters
    """
    
    n_clusters = len(set(cluster_labels))
    
    fig, ax = plt.subplots(1, 1)
    fig.set_size_inches(4, 4)

    # The 1st subplot is the silhouette plot
    # The silhouette coefficient can range from -1, 1 but in this example all
    # lie within [-0.1, 1]
    ax.set_xlim([-0.4, 1])
    # The (n_clusters+1)*10 is for inserting blank space between silhouette
    # plots of individual clusters, to demarcate them clearly.
    ax.set_ylim([0, len(X) + (n_clusters + 1) * 10])
    silhouette_avg = silhouette_score(X, cluster_labels, metric='cosine')

    # Compute the silhouette scores for each sample
    sample_silhouette_values = silhouette_samples(X, cluster_labels, metric='cosine')

    y_lower = 10
    for i in range(n_clusters):
        # Aggregate the silhouette scores for samples belonging to
        # cluster i, and sort them
        ith_cluster_silhouette_values = \
            sample_silhouette_values[cluster_labels == i]

        ith_cluster_silhouette_values.sort()

        size_cluster_i = ith_cluster_silhouette_values.shape[0]
        y_upper = y_lower + size_cluster_i

        color = dict_vars['colors_for_cluster'][i]
        ax.fill_betweenx(np.arange(y_lower, y_upper),
                          0, ith_cluster_silhouette_values,
                          facecolor=color, edgecolor=color, alpha=0.9)

        # Label the silhouette plots with their cluster numbers at the middle
        ax.text(-0.05, y_lower + 0.5 * size_cluster_i, str(i+1))

        # Compute the new y_lower for next plot
        y_lower = y_upper + 10  # 10 for the 0 samples

    ax.set_title("The silhouette plot for the various clusters.")
    ax.set_xlabel("The silhouette coefficient values")
    ax.set_ylabel("Cluster label")

    # The vertical line for average silhouette score of all the values
    ax.axvline(x=silhouette_avg, color="red", linestyle="--")

    ax.set_yticks([])  # Clear the yaxis labels / ticks
    ax.set_xticks([-0.4, -0.2, 0, 0.2, 0.4, 0.6, 0.8, 1])

    
def plot_event_data(dict_vars, fparams, data):

    fig, axs = plt.subplots(2,len(dict_vars['conditions']),figsize=(3*2,3*2), sharex='all', sharey='row')

    # loop through conditions and plot heatmaps of trial-avged activity
    for t in range(len(dict_vars['conditions'])):

        # set fig panel to plot in
        if len(dict_vars['conditions']) == 1:
            ax_heat = axs[0]
            ax_trace = axs[1]
        else:
            ax_heat = axs[0,t]
            ax_trace = axs[1,t]

        im = utils.subplot_heatmap(ax_heat, ' ', 
                                   data[dict_vars['sortresponse'], t*dict_vars['window_size']: (t+1)*dict_vars['window_size']], 
                                   clims = [-dict_vars['cmax'], dict_vars['cmax']], 
                                   extent_= [dict_vars['tvec'][0], dict_vars['tvec'][-1], data.shape[0], 0 ]) # plot_extent: set plot limits as [time_start, time_end, num_rois, 0]

        ax_heat.set_title(dict_vars['conditions'][t])
        ax_heat.axvline(0, linestyle='--', color='k', linewidth=0.5)   
        if fparams['flag_plot_reward_line']:
            ax_heat.axvline(fparams['second_event_seconds'], linestyle='--', color='k', linewidth=0.5) 

        ### roi-avg tseries 

        mean_ts = np.mean(data[dict_vars['sortresponse'], t*dict_vars['window_size']:(t+1)*dict_vars['window_size']], axis=0)
        stderr_ts = np.std(data[dict_vars['sortresponse'], t*dict_vars['window_size']:(t+1)*dict_vars['window_size']], axis=0)/np.sqrt(data.shape[0])
        ax_trace.plot(dict_vars['tvec'], mean_ts)
        shade = ax_trace.fill_between(dict_vars['tvec'], mean_ts - stderr_ts, mean_ts + stderr_ts, alpha=0.2) # this plots the shaded error bar

        ax_trace.axvline(0, linestyle='--', color='k', linewidth=0.5)  
        if fparams['flag_plot_reward_line']:
            ax_trace.axvline(fparams['second_event_seconds'], linestyle='--', color='k', linewidth=0.5)   
        ax_trace.set_xlabel('Time from event (s)')   

        # plot labels on first panel only
        if t==0:
            ax_trace.set_ylabel('Neurons'); ax_trace.set_ylabel(dict_vars['ylabel'])

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    cbar = fig.colorbar(im, ax = axs, shrink = 0.7)
    cbar.ax.set_ylabel('Heatmap {}'.format(dict_vars['ylabel']), fontsize=13);

    if fparams['flag_save_figs']:
        fig.savefig(os.path.join(dict_vars['save_dir'], 'trial_avg_pop_responses.pdf'), format='pdf')
        fig.savefig(os.path.join(dict_vars['save_dir'], 'trial_avg_pop_responses.png'), format='png', dpi=300)

In [None]:
# declare paths and names
dict_analysis_vars = {}
dict_analysis_vars = set_cluster_colors(dict_analysis_vars)
dict_analysis_vars = normalization_vars(dict_analysis_vars, fparams)
dict_analysis_vars = define_paths(dict_analysis_vars, fparams)
dict_analysis_vars = make_timing_info(dict_analysis_vars, fparams)

# load extracted traces from ROIs
signals = utils.load_signals(dict_analysis_vars['signals_fpath'])

if fparams['group_data']:
    
    dict_analysis_vars['conditions'] = fparams['group_data_conditions']

    # this line takes the pre-trial-averaged data and zscores the traces based on baseline period supplied by user
    if fparams['flag_normalization']:
        populationdata = np.squeeze(np.apply_along_axis(utils.zscore_, -1, signals, dict_analysis_vars['baseline_svec']))
    else:
        populationdata = np.squeeze(signals)
        
    num_samples_trial = int(populationdata.shape[-1]/len(fparams['group_data_conditions']))
    dict_analysis_vars['tvec'] = np.round(np.linspace(dict_analysis_vars['trial_start_end_sec'][0], dict_analysis_vars['trial_start_end_sec'][1], num_samples_trial), 2)

else:

    # load behav data from event times and get condition info
    dict_analysis_vars = load_behav(dict_analysis_vars, fparams)
    # turn trial-based data into trial-averaged concatenated data ready for machine learning
    # data normalization (or lack of) is determined by fparams['flag_normalization']
    populationdata = format_trial_data_for_ml(dict_analysis_vars, signals)

dict_analysis_vars['cmax'] = np.nanmax(np.abs([np.nanmin(populationdata), np.nanmax(populationdata)])) # Maximum colormap value. 

In [None]:
# sort ROIs based on activity in time window; used for plotting event data and sorting ROIs after clustering
dict_analysis_vars = sort_rois (dict_analysis_vars, fparams, populationdata)
plot_event_data(dict_analysis_vars, fparams, populationdata)

## Do PCA to reduce dimensionality in the time-domain

PCA: A linear algebra-based method to optimize how a set of variables can explain the variability of a dataset. Optimizing: meaning finding a new set of axes (ie. variables) that are linear combinations of the original axes where each new axis attempts to capture the most amount of variability in the data as possible while remaining linearly independent from the other new axes.

In this case, we are finding a new linearly independent parameter space that maximizes the explained variance into the top new axes

The goal of using PCA here is to reduce the number of variables explaining the data - the subsequent clustering algorithms perform better with fewer variables. One method of determining the number of variables to keep is to include the first n number of variables that amounts to some amount of variance explained in the data. Usually that number is 90% of variance explained as the last 10% is dominated mostly by variables explaining stochastic noise.

Another method is to find the number at which the scree plot bends. This is done by simply bending the scree plot around the line joining (1, variance explained by first PC) and (num of PCs, variance explained by the last PC) and finding the  number of components just below the minimum of this rotated plot

In [None]:
### Functions to perform PCA and plot visualizations

def num_pc_explained_var(explained_var, explained_var_thresh=90):
    """
    Select pcs for those that capture more than threshold amount of variability in the data
    """
    cum_sum = 0
    for idx, PC_var in enumerate(explained_var):
        cum_sum += PC_var
        if cum_sum > explained_var_thresh:
            return idx+1

# perform PCA across time
def do_PCA(dict_vars, fparams, data):

    if fparams['load_savedpca_or_dopca'] == 'dopca':
        pca = PCA(n_components=min(data.shape[0],data.shape[1]), whiten=True)
        pca.fit(data) 
        with open(os.path.join(fparams['fdir'], 'pcaresults.pickle'), 'wb') as f:
            pickle.dump(pca, f)
    elif fparams['load_savedpca_or_dopca'] == 'savedpca':
        with open(os.path.join(fparams['fdir'], 'pcaresults.pickle'), 'rb') as f:
            pca = pickle.load(f)

    # pca across time; transformed data: each ROI is now a linear combination of the original time-serie
    dict_vars['transformed_data'] = pca.transform(data)

    # grab eigenvectors (pca.components_); linear combination of original axes
    dict_vars['pca_vectors'] = pca.components_ 
    dict_vars['pca_explained_variance'] = pca.explained_variance_ratio_ 
    print('Number of PCs = {}'.format(dict_vars['pca_vectors'].shape[0]))

    # calculate where bend in scree plot is
    x = 100*pca.explained_variance_ratio_ # eigenvalue ratios
    xprime = x - (x[0] + (x[-1]-x[0])/(x.size-1)*np.arange(x.size))

    # define number of PCs
    num_retained_pcs_scree = np.argmin(xprime)
    num_retained_pcs_var = num_pc_explained_var(x, 90)
    if fparams['pca_num_pc_method'] == 0:
        dict_vars['num_retained_pcs'] = num_retained_pcs_scree
    elif fparams['pca_num_pc_method'] == 1:
        dict_vars['num_retained_pcs'] = num_retained_pcs_var
   
    return dict_vars


def plot_PCA_scree(dict_vars, fparams):
    # plot PCA scree plot
    fig, ax = plt.subplots(figsize=(2,2))
    ax.plot(np.arange(dict_vars['pca_explained_variance'].shape[0]).astype(int)+1, dict_vars['pca_explained_variance']*100, 'k')
    ax.set_ylabel('Percentage of\nvariance explained')
    ax.set_xlabel('PC number')
    ax.axvline(dict_vars['num_retained_pcs'], linestyle='--', color='k', linewidth=0.5)
    ax.set_title('Scree plot')
    
    [i.set_linewidth(0.5) for i in ax.spines.values()]
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

    fig.subplots_adjust(left=0.3)
    fig.subplots_adjust(right=0.98)
    fig.subplots_adjust(bottom=0.25)
    fig.subplots_adjust(top=0.9)
    
    if fparams['flag_save_figs']:
        fig.savefig(os.path.join(dict_vars['save_dir'], 'scree_plot.png'), format='png', dpi=300)

### plot retained principal components
def plot_PCs(dict_vars, fparams):
    numcols = 2.0
    fig, axs = plt.subplots(int(np.ceil(dict_vars['num_retained_pcs']/numcols)), int(numcols), sharey='all',
                            figsize=(2.2*numcols, 2.2*int(np.ceil(dict_vars['num_retained_pcs']/numcols))))
    
    for pc in range(dict_vars['num_retained_pcs']):
        ax = axs.flat[pc]
        for k, tempkey in enumerate(dict_vars['conditions']):
            ax.plot(dict_vars['tvec'], dict_vars['pca_vectors'][pc, k*dict_vars['window_size']:(k+1)*dict_vars['window_size']],
                    label='PC %d: %s'%(pc+1, tempkey))

        ax.axvline(0, linestyle='--', color='k', linewidth=1)
        ax.set_title(f'PC {pc+1}')
        # labels on first panel
        if pc == 0:
            ax.set_xlabel('Time from cue (s)'); ax.set_ylabel( 'PCA weights')
            ax.legend(dict_vars['conditions'])

    fig.tight_layout()
    # hide empty plot if present
    for ax in axs.flat[dict_vars['num_retained_pcs']:]:
        ax.set_visible(False)

    if fparams['flag_save_figs']:
        fig.savefig(os.path.join(fparams['save_dir'], 'PCA.png'), format='png', dpi=300)


In [None]:
       
dict_analysis_vars = do_PCA(dict_analysis_vars, fparams, populationdata)       
plot_PCA_scree(dict_analysis_vars, fparams)
plot_PCs(dict_analysis_vars, fparams)

## Clustering

In [None]:
### Functions to perform clustering

def prep_param_testing(dict_vars, fparams, data):
    silhouette_dict = {}
    silhouette_dict['num_retained_pcs'] = dict_vars['num_retained_pcs']
    silhouette_dict['possible_n_nearest_neighbors'] = fparams['possible_n_nearest_neighbors']
    silhouette_dict['shape'] = 'cluster_nn'
    max_n_clusters = np.min([fparams['max_n_clusters']+1, int(data.shape[0])])
    silhouette_dict['possible_n_clusters'] = np.arange(2, max_n_clusters+1) #This requires a minimum of 2 clusters.
    # When the data contain no clusters at all, it will be quite visible when inspecting the two obtained clusters, 
    # as the responses of the clusters will be quite similar. This will also be visible when plotting the data in
    # the reduced dimensionality PC space (done below).
    
    # initialize variables
    silhouette_dict['possible_clustering_models'] = np.array(["Spectral", "Kmeans"])
    silhouette_dict['silhouette_scores'] = np.nan*np.ones((silhouette_dict['possible_n_clusters'].size,
                                        silhouette_dict['possible_n_nearest_neighbors'].size,
                                        silhouette_dict['possible_clustering_models'].size))

    return silhouette_dict


def perform_hyperparameterization(silhouette_dict, dict_vars, fparams):
    # calculate optimal number of clusters and nearest neighbors using silhouette scores
    # loop through iterations of clustering params
    for n_clustersidx, n_clusters in enumerate(silhouette_dict['possible_n_clusters']):
        kmeans = KMeans(n_clusters=n_clusters, random_state=0) #tol=toler_options
        for nnidx, nn in enumerate(silhouette_dict['possible_n_nearest_neighbors']):
            spectral = SpectralClustering(n_clusters=n_clusters, affinity='nearest_neighbors', n_neighbors=nn, random_state=0)
            models = [spectral,kmeans]
            for modelidx,model in enumerate(models):
                model.fit(dict_vars['transformed_data'][:,:dict_vars['num_retained_pcs']])
                silhouette_dict['silhouette_scores'][n_clustersidx, nnidx, modelidx] = silhouette_score(dict_vars['transformed_data'][:,:dict_vars['num_retained_pcs']],
                                                                        model.labels_,
                                                                        metric='cosine')
                if modelidx == 0:
                    print('Done with numclusters = {}, num nearest neighbors = {}: score = {}.3f'.
                          format(n_clusters, nn, silhouette_dict['silhouette_scores'][n_clustersidx, nnidx, modelidx]))
                else:
                    print('Done with numclusters = {}, score = {}.3f'.
                         format(n_clusters, silhouette_dict['silhouette_scores'][n_clustersidx, nnidx, modelidx]))
    print(silhouette_dict['silhouette_scores'].shape)
    print('Done with model fitting')

#     with open(os.path.join(dict_vars['save_dir'],'silhouette_scores.pkl'), 'wb') as f:
#         pickle.dump(silhouette_dict['silhouette_scores'], f)
    
    return silhouette_dict



def extract_optimal_params(silhouette_dict):
    # Identify optimal parameters from the hyperparameter evaluation
    temp = np.where(silhouette_dict['silhouette_scores']==np.nanmax(silhouette_dict['silhouette_scores']))

    silhouette_dict['best_method'] = silhouette_dict['possible_clustering_models'][temp[2][0]]
    silhouette_dict['best_n_clusters'] = silhouette_dict['possible_n_clusters'][temp[0][0]]
    
    if 'Spectral' in silhouette_dict['best_method']:
        silhouette_dict['best_n_nearest_neighbors'] = silhouette_dict['possible_n_nearest_neighbors'][temp[1][0]]
        print('Optimal params: {} clusters, {} nearest neighbors, {} PCs, {} clustering, silhouette score {}'.
          format(silhouette_dict['best_n_clusters'], 
                 silhouette_dict['best_n_nearest_neighbors'], 
                 silhouette_dict['num_retained_pcs'], 
                 silhouette_dict['best_method'],
                np.nanmax(silhouette_dict['silhouette_scores'])))
    else:
        print('Optimal params: {} clusters, {} PCs, {} clustering, silhouette score {}'.
          format(silhouette_dict['best_n_clusters'],
                 silhouette_dict['num_retained_pcs'], 
                 silhouette_dict['best_method'], 
                 np.nanmax(silhouette_dict['silhouette_scores'])))
    
    return silhouette_dict


def reorder_clusters(data, sort_win_frames, rawlabels):
    # Since the clustering labels are arbitrary, I rename the clusters so that the first cluster will have the most
    # positive response and the last cluster will have the most negative response.
    uniquelabels = list(set(rawlabels))
    responses = np.nan*np.ones((len(uniquelabels),))
    for l, label in enumerate(uniquelabels):
        responses[l] = np.mean(data[rawlabels==label, sort_win_frames[0]:sort_win_frames[1]])
    temp = np.argsort(responses).astype(int)[::-1]
    temp = np.array([np.where(temp==a)[0][0] for a in uniquelabels])
    outputlabels = np.array([temp[a] for a in list(np.digitize(rawlabels, uniquelabels)-1)])
    return outputlabels


def recluster_optimal_params(dict_vars, fparams, silhouette_dict, data):
    # Redo clustering with these optimal parameters
    model = None
    if silhouette_dict['best_method'] == 'Spectral':
        model = SpectralClustering(n_clusters=silhouette_dict['best_n_clusters'],
                               affinity='nearest_neighbors',
                               n_neighbors=silhouette_dict['best_n_nearest_neighbors'],
                               random_state=0)
    else:
        model = KMeans(n_clusters=silhouette_dict['best_n_clusters'], random_state=0)

    model.fit(dict_vars['transformed_data'][:,:silhouette_dict['num_retained_pcs']])

    # these are the ROIs cluster belongings!
    dict_vars['newlabels'] = reorder_clusters(data, dict_vars['sortwindow_frames'], model.labels_)

    # Create a new variable containing all unique cluster labels
    dict_vars['uniquelabels'] = list(set(dict_vars['newlabels']))


    # Save this optimal clustering model.
    # with open(os.path.join(save_dir, 'clusteringmodel.pickle'), 'wb') as f:
    #     pickle.dump(model, f)

    # np.save(os.path.join(summarydictdir, dt_string+'_'+ clusterkey+'_' + 'spectral_clusterlabels.npy'), newlabels)

    return dict_vars


In [None]:
silhouette_dict = prep_param_testing(dict_analysis_vars, fparams, populationdata) # set up variables for hyperparameterization
silhouette_dict = perform_hyperparameterization(silhouette_dict, dict_analysis_vars, fparams)
silhouette_dict = extract_optimal_params(silhouette_dict)
dict_analysis_vars = recluster_optimal_params(dict_analysis_vars, fparams, silhouette_dict, populationdata)

In [None]:
### functions for cluster plotting

def plot_cluster_heatmaps(dict_vars, fparams, data):
    
    fig, axs = plt.subplots(len(dict_vars['conditions']),len(dict_vars['uniquelabels']),
                            figsize=(2*len(dict_vars['uniquelabels']),2*len(dict_vars['conditions'])))
    if len(axs.shape) == 1:
        axs = np.expand_dims(axs, axis=0)

    numroisincluster = np.nan*np.ones((len(dict_vars['uniquelabels']),))

    for c, cluster in enumerate(dict_vars['uniquelabels']):
        for k, tempkey in enumerate(dict_vars['conditions']):

            temp = data[np.where(dict_vars['newlabels']==cluster)[0], k*dict_vars['window_size']:(k+1)*dict_vars['window_size']]
            numroisincluster[c] = temp.shape[0]
            ax=axs[k, cluster]
            sortresponse = np.argsort(np.mean(temp[:,dict_vars['sortwindow_frames'][0]:dict_vars['sortwindow_frames'][1]], axis=1))[::-1]

            plot_extent = [dict_vars['tvec'][0], dict_vars['tvec'][-1], len(sortresponse), 0 ]
            im = utils.subplot_heatmap(ax, ' ', temp[sortresponse], 
                                       clims = [-dict_vars['cmax'], dict_vars['cmax']], extent_=plot_extent)

            axs[k, cluster].grid(False) 
            if k!=len(dict_vars['conditions'])-1:

                axs[k, cluster].set_xticks([])

            axs[k, cluster].set_yticks([])
            axs[k, cluster].axvline(0, linestyle='--', color='k', linewidth=0.5)
            if fparams['flag_plot_reward_line']:
                axs[k, cluster].axvline(second_event_seconds, linestyle='--', color='k', linewidth=0.5)
            if cluster==0:
                axs[k, 0].set_ylabel('%s'%(tempkey))
        axs[0, cluster].set_title('Cluster %d\n(n=%d)'%(cluster+1, numroisincluster[c]))

    fig.text(0.4, 0.15, 'Time from cue (s)', fontsize=12,
             horizontalalignment='center', verticalalignment='center', rotation='horizontal')
    fig.tight_layout()

    fig.subplots_adjust(wspace=0.1, hspace=0.1); fig.subplots_adjust(left=0.03); fig.subplots_adjust(right=0.93)
    fig.subplots_adjust(bottom=0.2); fig.subplots_adjust(top=0.83)                                    

    cbar = fig.colorbar(im, ax = axs, shrink = 0.7)
    cbar.ax.set_ylabel('Z-Score Activity', fontsize=13);

    if fparams['flag_save_figs']:
        plt.savefig(os.path.join(dict_vars['save_dir'], 'cluster_heatmap.png'))
        plt.savefig(os.path.join(dict_vars['save_dir'], 'cluster_heatmap.pdf'))

        
def plot_cluster_traces(dict_vars, fparams, data):
    # Plot amount of fluorescence normalized for each cluster by conditions over time
    fig, axs = plt.subplots(1,len(dict_vars['uniquelabels']),
                            figsize=(3*len(dict_vars['uniquelabels']),1.5*len(dict_vars['conditions'])))

    numroisincluster = np.nan*np.ones((len(dict_vars['uniquelabels']),))

    for c, cluster in enumerate(dict_vars['uniquelabels']):
        for k, tempkey in enumerate(dict_vars['conditions']):
            temp = data[np.where(dict_vars['newlabels']==cluster)[0], k*dict_vars['window_size']:(k+1)*dict_vars['window_size']]
            numroisincluster[c] = temp.shape[0]
            sortresponse = np.argsort(np.mean(temp[:,dict_vars['sortwindow_frames'][0]:dict_vars['sortwindow_frames'][1]], axis=1))[::-1]
            sns.lineplot(x="variable", y="value",data = pd.DataFrame(temp[sortresponse]).rename(columns=tvec_to_dict(dict_vars['tvec'])).melt(),
                        ax = axs[cluster],
                        palette=plt.get_cmap('coolwarm'),label = tempkey,legend = False)
            axs[cluster].grid(False)  
            axs[cluster].axvline(0, linestyle='--', color='k', linewidth=0.5)
            axs[cluster].spines['right'].set_visible(False); axs[cluster].spines['top'].set_visible(False)
            if cluster==0:
                axs[cluster].set_ylabel('Normalized fluorescence')
            else:
                axs[cluster].set_ylabel('')
            axs[cluster].set_xlabel('')
        axs[cluster].set_title('Cluster %d\n(n=%d)'%(cluster+1, numroisincluster[c]))
        axs[0].legend()
    fig.text(0.5, 0.05, 'Time from cue (s)', fontsize=12,
             horizontalalignment='center', verticalalignment='center', rotation='horizontal')
    fig.tight_layout()

    fig.subplots_adjust(wspace=0.1, hspace=0.1); fig.subplots_adjust(left=0.03); fig.subplots_adjust(right=0.93)
    fig.subplots_adjust(bottom=0.2); fig.subplots_adjust(top=0.83)

    if fparams['flag_save_figs']:
        plt.savefig(os.path.join(dict_vars['save_dir'], 'cluster_roiAvg_traces.png'))
        plt.savefig(os.path.join(dict_vars['save_dir'], 'cluster_roiAvg_traces.pdf'))
        
        
def plot_clusters_in_pca_space(dict_vars, fparams, silhouette_dict):
    # Perform TSNE on newly defined clusters
    num_clusterpairs = len(dict_vars['uniquelabels'])*(len(dict_vars['uniquelabels'])-1)/2
    numrows = int(np.ceil(num_clusterpairs**0.5)); numcols = int(np.ceil(num_clusterpairs/np.ceil(num_clusterpairs**0.5)))
    fig, axs = plt.subplots(numrows, numcols, figsize=(3*numrows, 3*numcols))

    tempsum = 0
    for c1, cluster1 in enumerate(dict_vars['uniquelabels']):
        for c2, cluster2 in enumerate(dict_vars['uniquelabels']):
            if cluster1>=cluster2:
                continue

            temp1 = dict_vars['transformed_data'][np.where(dict_vars['newlabels']==cluster1)[0], :silhouette_dict['num_retained_pcs']]
            temp2 = dict_vars['transformed_data'][np.where(dict_vars['newlabels']==cluster2)[0], :silhouette_dict['num_retained_pcs']]
            X = np.concatenate((temp1, temp2), axis=0)

            tsne = TSNE(n_components=2, init='random',
                        random_state=0, perplexity=np.sqrt(X.shape[0]))
            Y = tsne.fit_transform(X)

            if numrows*numcols==1:
                ax = axs
            else:
                ax = axs[int(tempsum/numcols),
                         abs(tempsum - int(tempsum/numcols)*numcols)]
            ax.scatter(Y[:np.sum(dict_vars['newlabels']==cluster1),0],
                       Y[:np.sum(dict_vars['newlabels']==cluster1),1],
                       color=dict_vars['colors_for_cluster'][cluster1], label='Cluster %d'%(cluster1+1), alpha=1)
            ax.scatter(Y[np.sum(dict_vars['newlabels']==cluster1):,0],
                       Y[np.sum(dict_vars['newlabels']==cluster1):,1],
                       color=dict_vars['colors_for_cluster'][cluster2+3], label='Cluster %d'%(cluster2+1), alpha=1)

            ax.set_xlabel('tsne dimension 1')
            ax.set_ylabel('tsne dimension 2')
            ax.legend()
            tempsum += 1

            fig.tight_layout()

In [None]:

plot_cluster_heatmaps(dict_analysis_vars, fparams, data)   


In [None]:
plot_cluster_traces(dict_analysis_vars, fparams, data)

In [None]:
plot_clusters_in_pca_space(dict_analysis_vars, fparams, silhouette_dict)    