# This notebook demonstrates the use of Sparse Component Analysis (SCA) using a center-out reaching dataset

Data are trial-averaged firing rates of individual neurons collected from motor cortex (primary motor and dorsal premotor) of a non-human primate.

For this task, the monkey began each trial by touching a central touch-point. A peripheral target was shown, and after a variable, unpredictable delay period, a go cue was delivered. After capturing the peripheral target, the monkey recieved a juice reward, and returned his hand to the touch-point to begin the next trial (see [Lara et al., 2018](https://pubmed.ncbi.nlm.nih.gov/30132759/))

Data have been aligned to target onset, outward movement onset, and return reach onset.

'data' is a Condition x Neuron x Time tensor of trial-averaged firing rates.


## Import various packages

In [None]:
import numpy as np
from scipy import io

import plotly.graph_objects as go
import plotly.express as px
import matplotlib.pyplot as plt

from plotly.subplots import make_subplots

from sca.models import SCA, WeightedPCA
from sca.util import get_sample_weights, get_accuracy

from sklearn.decomposition import FactorAnalysis
from sklearn.decomposition import FastICA
from sklearn.decomposition import SparsePCA

from nimfa import Nmf, Snmf

## local path to the directory
    - fill in the path to whererever you've saved 'datasets'

In [None]:
# local_path = ## FILL IN PATH ##

# local_path = '../datasets/'
# local_path = '../data/'

load_folder = '../data/'

monkName = 'Balboa'

data=io.loadmat(load_folder + monkName + '_outAndBack_redYellowConds_rawRates.mat')


## load data

In [None]:
# data=io.loadmat(local_path + 'monkeyB_reaching.mat')

# pull out the PSTHs
# data is a C x N x T tensor 
data_array=data['data']

## Preprocess data
We're going to perform two standard (to this dataset) pre-processing steps:

   1. soft-normalize the rates (firing rate range across all conditions and times + 5)
        This step has the effect of preventing the recovered factors from being dominated by a few high firing-rate neurons while also minimizing the impact of very low firing rate neurons

   2. subtract off the cross-condition mean
        The largest signal (in terms of variance) in motor cortex during reaching is a condition-invariant 'trigger signal' (see [Kaufman et al., 2016](https://pubmed.ncbi.nlm.nih.gov/27761519/) for further discussion).
        Because we are often more interested in the condition-specific signals, we typically subtract off this trigger signal.


In [None]:
#Downsample data to speed up SCA (using a factor of 10 here)
data_downsamp=data_array[:,:,np.arange(0,data_array.shape[2],10)]

# pull out some useful numbers
numConds,numN,trlDur = np.shape(data_downsamp)

#Concatenate all the conditions (so the matrix is size N x TC instead of C x N x T)
data_concat=data_downsamp.swapaxes(0,1).reshape([data_downsamp.shape[1],data_downsamp.shape[0]*data_downsamp.shape[2]])

#fr range
fr_range=np.ptp(data_concat,axis=1)[:,None]

# make a time mask
timeMask = np.tile(np.arange(trlDur),(1,numConds)).T.flatten()

# define the times we want to use for sca/pca
# target on: 20
# move on:   77
# return:    200
trainTimes = np.arange(20,230)

# define a 'training mask' for convenience 
trainMask = np.in1d(timeMask,trainTimes)

#Subtract cross-condition mean
# data_scm=data_downsamp-np.mean(data_downsamp,axis=0)[None,:,:]
data_scm=data_downsamp


#Concatenate all the conditions (so the matrix is size N x TC instead of C x N x T)
data_concat=data_scm.swapaxes(0,1).reshape([data_scm.shape[1],data_scm.shape[0]*data_scm.shape[2]])

#Soft normalize (divide each neuron by its fr range + 5)
data_norm=data_concat/(fr_range+5)

data_snm_norm=data_norm-np.mean(data_norm,axis=1)[:,None]


# rename the data for convenience
#Note that model requires (T x N) input rather than (N x T), which is why there are transposes below
fit_data=np.copy(data_snm_norm.T)
fit_data_pos=np.copy(data_norm.T)


# how much to weight each timestep (used by)
sample_weights=get_sample_weights(fit_data)


## Define some SCA parameters
SCA has three hyperparameters:

   number of requested factors (R_est)
   lam_orthog: determines the degree to which non-orthogonal dimensions are penalized
   lam_sparse: determines the degree to which non-sparse factors are punished.

For this analysis, we are going to use the default values for lam_orthog and lam_sparse.

Across all examined datasets, the recovered SCA factors vary litte across different hyperparameter choices

In [None]:
# number of dimensions to find
R_est=50


In [None]:
# fit SCA
sca=SCA(n_components=R_est, n_epochs=5000)
sca_latent=sca.fit_transform(X=fit_data)
# sca_latent=sca.fit_transform(X=fit_data, sample_weight=sample_weights)


# ##### PCA
# pca = WeightedPCA(n_components = R_est)
# pca_latent = pca.fit_transform(fit_data,sample_weight=sample_weights)

# ##### PCA+Varimax
# pca_var = WeightedPCA(n_components = R_est, rotate=True)
# pca_var_latent = pca_var.fit_transform(fit_data,sample_weight=sample_weights)

# #### Factor Analysis
# fa = FactorAnalysis(n_components= R_est)
# fa_latent = fa.fit_transform(fit_data)

# #### NMF
# nmf = Nmf(fit_data_pos,rank=R_est)
# nmf_fit = nmf()
# nmf_latent=np.array(nmf_fit.basis())

#### ICA
# ica = FastICA(R_est)
ica = FastICA(R_est,whiten='unit-variance')
ica_latent=ica.fit_transform(fit_data)

#### Sparse NMF
snmf = Snmf(fit_data_pos,rank=R_est,version='l')
snmf_fit = snmf()
snmf_latent=np.array(snmf_fit.basis())

#### Sparse PCA
# spca = SparsePCA(n_components= R_est)
# spca_latent = spca.fit_transform(fit_data)

In [None]:
# snmf = Snmf(fit_data_pos,rank=R_est,version='l',n_run=2)
# snmf_fit = snmf()
# snmf_latent=np.array(snmf_fit.basis())

In [None]:
# snmf = Snmf(fit_data_pos,rank=R_est,version='l',beta=1e-3)
# snmf_fit = snmf()
# snmf_latent=np.array(snmf_fit.basis())

In [None]:
snmf_latent=snmf_latent*np.linalg.norm(snmf_fit.coef(),axis=1)[None,:]

In [None]:
#Create list that has latents from all comparison methods
# latents=[sca_latent,pca_latent,fa_latent,nmf_latent, spca_latent,ica_latent,snmf_latent,pca_var_latent]
# latent_names=['SCA','PCA','FA','NMF','SPCA','ICA','SNMF','PCA+Varimax']

latents=[sca_latent,ica_latent,snmf_latent]
latent_names=['SCA','ICA','SNMF']



num_comparisons=len(latents)

In [None]:
# # Reshape latents

rs_latents=[np.reshape(np.array(latents[i]),(-1,8,R_est),order = 'F') for i in range(num_comparisons)]

# ### Ordered latents

rs_latents_ordered=[]

# for rs_latent in rs_latents: 

#     # calculate across condition variance
#     var = np.var(rs_latent,axis = 1)

#     # find peak occupancy of each dimension
#     pkIdx = np.argmax(var,axis = 0)

#     # define plotting order
#     order = np.argsort(pkIdx)

#     # resort ssa_latents by time of maximum occupancy
#     rs_latents_ordered.append(rs_latent[:,:,order]) 

In [None]:
occThresh_fract=.07

rs_sca_latent=rs_latents[0]

# calculate across condition variance
sca_var = np.var(rs_sca_latent,axis = 1)

# calculate max occupancy for each dimension
maxOcc_sca = np.max(sca_var, axis=0)

# set occupancy threshold
occThresh = np.max(maxOcc_sca) * occThresh_fract

# initialize a vector for sorting the dimensions and move unoccupied dimensions to the end of the list
sca_order = np.zeros(R_est)

# detect low occupancy dimensions
lowOccDims = np.argwhere(maxOcc_sca < occThresh)[:,0]
highOccDims = np.argwhere(maxOcc_sca >= occThresh)[:,0]
numLowOcc = lowOccDims.shape[0]
numHighOcc = highOccDims.shape[0]

# move them to the end of the list
sca_order[-numLowOcc:] = lowOccDims

# grab the rest of the dimensions and sort them by time of peak occupancy
sca_var = sca_var[:,highOccDims]

# find peak occupancy of each dimension
pkIdx = np.argmax(sca_var,axis = 0)

# define plotting order of the high occupancy dimensions and add to list
highOrder = np.argsort(pkIdx)
sca_order[:numHighOcc] = highOccDims[highOrder]
sca_order = np.copy(sca_order.astype(int))

# resort ssa_latents by time of maximum occupancy
rs_sca_latent = rs_sca_latent[:,:,sca_order]

rs_latents_ordered.append(rs_sca_latent)

In [None]:
for i,rs_latent in enumerate(rs_latents): 
    if i>0:

        # calculate across condition variance
        var = np.var(rs_latent,axis = 1)

        # find peak occupancy of each dimension
        pkIdx = np.argmax(var,axis = 0)

        # define plotting order
        order = np.argsort(pkIdx)

        # resort ssa_latents by time of maximum occupancy
        rs_latents_ordered.append(rs_latent[:,:,order]) 

In [None]:
#Calculate max latent values to create y limits for plotting
ymaxes=[1.01*np.max(np.abs(latents[i])) for i in range(num_comparisons)]

In [None]:
cMap  = ['#5e0044', '#6f144e', '#812858', '#933c62', '#a5506d', '#b76477', '#c97881', '#db8c8c']

# define some useful time points
tgt_idx=20
move_idx=77
ret_idx=200

In [None]:
# plot_latents=[0,5,6]
plot_latents=[0,1,2]

titles=[latent_names[p]for p in plot_latents]

num_plotted_latents=len(plot_latents)

fig = make_subplots(rows=R_est,cols = num_plotted_latents,shared_xaxes = True,vertical_spacing = 0,subplot_titles=titles)

for k,l in enumerate(plot_latents):
    for ii in range(R_est):

        for jj in range(numConds):
            latTrace = go.Scatter(y = rs_latents_ordered[l][:,jj,ii], line = go.scatter.Line(color = cMap[jj],width = 2.5),showlegend = False)
            fig.add_trace(latTrace,row = ii+1,col=k+1)

        # mark important task events
        fig.add_vline(x = tgt_idx,row = ii+1,col = k+1, line_color = 'black')
        fig.add_vline(x = move_idx,row = ii+1,col = k+1, line_color = 'black')
        fig.add_vline(x = ret_idx,row = ii+1,col = k+1, line_color = 'black')

    #clean up
        fig.update_yaxes(showgrid = False,zeroline = False,visible = False,range = [-ymaxes[l],ymaxes[l]],row=ii+1,col=k+1)

    # clean up
    fig.update_layout(height = 2000,width =1500,
                      paper_bgcolor = 'white',
                      plot_bgcolor = 'white')    


    fig.update_xaxes(color = 'black',showgrid = False,zeroline = False,visible = False)
    fig.update_xaxes(color = 'black',showgrid = False,zeroline = False,
                     ticks = 'outside',tickvals = [0,50],ticktext = ['0','500'],visible = True,row = R_est,col = k+1)

# fig.show()
figDir='/Users/jig289/Dropbox/SCA/SCA_Stuff/sca_resubmit/figures/'


fig.write_image(figDir + monkName + '_CO_50dims_defaultsca_noweight.pdf')

