## Generate figures from unimanual cycling data
- using single electrode data collected by AAR

## Import packages

In [None]:
import numpy as np
from scipy import io

import plotly.graph_objects as go
import plotly.express as px
import matplotlib.pyplot as plt

from plotly.subplots import make_subplots

from sca.models import SCA, WeightedPCA
from sca.util import get_sample_weights, get_accuracy

from sklearn.decomposition import FactorAnalysis
from sklearn.decomposition import FastICA
from sklearn.decomposition import SparsePCA

from nimfa import Nmf, Snmf

## Load Data

In [None]:
# monkName = 'Cousteau'
# # monkName = 'Drake'


# # save directory
# # figDir = '/Users/andrew/Documents/Projects/Churchland/Sparsity/figures/aarCycling/sparseLam0.03/'
# figDir='/Users/jig289/Dropbox/SCA/SCA_Stuff/sca_resubmit/figures/'


In [None]:
# # load psth and mask data from matlab
# # load_folder='/Users/andrew/Documents/Projects/Churchland/Sparsity/data/rawRates/'
# # load_folder='/Users/jig289/Dropbox/SCA/SCA_Stuff/sca_resubmit/data/'
# load_folder='/Users/jig289/Dropbox/Datasets/Cycling_data/'

# data=io.loadmat(load_folder+monkName+'_interp_cycling_m1_rawRates.mat')
# # data=io.loadmat(load_folder+monkName+'_interp_cycling_sma_rawRates.mat')

# # neural data is a matrix (CT x N )
# data_array=data['data'] #neural data (CT x N)

# # 'mask' was saved as a matlab structure, with fields like 'time','dist','direction'
# mask = data['mask']

# # grab the field names from mask
# keys = mask[0].dtype.descr

# # pull out the values
# vals = mask[0]

# # Assemble the keys and values into variables with the same name as that used in MATLAB
# for i in range(len(keys)):
#     key = keys[i][0]
#     val = np.squeeze(vals[key][0])
#     exec(key + '=val')

# # we now have variables 'time','dist','dir','pos','cycleNum','condNum','worldPos','firstTimeEachCond','lastTimeEachCond','cellIds', and 'cellNames'

# # decimate all of our mask variables
# # original length of our mask variables
# oL          = time.shape[0]
# time_ds     = time[np.arange(0,oL,10)]
# dist_ds     = dist[np.arange(0,oL,10)]
# dir_ds      = dir[np.arange(0,oL,10)]
# condNum_ds  = condNum[np.arange(0,oL,10)]
# cyNum_ds    = cycleNum[np.arange(0,oL,10)]

# # hard code target onset and move onset (in non-decimated time)
# tgt_idx = 500
# move_idx = 1500

# # how many conditions do we have?
# numConds = np.max(condNum_ds)

# # define an 'isMoving' vector (will be useful later)
# isMoving = np.zeros_like(time_ds)

# for ii in range(numConds):

#     tempIsMoving = np.zeros_like(time_ds[condNum_ds == (ii+1)])
#     tempIsMoving[150:] = 1
#     tempIsMoving[-50:] = 0

#     isMoving[condNum_ds == (ii+1)] = tempIsMoving

# # define a 'prep' vector that starts with condition start (time=0) and ends 150 ms before movement

# isPrep = np.zeros_like(time_ds)

# for ii in range(numConds):

#     tempIsPrep = np.zeros_like(time_ds[condNum_ds == (ii+1)])
#     tempIsPrep[:(150 - 15)] = 1

#     isPrep[condNum_ds == (ii+1)] = tempIsPrep


In [None]:
# Parameters
n_pad = 200  # Padding before and after
conditions = [
    (100, 200, 400, 1.0 + 0.02, 1.0),      # Condition 1
    (100, 100, 400, 1.0 + 0.02, 1.0),      # Condition 2
    (200, 200, 400, -1.0 - 0.02, -1.0),    # Condition 3
]

# Helper function to generate each condition block
def create_condition(f1_dur, cooccur_dur, f2_dur, f1_val, f2_val):
    f1_only = np.column_stack([np.full(f1_dur, f1_val), np.zeros(f1_dur)])
    cooccur = np.column_stack([np.full(cooccur_dur, f1_val), np.full(cooccur_dur, f2_val)])
    f2_only = np.column_stack([np.zeros(f2_dur), np.full(f2_dur, f2_val)])
    return np.vstack([f1_only, cooccur, f2_only])

# Build all condition segments with padding
all_segments_sim1 = []
for (f1_dur, cooccur_dur, f2_dur, f1_val, f2_val) in conditions:
    cond = create_condition(f1_dur, cooccur_dur, f2_dur, f1_val, f2_val)
    pad_before = np.zeros((n_pad, 2))
    pad_after = np.zeros((n_pad, 2))
    segment = np.vstack([pad_before, cond, pad_after])
    all_segments_sim1.append(segment)

# Plot conditions in separate subplots (stacked vertically)
fig, axes = plt.subplots(len(all_segments_sim1), 1, figsize=(10, 6), sharex=False)

for i, segment in enumerate(all_segments_sim1):
    time = np.arange(segment.shape[0])
    axes[i].plot(time, segment[:, 0], label="Factor 1",)
    axes[i].plot(time, segment[:, 1], label="Factor 2")


axes[-1].set_xlabel("Time")
plt.tight_layout()
# plt.savefig('sim1_ground_truth.pdf')
plt.show()


# Parameters
n_samples = 1000              # Samples per condition
n_pad = 400                   # Padding before and after

# Define condition blocks
cond1 = np.column_stack([np.ones(n_samples)*1.5+0.02, np.zeros(n_samples)])     # [1, 0]
cond2 = np.column_stack([np.zeros(n_samples), np.ones(n_samples)*2])            # [0, 1]
cond3 = np.column_stack([np.ones(n_samples)+0.02, np.ones(n_samples)])          # [1, 1]

conditions = [cond1, cond2, cond3]

# Build each condition segment with padding
all_segments_sim2 = []
for cond in conditions:
    pad_before = np.zeros((n_pad, 2))
    pad_after = np.zeros((n_pad, 2))
    segment = np.vstack([pad_before, cond, pad_after])
    all_segments_sim2.append(segment)

# Plot conditions in separate subplots (stacked vertically)
fig, axes = plt.subplots(len(all_segments_sim2), 1, figsize=(10, 6), sharex=False)

for i, segment in enumerate(all_segments_sim2):
    time = np.arange(segment.shape[0])
    axes[i].set_ylim(-0.1, 2.1)
    axes[i].plot(time, segment[:, 0], label="Factor 1")
    axes[i].plot(time, segment[:, 1], label="Factor 2")

axes[-1].set_xlabel("Time")
plt.tight_layout()
# plt.savefig('sim2_ground_truth.pdf')
plt.show()

In [None]:
import numpy.random as npr
from scipy.linalg import orth

Z_sim1 = np.vstack(all_segments_sim1)
np.random.seed(0) #To get the same simulated data

N_neurons=50 #Number of neurons
R_sim=2 #Number of dimensions in lowD representations

#Orthogonal matrix that projects low dimensional space to full neural space
V_tmp=orth(npr.randn(R_sim,N_neurons).T).T 

V_tmp_pos=npr.rand(R_sim,N_neurons)

#Create high-dimensional neural activity    
b=npr.randn(N_neurons) #Offset of neurons
X0_sim1=Z_sim1@V_tmp[:R_sim,:]+b #Project into high-dimensional space and add offset
noise_level = 0.1
X0_sim1=X0_sim1+noise_level*npr.randn(X0_sim1.shape[0],X0_sim1.shape[1]) #Add noise
X0_sim1_pos = Z_sim1@V_tmp_pos[:R_sim,:]+npr.rand(N_neurons)+noise_level*npr.randn(X0_sim1.shape[0],X0_sim1.shape[1])

Z_sim2 = np.vstack(all_segments_sim2)

#Orthogonal matrix that projects low dimensional space to full neural space
V_tmp=orth(npr.randn(R_sim,N_neurons).T).T 
V_tmp_pos=npr.rand(R_sim,N_neurons)

#Create high-dimensional neural activity    
b=npr.randn(N_neurons) #Offset of neurons
X0_sim2=Z_sim2@V_tmp[:R_sim,:]+b #Project into high-dimensional space and add offset
noise_level = 0.1
X0_sim2=X0_sim2+noise_level*npr.randn(X0_sim2.shape[0],X0_sim2.shape[1]) #Add noise
X0_sim2_pos = Z_sim2@V_tmp_pos[:R_sim,:]+npr.rand(N_neurons)+noise_level*npr.randn(X0_sim2.shape[0],X0_sim2.shape[1])


## Preprocess data

In [None]:
X_sim1=np.copy(X0_sim1-np.mean(X0_sim1,axis=0)[None,:])
X_sim2=np.copy(X0_sim2-np.mean(X0_sim2,axis=0)[None,:])
R_est = 2

X_sim1_pos=np.copy(X0_sim1_pos)
X_sim1_pos[X_sim1_pos<0]=0

X_sim2_pos=np.copy(X0_sim2_pos)
X_sim2_pos[X_sim2_pos<0]=0

sample_weights_sim1=np.ones([X_sim1.shape[0],1]) #Weight equally
sample_weights_sim2=np.ones([X_sim2.shape[0],1]) #Weight equally

In [None]:
# #Downsample data (using a factor of 10 here)
# data_downsamp=data_array[np.arange(0,data_array.shape[0],10),:]

# #transpose (so the matrix is size N x TC instead of CT x N)
# data_concat=data_downsamp.T

# #fr range (for normalizing later
# fr_range=np.ptp(data_concat,axis=1)[:,None]

# #Soft normalize (divide each neuron by its range + 5)
# data_norm=data_concat/(fr_range+5)

# # mean-center each neuron
# data_snm_norm=data_norm-np.mean(data_norm,axis=1)[:,None]

# # rename the data for convenience
# #Note that model requires (T x N) input rather than (N x T), which is why there are transposes below
# fit_data_pos=np.copy(data_norm.T)
# fit_data_pos[fit_data_pos<0]=0
# fit_data=np.copy(data_snm_norm.T)

# # how much to weight each timestep (used by)
# sample_weights=get_sample_weights(fit_data)

# # number of dimensions to find
# R_est=40



In [None]:
# #### SCA params

# # sparsity lambda
# lam_sparse = 0.03 #0.1 #was 0.03 for 40 dims

# #orthog lambda????
# lam_orthog = 100

# # number of iterations
# numIters = 5000

# # Whether or not to impose hard orthogonality constraint
# # hardOrthFlag = True
# hardOrthFlag = False



# Fit methods

In [None]:
fit_data = X_sim2
fit_data_pos = X_sim2_pos
sample_weights = sample_weights_sim2

# fit SCA
sca=SCA(n_components=R_est, n_epochs=5000)
# sca=SCA(n_components=R_est,lam_sparse=lam_sparse, n_epochs=5000)
sca_latent=sca.fit_transform(X=fit_data, sample_weight=sample_weights)

##### PCA
pca = WeightedPCA(n_components = R_est)
pca_latent = pca.fit_transform(fit_data,sample_weight=sample_weights)

##### PCA+Varimax
pca_var = WeightedPCA(n_components = R_est, rotate=True)
pca_var_latent = pca_var.fit_transform(fit_data,sample_weight=sample_weights)

#### Factor Analysis
fa = FactorAnalysis(n_components= R_est)
fa_latent = fa.fit_transform(fit_data)

# #### NMF
nmf = Nmf(fit_data_pos,rank=R_est)
nmf_fit = nmf()
nmf_latent=np.array(nmf_fit.basis())


#### ICA
# ica = FastICA(R_est)
ica = FastICA(R_est,whiten='unit-variance')
ica_latent=ica.fit_transform(fit_data)

#### Sparse NMF
snmf = Snmf(fit_data_pos,rank=R_est,max_iter=20,track_error=True,version='l')
snmf_fit = snmf()
snmf_latent=np.array(snmf_fit.basis())
snmf_latent=snmf_latent*np.linalg.norm(snmf_fit.coef(),axis=1)[None,:]

#### Sparse PCA
spca = SparsePCA(n_components= R_est)
spca_latent = spca.fit_transform(fit_data)

In [None]:
from scipy.optimize import linear_sum_assignment

Z_extra = Z_sim2

methods = [
    ("SCA", sca_latent),
    ("PCA", pca_latent),
    ("PCA+Varimax", pca_var_latent),
    ("ICA", ica_latent),
    ("FA", fa_latent),
    ("NMF", nmf_latent),
    ("SNMF", snmf_latent),
    ("SPCA", spca_latent),
]


aligned_methods = [("True", Z_extra)]  # first one is ground truth, no need to align
metrics = {"True": 1.0}  # perfect correlation baseline

# --- Align each method with ground truth ---
for name, latent in methods:
    corr = np.corrcoef(latent.T, Z_extra.T)[:R_est, R_est:]  # correlations (latent vs true)
    
    # Hungarian algorithm: maximize absolute correlation
    cost = -np.abs(corr)  # minimize negative correlation = maximize abs(corr)
    row_ind, col_ind = linear_sum_assignment(cost)
    
    aligned = np.zeros_like(latent)
    for r, c in zip(row_ind, col_ind):
        sign = np.sign(corr[r, c])
        aligned[:, c] = latent[:, r] * sign
        
    aligned_methods.append((name, aligned))

    abs_corrs = [np.corrcoef(aligned[:, j], Z_extra[:, j])[0, 1] for j in range(R_est)]
    score = np.mean(np.abs(abs_corrs))
    metrics[name] = score

num_methods = len(aligned_methods)

colors_true = ["#728599", "#b2821d"]
color_recon = "#6e0f3e"

num_methods = len(aligned_methods)
# Transposed layout: methods on rows, factors on columns
plt.figure(figsize=(6 * R_est, 2 * num_methods))

for i, (title, data) in enumerate(aligned_methods):   # loop over methods (rows)
    for j in range(R_est):                            # loop over factors (cols)
        ax = plt.subplot(num_methods, R_est, i * R_est + j + 1)
        
        # Choose color
        if title.startswith("True"):
            color = colors_true[j]
        else:
            color = color_recon
        
        ax.plot(data[:, j], color=color,linewidth=0.5)
        
        # Add dotted lines splitting y-axis in thirds
        y1 = 1800
        y2 = 3600
        ax.axvline(y1, color="gray", linestyle="--", linewidth=0.8)
        ax.axvline(y2, color="gray", linestyle="--", linewidth=0.8)
        
        # Remove ticks and labels
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_xlabel("")
        ax.set_ylabel("")
        
        # Titles for factor columns (only on top row)
        if i == 0:
            ax.set_title(f"Factor {j+1}",fontsize=20)
        
        # Method names only at the start of each row (leftmost panel)
        if j == 0:
            ax.set_ylabel(title, rotation=0, labelpad=50, va="center",fontsize=20)

plt.tight_layout()

plt.savefig('sim2_all_methods.pdf')
plt.show()
metrics_sim2 = metrics

In [None]:
metrics_sim1

In [None]:
metrics_sim2

In [None]:
# Common methods between sim1 and sim2
common_methods = [m for m in metrics_sim1.keys() if m in metrics_sim2]

x = [metrics_sim1[m] for m in common_methods]
y = [metrics_sim2[m] for m in common_methods]

# Use a color cycle (different color for each method)
colors = plt.cm.tab10.colors  # 10 distinct colors
color_map = {m: colors[i % len(colors)] for i, m in enumerate(common_methods)}

plt.figure(figsize=(6, 6))

for xi, yi, name in zip(x, y, common_methods):
    plt.scatter(xi, yi, color=color_map[name], label=name)
    plt.text(xi + 0.01, yi, name, fontsize=15, color=color_map[name],va="center")

# Axis limits (first quadrant zoomed in)
plt.xlim(0.79, 1.01)
plt.ylim(0.79, 1.01)

plt.xlabel("Mean |corr| (Sim 1)",fontsize=15)
plt.ylabel("Mean |corr| (Sim 2)",fontsize=15)
# plt.title("Method Performance Comparison")

# Remove top and right spines
ax = plt.gca()
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.tick_params(axis='both', which='major', labelsize=15)
plt.savefig('sim_summary.pdf')

# plt.grid(True, linestyle="--", alpha=0.6)
plt.show()

In [None]:
# t1=time.time()
# # #### Sparse NMF
# snmf = Snmf(fit_data_pos,rank=R_est,max_iter=20,track_error=True,version='l')
# snmf_fit = snmf()
# print(time.time()-t1)



In [None]:
# snmf_latent=np.array(snmf_fit.basis())

In [None]:
# snmf.residuals().shape

# Plot

In [None]:
# define some colors 
# conditions are ordered by: startPos -> direction -> distance (e.g., top start/forward/7 cycles, bottom start/forward/7 cycles, top start/backward/7 cycles ...)
colors = np.array(['darkgreen','springgreen','maroon','red'])
colors = np.tile(colors,(1,5)).squeeze()

In [None]:
# latents=[nmf_latent,pca_latent,sca_latent]
# latents=[sca_latent,ica_latent,pca_var_latent,nmf_latent]
# latents=[sca_latent,nmf_latent,spca_latent,fa_latent]
# latents=[sca_latent,snmf_latent,spca_latent,fa_latent]
latents=[sca_latent,nmf_latent]
titles=['SCA','NMF',]

# latent_names=['SCA','PCA','FA','NMF','SPCA','ICA','SNMF','PCA+Varimax']


num_comparisons=len(latents)

orders = [np.argsort(np.sum(latents[i]**2,axis=0))[::-1] for i in range(len(latents))]

##### Update this for some like ICA??

In [None]:
# R_est=30
# number of rows
numRows = R_est
ymaxes=[1.05*np.max(np.abs(latents[i])) for i in range(num_comparisons)]

num_plotted_latents=len(latents)


# range of y axis
# yRange = [-2.5,2.5]

fig = make_subplots(rows=R_est,cols = num_plotted_latents,shared_xaxes = True,vertical_spacing = 0,subplot_titles=titles)
# fig = make_subplots(rows=numRows,cols = 4,shared_xaxes = True,vertical_spacing = 0)

for k in range(num_plotted_latents):
    for i in range(R_est):
        for j in range(numConds):

            # grab the indices for this condition
            # remember that conditions are 1 indexed
            condIdx = condNum_ds == (j+1)

            # subplot indices
            rowIdx = i+1
            colIdx = k+1


            tempProj = go.Scatter(x = time_ds[condIdx],y = latents[k][condIdx,orders[k][i]],
                        line = go.scatter.Line(color = colors[j],width = 2),showlegend = False)
            fig.add_trace(tempProj,row = rowIdx,col=colIdx)

        # add a vertical line for scale
#         scaleLine = go.Scatter(x = [0,0],y = [-1,1],showlegend = False,mode = 'lines',
#                                    line = go.scatter.Line(color = 'black',width = 5))
#         fig.add_trace(scaleLine,row = rowIdx,col = colIdx)

        #clean up
        fig.update_yaxes(showgrid = False,zeroline = False,visible = False,range = [-ymaxes[k],ymaxes[k]],row=rowIdx,col=colIdx)

        
    fig.update_layout(height = 2000,width =1000,title = monkName,title_font_color = 'black',
                      paper_bgcolor = 'white',
                      plot_bgcolor = 'white')
    fig.update_xaxes(color = 'black',showgrid = False,zeroline = False,visible = False)
    fig.update_xaxes(color = 'black',showgrid = False,zeroline = False,
                     ticks = 'outside',tickvals = [tgt_idx,move_idx],ticktext = ['target','move'],visible = True,row = numRows,col = 1)

# save figure
# fig.write_image(figDir + monkName + '_scaProj_hardOrth.pdf')
# fig.write_image(figDir + monkName + '_scaProj_softOrth_60dims_ica.pdf')

fig.write_image(figDir + monkName + '_comparisons_40dims_nmf.pdf')


# fig.show()


In [None]:
# latents=[nmf_latent,pca_latent,sca_latent]
# latents=[sca_latent,ica_latent,pca_var_latent,nmf_latent]
# latents=[sca_latent,nmf_latent,spca_latent,fa_latent]
# latents=[sca_latent,snmf_latent,spca_latent,fa_latent]
latents=[sca_latent,pca_var_latent]
titles=['SCA','PCA+Varimax',]

# latent_names=['SCA','PCA','FA','NMF','SPCA','ICA','SNMF','PCA+Varimax']


num_comparisons=len(latents)

orders = [np.argsort(np.sum(latents[i]**2,axis=0))[::-1] for i in range(len(latents))]

##### Update this for some like ICA??

In [None]:
# R_est=30
# number of rows
numRows = R_est
ymaxes=[1.05*np.max(np.abs(latents[i])) for i in range(num_comparisons)]

num_plotted_latents=len(latents)


# range of y axis
# yRange = [-2.5,2.5]

fig = make_subplots(rows=R_est,cols = num_plotted_latents,shared_xaxes = True,vertical_spacing = 0,subplot_titles=titles)
# fig = make_subplots(rows=numRows,cols = 4,shared_xaxes = True,vertical_spacing = 0)

for k in range(num_plotted_latents):
    for i in range(R_est):
        for j in range(numConds):

            # grab the indices for this condition
            # remember that conditions are 1 indexed
            condIdx = condNum_ds == (j+1)

            # subplot indices
            rowIdx = i+1
            colIdx = k+1


            tempProj = go.Scatter(x = time_ds[condIdx],y = latents[k][condIdx,orders[k][i]],
                        line = go.scatter.Line(color = colors[j],width = 2),showlegend = False)
            fig.add_trace(tempProj,row = rowIdx,col=colIdx)

        # add a vertical line for scale
#         scaleLine = go.Scatter(x = [0,0],y = [-1,1],showlegend = False,mode = 'lines',
#                                    line = go.scatter.Line(color = 'black',width = 5))
#         fig.add_trace(scaleLine,row = rowIdx,col = colIdx)

        #clean up
        fig.update_yaxes(showgrid = False,zeroline = False,visible = False,range = [-ymaxes[k],ymaxes[k]],row=rowIdx,col=colIdx)

        
    fig.update_layout(height = 2000,width =1500,title = monkName,title_font_color = 'black',
                      paper_bgcolor = 'white',
                      plot_bgcolor = 'white')
    fig.update_xaxes(color = 'black',showgrid = False,zeroline = False,visible = False)
    fig.update_xaxes(color = 'black',showgrid = False,zeroline = False,
                     ticks = 'outside',tickvals = [tgt_idx,move_idx],ticktext = ['target','move'],visible = True,row = numRows,col = 1)

# save figure
# fig.write_image(figDir + monkName + '_scaProj_hardOrth.pdf')
# fig.write_image(figDir + monkName + '_scaProj_softOrth_60dims_ica.pdf')

fig.write_image(figDir + monkName + '_comparisons_10dims_pt1.pdf')


# fig.show()


In [None]:
get_accuracy(sca,fit_data)

In [None]:
get_accuracy(pca,fit_data)

In [None]:
# latents=[nmf_latent,pca_latent,sca_latent]
# latents=[sca_latent,ica_latent,pca_var_latent,nmf_latent]
# latents=[sca_latent,nmf_latent,spca_latent,fa_latent]
# latents=[sca_latent,snmf_latent,spca_latent,fa_latent]
latents=[sca_latent,pca_latent,fa_latent,nmf_latent]
titles=['SCA','PCA','FA','NMF',]

# latent_names=['SCA','PCA','FA','NMF','SPCA','ICA','SNMF','PCA+Varimax']


num_comparisons=len(latents)

orders = [np.argsort(np.sum(latents[i]**2,axis=0))[::-1] for i in range(len(latents))]

##### Update this for some like ICA??

In [None]:
# R_est=30
# number of rows
numRows = R_est
ymaxes=[1.05*np.max(np.abs(latents[i])) for i in range(num_comparisons)]

num_plotted_latents=len(latents)


# range of y axis
# yRange = [-2.5,2.5]

fig = make_subplots(rows=R_est,cols = num_plotted_latents,shared_xaxes = True,vertical_spacing = 0,subplot_titles=titles)
# fig = make_subplots(rows=numRows,cols = 4,shared_xaxes = True,vertical_spacing = 0)

for k in range(num_plotted_latents):
    for i in range(R_est):
        for j in range(numConds):

            # grab the indices for this condition
            # remember that conditions are 1 indexed
            condIdx = condNum_ds == (j+1)

            # subplot indices
            rowIdx = i+1
            colIdx = k+1


            tempProj = go.Scatter(x = time_ds[condIdx],y = latents[k][condIdx,orders[k][i]],
                        line = go.scatter.Line(color = colors[j],width = 2),showlegend = False)
            fig.add_trace(tempProj,row = rowIdx,col=colIdx)

        # add a vertical line for scale
#         scaleLine = go.Scatter(x = [0,0],y = [-1,1],showlegend = False,mode = 'lines',
#                                    line = go.scatter.Line(color = 'black',width = 5))
#         fig.add_trace(scaleLine,row = rowIdx,col = colIdx)

        #clean up
        fig.update_yaxes(showgrid = False,zeroline = False,visible = False,range = [-ymaxes[k],ymaxes[k]],row=rowIdx,col=colIdx)

        
    fig.update_layout(height = 2000,width =1500,title = monkName,title_font_color = 'black',
                      paper_bgcolor = 'white',
                      plot_bgcolor = 'white')
    fig.update_xaxes(color = 'black',showgrid = False,zeroline = False,visible = False)
    fig.update_xaxes(color = 'black',showgrid = False,zeroline = False,
                     ticks = 'outside',tickvals = [tgt_idx,move_idx],ticktext = ['target','move'],visible = True,row = numRows,col = 1)

# save figure
# fig.write_image(figDir + monkName + '_scaProj_hardOrth.pdf')
# fig.write_image(figDir + monkName + '_scaProj_softOrth_60dims_ica.pdf')

fig.write_image(figDir + monkName + '_comparisons_40dims_plot1.pdf')


# fig.show()


In [None]:
# latents=[nmf_latent,pca_latent,sca_latent]
# latents=[sca_latent,ica_latent,pca_var_latent,nmf_latent]
# latents=[sca_latent,nmf_latent,spca_latent,fa_latent]
# latents=[sca_latent,snmf_latent,spca_latent,fa_latent]
latents=[spca_latent,ica_latent,snmf_latent,pca_var_latent]
titles=['SPCA','ICA','SNMF','PCA+Varimax']

# latent_names=['SCA','PCA','FA','NMF','SPCA','ICA','SNMF','PCA+Varimax']


num_comparisons=len(latents)

orders = [np.argsort(np.sum(latents[i]**2,axis=0))[::-1] for i in range(len(latents))]

##### Update this for some like ICA??

In [None]:
# R_est=30
# number of rows
numRows = R_est
ymaxes=[1.05*np.max(np.abs(latents[i])) for i in range(num_comparisons)]

num_plotted_latents=len(latents)


# range of y axis
# yRange = [-2.5,2.5]

fig = make_subplots(rows=R_est,cols = num_plotted_latents,shared_xaxes = True,vertical_spacing = 0,subplot_titles=titles)
# fig = make_subplots(rows=numRows,cols = 4,shared_xaxes = True,vertical_spacing = 0)

for k in range(num_plotted_latents):
    for i in range(R_est):
        for j in range(numConds):

            # grab the indices for this condition
            # remember that conditions are 1 indexed
            condIdx = condNum_ds == (j+1)

            # subplot indices
            rowIdx = i+1
            colIdx = k+1


            tempProj = go.Scatter(x = time_ds[condIdx],y = latents[k][condIdx,orders[k][i]],
                        line = go.scatter.Line(color = colors[j],width = 2),showlegend = False)
            fig.add_trace(tempProj,row = rowIdx,col=colIdx)

        # add a vertical line for scale
#         scaleLine = go.Scatter(x = [0,0],y = [-1,1],showlegend = False,mode = 'lines',
#                                    line = go.scatter.Line(color = 'black',width = 5))
#         fig.add_trace(scaleLine,row = rowIdx,col = colIdx)

        #clean up
        fig.update_yaxes(showgrid = False,zeroline = False,visible = False,range = [-ymaxes[k],ymaxes[k]],row=rowIdx,col=colIdx)

        
    fig.update_layout(height = 2000,width =1500,title = monkName,title_font_color = 'black',
                      paper_bgcolor = 'white',
                      plot_bgcolor = 'white')
    fig.update_xaxes(color = 'black',showgrid = False,zeroline = False,visible = False)
    fig.update_xaxes(color = 'black',showgrid = False,zeroline = False,
                     ticks = 'outside',tickvals = [tgt_idx,move_idx],ticktext = ['target','move'],visible = True,row = numRows,col = 1)

# save figure
# fig.write_image(figDir + monkName + '_scaProj_hardOrth.pdf')
# fig.write_image(figDir + monkName + '_scaProj_softOrth_60dims_ica.pdf')

fig.write_image(figDir + monkName + '_comparisons_40dims_plot2.pdf')


# fig.show()


In [None]:
# latents=[nmf_latent,pca_latent,sca_latent]
latents=[sca_latent,ica_latent,pca_var_latent,nmf_latent]
# latents=[sca_latent,nmf_latent,spca_latent,fa_latent]

# latents=[sca_latent,snmf_latent,spca_latent,fa_latent]


num_comparisons=len(latents)

orders = [np.argsort(np.sum(latents[i]**2,axis=0))[::-1] for i in range(len(latents))]

##### Update this for some like ICA??

In [None]:
loadings=np.array(snmf_fit.coef())
plt.imshow(np.corrcoef(loadings))
plt.colorbar()

In [None]:
plt.imshow(np.corrcoef(sca.params['V']))
plt.colorbar()

In [None]:
loadings=np.array(snmf_fit.coef())
plt.imshow(np.corrcoef(loadings))
plt.colorbar()

In [None]:
plt.imshow(np.corrcoef(sca.params['V']))
plt.colorbar()

In [None]:
get_accuracy(sca,fit_data)

In [None]:
from sklearn.metrics import r2_score

r2_score(fit_data_pos,np.array(snmf.fitted()),multioutput='variance_weighted')

In [None]:
#Drake

In [None]:
loadings=np.array(snmf_fit.coef())
plt.imshow(np.corrcoef(loadings))
plt.colorbar()

In [None]:
plt.imshow(np.array((snmf_fit.coef())))
plt.colorbar()

In [None]:
plt.plot(np.array((snmf_fit.coef()))[20,:])
