# jPCA
This notebook is to reproduce jPCA functionality.

## Load in example data
For development, we will focus on just one condition. 

In [2]:
import numpy as np
from matplotlib import pyplot as plt
from sklearn.decomposition import PCA
from math import sqrt
from scipy import optimize

In [5]:
from scipy.io import loadmat

x = loadmat("../data/exampleData.mat")
x = x['Data'][0][0][0]
X = np.array(x)

## 1. Preprocess using PCA
There's another normalization preproc step but it is more data relevant than algo relevant so I will be leaving it out until later.

In [4]:
def pca_preproc(X, k):
    pca = PCA(n_components=k)
    pca.fit(X)
    X_red = pca.transform(X)
    return X_red, pca.explained_variance_ratio_

In [6]:
X.shape

(61, 218)

In [11]:
X_red, _ = pca_preproc(X, k=6)

In [12]:
X_red.shape

(61, 6)

## 2. Get Discrete Derivative of X_red
Need to save X minus final state as it is used in later steps.

In [14]:
X_prestate = X_red[:-1,]
dX = np.diff(X_red, axis=0)

In [22]:
print(X_prestate.shape)
print(dX.shape)

(60, 6)
(60, 6)


## 3. Solve for M_skew
The most important step

In [25]:
M0, _, _, _ = np.linalg.lstsq(X_prestate, dX, rcond=None)
M0_skew = .5*(M0 - M0.T)

In [29]:
M0.shape
print(M0_skew)

[[ 0.         -0.09082191  0.0214932  -0.02473931 -0.11288055  0.10015587]
 [ 0.09082191  0.          0.04595647  0.03914008  0.09296161 -0.01832981]
 [-0.0214932  -0.04595647  0.         -0.14769268  0.02202633  0.03025298]
 [ 0.02473931 -0.03914008  0.14769268  0.         -0.14714837  0.06671926]
 [ 0.11288055 -0.09296161 -0.02202633  0.14714837  0.          0.19425025]
 [-0.10015587  0.01832981 -0.03025298 -0.06671926 -0.19425025  0.        ]]


In [36]:
def mat2vec(mat):
    return mat.flatten('F')


def vec2mat(vec):
    shape = (int(sqrt(vec.size)), -1)
    return np.reshape(vec, shape, 'F')

In [31]:
m_skew = mat2vec(M0_skew)

In [61]:
vec2mat(m_skew)

array([[ 0.        , -0.09082191,  0.0214932 , -0.02473931, -0.11288055,
         0.10015587],
       [ 0.09082191,  0.        ,  0.04595647,  0.03914008,  0.09296161,
        -0.01832981],
       [-0.0214932 , -0.04595647,  0.        , -0.14769268,  0.02202633,
         0.03025298],
       [ 0.02473931, -0.03914008,  0.14769268,  0.        , -0.14714837,
         0.06671926],
       [ 0.11288055, -0.09296161, -0.02202633,  0.14714837,  0.        ,
         0.19425025],
       [-0.10015587,  0.01832981, -0.03025298, -0.06671926, -0.19425025,
         0.        ]])

In [67]:
def optimize_skew(m_skew, X_prestate, dX):
    def objective(x, X_prestate, dX):
        f = np.linalg.norm(dX - X_prestate@vec2mat(x))
        return f**2
    def derivative(x, X_prestate, dX):
        D = dX - X_prestate@vec2mat(x)
        D = D.T @ X_prestate
        return 2*mat2vec(D - D.T)
    
    return optimize.minimize(objective, m_skew, jac = derivative, args=(X_prestate, dX))

In [69]:
result = optimize_skew(m_skew, X_prestate, dX)
m_star = result.x

In [78]:
M_star = vec2mat(m_star)
print(M_star)
(np.isclose(M_star.T, -M_star)).all()

[[ 0.         -0.09093967  0.00462909 -0.0024506  -0.00979637  0.0002772 ]
 [ 0.09093967  0.          0.03637784  0.01857033  0.01674314 -0.00111542]
 [-0.00462909 -0.03637784  0.         -0.14404759  0.01828862  0.02939032]
 [ 0.0024506  -0.01857033  0.14404759  0.         -0.10559591  0.0274157 ]
 [ 0.00979637 -0.01674314 -0.01828862  0.10559591  0.          0.19456586]
 [-0.0002772   0.00111542 -0.02939032 -0.0274157  -0.19456586  0.        ]]


True

## Get components

In [79]:
np.linalg.eig(M_star)

(array([-4.00043713e-13+0.23058957j, -4.00043713e-13-0.23058957j,
        -9.31998402e-13+0.08349076j, -9.31998402e-13-0.08349076j,
         1.33208028e-12+0.1480101j ,  1.33208028e-12-0.1480101j ]),
 array([[ 0.00569288+0.05562876j,  0.00569288-0.05562876j,
         -0.67743378+0.j        , -0.67743378-0.j        ,
          0.18479485-0.06170388j,  0.18479485+0.06170388j],
        [ 0.05649337-0.02435237j,  0.05649337+0.02435237j,
          0.00114418+0.63178732j,  0.00114418-0.63178732j,
         -0.08982334-0.2983131j , -0.08982334+0.2983131j ],
        [-0.20602622+0.00834137j, -0.20602622-0.00834137j,
          0.17465813+0.110003j  ,  0.17465813-0.110003j  ,
          0.64410985+0.j        ,  0.64410985-0.j        ],
        [ 0.07264157+0.44550322j,  0.07264157-0.44550322j,
          0.09004153-0.21667149j,  0.09004153+0.21667149j,
          0.03005386-0.49019674j,  0.03005386+0.49019674j],
        [ 0.66791704+0.j        ,  0.66791704-0.j        ,
          0.04918575+0.020569