In [None]:
# load dependencies
import h5py
import tensortools as tt # toolbox for TCA
import os
import numpy as np
import importlib as imp
import matplotlib.pyplot as plt

# 3d state space plot
from matplotlib.colors import ListedColormap, BoundaryNorm    
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Line3DCollection

import utils
import load_preprocess_data

In [None]:
# indicate a file to analyze
fname = 'VJ_OFCVTA_7_260_D6'
fdir = 'C:\\2pData\\Vijay data\\VJ_OFCVTA_7_D8_trained\\'
sima_h5_path = os.path.join(fdir, fname + '_sima_mc.h5')

# set the sampling rate
fs = 5

In [None]:
# trial windowing 
trial_start_end_seconds = np.array([-1, 3]) # trial windowing in seconds relative to ttl-onset/trial-onset
conditions = ['minus', 'plus_rewarded']

# if helper scripts have been updated, can refresh them with this line
imp.reload(utils)
imp.reload(load_preprocess_data);

In [None]:
## extract trial data into xarray

num_avg_groups = 5.0 
""" number of segments to split trials over. Ie. Because single trial plots in state space is noisy, 
    let's break the trials up into groups and average to get less noisier signal.
""" 

data_dict = load_preprocess_data.load(fname, fdir, fs, trial_start_end_seconds, conditions, num_avg_groups)

In [None]:
# Make synthetic dataset.
condition = 'plus_rewarded'
R = 20 # dimensions/rank
# R is number of components

X = data_dict['all_cond']['flattenpix'].data
#dims are (trial, yx_pix, time)

In [None]:
# Fit tensor decomposition 

tca_out = tt.ncp_hals(X, rank=R, verbose=False) # CP decomposition by classic alternating least squares (ALS).
# The `rank` sets the number of components to be computed.
# output are factor matrices of the fitted results
# objective function is the frobenius norm


In [None]:
def predict(factors):

    max_rank = factors[0].shape[-1]
    prediction = np.zeros([ff.shape[0] for ff in factors])
    
    for rank in range(max_rank):
        for idx, dim in enumerate(factors):
            if idx == 0:
                outer_prod = dim[:, rank][None]
            elif idx == 1:
                outer_prod = outer_prod.T @ dim[:, rank][None]
            else:
                outer_prod = outer_prod[..., None] @ dim[:, rank][None]
        prediction += outer_prod
        
    return prediction


In [None]:
def obj_frob_norm(X, pred):
    
    normX = np.linalg.norm(X)
    
    return np.linalg.norm(X - pred) / normX

# tca_out.factors.factors : contains the outer product vectors
# predict(tca_out.factors.factors) is equivalent to tca_out.factors.full()
obj_frob_norm(X, predict(tca_out.factors.factors))

In [None]:
print(tca_out.factors[0].shape) # trial factor
print(tca_out.factors[1].shape) # pixel factor
print(tca_out.factors[2].shape) # time factor



In [None]:
fig, _, _ = tt.plot_factors(tca_out.factors)
tt.plot_factors(tca_out.factors, fig=fig);

In [None]:
# function to make a figure with subplots of heatmaps
def plot_img_vectorized_component(n_columns, data, original_dims):
    
    clims = [np.min(data), np.max(data)]
    
    num_comps = data.shape[0]
    n_rows = int(np.ceil(num_comps/n_columns))
   
    fig, ax = plt.subplots(nrows=n_rows, ncols=n_columns, figsize = (15, n_rows*4))
    for iComp in range(num_comps):
        
        if n_rows == 1:
            subplot_index = iComp
        else:
            subplot_index = np.unravel_index(iComp, (n_rows, n_columns)) # turn int index to a tuple of array coordinates
        
        title = f"Comp #{iComp}"
        pc_pixel_weights = data[iComp,:].reshape(original_dims)

        im = utils.subplot_heatmap(ax[subplot_index], title, pc_pixel_weights, cmap = 'Reds', clims = clims)
     
    fig.colorbar(im, ax = ax, shrink = 0.5)

In [None]:
pixel_factor = 1
time_factor = 2

n_columns = 3

original_dims = data_dict[condition]['data'].shape[1:3] # data dim in format of trial,y,x,sample

plot_img_vectorized_component(n_columns, np.transpose(tca_out.factors[pixel_factor]), original_dims)

In [None]:
component = 0

tseries = tca_out.factors[time_factor][:,component]
plt.figure()
plt.plot(data_dict['trial_tvec'], tseries)

In [None]:
def make_line_collection(x, y, z, color_encode, cmap, trial = False, alpha = 1.0):

    # Create a set of line segments
    points = np.array([x, y, z]).T.reshape(-1, 1, 3)
    segments = np.concatenate([points[:-1], points[1:]], axis=1)

    # Create the 3D-line collection object
    lc = Line3DCollection(segments, cmap=plt.get_cmap(cmap),
                        norm=plt.Normalize(np.min(color_encode), np.max(color_encode))) # set LUT for segment colors
    lc.set_array(color_encode) # set the dimension and values for color encoding
    
    
    # all trial averaged lines will have markers
    if trial == False:
        lc.set_linestyle(':')
        lc.set_linewidth(4)
        
    # trial group lines should encode segment number in alpha
    if trial == True:
        lc.set_alpha(alpha)
        lc.set_linewidth(1.5)
        
    return lc

In [None]:

# container to store data relevant to the 3d plot
s_space_dict = {}
s_space_dict['line_cmaps'] = ['autumn','winter']

# determine alpha for each trial (encoding time block)
trial_group_alphas = np.linspace(0.3, 1, num_avg_groups)

# loop through conditions
for idx_condition, condition in enumerate([condition]):

    # set up variables for this condition
    s_space_dict[condition] = {} # sub-dict for condition-specific data 
    n = data_dict[condition]['num_samples'] # number of data points
    cmap_lc = s_space_dict['line_cmaps'][idx_condition] # grab this condition's line cmap
    
    #set x,y,z, time data
    x = tca_out.factors[time_factor][:,0]
    y = tca_out.factors[time_factor][:,1]
    z = tca_out.factors[time_factor][:,2]
    svec = np.arange(0,n) # sample vector; important for encoding color as time
    # USER DEFINE: which dimension to encode color; can be x, y, z, svec
    color_encode = svec 
    
    # update x,y,z limits based on this condition's data
    if idx_condition == 0:
        xlim = [np.min(x), np.max(x)]; ylim = [np.min(y), np.max(y)]; zlim = [np.min(z), np.max(z)]
    else:
        xlim = update_lims([np.min(x), np.max(x)], xlim); 
        ylim = update_lims([np.min(y), np.max(y)], ylim); 
        zlim = update_lims([np.min(z), np.max(z)], zlim);
    
    ### Create line segment objects for ALL TRIAL-AVGED DATA ###
    s_space_dict[condition]['line_collect'] = make_line_collection(x, y, z, color_encode, cmap_lc)
    
    ### Create line segment objects for TRIAL-BLOCKED/GROUPED DATA ###
#     s_space_dict[condition]['line_collect_trial'] = {}
#     for idx, trial in enumerate(data_dict[condition]['Xt_trial']):
        
#         # make the line segment object for this trial group
#         s_space_dict[condition]['line_collect_trial'][idx] = make_line_collection(trial[:,0], trial[:,1], trial[:,2], 
#                                                                                   color_encode, 
#                                                                                   cmap_lc, 
#                                                                                   trial = True, 
#                                                                                   alpha = trial_group_alphas[idx])

#         # update x,y,z limits based on this "trial's" data
#         xlim = update_lims([np.min(trial[:,0]), np.max(trial[:,0])], xlim); 
#         ylim = update_lims([np.min(trial[:,1]), np.max(trial[:,1])], ylim);
#         zlim = update_lims([np.min(trial[:,2]), np.max(trial[:,2])], zlim);
        
# create plot and set attributes
fig = plt.figure(figsize = (9,7))
ax = fig.gca(projection='3d')
ax.set_xlim(xlim); ax.set_ylim(ylim); ax.set_zlim(zlim)
plt.title('TCA State Space')
ax.set_xlabel('TC0', fontsize = 20); ax.set_ylabel('TC1', fontsize = 20); ax.set_zlabel('TC2', fontsize = 20);

# plot the line segments
for condition in [condition]:
    
    # for all trial-avged data
    ax.add_collection3d(s_space_dict[condition]['line_collect'], zs=z, zdir='z')

#     # for trial group data
#     for trial_lc in s_space_dict[condition]['line_collect_trial'].values():
        
#         ax.add_collection3d(trial_lc, zs=z, zdir='z')
        
ax.legend(['All Trial Avg','Trial 1-10 Avg','Trial 11-20 Avg',
           'Trial 21-30 Avg','Trial 31-40 Avg','Trial 41-50 Avg']);