In [1]:
import sys
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import joblib as jl
import cebra.datasets
from cebra import CEBRA
import scipy.io as sio
from sklearn.neighbors import KNeighborsRegressor, KNeighborsClassifier
from sklearn.model_selection import GridSearchCV
import sklearn.metrics
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") ### select either 0 or 1
print(f'Using device: {device}')

Using device: cuda:1


In [2]:
dur = 35
iterations = 20000
output_dimension = 3
cebra_veldir_model = CEBRA(model_architecture='offset1-model',
                           batch_size=512,
                           learning_rate = 0.0001,
                           temperature = 1,
                           output_dimension = output_dimension,
                           max_iterations=iterations,
                           distance='cosine',
                           conditional='time_delta',
                           device='cuda_if_available',
                           verbose=True,
                           time_offsets=1)
def split_data(neural, continuous_index, train_trial):
            split_idx = train_trial*dur 
            neural_train = neural[:split_idx]
            neural_test = neural[split_idx:]
            continuous_index_train = continuous_index[:split_idx]
            continuous_index_test = continuous_index[split_idx:]
            return neural_train,neural_test,continuous_index_train,continuous_index_test
        
directory = "./data/SU_3S1/"
files = os.listdir(directory)
for file in files:
    if "Han" in file or "Lando" in file: ### 
        mat_contents = sio.loadmat(os.path.join(directory, file))
        filename_parts = file.split("_neural_con_dis_index")
        new_filename = filename_parts[0] + "_embed_"+str(iterations)+"itr_S1.npz"
        file_save = os.path.join(directory, new_filename)
        print(file_save)

        neural = mat_contents['neural_S1']
        continuous_index_2d = mat_contents['continuous_index']*10
        discrete_index = mat_contents['discrete_index'].astype(int)*45 ### must change data-type here
        
        continuous_index = np.hstack((continuous_index_2d, discrete_index))
        total_trial = int(discrete_index.shape[0]/dur)
        train_trial = int(total_trial*0.8)
        test_trial = total_trial-train_trial
        
        neural_train, neural_test, continuous_index_train, continuous_index_test = split_data(neural, continuous_index, train_trial)
        target_angle_train = continuous_index_train[:, 2].copy()
        target_angle_test = continuous_index_test[:, 2].copy()
        cebra_veldir_model.fit(neural_train, continuous_index_train)
        cebra_veldir_train = cebra_veldir_model.transform(neural_train)
        cebra_veldir_test  = cebra_veldir_model.transform(neural_test)

        fig = plt.figure(figsize=(6,5))
        ax = plt.subplot(111)
        train_loss = cebra_veldir_model.state_dict_['loss']
        train_loss_stable = train_loss[-10:].numpy()
        ax.plot(train_loss, c='deepskyblue')
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.set_xlabel('Iterations')
        ax.set_ylabel('RnC Loss')
        plt.legend(bbox_to_anchor=(0.5,0.3), frameon = False )
        plt.title('iterations='+str(iterations)+'  final loss='+str(np.mean(train_loss_stable)))
        new_filename = filename_parts[0] + "_train_loss_"+str(iterations)+"itr.pdf" 
        output_path = os.path.join(directory, new_filename)
        plt.savefig(output_path)
        plt.close(fig)

        initial_positions = np.zeros((train_trial, 1, 2))
        velocity_reshaped = continuous_index_train[:, 0:2].reshape(train_trial, dur, 2)
        displacements = np.cumsum(velocity_reshaped, axis=1)
        locations = initial_positions + displacements
        pos_truth = locations.reshape(train_trial*dur, 2)
        truth_XY = pos_truth[:, 0:2]

        knr = GridSearchCV(KNeighborsRegressor(), {'n_neighbors': [2,4,8,16,32,64,128,256,512,1024]}, n_jobs=8)
        knr.fit(cebra_veldir_train, continuous_index_train[:, 0:2]) 
        pred_vel = knr.predict(cebra_veldir_train) 
        knc = GridSearchCV(KNeighborsClassifier(), {'n_neighbors': [2,4,8,16,32,64,128,256,512,1024]}, n_jobs=8)
        knc.fit(cebra_veldir_train, target_angle_train)
        pred_dir = knc.predict(cebra_veldir_train)

        velocity_reshaped = pred_vel.reshape(train_trial, dur, 2)
        displacements = np.cumsum(velocity_reshaped, axis=1)
        locations = initial_positions + displacements
        pred_XY = locations.reshape(train_trial*dur, 2)
        
        posi_r2 = sklearn.metrics.r2_score(truth_XY, pred_XY) ### proportion of total variation explained by model
        vel_r2 = sklearn.metrics.r2_score(continuous_index_train[:, 0:2], pred_vel)
        
        differences = abs(pred_dir - target_angle_train)
        angle_diffs = np.where(differences > 180, 360 - differences, differences)
        fig = plt.figure(figsize=(10, 5))
        ax1 = plt.subplot(121)
        ax1.scatter(truth_XY[:, 0], truth_XY[:, 1], alpha=1, color=plt.cm.hsv(1/360*target_angle_train), s=0.3)
        ax1.spines["right"].set_visible(False)
        ax1.spines["top"].set_visible(False)
        plt.title('GridSearchCV-R2='+str(round(knr.best_score_,3))+' MAE='+str(round(np.mean(angle_diffs),1)))
        
        ax2 = plt.subplot(122)
        ax2.scatter(pred_XY[:, 0], pred_XY[:, 1], alpha=1, color=plt.cm.hsv(1/360*pred_dir), s=0.3)
        ax2.spines["right"].set_visible(False)
        ax2.spines["top"].set_visible(False)
        plt.title('True vs Pred-R2 vel='+str(round(vel_r2, 3))+' pos='+str(round(posi_r2, 3)))
        new_filename = filename_parts[0] + "_Decoding_"+str(iterations)+"itr.pdf"
        output_path = os.path.join(directory, new_filename)
        plt.savefig(output_path)
        plt.close(fig)
        
        idx1, idx2, idx3 = 0, 1, 2
        fig = plt.figure(figsize=(10, 5), dpi=250)
        ax = plt.subplot(121, projection = '3d')
        x = ax.scatter(cebra_veldir_train[:, idx1],
                       cebra_veldir_train[:, idx2],
                       cebra_veldir_train[:, idx3],
                       c=target_angle_train/360, ## [115800]
                       cmap=plt.cm.hsv,
                       edgecolors='none',
                       alpha=0.75,
                       s=0.3)
        ax.grid(False)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_zticks([])
        ax = plt.subplot(122, projection = '3d')
        for i in range(8):
            direction_trial = (target_angle_train//45 == i)
            trial_avg = cebra_veldir_train[direction_trial, :].reshape(-1,dur,output_dimension).mean(axis=0)
            ax.scatter(trial_avg[:, idx1], 
                       trial_avg[:, idx2],
                       trial_avg[:, idx3],
                       color=plt.cm.hsv(1 / 8 * i),
                       edgecolors='none',
                       alpha=0.75,
                       s=3)
            ax.plot(trial_avg[:, idx1], 
                trial_avg[:, idx2],
                trial_avg[:, idx3],
                color=plt.cm.hsv(1 / 8 * i), 
                linewidth=0.5,
                alpha=0.75)  
        ax.grid(False)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_zticks([])
        new_filename = filename_parts[0] + "_Embedding_"+str(iterations)+"itr.pdf"
        output_path = os.path.join(directory, new_filename)
        plt.savefig(output_path)
        plt.close(fig)

        np.savez(file_save,
                 train_loss_stable=train_loss,
                 cebra_veldir_train=cebra_veldir_train,
                 cebra_veldir_test=cebra_veldir_test,
                 neural_train=neural_train,
                 neural_test=neural_test,
                 continuous_index_train=continuous_index_train,
                 continuous_index_test=continuous_index_test)

./data/Han_20171207_embed_20000itr_S1.npz


pos: -0.8666 neg:  6.4072 total:  5.5406 temperature:  1.0000: 100%|█| 20000/200
No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.


./data/Han_20171201_embed_20000itr_S1.npz


pos: -0.8445 neg:  6.4041 total:  5.5596 temperature:  1.0000: 100%|█| 20000/200
No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.


./data/Han_20171204_embed_20000itr_S1.npz


pos: -0.8695 neg:  6.4149 total:  5.5454 temperature:  1.0000: 100%|█| 20000/200
No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
