# PCA & Clustering

The goal of this code is to input activity time-series data from a neural recording and cluster the cells/rois (samples) based on the neural activity (features). Clustering is performed on trial-averaged event-related responses; data from different trial conditions are concatenated and fed into dimensionality reduction (PCA) and finally into multiple clustering algorithms. The optimal hyperparameters for PCA and clustering methods are automatically determined based on the best silhouette score. 

1) PCA to reduce dimensionality of trial-averaged event-related responses (rois x time), with respect to time dimension. Intuitive concept: PCA is performed on the time dimension (each time point is treated as a feature/variable. That means the resulting principal components (PCs) are linear combinations of the original time points. The first PCs represent each ROI's datapoint resides in n dimensional space where n is the number of samples in the event-related window. PCA finds new set of (orthogonal) axes that maximizes the variance in the activity. These new axes are linear combinations of the original axes
 

2) Clustering: The roi data are now characterized by a reduced set of optimized axes describing time. We now cluster using either kMeans clustering or spectral clustering.
    
    1. KMeans clustering: Assuming data clouds are gaussian. The three main steps of kMeans clustering are **A)** Initialize the K value, **B)** Calculate the distance between test input and K trained nearest neighbors, **C)** Return class category by taking the majority of votes
    
    2. Spectral clustering: Not assuming any particular shape of the cluster data points. The three main steps of spectral clustering are **A)** create graph theory similarity matrix for each ROI based on how close other ROIs are in the PCA space, **B)** perform eigendecomposition of the similarity matrix, **C)** Use kmeans clustering on the transformed data. 

Prerequisites
------------------------------------

All data should reside in a parent folder. This folder's name should be the name of the session and ideally be the same as the base name of the recording file.

Data need to be run through the NAPECA event_rel_analysis code in order to generate the event_data_dict.pkl file, which contains event-related activity across different behavioral conditions for all neurons/ROIs.


How to run this code
------------------------------------

In this jupyter notebook, just run all cells in order (shift + enter).

__You can indicate specific files and parameters to include in the second cell__

Required Packages
-----------------
Python 3.7, seaborn, matplotlib, pandas, scikit-learn, statsmodels, numpy, h5py

Custom code requirements: utils

Parameters 
----------

fname_signal : string
    
    Name of file that contains roi activity traces. Must include full file name with extension. Accepted file types: .npy, .csv. IMPORTANT: data dimensions should be rois (y) by samples/time (x)

fname_events : string

    Name of file that contains event occurrences. Must include full file name with extension. Accepted file types: .pkl, .csv. Pickle (pkl) files need to contain a dictionary where keys are the condition names and the values are lists containing samples/frames for each corresponding event. Csv's should have two columns (event condition, sample). The first row are the column names. Subsequent rows contain each trial's event condition and sample in tidy format. See example in sample_data folder for formatting, or this link: https://github.com/zhounapeuw/NAPE_imaging_postprocess/raw/main/docs/_images/napeca_post_event_csv_format.png

fdir : string 

    Root file directory containing the raw tif, tiff, h5 files. Note: leave off the last backslash. For example: ../napeca_post/sample_data if clone the repo directly

trial_start_end : list of two entries  

    Entries can be ints or floats. The first entry is the time in seconds relative to the event/ttl onset for the start of the event analysis window (negative if start time is before the event/ttl onset). Event analysis window refers to the primary time window around events that is visualized in plots. The second entry is the time in seconds for the end of the event analysis window. For example if the desired analysis window is 5.5 seconds before event onset and 8 seconds after, `trial_start_end` would be [-5.5, 8].  
    
baseline_start_end : list of two entries  

    Entries can be ints or floats. The first entry is the time in seconds relative to the event/ttl onset for the start of the baseline window (negative if start time is before the event/ttl onset). Baseline window refers to the time window relative to event onset that is used to calculate the mean and/or standard deviation for normalization. The second entry is the time in seconds (relative to event onset) for the end of the baseline window. For example if the desired analysis window is 5.5 seconds to 0.2 seconds before event onset, `baseline_start_end` would be [-5.5, -0.2]. If a single number is supplied, the baseline window onset will default to the first entry of `trial_start_end` and window end will be the value supplied to `baseline_start_end`.

event_sort_analysis_win : list with two float entries

    Time window [a, b] in seconds during which some visualization calculations will apply to. For example, if the user sets flag_sort_rois to be True, ROIs in heatmaps will be sorted based on the mean activity in the time window between a and b. 

pca_num_pc_method : 0 or 1 (int)

    Method for calculating number of principal components to retain from PCA preprocessing. 0 for bend in scree plot, 1 for num PCs that account for 90% variance.
    User should try either method and observe which result fits the experiment. Sometimes may not impact the results.

flag_save_figs : boolean  

    Set as True to save figures as JPG and vectorized formats. 

selected_conditions : list of strings

    Specific conditions that the user wants to analyze; needs to be exactly the name of conditions in the events CSV or pickle file

flag_plot_reward_line: boolean  

    If set to True, plot a vertical line for secondary event. Time of vertical line is dictated by the variable second_event_seconds
    
second_event_seconds: int/float
    
    Time in seconds (relative to primary event onset) for plotting a vertical dotted line indicating an optional second event occurrence 

max_n_clusters : integer
    
    Maximum number of clusters expected for clustering models. As general rule, select this number based on maximum expected number of clusters in the data + ~5. Keep in mind that larger values will increase processing time
    
possible_n_nearest_neighbors : array of integers
    
    In spectral clustering, set n_neighbors to n from the range of possible_n_nearest_neighbors for each data point and create connectivity graph (affinity matrix).
    
    In general, choosing the value of possible_n_nearest_neighbors is sqrt(N) where N stands for the number of samples in your dataset. ( https://towardsdatascience.com/a-simple-introduction-to-k-nearest-neighbors-algorithm-b3519ed98e )

In [None]:
import pickle
import math
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC, SVR, LinearSVC
from sklearn.metrics import accuracy_score, silhouette_score, adjusted_rand_score, silhouette_samples
from sklearn.cluster import AgglomerativeClustering, SpectralClustering, KMeans
from sklearn.model_selection import KFold, LeaveOneOut, train_test_split
from sklearn.model_selection import GridSearchCV
from sklearn.kernel_ridge import KernelRidge
from sklearn import linear_model
from sklearn.manifold import TSNE
import scipy.stats as stats
import statsmodels.api as sm
import statsmodels.formula.api as smf
from patsy import (ModelDesc, EvalEnvironment, Term, EvalFactor, LookupFactor, dmatrices, INTERCEPT)
from statsmodels.distributions.empirical_distribution import ECDF
import matplotlib.cm as cm
import matplotlib.colors as colors
import matplotlib.colorbar as colorbar
import sys
import os
import re
import glob
import seaborn as sns
import matplotlib.pyplot as plt
#important for text to be detected when importing saved figures into illustrator
import matplotlib
matplotlib.rcParams['pdf.fonttype']=42
matplotlib.rcParams['ps.fonttype']=42
plt.rcParams["font.family"] = "Arial"

import numpy as np
import pandas as pd
import json
import utils

In [None]:
"""
USER-DEFINED VARIABLES
"""

fname_signal = 'VJ_OFCVTA_7_260_D6_neuropil_corrected_signals_15_50_beta_0.8.npy'   # name of your npy or csv file that contains activity signals
fname_events = 'event_times_VJ_OFCVTA_7_260_D6_trained.csv' # name of your pickle or csv file that contains behavioral event times (in seconds)
# fdir signifies to the root path of the data. Currently, the abspath phrase points to sample data from the repo.
# To specify a path that is on your local computer, use this string format: r'your_root_path', where you should copy/paste
# your path between the single quotes (important to keep the r to render as a complete raw string). See example below:
# r'C:\Users\stuberadmin\Documents\GitHub\NAPE_imaging_postprocess\napeca_post\sample_data' 
fdir = os.path.abspath('./sample_data/VJ_OFCVTA_7_260_D6') # for an explicit path, eg. r'C:\2pData\Vijay data\VJ_OFCVTA_7_D8_trained'
fs = 5 # sampling rate of activity data

# trial extraction info
trial_start_end = [-2, 8] # primary visualization window relative to event onset; [start, end] times (in seconds) 
baseline_start_end = [-2, -0.2] # baseline window (in seconds) for performing baseline normalization. either a list [start, end] or an int/float (see details in markdown above); I set this to -0.2 to be safe I'm not grabbing a sample that includes the event
event_sort_analysis_win = [0, 5] # time window (in seconds)

pca_num_pc_method = 0 # 0 for bend in scree plot, 1 for num PCs that account for 90% variance

# variables for clustering
max_n_clusters = 10 # from Vijay: Maximum number of clusters expected. This should be based on the number of functional neuron groups you expect + ~3. In your data, 
# might be worth increasing this, but it will take more time to run.

'''In spectral clustering: get n nearest neighbors for each data point and create connectivity graph (affinity matrix)'''
possible_n_nearest_neighbors = np.arange(1, 10) #np.array([3,5,10]) # This should be selected for each dataset
# appropriately. When 4813 neurons are present, the above number of nearest neighbors ([30,40,30,50,60]) provides a good sweep of the
# parameter space. But it will need to be changed for other data.

# optional arguments
selected_conditions = None # set to a list of strings if you want to filter specific conditions to analyze; eg. ['plus', 'minus']
flag_plot_reward_line = False # if there's a second event that happens after the main event, it can be indicated if set to True; timing is dictated by the next variables below
second_event_seconds = 1 # time in seconds
flag_save_figs = True # set to true if you want to save plots
heatmap_cmap_scaling = 1 # set to lower value if colormap range is too large

# set to True if the data you are loading in already has data from different conditions concatenated together
# do not set to True if data come directly from suite2p or sima/napeca preprocessing!
group_data = False
group_data_conditions = ['cs_plus', 'cs_minus']

In [None]:
# declare paths and names
fname = os.path.split(fdir)[1]
signals_fpath = os.path.join(fdir, fname_signal)
save_dir = os.path.join(fdir, 'event_rel_analysis')
utils.check_exist_dir(save_dir); # make the save directory

signals = utils.load_signals(signals_fpath)

trial_start_end_sec = np.array(trial_start_end) # trial windowing in seconds relative to ttl-onset/trial-onset, in seconds
if type(baseline_start_end) is list:
    baseline_start_end_sec = np.array(baseline_start_end)
elif isinstance(baseline_start_end, (int, float)):
    baseline_start_end_sec = np.array([trial_start_end_sec[0], baseline_start_end])
    
baseline_begEnd_samp = baseline_start_end_sec*fs
baseline_svec = (np.arange(baseline_begEnd_samp[0], baseline_begEnd_samp[1] + 1, 1) -
                        baseline_begEnd_samp[0]).astype('int')

if group_data:
    conditions = group_data_conditions

    if selected_conditions:
        conditions = selected_conditions

    num_conditions = len(conditions)

    populationdata = np.squeeze(np.apply_along_axis(utils.zscore_, -1, signals, baseline_svec))

    num_samples_trial = int(populationdata.shape[-1]/len(group_data_conditions))
    tvec = np.round(np.linspace(trial_start_end_sec[0], trial_start_end_sec[1], num_samples_trial), 2)

else:

    events_file_path = os.path.join(fdir, fname_events)

    
    glob_event_files = glob.glob(events_file_path) # look for a file in specified directory
    if not glob_event_files:
        print(f'{events_file_path} not detected. Please check if path is correct.')
    if 'csv' in glob_event_files[0]:
        event_times = utils.df_to_dict(glob_event_files[0])
    elif any(x in glob_event_files[0] for x in ['pkl', 'pickle']):
        event_times = pickle.load( open( glob_event_files[0], "rb" ), fix_imports=True, encoding='latin1' ) # latin1 b/c original pickle made in python 2
    event_frames = utils.dict_time_to_samples(event_times, fs)


    # identify conditions to analyze
    all_conditions = event_frames.keys()
    conditions = [ condition for condition in all_conditions if len(event_frames[condition]) > 0 ] # keep conditions that have events

    conditions.sort()
    if selected_conditions:
        conditions = selected_conditions

    num_conditions = len(conditions)

    ### define trial timing

    # convert times to samples and get sample vector for the trial 
    trial_begEnd_samp = trial_start_end_sec*fs # turn trial start/end times to samples
    trial_svec = np.arange(trial_begEnd_samp[0], trial_begEnd_samp[1])
    # calculate time vector for plot x axes
    num_samples_trial = len( trial_svec )
    tvec = np.round(np.linspace(trial_start_end_sec[0], trial_start_end_sec[1], num_samples_trial+1), 2)


    """
    MAIN data processing function to extract event-centered data

    extract and save trial data, 
    saved data are in the event_rel_analysis subfolder, a pickle file that contains the extracted trial data
    """
    data_dict = utils.extract_trial_data(signals, tvec, trial_begEnd_samp, event_frames, 
                                        conditions, baseline_start_end_samp = baseline_begEnd_samp, save_dir=None)


    #### concatenate data across trial conditions

    # concatenates data across trials in the time axis; populationdata dimentionss are ROI by time (trials are appended)
    populationdata = np.concatenate([data_dict[condition]['ztrial_avg_data'] for condition in conditions], axis=1)
    np.save(os.path.join(save_dir, 'cluster_pop_data.npy'), populationdata)
    
    # remove rows with nan values
    nan_rows = np.unique(np.where(np.isnan(populationdata))[0])
    if nan_rows.size != 0:
        populationdata = np.delete(populationdata, obj=nan_rows, axis=0)
        print('Some ROIs contain nan in tseries!')

cmax = np.nanmax(np.abs([np.nanmin(populationdata), np.nanmax(populationdata)])) # Maximum colormap value. 

In [None]:
def standardize_plot_graphics(ax):
    """
    Standardize plots
    """
    [i.set_linewidth(0.5) for i in ax.spines.itervalues()] # change the width of spines for both axis
    ax.spines['right'].set_visible(False) # remove top the right axis
    ax.spines['top'].set_visible(False)
    return ax

def fit_regression(x, y):
    """
    Fit a linear regression with ordinary least squares
    """
    lm = sm.OLS(y, sm.add_constant(x)).fit() # add a column of 1s for intercept before fitting
    x_range = sm.add_constant(np.array([x.min(), x.max()]))
    x_range_pred = lm.predict(x_range)
    return lm.pvalues[1], lm.params[1], x_range[:,1], x_range_pred, lm.rsquared

def CDFplot(x, ax, **kwargs):
    """
    Create a cumulative distribution function (CDF) plot
    """
    x = np.array(x)
    ix= np.argsort(x)
    ax.plot(x[ix], ECDF(x)(x)[ix], **kwargs)
    return ax


def fit_regression_and_plot(x, y, ax, plot_label='', color='k', linecolor='r', markersize=3,
                            show_pval=True):
    """
    Fit a linear regression model with ordinary least squares and visualize the results
    """
    #linetype is a string like 'bo'
    pvalue, slope, temp, temppred, R2 = fit_regression(x, y)   
    if show_pval:
        plot_label = '%s p=%.2e\nr=%.3f'% (plot_label, pvalue, np.sign(slope)*np.sqrt(R2))
    else:
        plot_label = '%s r=%.3f'% (plot_label, np.sign(slope)*np.sqrt(R2))
    ax.scatter(x, y, color=color, label=plot_label, s=markersize)
    ax.plot(temp, temppred, color=linecolor)
    return ax, slope, pvalue, R2


def make_silhouette_plot(X, cluster_labels):

    """
    Create silhouette plot for the clusters
    """
    
    n_clusters = len(set(cluster_labels))
    
    fig, ax = plt.subplots(1, 1)
    fig.set_size_inches(4, 4)

    # The 1st subplot is the silhouette plot
    # The silhouette coefficient can range from -1, 1 but in this example all
    # lie within [-0.1, 1]
    ax.set_xlim([-0.4, 1])
    # The (n_clusters+1)*10 is for inserting blank space between silhouette
    # plots of individual clusters, to demarcate them clearly.
    ax.set_ylim([0, len(X) + (n_clusters + 1) * 10])
    silhouette_avg = silhouette_score(X, cluster_labels, metric='cosine')

    # Compute the silhouette scores for each sample
    sample_silhouette_values = silhouette_samples(X, cluster_labels, metric='cosine')

    y_lower = 10
    for i in range(n_clusters):
        # Aggregate the silhouette scores for samples belonging to
        # cluster i, and sort them
        ith_cluster_silhouette_values = \
            sample_silhouette_values[cluster_labels == i]

        ith_cluster_silhouette_values.sort()

        size_cluster_i = ith_cluster_silhouette_values.shape[0]
        y_upper = y_lower + size_cluster_i

        color = colors_for_cluster[i]
        ax.fill_betweenx(np.arange(y_lower, y_upper),
                          0, ith_cluster_silhouette_values,
                          facecolor=color, edgecolor=color, alpha=0.9)

        # Label the silhouette plots with their cluster numbers at the middle
        ax.text(-0.05, y_lower + 0.5 * size_cluster_i, str(i+1))

        # Compute the new y_lower for next plot
        y_lower = y_upper + 10  # 10 for the 0 samples

    ax.set_title("The silhouette plot for the various clusters.")
    ax.set_xlabel("The silhouette coefficient values")
    ax.set_ylabel("Cluster label")

    # The vertical line for average silhouette score of all the values
    ax.axvline(x=silhouette_avg, color="red", linestyle="--")

    ax.set_yticks([])  # Clear the yaxis labels / ticks
    ax.set_xticks([-0.4, -0.2, 0, 0.2, 0.4, 0.6, 0.8, 1])

In [None]:
# variables for plotting

# calculated variables
window_size = int(populationdata.shape[1]/num_conditions) # Total number of frames in a trial window; needed to split processed concatenated data
sortwindow_frames = [int(np.round(time*fs)) for time in event_sort_analysis_win] # Sort responses between first lick and 10 seconds.
sortresponse = np.argsort(np.mean(populationdata[:,sortwindow_frames[0]:sortwindow_frames[1]], axis=1))[::-1]
# sortresponse corresponds to an ordering of the neurons based on their average response in the sortwindow


In [None]:
fig, axs = plt.subplots(2,num_conditions,figsize=(3*2,3*2), sharex='all', sharey='row')

# loop through conditions and plot heatmaps of trial-avged activity
for t in range(num_conditions):
    
    if num_conditions == 1:
        ax = axs[0]
    else:
        ax = axs[0,t]

    plot_extent = [tvec[0], tvec[-1], populationdata.shape[0], 0 ] # set plot limits as [time_start, time_end, num_rois, 0]
    im = utils.subplot_heatmap(ax, ' ', populationdata[sortresponse, t*window_size: (t+1)*window_size], 
                               clims = [-cmax, cmax], extent_=plot_extent)
    ax.set_title(conditions[t])
    
    ax.axvline(0, linestyle='--', color='k', linewidth=0.5)   
    if flag_plot_reward_line:
        ax.axvline(second_event_seconds, linestyle='--', color='k', linewidth=0.5) 
    
    ### roi-avg tseries 
    if num_conditions == 1:
        ax = axs[1]
    else:
        ax = axs[1,t]
    mean_ts = np.mean(populationdata[sortresponse, t*window_size:(t+1)*window_size], axis=0)
    stderr_ts = np.std(populationdata[sortresponse, t*window_size:(t+1)*window_size], axis=0)/np.sqrt(populationdata.shape[0])
    ax.plot(tvec, mean_ts)
    shade = ax.fill_between(tvec, mean_ts - stderr_ts, mean_ts + stderr_ts, alpha=0.2) # this plots the shaded error bar
    ax.axvline(0, linestyle='--', color='k', linewidth=0.5)  
    if flag_plot_reward_line:
        ax.axvline(second_event_seconds, linestyle='--', color='k', linewidth=0.5)   
    ax.set_xlabel('Time from event (s)')   
 
    if t==0:
        ax.set_ylabel('Neurons')
        ax.set_ylabel('Mean norm. fluor.')

plt.tight_layout(rect=[0, 0.03, 1, 0.95])
cbar = fig.colorbar(im, ax = axs, shrink = 0.7)
cbar.ax.set_ylabel('Heatmap Z-Score Activity', fontsize=13);

# if flag_save_figs:
#     fig.savefig(os.path.join(save_dir, 'trial_avg_pop_responses.pdf'), format='pdf')
#     fig.savefig(os.path.join(save_dir, 'trial_avg_pop_responses.png'), format='png', dpi=300)

## Do PCA to reduce dimensionality in the time-domain

PCA: A linear algebra-based method to optimize how a set of variables can explain the variability of a dataset. Optimizing: meaning finding a new set of axes (ie. variables) that are linear combinations of the original axes where each new axis attempts to capture the most amount of variability in the data as possible while remaining linearly independent from the other new axes.

In this case, we are finding a new linearly independent parameter space that maximizes the explained variance into the top new axes

In [None]:
def num_pc_explained_var(explained_var, explained_var_thresh=90):
    """
    Select pcs for those that capture more than threshold amount of variability in the data
    """
    cum_sum = 0
    for idx, PC_var in enumerate(explained_var):
        cum_sum += PC_var
        if cum_sum > explained_var_thresh:
            return idx+1

In [None]:
load_savedpca_or_dopca = 'dopca'
# Select 'dopca' for doing PCA on the data. Select 'savedpca' for loading my previous results

# perform PCA across time
if load_savedpca_or_dopca == 'dopca':
    pca = PCA(n_components=min(populationdata.shape[0],populationdata.shape[1]), whiten=True)
    pca.fit(populationdata) 
    with open(os.path.join(fdir, 'pcaresults.pickle'), 'wb') as f:
        pickle.dump(pca, f)
elif load_savedpca_or_dopca == 'savedpca':
    with open(os.path.join(fdir, 'OFCCaMKII_pcaresults.pickle'), 'rb') as f:
        pca = pickle.load(f)

# pca across time
transformed_data = pca.transform(populationdata)
# transformed data: each ROI is now a linear combination of the original time-serie
# np.save(os.path.join(save_dir, "transformed_data.npy"),transformed_data)

# grab eigenvectors (pca.components_); linear combination of original axes
pca_vectors = pca.components_ 
print(f'Number of PCs = {pca_vectors.shape[0]}')

# Number of PCs to be kept is defined as the number at which the 
# scree plot bends. This is done by simply bending the scree plot
# around the line joining (1, variance explained by first PC) and
# (num of PCs, variance explained by the last PC) and finding the 
# number of components just below the minimum of this rotated plot
x = 100*pca.explained_variance_ratio_ # eigenvalue ratios
xprime = x - (x[0] + (x[-1]-x[0])/(x.size-1)*np.arange(x.size))

# define number of PCs
num_retained_pcs_scree = np.argmin(xprime)
num_retained_pcs_var = num_pc_explained_var(x, 90)
if pca_num_pc_method == 0:
    num_retained_pcs = num_retained_pcs_scree
elif pca_num_pc_method == 1:
    num_retained_pcs = num_retained_pcs_var

In [None]:

print(f'Number of PCs to keep = {num_retained_pcs}')

# plot PCA plot
fig, ax = plt.subplots(figsize=(2,2))
ax.plot(np.arange(pca.explained_variance_ratio_.shape[0]).astype(int)+1, x, 'k')
ax.set_ylabel('Percentage of\nvariance explained')
ax.set_xlabel('PC number')
ax.axvline(num_retained_pcs, linestyle='--', color='k', linewidth=0.5)
ax.set_title('Scree plot')
[i.set_linewidth(0.5) for i in ax.spines.values()]
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)

fig.subplots_adjust(left=0.3)
fig.subplots_adjust(right=0.98)
fig.subplots_adjust(bottom=0.25)
fig.subplots_adjust(top=0.9)

# if flag_save_figs:
#     fig.savefig(os.path.join(save_dir, 'scree_plot.png'), format='png', dpi=300)



### plot retained principal components
numcols = 2.0
fig, axs = plt.subplots(int(np.ceil(num_retained_pcs/numcols)), int(numcols), sharey='all',
                        figsize=(2.2*numcols, 2.2*int(np.ceil(num_retained_pcs/numcols))))
for pc in range(num_retained_pcs):
    ax = axs.flat[pc]
    for k, tempkey in enumerate(conditions):
        ax.plot(tvec, pca_vectors[pc, k*window_size:(k+1)*window_size],
                label='PC %d: %s'%(pc+1, tempkey))

    ax.axvline(0, linestyle='--', color='k', linewidth=1)
    ax.set_title(f'PC {pc+1}')

    # labels
    if pc == 0:
        ax.set_xlabel('Time from cue (s)')
        ax.set_ylabel( 'PCA weights')


fig.tight_layout()
for ax in axs.flat[num_retained_pcs:]:
    ax.set_visible(False)

plt.tight_layout()
    
# if flag_save_figs:
#     fig.savefig(os.path.join(save_dir, 'PCA.png'), format='png', dpi=300)

## Clustering

In [None]:
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning) 

In [None]:
# calculate optimal number of clusters and nearest neighbors using silhouette scores
min_clusters = np.min([max_n_clusters+1, int(populationdata.shape[0])])
possible_n_clusters = np.arange(2, max_n_clusters+1) #This requires a minimum of 2 clusters.
# When the data contain no clusters at all, it will be quite visible when inspecting the two obtained clusters, 
# as the responses of the clusters will be quite similar. This will also be visible when plotting the data in
# the reduced dimensionality PC space (done below).

possible_clustering_models = np.array(["Spectral", "Kmeans"])
silhouette_scores = np.nan*np.ones((possible_n_clusters.size,
                                    possible_n_nearest_neighbors.size,
                                    possible_clustering_models.size))

# loop through iterations of clustering params
for n_clustersidx, n_clusters in enumerate(possible_n_clusters):
    kmeans = KMeans(n_clusters=n_clusters, random_state=0) #tol=toler_options
    for nnidx, nn in enumerate(possible_n_nearest_neighbors):
        spectral = SpectralClustering(n_clusters=n_clusters, affinity='nearest_neighbors', n_neighbors=nn, random_state=0)
        models = [spectral,kmeans]
        for modelidx,model in enumerate(models):
            model.fit(transformed_data[:,:num_retained_pcs])
            silhouette_scores[n_clustersidx, nnidx, modelidx] = silhouette_score(transformed_data[:,:num_retained_pcs],
                                                                    model.labels_,
                                                                    metric='cosine')
            if modelidx == 0:
                print(f'Done with numclusters = {n_clusters}, num nearest neighbors = {nn}: score = {silhouette_scores[n_clustersidx, nnidx, modelidx]}.3f')
            else:
                print(f'Done with numclusters = {n_clusters}, score = {silhouette_scores[n_clustersidx, nnidx, modelidx]}.3f')
print(silhouette_scores.shape)
print('Done with model fitting')

silhouette_dict = {}
silhouette_dict['possible_clustering_models'] = possible_clustering_models
silhouette_dict['num_retained_pcs'] = num_retained_pcs
silhouette_dict['possible_n_clusters'] = possible_n_clusters
silhouette_dict['possible_n_nearest_neighbors'] = possible_n_nearest_neighbors
silhouette_dict['silhouette_scores'] = silhouette_scores
silhouette_dict['shape'] = 'cluster_nn'
#with open(os.path.join(save_dir,'silhouette_scores.pkl'), 'wb') as f:
#    pickle.dump(temp, f)

## Recluster with optimal params

In [None]:
# Identify optimal parameters from the above parameter space
temp = np.where(silhouette_dict['silhouette_scores']==np.nanmax(silhouette_dict['silhouette_scores']))

n_clusters = silhouette_dict['possible_n_clusters'][temp[0][0]]
n_nearest_neighbors = silhouette_dict['possible_n_nearest_neighbors'][temp[1][0]]
num_retained_pcs = silhouette_dict['num_retained_pcs']
method = silhouette_dict['possible_clustering_models'][temp[2][0]]
print(f"clusters: {n_clusters}, nearest neighbors: {n_nearest_neighbors}, PCs: {num_retained_pcs}, method: {method}")

# Redo clustering with these optimal parameters
model = None
if method == 'Spectral':
    model = SpectralClustering(n_clusters=n_clusters,
                           affinity='nearest_neighbors',
                           n_neighbors=n_nearest_neighbors,
                           random_state=0)
else:
    model = KMeans(n_clusters=n_clusters, random_state=0)


# model = AgglomerativeClustering(n_clusters=9,
#                                 affinity='l1',
#                                 linkage='average')

model.fit(transformed_data[:,:num_retained_pcs])

temp = silhouette_score(transformed_data[:,:num_retained_pcs], model.labels_, metric='cosine')

print(f'Average silhouette score = {temp}.3f')

# Save this optimal clustering model.
# with open(os.path.join(save_dir, 'clusteringmodel.pickle'), 'wb') as f:
#     pickle.dump(model, f)

In [None]:
# Since the clustering labels are arbitrary, I rename the clusters so that the first cluster will have the most
# positive response and the last cluster will have the most negative response.
def reorder_clusters(data, sort_win_frames, rawlabels):
    uniquelabels = list(set(rawlabels))
    responses = np.nan*np.ones((len(uniquelabels),))
    for l, label in enumerate(uniquelabels):
        responses[l] = np.mean(data[rawlabels==label, sort_win_frames[0]:sort_win_frames[1]])
    temp = np.argsort(responses).astype(int)[::-1]
    temp = np.array([np.where(temp==a)[0][0] for a in uniquelabels])
    outputlabels = np.array([temp[a] for a in list(np.digitize(rawlabels, uniquelabels)-1)])
    return outputlabels
newlabels = reorder_clusters(populationdata, sortwindow_frames, model.labels_)

# Create a new variable containing all unique cluster labels
uniquelabels = list(set(newlabels))

# np.save(os.path.join(summarydictdir, dt_string+'_'+ clusterkey+'_' + 'spectral_clusterlabels.npy'), newlabels)

colors_for_cluster = plt.cm.viridis(np.linspace(0,1,len(uniquelabels)+3))

np.save(os.path.join(save_dir, 'cluster_labels.npy'), newlabels)

In [None]:
# Plot z-score activity for each cluster over time
sortwindow = [15, 100]

fig, axs = plt.subplots(len(conditions),len(uniquelabels),
                        figsize=(2*len(uniquelabels),2*len(conditions)))
if len(axs.shape) == 1:
    axs = np.expand_dims(axs, axis=0)

numroisincluster = np.nan*np.ones((len(uniquelabels),))

for c, cluster in enumerate(uniquelabels):
    for k, tempkey in enumerate(conditions):
        temp = populationdata[np.where(newlabels==cluster)[0], k*window_size:(k+1)*window_size]
        numroisincluster[c] = temp.shape[0]
        ax=axs[k, cluster]
        sortresponse = np.argsort(np.mean(temp[:,sortwindow[0]:sortwindow[1]], axis=1))[::-1]
        
        plot_extent = [tvec[0], tvec[-1], len(sortresponse), 0 ]
        im = utils.subplot_heatmap(ax, ' ', temp[sortresponse], 
                                   clims = [-cmax*heatmap_cmap_scaling, cmax*heatmap_cmap_scaling], extent_=plot_extent)

        axs[k, cluster].grid(False) 
        if k!=len(conditions)-1:

            axs[k, cluster].set_xticks([])

        axs[k, cluster].set_yticks([])
        axs[k, cluster].axvline(0, linestyle='--', color='k', linewidth=0.5)
        if flag_plot_reward_line:
            axs[k, cluster].axvline(second_event_seconds, linestyle='--', color='k', linewidth=0.5)
        if cluster==0:
            axs[k, 0].set_ylabel('%s'%(tempkey))
    axs[0, cluster].set_title('Cluster %d\n(n=%d)'%(cluster+1, numroisincluster[c]))
    
fig.text(0.5, 0.05, 'Time from cue (s)', fontsize=12,
         horizontalalignment='center', verticalalignment='center', rotation='horizontal')
fig.tight_layout()

fig.subplots_adjust(wspace=0.1, hspace=0.1)
fig.subplots_adjust(left=0.03)
fig.subplots_adjust(right=0.93)
fig.subplots_adjust(bottom=0.2)
fig.subplots_adjust(top=0.83)                                    

cbar = fig.colorbar(im, ax = axs, shrink = 0.7)
cbar.ax.set_ylabel('Z-Score Activity', fontsize=13);

if flag_save_figs:
    plt.savefig(os.path.join(save_dir, 'cluster_heatmap.png'))
    plt.savefig(os.path.join(save_dir, 'cluster_heatmap.pdf'))

In [None]:
tvec_convert_dict = {}
for i in range(len(tvec)):
    tvec_convert_dict[i] = tvec[i] 

In [None]:
# Plot amount of fluorescence normalized for each cluster by conditions over time
fig, axs = plt.subplots(1,len(uniquelabels),
                        figsize=(3*len(uniquelabels),1.5*len(conditions)))

for c, cluster in enumerate(uniquelabels):

    for k, tempkey in enumerate(conditions):
        temp = populationdata[np.where(newlabels==cluster)[0], k*window_size:(k+1)*window_size]
        numroisincluster[c] = temp.shape[0]
        sortresponse = np.argsort(np.mean(temp[:,sortwindow[0]:sortwindow[1]], axis=1))[::-1]
        sns.lineplot(x="variable", y="value",data = pd.DataFrame(temp[sortresponse]).rename(columns=tvec_convert_dict).melt(),
                    ax = axs[cluster],
                    palette=plt.get_cmap('coolwarm'),label = tempkey,legend = False)
        axs[cluster].grid(False)  
        axs[cluster].axvline(0, linestyle='--', color='k', linewidth=0.5)
        axs[cluster].spines['right'].set_visible(False)
        axs[cluster].spines['top'].set_visible(False)
        if cluster==0:
            axs[cluster].set_ylabel('Normalized fluorescence')
        else:
            axs[cluster].set_ylabel('')
        axs[cluster].set_xlabel('')
    axs[cluster].set_title('Cluster %d\n(n=%d)'%(cluster+1, numroisincluster[c]))
    axs[0].legend()
fig.text(0.5, 0.05, 'Time from cue (s)', fontsize=12,
         horizontalalignment='center', verticalalignment='center', rotation='horizontal')
fig.tight_layout()

fig.subplots_adjust(wspace=0.1, hspace=0.1)
fig.subplots_adjust(left=0.03)
fig.subplots_adjust(right=0.93)
fig.subplots_adjust(bottom=0.2)
fig.subplots_adjust(top=0.83)

if flag_save_figs:
    plt.savefig(os.path.join(save_dir, 'cluster_roiAvg_traces.png'))
    plt.savefig(os.path.join(save_dir, 'cluster_roiAvg_traces.pdf'))

In [None]:
# Perform TSNE on newly defined clusters
num_clusterpairs = len(uniquelabels)*(len(uniquelabels)-1)/2

numrows = int(np.ceil(num_clusterpairs**0.5))
numcols = int(np.ceil(num_clusterpairs/np.ceil(num_clusterpairs**0.5)))
fig, axs = plt.subplots(numrows, numcols, figsize=(3*numrows, 3*numcols))

tempsum = 0
for c1, cluster1 in enumerate(uniquelabels):
    for c2, cluster2 in enumerate(uniquelabels):
        if cluster1>=cluster2:
            continue

        temp1 = transformed_data[np.where(newlabels==cluster1)[0], :num_retained_pcs]
        temp2 = transformed_data[np.where(newlabels==cluster2)[0], :num_retained_pcs]
        X = np.concatenate((temp1, temp2), axis=0)

        tsne = TSNE(n_components=2, init='random',
                    random_state=0, perplexity=np.sqrt(X.shape[0]))
        Y = tsne.fit_transform(X)

        if numrows*numcols==1:
            ax = axs
        else:
            ax = axs[int(tempsum/numcols),
                     abs(tempsum - int(tempsum/numcols)*numcols)]
        ax.scatter(Y[:np.sum(newlabels==cluster1),0],
                   Y[:np.sum(newlabels==cluster1),1],
                   color=colors_for_cluster[cluster1], label='Cluster %d'%(cluster1+1), alpha=1)
        ax.scatter(Y[np.sum(newlabels==cluster1):,0],
                   Y[np.sum(newlabels==cluster1):,1],
                   color=colors_for_cluster[cluster2+3], label='Cluster %d'%(cluster2+1), alpha=1)

        ax.set_xlabel('tsne dimension 1')
        ax.set_ylabel('tsne dimension 2')
        ax.legend()
        tempsum += 1

        fig.tight_layout()

    