In [None]:
# load dependencies
import h5py
import tensortools as tt # toolbox for TCA
import os
import numpy as np
import importlib as imp
import matplotlib.pyplot as plt

import utils
import load_preprocess_data

In [None]:
# indicate a file to analyze
fname = 'VJ_OFCVTA_7_260_D6'
fdir = 'C:\\2pData\\Vijay data\\VJ_OFCVTA_7_D8_trained\\'
sima_h5_path = os.path.join(fdir, fname + '_sima_mc.h5')

# set the sampling rate
fs = 5

In [None]:
# trial windowing 
trial_start_end_seconds = np.array([-1, 3]) # trial windowing in seconds relative to ttl-onset/trial-onset
conditions = ['minus', 'plus_rewarded']

# if helper scripts have been updated, can refresh them with this line
imp.reload(utils)
imp.reload(load_preprocess_data);

In [None]:
## extract trial data into xarray

num_avg_groups = 5.0 
""" number of segments to split trials over. Ie. Because single trial plots in state space is noisy, 
    let's break the trials up into groups and average to get less noisier signal.
""" 

data_dict = load_preprocess_data.load(fname, fdir, fs, trial_start_end_seconds, conditions, num_avg_groups)

In [None]:
# Make data input

condition = 'plus_rewarded'

X = data_dict['all_cond']['flattenpix'] # data_dict[condition]['xarr_flatten_xy'].data
#dims are (trial, yx_pix, time)

In [None]:
num_ranks = 30

print('Run the same data through TCA twice to test the consistency of their models')

U = tt.ncp_hals(X, rank=num_ranks, verbose=False) # CP decomposition by classic alternating least squares (ALS).
# The `rank` sets the number of components to be computed.
# output are factor matrices of the fitted results
# objective function is the frobenius norm
V = tt.ncp_hals(X, rank=num_ranks, verbose=False)

# Compare the low-dimensional factors from the two fits.
fig, _, _ = tt.plot_factors(U.factors)
tt.plot_factors(V.factors, fig=fig)

print('Align the two runs/iterations and plot')

# Align the two fits and print a similarity score.
sim = tt.kruskal_align(U.factors, V.factors, permute_U=True, permute_V=True)
print(sim)

# Plot the results again to see alignment.
fig, ax, po = tt.plot_factors(U.factors)
tt.plot_factors(V.factors, fig=fig)

# Show plots.
plt.show()

# Test Error and Model Similarity (within data set) As Function of rank

In [None]:
%time 

# Fit ensembles of tensor decompositions.

methods = (

  'cp_als',    # fits unconstrained tensor decomposition.

  'ncp_bcd',   # fits nonnegative tensor decomposition.

  'ncp_hals',  # fits nonnegative tensor decomposition.

)

ensembles = {}

"""
    fit_option arguments:
    
        tol : float, optional (default ``tol=1E-5``)
        Stopping tolerance for reconstruction error.

        max_iter : integer, optional (default ``max_iter = 500``)
        Maximum number of iterations to perform before exiting.
    
        min_iter : integer, optional (default ``min_iter = 1``)
        Minimum number of iterations to perform before exiting.

        max_time : integer, optional (default ``max_time = np.inf``)
        Maximum computational time before exiting.

        verbose : bool ``{'True', 'False'}``, optional (default ``verbose=True``)
        Display progress.
"""

for m in methods:

    ensembles[m] = tt.Ensemble(fit_method=m, fit_options=dict(tol=1e-3)) # tolerance: percent error threshold to terminate

    ensembles[m].fit(X, ranks=range(1, num_ranks), replicates=3)

# Plotting options for the unconstrained and nonnegative models.

plot_options = {

  'cp_als': {

    'line_kw': {

      'color': 'black',
      'label': 'cp_als',

    },

    'scatter_kw': {

      'color': 'black',

    },
  },

  'ncp_hals': {

    'line_kw': {

      'color': 'blue',
      'alpha': 0.5,
      'label': 'ncp_hals',

    },

    'scatter_kw': {

      'color': 'blue',
      'alpha': 0.5,

    },
  },

  'ncp_bcd': {

    'line_kw': {

      'color': 'red',
      'alpha': 0.5,
      'label': 'ncp_bcd',

    },

    'scatter_kw': {

      'color': 'red',
      'alpha': 0.5,

    },
  },
}



# Plot similarity and error plots.

plt.figure()

for m in methods:

    tt.plot_objective(ensembles[m], **plot_options[m])

plt.legend()



plt.figure()

for m in methods:

    tt.plot_similarity(ensembles[m], **plot_options[m])

plt.legend()



plt.show()

In [None]:
V.factors[time_factor][:,component]

In [None]:

# container to store data relevant to the 3d plot
s_space_dict = {}
s_space_dict['line_cmaps'] = ['autumn','winter']

# determine alpha for each trial (encoding time block)
trial_group_alphas = np.linspace(0.3, 1, num_avg_groups)

# loop through conditions
for idx_condition, condition in enumerate(conditions):

    # set up variables for this condition
    s_space_dict[condition] = {} # sub-dict for condition-specific data 
    n = data_dict[condition]['num_samples'] # number of data points
    cmap_lc = s_space_dict['line_cmaps'][idx_condition] # grab this condition's line cmap
    
    #set x,y,z, time data
    x = data_dict[condition]['Xt'][:,0]
    y = data_dict[condition]['Xt'][:,1]
    z = data_dict[condition]['Xt'][:,2]
    svec = np.arange(0,n) # sample vector; important for encoding color as time
    # USER DEFINE: which dimension to encode color; can be x, y, z, svec
    color_encode = svec 
    
    # update x,y,z limits based on this condition's data
    if idx_condition == 0:
        xlim = [np.min(x), np.max(x)]; ylim = [np.min(y), np.max(y)]; zlim = [np.min(z), np.max(z)]
    else:
        xlim = update_lims([np.min(x), np.max(x)], xlim); 
        ylim = update_lims([np.min(y), np.max(y)], ylim); 
        zlim = update_lims([np.min(z), np.max(z)], zlim);
    
    ### Create line segment objects for ALL TRIAL-AVGED DATA ###
    s_space_dict[condition]['line_collect'] = make_line_collection(x, y, z, color_encode, cmap_lc)
    
    ### Create line segment objects for TRIAL-BLOCKED/GROUPED DATA ###
    s_space_dict[condition]['line_collect_trial'] = {}
    for idx, trial in enumerate(data_dict[condition]['Xt_trial']):
        
        # make the line segment object for this trial group
        s_space_dict[condition]['line_collect_trial'][idx] = make_line_collection(trial[:,0], trial[:,1], trial[:,2], 
                                                                                  color_encode, 
                                                                                  cmap_lc, 
                                                                                  trial = True, 
                                                                                  alpha = trial_group_alphas[idx])

        # update x,y,z limits based on this "trial's" data
        xlim = update_lims([np.min(trial[:,0]), np.max(trial[:,0])], xlim); 
        ylim = update_lims([np.min(trial[:,1]), np.max(trial[:,1])], ylim);
        zlim = update_lims([np.min(trial[:,2]), np.max(trial[:,2])], zlim);
        
# create plot and set attributes
fig = plt.figure(figsize = (9,7))
ax = fig.gca(projection='3d')
ax.set_xlim(xlim); ax.set_ylim(ylim); ax.set_zlim(zlim)
plt.title('PCA State Space')
ax.set_xlabel('PC0', fontsize = 20); ax.set_ylabel('PC1', fontsize = 20); ax.set_zlabel('PC2', fontsize = 20);

# plot the line segments
for condition in conditions:
    
    # for all trial-avged data
    ax.add_collection3d(s_space_dict[condition]['line_collect'], zs=z, zdir='z')

    # for trial group data
    for trial_lc in s_space_dict[condition]['line_collect_trial'].values():
        
        ax.add_collection3d(trial_lc, zs=z, zdir='z')
        
ax.legend(['All Trial Avg','Trial 1-10 Avg','Trial 11-20 Avg',
           'Trial 21-30 Avg','Trial 31-40 Avg','Trial 41-50 Avg']);