In [3]:
import os, re
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression,LogisticRegression
import sklearn.metrics
from scipy.linalg import orthogonal_procrustes
from itertools import permutations, combinations
dur = 35
emb_dim = 3
N_angles = 8

directory = './data_NER/S1 3/NER_Han2017/NPZ/'
name_range = slice(-28, -22)
file_save = "./data_NER/Fig7/NER_Han2017.npz"

# directory = './data_NER/S1 3/Cebra_Han2017/NPZ/'
# name_range = slice(-28, -22)
# file_save = "./data_NER/Fig7/Cebra_Han2017.npz"

# directory = './data_NER/S1 3/piVAE_Han2017/NPZ/'
# name_range = slice(-19, -13)
# file_save = "./data_NER/Fig7/piVAE_Han2017.npz"

def get_best_R(R_all, emb_A, emb_A_8angle_align):
    determinants = [np.linalg.det(R_all[:, :, i]) for i in range(R_all.shape[2])]
    positive_dets = [det for det in determinants if det >= 0]
    negative_dets = [det for det in determinants if det < 0]

    if len(positive_dets)>0:
        target_dets = positive_dets
        differences = [abs(abs(det) - 1) for det in target_dets]
        min_index = np.argmin(differences)
        best_R_index_p = determinants.index(positive_dets[min_index])
        best_R_p = R_all[:, :, best_R_index_p]
        emb_A_whole_align_p = np.matmul(emb_A, best_R_p)
        align_diff_p = np.sum(abs(emb_A_whole_align_p-emb_A_8angle_align))
        ## print('diff positive detR=', align_diff_p)
    elif len(positive_dets) == 0:
        align_diff_p = 5000000 ### arbitory value
        
    if len(negative_dets)>0:
        target_dets = negative_dets
        differences = [abs(abs(det) - 1) for det in target_dets]
        min_index = np.argmin(differences)
        best_R_index_n = determinants.index(negative_dets[min_index])
        best_R_n = R_all[:, :, best_R_index_n]
        emb_A_whole_align_n = np.matmul(emb_A, best_R_n)
        align_diff_n = np.sum(abs(emb_A_whole_align_n-emb_A_8angle_align))
        ## print('diff negative detR=', align_diff_n)
    elif len(negative_dets) == 0:
        align_diff_n = 5000000
        
    if align_diff_p<align_diff_n:
        best_R = best_R_p
        ## print('Using positive R')
    elif align_diff_p>align_diff_n:
        best_R = best_R_n
        ## print('Using negative R')
    return best_R


def cross_decode(file_path1, file_path2):
    Monkey_A = np.load(file_path1)
    XYTarget_A = np.concatenate((Monkey_A['continuous_index_train'], Monkey_A['continuous_index_test']), axis=0)
    emb_A = np.concatenate((Monkey_A['cebra_veldir_train'], Monkey_A['cebra_veldir_test']), axis=0)
    if np.max(XYTarget_A[:, 2])>10: ### angles in 0-45-90-...315degrees
        XYTarget_A[:, 2] = XYTarget_A[:, 2]/45
        
    Monkey_B = np.load(file_path2)
    XYTarget_B = np.concatenate((Monkey_B['continuous_index_train'], Monkey_B['continuous_index_test']), axis=0)
    emb_B = np.concatenate((Monkey_B['cebra_veldir_train'], Monkey_B['cebra_veldir_test']), axis=0)
    if np.max(XYTarget_B[:, 2])>10:
        XYTarget_B[:, 2] = XYTarget_B[:, 2]/45
    
    train_trial_A = int(Monkey_A['continuous_index_train'].shape[0]/dur)
    test_trial_A = int(Monkey_A['continuous_index_test'].shape[0]/dur)
    train_trial_B = int(Monkey_B['continuous_index_train'].shape[0]/dur)
    test_trial_B = int(Monkey_B['continuous_index_test'].shape[0]/dur)
    
    R_all = np.zeros((emb_dim, emb_dim, N_angles))
    for a in range(N_angles):
        direction_trial = (XYTarget_A[:, 2] == a)
        trial_avg_A = emb_A[direction_trial, :].reshape(-1,dur,emb_dim).mean(axis=0)
        direction_trial = (XYTarget_B[:, 2] == a)
        trial_avg_B = emb_B[direction_trial, :].reshape(-1,dur,emb_dim).mean(axis=0)
        R, sca = orthogonal_procrustes(trial_avg_A, trial_avg_B) ### both are (dur, 3emb-dim)
        R_all[:,:, a] = R
        det_R = np.linalg.det(R)
    trial_arrays = []
    for i in range(N_angles):
        direction_trial = (XYTarget_A[:, 2] == i)
        trial_A = emb_A[direction_trial, :].reshape(-1,dur,emb_dim)
        trial_A = np.matmul(trial_A, R_all[:,:,i])
        trial_arrays.append((direction_trial, trial_A))
    emb_A_8angle_align = np.empty_like(emb_A)
    for mask, trial_data in trial_arrays: ### loop-through 8 times=angles
        flat_data = trial_data.reshape(-1, emb_dim) ### (n-trials*dur, 3emb-dim)
        emb_A_8angle_align[mask, :] = flat_data 
     
    emb_A_whole_align = np.matmul(emb_A, get_best_R(R_all, emb_A, emb_A_8angle_align))
    
    continuous_index_train = XYTarget_A[:train_trial_A*dur, :]
    cebra_veldir_train = emb_A_whole_align[:train_trial_A*dur, :] ####***** three choices here *****####
    continuous_index_test_B = XYTarget_B[-test_trial_B*dur:, :]
    cebra_veldir_test_B = emb_B[-test_trial_B*dur:, :]
    
    X = cebra_veldir_train
    y = continuous_index_train[:, 0:2]
    y_C = continuous_index_train[:, 2]
    reg = LinearRegression().fit(X, y) ### n_jobs = 8 >>> unnecessary
    pred_vel = reg.predict(X) 
    LogisticReg = LogisticRegression(max_iter=500, multi_class='multinomial', solver='lbfgs')
    LogisticReg.fit(X, y_C)
    ###******** this part will use previous trained "reg & LogisticReg" ###********
    ###******** this part will use previous trained "reg & LogisticReg" ###********
    X = cebra_veldir_test_B
    y = continuous_index_test_B[:, 0:2]
    y_C = continuous_index_test_B[:, 2]
    pred_vel_test = reg.predict(X) 
    velocity_reshaped = y.reshape(test_trial_B, dur, 2)
    locations = np.cumsum(velocity_reshaped, axis=1)
    truth_XY = locations.reshape(test_trial_B*dur, 2)
    velocity_reshaped = pred_vel_test.reshape(test_trial_B, dur, 2)
    locations = np.cumsum(velocity_reshaped, axis=1)
    pred_XY = locations.reshape(test_trial_B*dur, 2)

    posi_test_r2 = sklearn.metrics.r2_score(truth_XY, pred_XY)
    vel_test_r2 = sklearn.metrics.r2_score(pred_vel_test, y)

    pred_dir_test = LogisticReg.predict(X)
    pred_dir_acc = np.zeros((dur, test_trial_B))
    for i in range(test_trial_B): ## test_trial
        t_pred = pred_dir_test[dur*(i):dur*(i+1)]
        t_truth = y_C[dur*(i):dur*(i+1)]
        pred_dir_acc[np.where(t_pred == t_truth), i] = 1
    acc_time = 100*np.sum(pred_dir_acc, axis=1)/test_trial_B
    pred_max_acc = np.max(acc_time)
    return posi_test_r2,vel_test_r2,pred_max_acc,acc_time

def self_decode(file_path1):
    Monkey_A = np.load(file_path1)
    X = Monkey_A['cebra_veldir_train']
    y = Monkey_A['continuous_index_train'][:, 0:2]
    y_C = Monkey_A['continuous_index_train'][:, 2]

    train_trial_A = int(y.shape[0]/dur)
    velocity_reshaped = y.reshape(train_trial_A, dur, 2)
    locations = np.cumsum(velocity_reshaped, axis=1)
    truth_XY = locations.reshape(train_trial_A*dur, 2)
    
    reg = LinearRegression().fit(X, y) ### n_jobs = 8 >>> unnecessary
    pred_vel = reg.predict(X) 
    LogisticReg = LogisticRegression(max_iter=500, multi_class='multinomial', solver='lbfgs')
    LogisticReg.fit(X, y_C)
    ###******** this part will use previous trained "reg & LogisticReg" ###********
    ###******** this part will use previous trained "reg & LogisticReg" ###********
    X = Monkey_A['cebra_veldir_test']
    y = Monkey_A['continuous_index_test'][:,0:2]
    y_C = Monkey_A['continuous_index_test'][:, 2]
    
    pred_vel_test = reg.predict(X)
    test_trial_A = int(y.shape[0]/dur)
    velocity_reshaped = y.reshape(test_trial_A, dur, 2)
    locations = np.cumsum(velocity_reshaped, axis=1)
    truth_XY = locations.reshape(test_trial_A*dur, 2)
    velocity_reshaped = pred_vel_test.reshape(test_trial_A, dur, 2)
    locations = np.cumsum(velocity_reshaped, axis=1)
    pred_XY = locations.reshape(test_trial_A*dur, 2)

    posi_test_r2 = sklearn.metrics.r2_score(truth_XY, pred_XY)
    vel_test_r2 = sklearn.metrics.r2_score(pred_vel_test, y) ## default is "uniform_average"
    print('vel_test_r2=', vel_test_r2) ## two values for XY if using 'raw_values';otherwise, average X&Y

    pred_dir_test = LogisticReg.predict(X)     
    pred_dir_matches = pred_dir_test == y_C
    pred_dir_test_acc = 100*np.sum(pred_dir_matches)/test_trial_A/dur
    differences = 45*abs(pred_dir_test - y_C)
    angle_diffs = np.where(differences > 180, 360 - differences, differences)
    dir_test_r2 = sklearn.metrics.r2_score(pred_dir_test, y_C)
    
    pred_dir_acc = np.zeros((dur, test_trial_A))
    for i in range(test_trial_A):
        t_pred = pred_dir_test[dur*(i):dur*(i+1)]
        t_truth = y_C[dur*(i):dur*(i+1)]
        pred_dir_acc[np.where(t_pred == t_truth), i] = 1
    acc_time = 100*np.sum(pred_dir_acc, axis=1)/test_trial_A
    pred_max_acc = np.max(acc_time)

    return posi_test_r2, vel_test_r2, pred_max_acc, acc_time, pred_vel_test, y_C
            
### List all files in the directory
files = [os.path.join(directory, f) for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]
n = len(files)
pos_R_2D = np.zeros((n, n))
vel_R_2D = np.zeros((n, n))
peak_acc_2D = np.zeros((n, n))
acc_time_2D = np.zeros((dur, n*n))
date_subjects = []
n_compare = 0

def list_and_sort_files(directory):
    # List all files
    files = [os.path.join(directory, f) for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]
    def extract_date(filename):
        # This regex matches 6 consecutive digits that likely represent a date in YYMMDD format
        match = re.search(r'(\d{6})', os.path.basename(filename))
        date = match.group(0) if match else '000000'  # Default to '000000' if no date is found
        # Standardize to YYMMDD if necessary
        year_prefix = '20'  # assuming all dates are after the year 2000
        return int(year_prefix + date) if len(date) == 6 else int(date)
    sorted_files = sorted(files, key=extract_date)
    return sorted_files
sorted_files=list_and_sort_files(directory)

for i, file1 in enumerate(sorted_files):
    # print("Reading file:", file1)
    for j, file2 in enumerate(sorted_files):
        if i != j:    ### with-others
            posi_test_r2,vel_test_r2,pred_max_acc,acc_time = cross_decode(file1, file2)
            # print('#'+str(n_compare+1)+' cross compare')
        elif i == j:  ### with-itself
            posi_test_r2,vel_test_r2,pred_max_acc,acc_time, vel_pred, vel_real = self_decode(file1)
            print('#'+str(n_compare+1)+' self compare')
        pos_R_2D[i, j] = posi_test_r2
        vel_R_2D[i, j] = vel_test_r2
        peak_acc_2D[i, j] = pred_max_acc
        acc_time_2D[:, n_compare] = acc_time
        if "M1PMd" in directory:
            date = file1[-29:-23]
            suffix = file1[-7:-5]
            date_subjects.append(f"{date}{suffix}")  
        elif "M1PMd" not in directory:
            # print(file1[name_range])
            date_subjects.append(file1[name_range])
        n_compare = n_compare+1
# print('label of date:', np.unique(date_subjects))

np.savez(file_save, date_subjects = date_subjects,
         pos_R_2D=pos_R_2D, vel_R_2D=vel_R_2D, peak_acc_2D=peak_acc_2D, acc_time_2D=acc_time_2D)

vel_test_r2= 0.702426333732958
#1 self compare
vel_test_r2= 0.6414935940186965
#5 self compare
vel_test_r2= 0.6705147422030597
#9 self compare
