In [9]:
### There are two losses: one for training, one for validation
### Training for 1000times reduce the training loss only
### https://github.com/zhd96/pi-vae/blob/main/examples/pi-vae_macaque_data.ipynb
### mcp = ModelCheckpoint(model_chk_path, monitor="val_loss", save_best_only=True, save_weights_only=True)
### The validation loss is critical here
import sys, os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import joblib as jl
import cebra.datasets
from cebra import CEBRA
import scipy.io as sio
from sklearn.linear_model import LinearRegression,LogisticRegression
import sklearn.metrics
import torch

sys.path.insert(0, './third_party/pivae')
import pivae_code.datasets, pivae_code.conv_pi_vae, pivae_code.pi_vae
train_percent = 0.60
valid_percent = 0.20
test_percent = 0.20

dur = 35
n_conds = 8
seed = 42
embed_dimension = 3
batch_size = 200
np.random.seed(seed)
iterations = 50
angle_to_new_value = {-180: 4,-135: 5,-90: 6,-45: 7,0: 0,45: 1,90: 2,135: 3,180: 4}
directory = "./data/SU_3S1/"
files = os.listdir(directory)


In [10]:
def dataset_2D_to_3D(dataset_2D):
    # data = train_set.neural.numpy()  # [time_bins, neurons]
    time_bins, neurons = dataset_2D.shape
    receptive_field_size = 10  # Total bins in receptive field
    half_window = receptive_field_size // 2
    dataset_3D = np.zeros((time_bins, neurons, receptive_field_size))
    for t in range(time_bins):
        for n in range(neurons):
            # Calculate the indices for the receptive field window
            start_idx = max(0, t - half_window)
            end_idx = min(time_bins, t + half_window + 1)

            # Slice the window for neuron n
            window = dataset_2D[start_idx:end_idx, n]

            # Calculate where to place the window in the receptive field dimension
            # Adjusting indices to fit exactly within the receptive field slots
            window_start = half_window - (t - start_idx)
            window_end = window_start + (end_idx - start_idx)

            # Ensure the window fits exactly into the new_data array
            window_start = max(0, window_start)
            window_end = min(receptive_field_size, window_end)
            
            dataset_3D[t, n, window_start:window_end] = window[:window_end - window_start]
    return dataset_3D

def to_batch_list(x, y, batch_size):
    x = x.squeeze()
    ### print(x.shape) ### (6885/1390/1903, 120, 10)
    if len(x.shape) == 3:
        x = x.transpose(0,2,1) 
        print(x.shape) ### (6885/1390/1903, 10, 120)
    x_batch_list = np.array_split(x, int(len(x) / batch_size))
    print(int(len(x) / batch_size)) ### 6885/1390/1903 divided by batch-size===34/6/9
    y_batch_list = np.array_split(y, int(len(y) / batch_size))
    return x_batch_list, y_batch_list

def custom_data_generator(x_all, u_one_hot):
    while True:
        for ii in range(len(x_all)):
            #print(x_all[ii].shape)
            #print(u_one_hot[ii].shape)
            yield ([x_all[ii], u_one_hot[ii]], None)
            
      

In [11]:
for file in files:
    if "Han" in file:   
        mat_contents = sio.loadmat(os.path.join(directory, file))
        filename_parts = file.split("_neural_con_dis_index")
        new_filename = filename_parts[0][:7]+filename_parts[0][-6:] + "_embed_S1.npz"
        file_save = os.path.join(directory, new_filename)
        neural = mat_contents['neural_S1']
        continuous_index_2d = mat_contents['continuous_index']*10
        discrete_index = mat_contents['discrete_index']
        if np.max(discrete_index)>10:
            print('map angle to new value')
            vectorized_map = np.vectorize(lambda x: angle_to_new_value[x])
            discrete_index = vectorized_map(discrete_index)
        continuous_index = np.hstack((continuous_index_2d, discrete_index.astype(int)*45)) ## S1 data is <uint8> 
        N_bins, N_neurons = neural.shape
        train_end = int(N_bins * train_percent)// dur * dur
        valid_end = train_end + int(N_bins * valid_percent)
        valid_end = valid_end// dur * dur

        train_neural = neural[:train_end, :]
        Y_train = continuous_index[:train_end, :]
        valid_neural = neural[train_end:valid_end, :]
        Y_valid = continuous_index[train_end:valid_end, :]
        test_neural = neural[valid_end:, :]
        Y_test = continuous_index[valid_end:, :]
        print(np.unique(Y_train[:, 2]))
        X_train = dataset_2D_to_3D(train_neural)
        X_valid = dataset_2D_to_3D(valid_neural)
        X_test = dataset_2D_to_3D(test_neural)

        train_x, train_u = to_batch_list(X_train, Y_train, batch_size)
        train_loader = custom_data_generator(train_x, train_u)

        valid_x, valid_u = to_batch_list(X_valid, Y_valid, batch_size)
        valid_loader = custom_data_generator(valid_x, valid_u)

        test_x, test_u = to_batch_list(X_test, Y_test, batch_size)
        test_loader  = custom_data_generator(test_x, test_u)
        
        conv_pivae = pivae_code.conv_pi_vae.conv_vae_mdl(
                dim_x = N_neurons, ### number of neurons
                dim_z = embed_dimension, ### embedding dimension
                dim_u = 3, ### label(discrete or continuous) dimension; Hippo's position+L+R=3 
                time_window=10,
                gen_nodes=60,
                n_blk=2,
                mdl="poisson",
                disc=False,
                learning_rate=0.00025)      
        s_n = conv_pivae.fit_generator(
            train_loader, ### will call "def custom_data_generator" 
            steps_per_epoch=len(train_x), ### 34
            epochs=iterations, ### iterations
            verbose=1,
            validation_data = valid_loader,
            validation_steps = len(valid_x))
        
        X = np.concatenate(train_x) ### (Xbins, 10=5ms-offset+5ms-offset, Xneurons)
        labels = np.concatenate(train_u) ### (Xbins, position+direction)
        outputs_train = conv_pivae.predict([X, labels])
        X = np.concatenate(test_x) 
        labels = np.concatenate(test_u) 
        outputs_test = conv_pivae.predict([X, labels])
        ### Outputs: post_mean, post_log_var, z_sample,fire_rate, lam_mean, lam_log_var, z_mean, z_log_var
        embed_train = outputs_train[0]
        embed_test = outputs_test[0]
        
        fig = plt.figure(figsize=(6,5))
        ax = plt.subplot(111)
        val_loss = s_n.history['val_loss'][:]
        loss = s_n.history['loss'][:]
        loss = np.array(s_n.history['loss'])
        loss_stable = loss[-10:]
        plt.plot(val_loss, c='deepskyblue', label='val-loss')
        plt.plot(loss, c='blue', label='loss')
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.set_xlabel('Iterations')
        ax.set_ylabel('piVAE Loss')
        plt.legend(bbox_to_anchor=(0.5,0.3), frameon = False )
        plt.title('itr='+str(iterations)+' loss='+str(int(np.mean(loss_stable))))
        new_filename = filename_parts[0][:7]+filename_parts[0][-6:] + "_training_loss.pdf"
        output_path = os.path.join(directory, new_filename)
        plt.savefig(output_path)
        plt.close(fig)
        
        ## %matplotlib notebook
        ## fig = plt.figure(figsize=(4, 2), dpi=250)
        fig = plt.figure(figsize=(6,5))
        ax = fig.add_subplot(121, projection='3d')
        norm = plt.Normalize(vmin=0, vmax=1)
        ax.scatter(embed_train[:, 0],embed_train[:, 1],embed_train[:, 2],
                       c=Y_train[:,2]/360, cmap=plt.cm.hsv,edgecolors='none',norm=norm,alpha=0.75,s=1)
        ax.grid(False)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_zticks([])

        ax = plt.subplot(122, projection='3d')
        idx1, idx2, idx3= 0, 1, 2
        for i in range(n_conds):
            direction_trial = (Y_train[:,2]//45 == i)
            trial_avg = embed_train[direction_trial, :].reshape(-1,dur,embed_dimension).mean(axis=0)
            ax.scatter(trial_avg[:, idx1], trial_avg[:, idx2], trial_avg[:, idx3],
                       color=plt.cm.hsv(1 / n_conds * i), edgecolors='none', alpha=0.75, s=2)
            ax.plot(trial_avg[:, idx1],trial_avg[:, idx2], trial_avg[:, idx3],
                color=plt.cm.hsv(1 / n_conds * i),linewidth=0.5, alpha=0.75)  
        ax.grid(False)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_zticks([])
        new_filename = filename_parts[0][:7]+filename_parts[0][-6:] + "_Embedding.pdf"
        output_path = os.path.join(directory, new_filename)
        plt.savefig(output_path)
        plt.close(fig)
        
        train_trial = int(Y_train.shape[0]/dur)
        velocity_reshaped = Y_train[:, 0:2].reshape(train_trial, dur, 2)
        locations = np.cumsum(velocity_reshaped, axis=1)
        truth_XY = locations.reshape(train_trial*dur, 2)
        
        X = embed_train
        y = Y_train[:, 0:2]
        reg = LinearRegression().fit(X, y) ### n_jobs = 8 >>> unnecessary
        pred_vel = reg.predict(X) 
        y_C = Y_train[:, 2]
        LogisticReg = LogisticRegression(max_iter=500, multi_class='multinomial', solver='lbfgs')
        LogisticReg.fit(X, y_C)
        pred_dir = LogisticReg.predict(X)

        velocity_reshaped = pred_vel.reshape(train_trial, dur, 2)
        locations = np.cumsum(velocity_reshaped, axis=1)
        pred_XY = locations.reshape(train_trial*dur, 2)

        posi_r2 = sklearn.metrics.r2_score(truth_XY, pred_XY) ### proportion of total variation explained by model
        vel_r2 = sklearn.metrics.r2_score(Y_train[:, 0:2], pred_vel) ### == reg.score(X, y)

        differences = 1*abs(pred_dir - Y_train[:, 2])
        angle_diffs = np.where(differences > 180, 360 - differences, differences)

        fig = plt.figure(figsize=(10, 5))
        ax1 = plt.subplot(121)
        ax1.scatter(truth_XY[:, 0], truth_XY[:, 1], alpha=1, 
                    color=plt.cm.hsv(1/360*Y_train[:, 2]), s=0.3)
        ax1.spines["right"].set_visible(False)
        ax1.spines["top"].set_visible(False)
        plt.title('Var-R2 vel='+str(round(reg.score(X, y), 3))+' dir='+str(round(LogisticReg.score(X, y_C), 3))\
                 +' MAE='+str(round(np.mean(angle_diffs),1)))
        ax2 = plt.subplot(122)
        ax2.scatter(pred_XY[:, 0], pred_XY[:, 1], alpha=1, color=plt.cm.hsv(1/360*pred_dir), s=0.3)
        ax2.spines["right"].set_visible(False)
        ax2.spines["top"].set_visible(False)

        plt.title('True vs Pred-R2 vel='+str(round(vel_r2, 3))+' pos='+str(round(posi_r2, 3)))
        new_filename = filename_parts[0][:7]+filename_parts[0][-6:] + "_train_variance_"+str(round(posi_r2, 3))+".pdf"
        output_path = os.path.join(directory, new_filename)
        plt.savefig(output_path)
        plt.close()

        np.savez(file_save,
                cebra_veldir_train=embed_train,
                 cebra_veldir_test=embed_test,
                 continuous_index_train=Y_train,
                 continuous_index_test=Y_test)

[  0.  45.  90. 135. 180. 225. 270. 315.]
(6125, 10, 70)
30
(2030, 10, 70)
10
(2065, 10, 70)
10




Model: "vae"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_4 (InputLayer)            (None, 10, 70)       0                                            
__________________________________________________________________________________________________
input_6 (InputLayer)            (None, 3)            0                                            
__________________________________________________________________________________________________
encoder (Model)                 [(None, 3), (None, 3 28856       input_4[0][0]                    
                                                                 input_6[0][0]                    
__________________________________________________________________________________________________
decoder (Model)                 (None, 10, 70)       1015168     encoder[1][2]                  