# This notebook demonstrates the use of Sparse Component Analysis (SCA) using a cycling dataset

Data are trial-averaged firing rates of individual neurons collected from motor cortex (primary motor and dorsal premotor) of a non-human primate.

The monkey was performing a cycling task ([Russo et al., 2018](https://pubmed.ncbi.nlm.nih.gov/29398358/))

Briefly, this task involves the monkey using a bicycle pedal to traverse a virtual environment.

On each trial, the monkey is told how many cycles he will have to perform (0.5, 1,2,4, or 7).

After a variable delay period, the monkey receives a go cue, and he pedals to the target, comes to a stop, and is given a juice reward.

Single-trial rates have been (lightly) stretched or compressed such that each individual cycle is the same duration.

## Import packages

In [None]:
import numpy as np
import scipy
from scipy import io
import time

import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from sca.models import SCA, WeightedPCA, SCANonlinear
from sca.util import get_sample_weights, get_accuracy
from plotly.subplots import make_subplots

## local path to the directory
    - fill in the path to whererever you've saved 'datasets'

In [None]:
# local_path = ## FILL IN PATH ##

local_path = '/Users/sherryan/Downloads/sca-main/datasets/'

## Load Data

In [None]:
# load psth and mask data
data=io.loadmat(local_path +'monkeyC_cycling.mat')

# neural data is a matrix (CT x N )
data_array=data['data'] #neural data (CT x N)

# 'mask' was saved as a matlab structure, with fields like 'time','dist','direction'
mask = data['mask']

# grab the field names from mask
keys = mask[0].dtype.descr

# pull out the values
vals = mask[0]

# Assemble the keys and values into variables with the same name as that used in MATLAB
for i in range(len(keys)):
    key = keys[i][0]
    val = np.squeeze(vals[key][0])
    exec(key + '=val')

# we now have variables 'time','dist','dir','pos','cycleNum','condNum','worldPos','firstTimeEachCond','lastTimeEachCond','cellIds', and 'cellNames'

# decimate all of our mask variables
# original length of our mask variables
oL          = time.shape[0]
time_ds     = time[np.arange(0,oL,10)]
dist_ds     = dist[np.arange(0,oL,10)]
dir_ds      = dir[np.arange(0,oL,10)]
condNum_ds  = condNum[np.arange(0,oL,10)]
cyNum_ds    = cycleNum[np.arange(0,oL,10)]

# hard code target onset and move onset (in non-decimated time)
tgt_idx = 500
move_idx = 1500

# how many conditions do we have?
numConds = np.max(condNum_ds)

## Preprocess data
    downsample data by a factor of 10 (to speed up SCA)
    soft normalize each rate (firing rate range across all conditions/times + 5)
        this normalization prevents high firing-rate neurons from dominating the factors

In [None]:
#Downsample data (using a factor of 10 here)
data_downsamp=data_array[np.arange(0,data_array.shape[0],10),:]

#transpose (so the matrix is size N x TC instead of CT x N)
data_concat=data_downsamp.T

#fr range (for normalizing later)
fr_range=np.ptp(data_concat,axis=1)[:,None]

#Soft normalize (divide each neuron by its range + 5)
data_norm=data_concat/(fr_range+5)

# mean-center each neuron
data_snm_norm=data_norm-np.mean(data_norm,axis=1)[:,None]

# rename the data for convenience
#Note that model requires (T x N) input rather than (N x T), which is why there are transposes below
fit_data=np.copy(data_snm_norm.T)

# how much to weight each timestep
sample_weights=get_sample_weights(fit_data)


## Define the SCA hyperparameters
number of dimensions

lam_sparse: how much to penalize non-sparse factors
    for this analysis, we're going to increase lam_sparse by a factor of 10 from the default.
        On this dataset, this increased sparsity gives better (more interpretable) qualitative results while minimally impacting reconstruction accuracy

lam_orthog: how much to penalize non-orthogonal dimensions
    we're going to use the default lam_orthog for this analysis.

In [None]:
# number of dimensions to find
R_est=30



## Fit and plot nonlinear SCA w/ orthog

In [None]:
# fit model
sca = SCANonlinear(n_components=R_est, lam_sparse=0.03, n_epochs = 30000,  lr=.001)#numIters)

# sca = SCA(n_components=R_est, lam_sparse=lam_sparse, n_epochs = numIters, lam_orthog=100)
sca_latent = sca.fit_transform(X=fit_data, sample_weight = sample_weights)

# plot loss
plt.plot(sca.losses);
plt.title('loss');

# display fraction of total variance explained by the SCA factors
print('SCA R2: ' + str(sca.r2_score))

In [None]:
###### plot loss

#lr .001,30000 epochs
plt.plot(sca.losses[-3000:]);
plt.title('loss');

print(sca.losses[-1])

In [None]:
# define some colors
# red/green corresponds to backward and forward conditions.
# dark/light colors correspond to top-start vs. bottom-start conditions
colors = np.array(['darkgreen','springgreen','maroon','red'])
colors = np.tile(colors,(1,5)).squeeze()

In [None]:
# R_est=30
# number of rows
numRows = int(R_est/2)

# range of y axis
# yRange = [-15,15]

yRange = [-3.5,3.5]


# yRange = [-2.5,2.5]
# yRange = [-1.5,1.5]

# yRange = [-.5,.5]


fig = make_subplots(rows=numRows,cols = 2,shared_xaxes = True,vertical_spacing = 0)

for i in range(R_est):
    for j in range(numConds):

        # grab the indices for this condition
        # remember that conditions are 1 indexed
        condIdx = condNum_ds == (j+1)

        # subplot indices
        rowIdx = int(np.mod(i,numRows))+1
        colIdx = int(i/numRows)+1


        tempProj = go.Scatter(x = time_ds[condIdx],y = sca_latent[condIdx,i],
                    line = go.scatter.Line(color = colors[j],width = 1),showlegend = False)
        fig.add_trace(tempProj,row = rowIdx,col=colIdx)

    # add a vertical line for scale
    scaleLine = go.Scatter(x = [-70,-70],y = [-1,1],showlegend = False,mode = 'lines',
                               line = go.scatter.Line(color = 'black',width = 5))
    fig.add_trace(scaleLine,row = rowIdx,col = colIdx)


fig.update_layout(height = 1000,width =1000,title = 'SCA',title_font_color = 'black',
                  paper_bgcolor = 'white',
                  plot_bgcolor = 'white')
fig.update_yaxes(showgrid = False,zeroline = False,visible = False, range=yRange)
fig.update_xaxes(color = 'black',showgrid = False,zeroline = False,visible = False)
fig.update_xaxes(color = 'black',showgrid = False,zeroline = False,
                 ticks = 'outside',tickvals = [tgt_idx,move_idx],ticktext = ['target','move'],visible = True,row = numRows,col = 1)

# plt.savefig('cycling_sca_nonlinear.pdf')
# fig.write_image('cycling_sca_nonlinear.pdf')
# fig.write_image('cycling_sca_nonlinear_thin.pdf')
fig.show()
# fig.write_image('cycling_softorth_sampleweighting_06_100.pdf')





In [None]:
time_ds[condIdx & (condNum_ds == (condition_start + j))][-5]

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Your existing color setup
colors = np.array(['darkgreen', 'springgreen', 'maroon', 'red'])
colors = np.tile(colors, (1, 5)).squeeze()

# Number of cycling lengths (5 groups)
numCycles = 5

# Create the figure and subplots
fig, axes = plt.subplots(4, 3, figsize=(15, 5), sharex=True, sharey=True)

# Flatten the axes array for easier indexing
axes = axes.flatten()

# Loop through the cycling lengths
for dim in range(R_est):
    if np.isin(dim, [10,15,20,26]):
        for cycle in range(numCycles):  # Loop through 5 cycling lengths (groups of conditions)
            if np.isin(cycle,[0,1,4]):
                # Get the condition indices for this cycling length (4 conditions per group)
                condition_start = cycle * 4 + 1
                condition_end = condition_start + 4
                condIdx = np.isin(condNum_ds, np.arange(condition_start, condition_end + 1))  # Get conditions for this group

                ax = axes[np.argwhere(np.array([10,15,20,26])==dim)[0,0] * 3 + np.argwhere(np.array([0,1,4])==cycle)[0,0]]  # Calculate the index of the subplot for this cycling length and R_est

                # Plot all the conditions in this group (4 conditions per cycle)
                for j in range(4):  # 4 conditions in each cycling length
                    ax.plot(time_ds[condIdx & (condNum_ds == (condition_start + j))],
                            sca_latent[condIdx & (condNum_ds == (condition_start + j)), dim],
                            color=colors[j], linewidth=2, label=f'Condition {condition_start + j}')

                # Add a vertical line for scale
                if cycle == 0:
                    ax.plot([-70, -70], [-1, 1], color='black', linewidth=2)

                # Remove gridlines and other details
                ax.grid(False)
                ax.set_yticklabels([])
                ax.spines['top'].set_visible(False)
                ax.spines['right'].set_visible(False)
                ax.spines['bottom'].set_visible(False)
                ax.spines['left'].set_visible(False)
                # Label columns and rows
                # if cycle == 0:  # Label the rows
                #     ax.set_ylabel(f'Dim. {dim + 1}', fontsize=12)
                ax.set_yticks([])
                ax.set_xticks([])
                if dim == 26:
                    # Set limits and labels
                    ax.set_xticks([tgt_idx, move_idx])
                    ax.set_xticklabels(['target', 'move'])
                    ax.tick_params(axis='x', direction='out', length=6)

# for cycle in range(numCycles):
#     axes[cycle].set_title(f'Cycle {cycle + 1}', fontsize=12)

# Set the figure title and layout
fig.suptitle('Autoencoder latents', fontsize=16, color='black')
fig.tight_layout(rect=[0, 0, 1, 0.96])  # Adjust layout to avoid overlap

# # Save the figure as a PDF
# figDir = '/Users/sherryan/glaserlab/sca_analysis_parent/sca_analysis/original_fit/'
plt.savefig('cycling_sca_flat.pdf')

# Show the plot (optional)
plt.show()




## Fit and plot Autoencoder

In [None]:
# fit model

# sca = SCANonlinear(n_components=R_est, lam_sparse=0.03, n_epochs = 6000,  lr=.01)#numIters)
sca = SCANonlinear(n_components=R_est, lam_sparse=0, lam_orthog=0, n_epochs = 30000,  lr=.001)#numIters)


# sca = SCA(n_components=R_est, lam_sparse=lam_sparse, n_epochs = numIters, lam_orthog=100)
sca_latent = sca.fit_transform(X=fit_data, sample_weight = sample_weights)

# plot loss
plt.plot(sca.losses);
plt.title('loss');

# display fraction of total variance explained by the SCA factors
print('SCA R2: ' + str(sca.r2_score))

In [None]:
###### plot loss

#lr .001,30000 epochs
plt.plot(sca.losses[-3000:]);
plt.title('loss');

print(sca.losses[-1])

In [None]:
# R_est=30
# number of rows
numRows = int(R_est/2)

# range of y axis
# yRange = [-15,15]

yRange = [-5,5]


# yRange = [-2.5,2.5]
# yRange = [-1.5,1.5]

# yRange = [-.5,.5]


fig = make_subplots(rows=numRows,cols = 2,shared_xaxes = True,vertical_spacing = 0)

for i in range(R_est):
    for j in range(numConds):

        # grab the indices for this condition
        # remember that conditions are 1 indexed
        condIdx = condNum_ds == (j+1)

        # subplot indices
        rowIdx = int(np.mod(i,numRows))+1
        colIdx = int(i/numRows)+1


        tempProj = go.Scatter(x = time_ds[condIdx],y = sca_latent[condIdx,i],
                    line = go.scatter.Line(color = colors[j],width = 1),showlegend = False)
        fig.add_trace(tempProj,row = rowIdx,col=colIdx)

    # add a vertical line for scale
    scaleLine = go.Scatter(x = [-70,-70],y = [-1.5,1.5],showlegend = False,mode = 'lines',
                               line = go.scatter.Line(color = 'black',width = 5))
    fig.add_trace(scaleLine,row = rowIdx,col = colIdx)


fig.update_layout(height = 1000,width =1000,title = 'Autoencoder',title_font_color = 'black',
                  paper_bgcolor = 'white',
                  plot_bgcolor = 'white')
fig.update_yaxes(showgrid = False,zeroline = False,visible = False, range=yRange)
fig.update_xaxes(color = 'black',showgrid = False,zeroline = False,visible = False)
fig.update_xaxes(color = 'black',showgrid = False,zeroline = False,
                 ticks = 'outside',tickvals = [tgt_idx,move_idx],ticktext = ['target','move'],visible = True,row = numRows,col = 1)

# plt.savefig('cycling_sca_nonlinear.pdf')
# fig.write_image('cycling_sca_autoencoder.pdf')
fig.write_image('cycling_sca_autoencoder_thin.pdf')

fig.show()
# fig.write_image('cycling_softorth_sampleweighting.pdf')

# fig.write_image('cycling_softorth_sampleweighting_06_100.pdf')





In [None]:
np.argwhere(np.array([0,1,4])==0)[0,0]

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Your existing color setup
colors = np.array(['darkgreen', 'springgreen', 'maroon', 'red'])
colors = np.tile(colors, (1, 5)).squeeze()

# Number of cycling lengths (5 groups)
numCycles = 5

# Create the figure and subplots
fig, axes = plt.subplots(4, 3, figsize=(15, 5), sharex=True, sharey=True)

# Flatten the axes array for easier indexing
axes = axes.flatten()

# Loop through the cycling lengths
for dim in range(R_est):
    if np.isin(dim, [0,1,2,3]):
        for cycle in range(numCycles):  # Loop through 5 cycling lengths (groups of conditions)
            if np.isin(cycle,[0,1,4]):
                # Get the condition indices for this cycling length (4 conditions per group)
                condition_start = cycle * 4 + 1
                condition_end = condition_start + 4
                condIdx = np.isin(condNum_ds, np.arange(condition_start, condition_end + 1))  # Get conditions for this group

                ax = axes[np.argwhere(np.array([0,1,2,3])==dim)[0,0] * 3 + np.argwhere(np.array([0,1,4])==cycle)[0,0]]  # Calculate the index of the subplot for this cycling length and R_est

                # Plot all the conditions in this group (4 conditions per cycle)
                for j in range(4):  # 4 conditions in each cycling length
                    ax.plot(time_ds[condIdx & (condNum_ds == (condition_start + j))],
                            sca_latent[condIdx & (condNum_ds == (condition_start + j)), dim],
                            color=colors[j], linewidth=2, label=f'Condition {condition_start + j}')

                # Add a vertical line for scale
                if cycle == 0:
                    ax.plot([-70, -70], [-1, 1], color='black', linewidth=2)

                # Remove gridlines and other details
                ax.grid(False)
                ax.set_yticklabels([])
                ax.spines['top'].set_visible(False)
                ax.spines['right'].set_visible(False)
                ax.spines['bottom'].set_visible(False)
                ax.spines['left'].set_visible(False)
                # Label columns and rows
                # if cycle == 0:  # Label the rows
                #     ax.set_ylabel(f'Dim. {dim + 1}', fontsize=12)
                ax.set_yticks([])
                ax.set_xticks([])
                if dim == 3:
                    # Set limits and labels
                    ax.set_xticks([tgt_idx, move_idx])
                    ax.set_xticklabels(['target', 'move'])
                    ax.tick_params(axis='x', direction='out', length=6)

# for cycle in range(numCycles):
#     axes[cycle].set_title(f'Cycle {cycle + 1}', fontsize=12)

# Set the figure title and layout
fig.suptitle('Autoencoder latents', fontsize=16, color='black')
fig.tight_layout(rect=[0, 0, 1, 0.96])  # Adjust layout to avoid overlap

# # Save the figure as a PDF
# figDir = '/Users/sherryan/glaserlab/sca_analysis_parent/sca_analysis/original_fit/'
plt.savefig('cycling_autoencoder_flat.pdf')

# Show the plot (optional)
plt.show()


