# This notebook demonstrates the use of Sparse Component Analysis (SCA) using a center-out reaching dataset

Data are trial-averaged firing rates of individual neurons collected from motor cortex (primary motor and dorsal premotor) of a non-human primate.

For this task, the monkey began each trial by touching a central touch-point. A peripheral target was shown, and after a variable, unpredictable delay period, a go cue was delivered. After capturing the peripheral target, the monkey recieved a juice reward, and returned his hand to the touch-point to begin the next trial (see [Lara et al., 2018](https://pubmed.ncbi.nlm.nih.gov/30132759/))

Data have been aligned to target onset, outward movement onset, and return reach onset.

'data' is a Condition x Neuron x Time tensor of trial-averaged firing rates.


## Import various packages

In [None]:
import numpy as np
from scipy import io

import plotly.graph_objects as go
import plotly.express as px
import matplotlib.pyplot as plt

from plotly.subplots import make_subplots

from sca.models import SCA, WeightedPCA
from sca.util import get_sample_weights, get_accuracy

from sklearn.decomposition import FactorAnalysis
from sklearn.decomposition import FastICA
from sklearn.decomposition import SparsePCA

from nimfa import Nmf, Snmf

## local path to the directory
    - fill in the path to whererever you've saved 'datasets'

In [None]:
# local_path = ## FILL IN PATH ##

# local_path = '../datasets/'
# local_path = '../data/'

load_folder = '/Users/sherryan/Desktop/sca-main/datasets/'
monkName = 'Balboa'

data=io.loadmat(load_folder + monkName + '_outAndBack_redYellowConds_rawRates.mat')


## load data

In [None]:
# data=io.loadmat(local_path + 'monkeyB_reaching.mat')

# pull out the PSTHs
# data is a C x N x T tensor 
data_array=data['data']

## Preprocess data
We're going to perform two standard (to this dataset) pre-processing steps:

   1. soft-normalize the rates (firing rate range across all conditions and times + 5)
        This step has the effect of preventing the recovered factors from being dominated by a few high firing-rate neurons while also minimizing the impact of very low firing rate neurons

   2. subtract off the cross-condition mean
        The largest signal (in terms of variance) in motor cortex during reaching is a condition-invariant 'trigger signal' (see [Kaufman et al., 2016](https://pubmed.ncbi.nlm.nih.gov/27761519/) for further discussion).
        Because we are often more interested in the condition-specific signals, we typically subtract off this trigger signal.


In [None]:
#Downsample data to speed up SCA (using a factor of 10 here)
data_downsamp=data_array[:,:,np.arange(0,data_array.shape[2],10)]

# pull out some useful numbers
numConds,numN,trlDur = np.shape(data_downsamp)

#Concatenate all the conditions (so the matrix is size N x TC instead of C x N x T)
data_concat=data_downsamp.swapaxes(0,1).reshape([data_downsamp.shape[1],data_downsamp.shape[0]*data_downsamp.shape[2]])

#fr range
fr_range=np.ptp(data_concat,axis=1)[:,None]

# make a time mask
timeMask = np.tile(np.arange(trlDur),(1,numConds)).T.flatten()

# define the times we want to use for sca/pca
# target on: 20
# move on:   77
# return:    200
trainTimes = np.arange(20,230)

# define a 'training mask' for convenience 
trainMask = np.in1d(timeMask,trainTimes)

#Subtract cross-condition mean
# data_scm=data_downsamp-np.mean(data_downsamp,axis=0)[None,:,:]
data_scm=data_downsamp


#Concatenate all the conditions (so the matrix is size N x TC instead of C x N x T)
data_concat=data_scm.swapaxes(0,1).reshape([data_scm.shape[1],data_scm.shape[0]*data_scm.shape[2]])

#Soft normalize (divide each neuron by its fr range + 5)
data_norm=data_concat/(fr_range+5)

data_snm_norm=data_norm-np.mean(data_norm,axis=1)[:,None]


# rename the data for convenience
#Note that model requires (T x N) input rather than (N x T), which is why there are transposes below
fit_data=np.copy(data_snm_norm.T)
fit_data_pos=np.copy(data_norm.T)


# how much to weight each timestep (used by)
sample_weights=get_sample_weights(fit_data)


## Define some SCA parameters
SCA has three hyperparameters:

   number of requested factors (R_est)
   lam_orthog: determines the degree to which non-orthogonal dimensions are penalized
   lam_sparse: determines the degree to which non-sparse factors are punished.

For this analysis, we are going to use the default values for lam_orthog and lam_sparse.

Across all examined datasets, the recovered SCA factors vary litte across different hyperparameter choices

In [None]:
# number of dimensions to find
R_est=15
lam_sparse=.1


In [None]:
# fit SCA
sca=SCA(n_components=R_est,lam_sparse=lam_sparse, n_epochs=5000)
sca_latent=sca.fit_transform(X=fit_data, sample_weight=sample_weights)

# fit SCA unweighted
sca_uw=SCA(n_components=R_est,lam_sparse=lam_sparse, n_epochs=5000)
sca_uw_latent=sca_uw.fit_transform(X=fit_data)


##### PCA
pca = WeightedPCA(n_components = R_est)
pca_latent = pca.fit_transform(fit_data)


##### PCA+Varimax
pca_var = WeightedPCA(n_components = R_est, rotate=True)
pca_var_latent = pca_var.fit_transform(fit_data)


#### Factor Analysis
fa = FactorAnalysis(n_components= R_est)
fa_latent = fa.fit_transform(fit_data)

#### NMF
nmf = Nmf(fit_data_pos,rank=R_est)
nmf_fit = nmf()
nmf_latent=np.array(nmf_fit.basis())

#### ICA
# ica = FastICA(R_est)
ica = FastICA(R_est,whiten='unit-variance')
ica_latent=ica.fit_transform(fit_data)

#### Sparse NMF
snmf = Snmf(fit_data_pos,rank=R_est,version='l')
snmf_fit = snmf()
snmf_latent=np.array(snmf_fit.basis())

#### Sparse PCA
spca = SparsePCA(n_components= R_est)
spca_latent = spca.fit_transform(fit_data)

In [None]:
#Create list that has latents from all comparison methods
latents=[sca_latent,sca_uw_latent,pca_latent,fa_latent,nmf_latent, spca_latent,ica_latent,snmf_latent,pca_var_latent]
latent_names=['SCA','SCA_Unweighted','PCA','FA','NMF','SPCA','ICA','SNMF','PCA+Varimax']

num_comparisons=len(latents)

In [None]:
# Reshape latents

rs_latents=[np.reshape(np.array(latents[i]),(-1,8,R_est),order = 'F') for i in range(num_comparisons)]

### Ordered latents

## order SCA latents for plotting 
rs_latents_ordered=[]

for i,rs_latent in enumerate(rs_latents): 
    
    # calculate across condition variance
    var = np.var(rs_latent,axis = 1)
    # find peak occupancy of each dimension
    pkIdx = np.argmax(var,axis = 0)
    # define plotting order
    order = np.argsort(pkIdx)
    
    # resort ssa_latents by time of maximum occupancy
    rs_latents_ordered.append(rs_latent[:,:,order]) 

In [None]:
# Reshape latents

### Ordered latents

## order SCA latents for plotting 
sca_latent_order = [0,4,14,1,2,3,5,6,7,12,13,8,9,10,11]
sca_latent = rs_latents_ordered[0][:, :, sca_latent_order]
# sca_latent = rs_latents[0]
rs_latents_ordered2=[sca_latent]

for i,rs_latent in enumerate(rs_latents): 
    if i == 0: ## this is SCA
        continue
    sca_flat = sca_latent.reshape(-1, sca_latent.shape[-1])
    other_flat = rs_latent.reshape(-1, rs_latent.shape[-1])
    corr = np.corrcoef(sca_flat.T, other_flat.T)[:sca_flat.shape[1], sca_flat.shape[1]:]
    abs_corr = np.abs(corr)
    # greedy match: for each SCA dim, find best latent
    order = np.argmax(abs_corr, axis=1)
    signs = np.sign(corr[np.arange(sca_flat.shape[1]), order])
    reordered = np.zeros_like(rs_latent)
    for j in range(rs_latent.shape[-1]):
        reordered[:, :, j] = rs_latent[:, :, order[j]] * signs[j]
    rs_latents_ordered2.append(reordered)


In [None]:
rs_latent.shape

In [None]:
#Calculate max latent values to create y limits for plotting
ymaxes=[1.05*np.max(np.abs(latents[i])) for i in range(num_comparisons)]

In [None]:
cMap  = ['#5e0044', '#6f144e', '#812858', '#933c62', '#a5506d', '#b76477', '#c97881', '#db8c8c']

# define some useful time points
tgt_idx=20
move_idx=77
ret_idx=200

In [None]:
#Write which methods to plot from the previous list of latents 
plot_latents=[0,1,2,3,4]
titles=[latent_names[p]for p in plot_latents]


num_plotted_latents=len(plot_latents)

fig = make_subplots(rows=R_est,cols = num_plotted_latents,shared_xaxes = True,vertical_spacing = 0,subplot_titles=titles)

for k,l in enumerate(plot_latents):
    for ii in range(R_est):
        for jj in range(numConds):
            latTrace = go.Scatter(y = rs_latents_ordered2[l][:,jj,ii], line = go.scatter.Line(color = cMap[jj],width = 2.5),showlegend = False)
            fig.add_trace(latTrace,row = ii+1,col=k+1)

        # mark important task events
        fig.add_vline(x = tgt_idx,row = ii+1,col = k+1, line_color = 'black')
        fig.add_vline(x = move_idx,row = ii+1,col = k+1, line_color = 'black')
        fig.add_vline(x = ret_idx,row = ii+1,col = k+1, line_color = 'black')

    #clean up
        fig.update_yaxes(showgrid = False,zeroline = False,visible = False,range = [-ymaxes[l],ymaxes[l]],row=ii+1,col=k+1)

    # clean up
    fig.update_layout(height = 2000,width =1500,
                      paper_bgcolor = 'white',
                      plot_bgcolor = 'white')    


    fig.update_xaxes(color = 'black',showgrid = False,zeroline = False,visible = False)
    fig.update_xaxes(color = 'black',showgrid = False,zeroline = False,
                     ticks = 'outside',tickvals = [0,50],ticktext = ['0','500'],visible = True,row = R_est,col = k+1)

fig.show()

# figDir='/Users/jig289/Dropbox/SCA/SCA_Stuff/sca_resubmit/figures/'

fig.write_image( monkName + '_CO_15dims_plot1_v2_ordered_invFirst.pdf')

# fig.write_image(figDir + monkName + '_CO_15dims_plot1_nonweighted.pdf')
# fig.write_image(figDir + monkName + '_CO_15dims_plot1_weighted_defaultsparsity_3000ep.pdf')

In [None]:
plot_latents=[0,5,6,7,8]
titles=[latent_names[p]for p in plot_latents]

num_plotted_latents=len(plot_latents)

fig = make_subplots(rows=R_est,cols = num_plotted_latents,shared_xaxes = True,vertical_spacing = 0,subplot_titles=titles)

for k,l in enumerate(plot_latents):
    for ii in range(R_est):

        for jj in range(numConds):
            latTrace = go.Scatter(y = rs_latents_ordered2[l][:,jj,ii], line = go.scatter.Line(color = cMap[jj],width = 2.5),showlegend = False)
            fig.add_trace(latTrace,row = ii+1,col=k+1)

        # mark important task events
        fig.add_vline(x = tgt_idx,row = ii+1,col = k+1, line_color = 'black')
        fig.add_vline(x = move_idx,row = ii+1,col = k+1, line_color = 'black')
        fig.add_vline(x = ret_idx,row = ii+1,col = k+1, line_color = 'black')

    #clean up
        fig.update_yaxes(showgrid = False,zeroline = False,visible = False,range = [-ymaxes[l],ymaxes[l]],row=ii+1,col=k+1)

    # clean up
    fig.update_layout(height = 2000,width =1500,
                      paper_bgcolor = 'white',
                      plot_bgcolor = 'white')    


    fig.update_xaxes(color = 'black',showgrid = False,zeroline = False,visible = False)
    fig.update_xaxes(color = 'black',showgrid = False,zeroline = False,
                     ticks = 'outside',tickvals = [0,50],ticktext = ['0','500'],visible = True,row = R_est,col = k+1)

# fig.show()
# figDir='/Users/jig289/Dropbox/SCA/SCA_Stuff/sca_resubmit/figures/'

# fig.write_image(figDir + monkName + '_CO_15dims_plot2_nonweighted.pdf')
# fig.write_image(figDir + monkName + '_CO_15dims_plot2_weighted.pdf')
fig.write_image(monkName + '_CO_15dims_plot2_v2_ordered.pdf')


In [None]:
plot_latents=[0,7]
titles=[latent_names[p]for p in plot_latents]

num_plotted_latents=len(plot_latents)

fig = make_subplots(rows=R_est,cols = num_plotted_latents,shared_xaxes = True,vertical_spacing = 0,subplot_titles=titles)

for k,l in enumerate(plot_latents):
    for ii in range(R_est):

        for jj in range(numConds):
            latTrace = go.Scatter(y = rs_latents_ordered[l][:,jj,ii], line = go.scatter.Line(color = cMap[jj],width = 2.5),showlegend = False)
            fig.add_trace(latTrace,row = ii+1,col=k+1)

        # mark important task events
        fig.add_vline(x = tgt_idx,row = ii+1,col = k+1, line_color = 'black')
        fig.add_vline(x = move_idx,row = ii+1,col = k+1, line_color = 'black')
        fig.add_vline(x = ret_idx,row = ii+1,col = k+1, line_color = 'black')

    #clean up
        fig.update_yaxes(showgrid = False,zeroline = False,visible = False,range = [-ymaxes[l],ymaxes[l]],row=ii+1,col=k+1)

    # clean up
    fig.update_layout(height = 2000,width =1500,
                      paper_bgcolor = 'white',
                      plot_bgcolor = 'white')    


    fig.update_xaxes(color = 'black',showgrid = False,zeroline = False,visible = False)
    fig.update_xaxes(color = 'black',showgrid = False,zeroline = False,
                     ticks = 'outside',tickvals = [0,50],ticktext = ['0','500'],visible = True,row = R_est,col = k+1)

# fig.show()
figDir='/Users/jig289/Dropbox/SCA/SCA_Stuff/sca_resubmit/figures/'

fig.write_image(figDir + monkName + '_CO_6dims_nonweighted_both_pt1.pdf')

In [None]:
#.05
get_accuracy(sca,fit_data)

In [None]:
#.1
get_accuracy(sca,fit_data)

In [None]:
# figDir='/Users/jig289/Dropbox/SCA/SCA_Stuff/sca_resubmit/figures/'

# fig.write_image(figDir + monkName + '_CO.pdf')

In [None]:
loadings=np.array(snmf_fit.coef())
plt.imshow(np.corrcoef(loadings))
plt.colorbar()

In [None]:
snmf.rss()

In [None]:
np.sum(np.array(snmf.residuals())**2)

In [None]:
get_accuracy(sca,fit_data)

In [None]:
from sklearn.metrics import r2_score

r2_score(fit_data_pos,np.array(snmf.fitted()),multioutput='variance_weighted')

In [None]:
#15 dims

In [None]:
get_accuracy(sca,fit_data)

In [None]:
r2_score(fit_data_pos,np.array(snmf.fitted()),multioutput='variance_weighted')

In [None]:
plt.plot(np.array((snmf_fit.coef()))[10,:])


In [None]:
plt.imshow(np.array((snmf_fit.coef())))
plt.colorbar()

In [None]:
snmf = Snmf(fit_data_pos,rank=R_est,version='l')
snmf_fit = snmf()
snmf_latent=np.array(snmf_fit.basis())

In [None]:
plt.imshow(np.array((snmf_fit.coef())))
plt.colorbar()

In [None]:
plt.plot(np.array((snmf_fit.coef()))[10,:])


In [None]:
snmf = Snmf(fit_data_pos,rank=R_est,version='l',beta=1e-2)
snmf_fit = snmf()
snmf_latent=np.array(snmf_fit.basis())

In [None]:
plt.imshow(np.array((snmf_fit.coef())))
plt.colorbar()

In [None]:
snmf = Snmf(fit_data_pos,rank=R_est,version='r',beta=1e-2)
snmf_fit = snmf()
snmf_latent=np.array(snmf_fit.basis())

In [None]:
plt.imshow(np.array((snmf_fit.coef())))
plt.colorbar()

In [None]:
snmf = Snmf(fit_data_pos,rank=R_est,version='r',beta=1)
snmf_fit = snmf()
snmf_latent=np.array(snmf_fit.basis())

In [None]:
plt.imshow(np.array((snmf_fit.coef())))
plt.colorbar()

In [None]:
plt.plot(np.array((snmf_fit.coef()))[10,:])


In [None]:
snmf = Snmf(fit_data_pos,rank=R_est,version='l',beta=1)
snmf_fit = snmf()
snmf_latent=np.array(snmf_fit.basis())

In [None]:
plt.imshow(np.array((snmf_fit.coef())))
plt.colorbar()

In [None]:
plt.plot(snmf_latent[:,0])

In [None]:
snmf = Snmf(fit_data_pos,rank=R_est,version='l',beta=1e-4)
snmf_fit = snmf()
snmf_latent=np.array(snmf_fit.basis())

In [None]:
plt.plot(snmf_latent[:,0])

In [None]:
# latents=[pca_latent,fa_latent,nmf_latent,ica_latent,snmf_latent,pca_var_latent,sca_latent]

# latents=[sca_latent,pca_latent,nmf_latent]

# latent_name=['ica','snmf','var','sca']

latents=[spca_latent,ica_latent,snmf_latent,pca_var_latent]


num_comparisons=len(latents)

In [None]:
ymaxes=[1.05*np.max(np.abs(latents[i])) for i in range(num_comparisons)]

In [None]:
rs_latents=[np.reshape(np.array(latents[i]),(-1,8,R_est),order = 'F') for i in range(num_comparisons)]


### Ordered latents

rs_latents_ordered=[]

for rs_latent in rs_latents: 

    # calculate across condition variance
    var = np.var(rs_latent,axis = 1)

    # find peak occupancy of each dimension
    pkIdx = np.argmax(var,axis = 0)

    # define plotting order
    order = np.argsort(pkIdx)

    # resort ssa_latents by time of maximum occupancy
    rs_latents_ordered.append(rs_latent[:,:,order]) 

In [None]:
# range for y axis
# yRange = [-1.8,1.8]

num_plotted_latents=len(latents)

yRange = [-2.3,2.3]


fig = make_subplots(rows=R_est,cols = num_plotted_latents,shared_xaxes = True,vertical_spacing = 0)

for k in range(num_plotted_latents):
    for ii in range(R_est):

        for jj in range(numConds):
            latTrace = go.Scatter(y = rs_latents_ordered[k][:,jj,ii], line = go.scatter.Line(color = cMap[jj],width = 2.5),showlegend = False)
            fig.add_trace(latTrace,row = ii+1,col=k+1)

        # mark important task events
        fig.add_vline(x = tgt_idx,row = ii+1,col = k+1, line_color = 'black')
        fig.add_vline(x = move_idx,row = ii+1,col = k+1, line_color = 'black')
        fig.add_vline(x = ret_idx,row = ii+1,col = k+1, line_color = 'black')

        # add a vertical line for scale
#         scaleLine = go.Scatter(x = [0,0],y = [-ymaxes[k],ymaxes[k]],showlegend = False,mode = 'lines',
#                                 line = go.scatter.Line(color = 'black',width = 5))
#         fig.add_trace(scaleLine,row = ii+1,col = k+1)

    #clean up
        fig.update_yaxes(showgrid = False,zeroline = False,visible = False,range = [-ymaxes[k],ymaxes[k]],row=ii+1,col=k+1)

    # clean up
    fig.update_layout(height = 2000,width =1000,title = 'SCA',title_font_color = 'black',
                      paper_bgcolor = 'white',
                      plot_bgcolor = 'white')
    
    
#     fig.update_yaxes(showgrid = False,zeroline = False,visible = False,range = yRange)
#     next(fig.select_yaxes(row=ii+1, col= k+1)).update(range=[-ymaxes[k],ymaxes[k]])

    fig.update_xaxes(color = 'black',showgrid = False,zeroline = False,visible = False)
    fig.update_xaxes(color = 'black',showgrid = False,zeroline = False,
                     ticks = 'outside',tickvals = [0,50],ticktext = ['0','500'],visible = True,row = R_est,col = k+1)

fig.show()

## Fit SCA

In [None]:
# fit model and 
# sca = SCA(n_components=R_est,lam_sparse=.1,lam_orthog=95)
sca = SCA(n_components=R_est)#,lam_sparse=.05,lam_orthog=95)

# sca.fit(X=fit_data[trainMask,:],sample_weight = sample_weights[trainMask])
sca.fit(X=fit_data,sample_weight = sample_weights)



# project all data into the sca dimensions
sca_latent=sca.transform(fit_data)


# plot loss
plt.plot(sca.losses);
plt.title('loss');

# display fraction of total variance explained by the SCA factors
print('SCA R2: ' + str(sca.r2_score))

In [None]:
sca_latent.shape

In [None]:
# plot the dot product of the learned U weights (which are not constrained to be orthonormal)
U = sca.params['U']
U_dProd = U.T@U

# plot
fig = px.imshow(U_dProd)
fig.update_layout(height = 300, width = 300,title = 'dot product of U')

In [None]:
# plot the dot product of the learned V weights (which are constrained to have a norm of 1, and may or may not be constrained to be orthogonal, depending on the 'orthFlg')
# V = sca.params['V']
# V_dProd = V@V.T

# # plot
# fig = px.imshow(abs(V_dProd),range_color = [0,1])
# fig.update_layout(height = 300, width = 300,title = 'dot product of V')



product=sca.params['V']@sca.params['V'].T
fig = go.Figure(go.Heatmap(z=product, zmin=-0.5, zmax=0.5, colorscale='RdBu'))
fig.update_layout(height=500, width=500)
# fig.write_html(figDir + monkName + '_scaW_lamSparse_' + str(lam_sparse) + '_' + str(R_est) + 'dims.html')
fig.show()

### Fit wPCA for comparison
   We're going to use a variant of PCA (weighted PCA), which is a more direct comparison to SCA (they primarily differ in the use of sparsity, rather than also differing due to sample weighting)

   The weighting ensures that low-firing rate time periods are not ignored in favor of high firing-rate periods

In [None]:
#Fit model and project data into the PC space
wpca = WeightedPCA(n_components=R_est)
# pca_latent = wpca.fit_transform(X=fit_data[trainMask,:], sample_weight=sample_weights[trainMask])
wpca.fit(X=fit_data[trainMask,:], sample_weight=sample_weights[trainMask])



pca_latent = wpca.transform(fit_data)

# get the factor weights
U_pca = wpca.params['U']
V_pca = wpca.params['V']

# calculate reconstruction R2
[pca_r2_score, pca_reconstruction_loss]=get_accuracy(wpca,fit_data[trainMask,:],sample_weights[trainMask])

# display results
print('PCA R2: ' + str(pca_r2_score))

In [None]:
rotator = Rotator(method='varimax')

pca_latents_rotated=rotator.fit_transform(pca_latent)


### Note that SCA and wPCA account for virtually identical amounts of neural variance

### define some plotting colors

In [None]:
cMap  = ['#5e0044', '#6f144e', '#812858', '#933c62', '#a5506d', '#b76477', '#c97881', '#db8c8c']

### Plot factors
order factors by the time of maximum occupancy (cross-condition variance)

In [None]:
# calculate across-condition variance of each projection as a function of time

# reshape both latents to be size T x C x K 
rs_sca_latent = np.reshape(sca_latent,(-1,8,R_est),order = 'F')

# calculate across condition variance
sca_var = np.var(rs_sca_latent,axis = 1)

# find peak occupancy of each dimension
pkIdx = np.argmax(sca_var,axis = 0)

# define plotting order
sca_order = np.argsort(pkIdx)

# resort ssa_latents by time of maximum occupancy
rs_sca_latent = rs_sca_latent[:,:,sca_order]

# do the same for the pca projections
# rs_pca_latent = np.reshape(pca_latent,(-1,8,R_est),order = 'F')
rs_pca_latent = np.reshape(pca_latents_rotated,(-1,8,R_est),order = 'F')



pca_var = np.var(rs_pca_latent,axis = 1)
pkIdx = np.argmax(pca_var,axis = 0)
pca_order = np.argsort(pkIdx)
rs_pca_latent = rs_pca_latent[:,:,pca_order]

# define some useful time points
tgt_idx=20
move_idx=77
ret_idx=200


#### SCA
vertical lines mark target onset, outward reach onset, and return reach onset

In [None]:
# range for y axis
# yRange = [-1.8,1.8]

yRange = [-2,2]


fig = make_subplots(rows=R_est,cols = 1,shared_xaxes = True,vertical_spacing = 0)

for ii in range(R_est):

    for jj in range(numConds):
        latTrace = go.Scatter(y = rs_sca_latent[:,jj,ii], line = go.scatter.Line(color = cMap[jj],width = 2.5),showlegend = False)
        fig.add_trace(latTrace,row = ii+1,col=1)

    # mark important task events
    fig.add_vline(x = tgt_idx,row = ii+1,col = 1, line_color = 'black')
    fig.add_vline(x = move_idx,row = ii+1,col = 1, line_color = 'black')
    fig.add_vline(x = ret_idx,row = ii+1,col = 1, line_color = 'black')

    # add a vertical line for scale
    scaleLine = go.Scatter(x = [0,0],y = [-1,1],showlegend = False,mode = 'lines',
                            line = go.scatter.Line(color = 'black',width = 5))
    fig.add_trace(scaleLine,row = ii+1,col = 1)


# clean up
fig.update_layout(height = 2000,width =600,title = 'SCA',title_font_color = 'black',
                  paper_bgcolor = 'white',
                  plot_bgcolor = 'white')
fig.update_yaxes(showgrid = False,zeroline = False,visible = False,range = yRange)
fig.update_xaxes(color = 'black',showgrid = False,zeroline = False,visible = False)
fig.update_xaxes(color = 'black',showgrid = False,zeroline = False,
                 ticks = 'outside',tickvals = [0,50],ticktext = ['0','500'],visible = True,row = R_est,col = 1)

fig.show()

### PCA

In [None]:
# range for y axis
yRange = [-1.8,1.8]

fig = make_subplots(rows=R_est,cols = 1,shared_xaxes = True,vertical_spacing = 0)

for ii in range(R_est):

    for jj in range(numConds):
        latTrace = go.Scatter(y = rs_pca_latent[:,jj,ii], line = go.scatter.Line(color = cMap[jj],width = 2.5),showlegend = False)
        fig.add_trace(latTrace,row = ii+1,col=1)

    # mark important task events
    fig.add_vline(x = tgt_idx,row = ii+1,col = 1, line_color = 'black')
    fig.add_vline(x = move_idx,row = ii+1,col = 1, line_color = 'black')
    fig.add_vline(x = ret_idx,row = ii+1,col = 1, line_color = 'black')

    # add a vertical line for scale
    scaleLine = go.Scatter(x = [0,0],y = [-1,1],showlegend = False,mode = 'lines',
                            line = go.scatter.Line(color = 'black',width = 5))
    fig.add_trace(scaleLine,row = ii+1,col = 1)


# clean up
fig.update_layout(height = 2000,width =600,title = 'PCA',title_font_color = 'black',
                  paper_bgcolor = 'white',
                  plot_bgcolor = 'white')
fig.update_yaxes(showgrid = False,zeroline = False,visible = False,range = yRange)
fig.update_xaxes(color = 'black',showgrid = False,zeroline = False,visible = False)
fig.update_xaxes(color = 'black',showgrid = False,zeroline = False,
                 ticks = 'outside',tickvals = [0,50],ticktext = ['0','500'],visible = True,row = R_est,col = 1)

fig.show()

# Not sample weighted

In [None]:
#Fit model and project data into the PC space
pca = WeightedPCA(n_components=R_est)
# pca_latent = wpca.fit_transform(X=fit_data[trainMask,:])

wpca.fit(X=fit_data[trainMask,:])



pca_latent = wpca.transform(fit_data)


# get the factor weights
U_pca = wpca.params['U']
V_pca = wpca.params['V']

# calculate reconstruction R2
[pca_r2_score, pca_reconstruction_loss]=get_accuracy(wpca,fit_data[trainMask,:],sample_weights[trainMask])

# display results
print('PCA R2: ' + str(pca_r2_score))

In [None]:

rotator = Rotator(method='varimax')

pca_latents_rotated=rotator.fit_transform(pca_latent)



rs_pca_latent = np.reshape(pca_latents_rotated,(-1,8,R_est),order = 'F')
# rs_pca_latent = np.reshape(pca_latent,(-1,8,R_est),order = 'F')



pca_var = np.var(rs_pca_latent,axis = 1)
pkIdx = np.argmax(pca_var,axis = 0)
pca_order = np.argsort(pkIdx)
rs_pca_latent = rs_pca_latent[:,:,pca_order]

In [None]:
# range for y axis
yRange = [-1.8,1.8]

fig = make_subplots(rows=R_est,cols = 1,shared_xaxes = True,vertical_spacing = 0)

for ii in range(R_est):

    for jj in range(numConds):
        latTrace = go.Scatter(y = rs_pca_latent[:,jj,ii], line = go.scatter.Line(color = cMap[jj],width = 2.5),showlegend = False)
        fig.add_trace(latTrace,row = ii+1,col=1)

    # mark important task events
    fig.add_vline(x = tgt_idx,row = ii+1,col = 1, line_color = 'black')
    fig.add_vline(x = move_idx,row = ii+1,col = 1, line_color = 'black')
    fig.add_vline(x = ret_idx,row = ii+1,col = 1, line_color = 'black')

    # add a vertical line for scale
    scaleLine = go.Scatter(x = [0,0],y = [-1,1],showlegend = False,mode = 'lines',
                            line = go.scatter.Line(color = 'black',width = 5))
    fig.add_trace(scaleLine,row = ii+1,col = 1)


# clean up
fig.update_layout(height = 2000,width =600,title = 'PCA',title_font_color = 'black',
                  paper_bgcolor = 'white',
                  plot_bgcolor = 'white')
fig.update_yaxes(showgrid = False,zeroline = False,visible = False,range = yRange)
fig.update_xaxes(color = 'black',showgrid = False,zeroline = False,visible = False)
fig.update_xaxes(color = 'black',showgrid = False,zeroline = False,
                 ticks = 'outside',tickvals = [0,50],ticktext = ['0','500'],visible = True,row = R_est,col = 1)

fig.show()

## Comparisons between SCA and PCA factors
Each SCA factor is primarily active before a reach, during a reach, or during the period when the monkey is holding his had at the peripheral target.

These different patterns of activity correspond to different component computations that generate a reach: preparation, reach execution, and postural maintenance.

Without any supervision, SCA finds factors that reflect the computational division of labor during reaching.

PCA factors, on the other hand, do not show this clean parcellation. Most PCA factors are active during multiple computational epochs.