In [13]:
import os
import h5py
import re
from scipy.io import loadmat
import numpy as np
from preproc import *
from sklearn.preprocessing import MinMaxScaler
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.decomposition import PCA
from dPCA import dPCA
from sklearn.manifold import Isomap
from sklearn.metrics import calinski_harabasz_score, davies_bouldin_score
from scipy.spatial.distance import directed_hausdorff
from itertools import combinations
from scipy.special import factorial
from scipy.stats import f_oneway
from statsmodels.stats.multicomp import pairwise_tukeyhsd
from scipy.stats import ttest_ind
from scipy.stats import ttest_1samp
from scipy.stats import mannwhitneyu
from statsmodels.multivariate.manova import MANOVA
import pandas as pd
import statsmodels.api as sm
from scipy.cluster.hierarchy import linkage, dendrogram, fcluster
from sklearn.cluster import KMeans
from scipy.signal import savgol_filter
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation
from matplotlib.gridspec import GridSpec
from matplotlib.ticker import StrMethodFormatter
from matplotlib.lines import Line2D
from matplotlib.patches import Rectangle
from matplotlib.colors import ListedColormap
from matplotlib.widgets import Slider
from IPython.display import HTML
import textwrap
import pickle
from sklearn.cluster import DBSCAN
from sklearn.preprocessing import StandardScaler


In [14]:
%matplotlib inline
plt.rcParams['font.size'] = 12
day_list = ['20181105', '20181102', '20181101']


In [15]:
# Directory to save figures
figsave_dir = 'figures'
to_save = True
save_name = 'July-2024'

# Set name/annotation to use in plot titles
popl_name = '364 cells'

# Some common constants
num_sess = len(day_list)
num_goals = 6
tbin_size = 0.1

In [16]:
# Common functions
def group_by_trajectory(timeseries: np.array, trajectories: np.array) -> list:
    # num_goals = 6
    grouped = {(i+1, j+1): list() for j in range(num_goals) for i in range(num_goals) if i != j}
    for idx, traj in enumerate(trajectories):
        goal1, goal2 = int(traj[0]), int(traj[1])
        if goal1 == 0:
            continue
        grouped[(goal1, goal2)].append(timeseries[idx])
    return grouped

In [17]:
good_cell_labels = list()
# pattern = re.compile(r'/(\d{8})/session\d+/array0*\d+/channel0*(\d+)/cell0*(\d+)')
with open('data/cell_list_hm.txt', 'r') as file:
    for line in file:
        line = line.strip().split('/')
        good_cell_labels.append(f'{line[5]}ch{str(int(line[8][7:]))}c{str(int(line[9][4:]))}')

In [18]:
place_cell_labels = list()
with open('sigcells/place_cells.txt', 'r') as file:
    for line in file:
        line = line.strip().split('/')
        place_cell_labels.append(f'{line[5]}ch{str(int(line[8][7:]))}c{str(int(line[9][4:]))}')
# View cells
view_cell_labels = list()
with open('sigcells/view_cells.txt', 'r') as file:
    for line in file:
        line = line.strip().split('/')
        view_cell_labels.append(f'{line[5]}ch{str(int(line[8][7:]))}c{str(int(line[9][4:]))}')
# Merge and remove duplicates from both lists
pv_cell_labels = list(set(place_cell_labels).union(set(view_cell_labels)))

In [19]:
all_cell_labels = list()
all_net_information_gain = list()
all_net_information_gain_scaled= list()

for day in day_list:
    with open(f'data/infogain/{day}_data.pkl', 'rb') as file:
        data = pickle.load(file)
        all_cell_labels.extend(data['cell_labels'])
        all_net_information_gain.append(data['net_information_gain'])
        all_net_information_gain_scaled.append(data['net_information_gain_scaled'])

all_net_information_gain = np.vstack(all_net_information_gain)
all_net_information_gain_scaled = np.vstack(all_net_information_gain_scaled)
num_all_cells = len(all_cell_labels)
for i, label in enumerate(all_cell_labels):
    all_cell_labels[i] = all_cell_labels[i][4:]

In [20]:
class ResponseClassifier():
    def __init__(self, model):
        self.model = model

    def fit(self, X, y, thresh):
        self.model.fit(X, y)
        self.thresh = thresh

    def set_thresh(self, thresh):
        self.thresh = thresh

    def predict(self, X):
        classes = self.model.predict(X)
        probs = self.model.predict_proba(X)
        filt = np.where(np.max(probs, axis=1) < self.thresh)
        classes[filt] = 0
        return classes


def train_test_split(X: np.array, y: np.array, split: tuple) -> tuple:
    # itr: current iteration of k-fold validation
    # way: value of k in k-fold validation
    itr, way = split
    X_train, X_test = list(), list()
    y_train, y_test = list(), list()
    for num, obs in enumerate(X):
        if num % way == itr:
            X_test.append(obs)
            y_test.append(y[num])
        else:
            X_train.append(obs)
            y_train.append(y[num])
    # Return format: X_train, X_test, y_train, y_test   
    return np.array(X_train), np.array(X_test), np.array(y_train), np.array(y_test)

def confusion_matrix(y_pred: np.array, y_actl: np.array, num_classes: int) -> np.array:
    res = np.zeros((num_classes, num_classes), dtype=int)
    for pred, actl in zip(y_pred, y_actl):
        pred, actl = int(pred - 1), int(actl - 1)
        res[pred, actl] += 1
    return res

def prediction_accuracy(y_pred: np.array, y_actl: np.array) -> float:
    count, total = 0, y_actl.shape[0]
    for num, obs in enumerate(y_pred):
        if obs == y_actl[num]:
            count += 1
    return count / total

In [21]:
all_spikerates_cue, all_spikerates_hints, all_spikerates_navend = list(), list(), list()
all_goals_cue, all_goals_hints, all_goals_navend = list(), list(), list()
all_trajectories_cue, all_trajectories_hints, all_trajectories_navend = list(), list(), list()

'''
for day in day_list:
    with open(f'data/{day}_data.pkl', 'rb') as file:
        data = pickle.load(file)
        # Drop the first trial from each day's dataset
        all_spikerates_cue.append(data['raw_data']['spikerates_cue'][1:])
        all_spikerates_hints.append(data['raw_data']['spikerates_hints'][1:])
        all_spikerates_navend.append(data['raw_data']['spikerates_navend'][1:])

        all_goals_cue.append(data['raw_data']['goals_cue'][1:])
        all_goals_hints.append(data['raw_data']['goals_hints'][1:])
        all_goals_navend.append(data['raw_data']['goals_navend'][1:])

        all_trajectories_cue.append(data['raw_data']['trajectories_cue'][1:])
        all_trajectories_hints.append(data['raw_data']['trajectories_hints'][1:])
        all_trajectories_navend.append(data['raw_data']['trajectories_navend'][1:])
'''
        
for day in day_list:
    with open(f'data/combined/{day}_data.pkl', 'rb') as file:
        data = pickle.load(file)
        # Drop the first trial from each day's dataset
        all_spikerates_cue.append(data['cue_mean']['spikerates_cue'][1:])
        all_spikerates_hints.append(data['hint_mean']['spikerates_hints'][1:])
        all_spikerates_navend.append(data['navend_mean']['spikerates_navend'][1:])

        all_goals_cue.append(data['cue_mean']['goals_cue'][1:])
        all_goals_hints.append(data['hint_mean']['goals_hints'][1:])
        all_goals_navend.append(data['navend_mean']['goals_navend'][1:])

        all_trajectories_cue.append(data['cue_mean']['trajectories_cue'][1:])
        all_trajectories_hints.append(data['hint_mean']['trajectories_hints'][1:])
        all_trajectories_navend.append(data['navend_mean']['trajectories_navend'][1:])

In [22]:
# Append cells from each session to form a pseudopopulation, append in chronological order on responses grouped by (start, end) goals
# Cue phase spike rates
all_spikerates_cue_per_traj = list()
for i in range(num_sess):
    all_spikerates_cue_per_traj.append(group_by_trajectory(all_spikerates_cue[i], all_trajectories_cue[i]))
trajectories = list(all_spikerates_cue_per_traj[0].keys())
combined_spikerates_cue_per_traj = {key: list() for key in trajectories}
for traj in trajectories:
    trials_data = [sess[traj] for sess in all_spikerates_cue_per_traj]
    num_trials = min(map(len, trials_data))
    for trial in range(num_trials):
        combined_spikerates_cue_per_traj[traj].append(np.hstack([sess[trial] for sess in trials_data]))
    combined_spikerates_cue_per_traj[traj] = np.vstack(combined_spikerates_cue_per_traj[traj])

# Hint view spike rates
all_spikerates_hints_per_traj = list()
for i in range(num_sess):
    all_spikerates_hints_per_traj.append(group_by_trajectory(all_spikerates_hints[i], all_trajectories_hints[i]))
combined_spikerates_hints_per_traj = {key: list() for key in trajectories}
for traj in trajectories:
    trials_data = [sess[traj] for sess in all_spikerates_hints_per_traj]
    num_trials = min(map(len, trials_data))
    for trial in range(num_trials):
        combined_spikerates_hints_per_traj[traj].append(np.hstack([sess[trial] for sess in trials_data]))
    combined_spikerates_hints_per_traj[traj] = np.vstack(combined_spikerates_hints_per_traj[traj])

# Nav end phase spike rates
all_spikerates_navend_per_traj = list()
for i in range(num_sess):
    all_spikerates_navend_per_traj.append(group_by_trajectory(all_spikerates_navend[i], all_trajectories_navend[i]))
combined_spikerates_navend_per_traj = {key: list() for key in trajectories}
for traj in trajectories:
    trials_data = [sess[traj] for sess in all_spikerates_navend_per_traj]
    num_trials = min(map(len, trials_data))
    for trial in range(num_trials):
        combined_spikerates_navend_per_traj[traj].append(np.hstack([sess[trial] for sess in trials_data]))
    combined_spikerates_navend_per_traj[traj] = np.vstack(combined_spikerates_navend_per_traj[traj])

# Filter out cells not in good_cell_labels
cell_filter = np.array([idx for idx, cell in enumerate(all_cell_labels) if cell in set(good_cell_labels)])

for traj, data in combined_spikerates_cue_per_traj.items():
    for n, trial in enumerate(data):
        data[n] = trial[:,cell_filter]

for traj, data in combined_spikerates_hints_per_traj.items():
    for n, trial in enumerate(data):
        data[n] = trial[:,cell_filter]

for traj, data in combined_spikerates_navend_per_traj.items():
    for n, trial in enumerate(data):
        data[n] = trial[:,cell_filter]

# Collect spike rate data into single arrays for each phase
combined_spikerates_cue, combined_goals_cue = list(), list()
for traj, obs in combined_spikerates_cue_per_traj.items():
    combined_spikerates_cue.append(obs)
    combined_goals_cue.extend(obs.shape[0]*[traj[1]])
combined_spikerates_cue = np.vstack(combined_spikerates_cue)
combined_goals_cue = np.array(combined_goals_cue)

combined_spikerates_hints, combined_goals_hints = list(), list()
for traj, obs in combined_spikerates_hints_per_traj.items():
    combined_spikerates_hints.append(obs)
    combined_goals_hints.extend(obs.shape[0]*[traj[1]])
combined_spikerates_hints = np.vstack(combined_spikerates_hints)
combined_goals_hints = np.array(combined_goals_hints)

combined_spikerates_navend, combined_goals_navend = list(), list()
for traj, obs in combined_spikerates_navend_per_traj.items():
    combined_spikerates_navend.append(obs)
    combined_goals_navend.extend(obs.shape[0]*[traj[1]])
combined_spikerates_navend = np.vstack(combined_spikerates_navend)
combined_goals_navend = np.array(combined_goals_navend)

'''
# Shuffle the order of the observations
np.random.seed(0)
shuffle_order = np.random.permutation(combined_goals_cue.shape[0])
combined_spikerates_cue = combined_spikerates_cue[shuffle_order]
combined_goals_cue = combined_goals_cue[shuffle_order]

np.random.seed(0)
shuffle_order = np.random.permutation(combined_goals_hints.shape[0])
combined_spikerates_hints = combined_spikerates_hints[shuffle_order]
combined_goals_hints = combined_goals_hints[shuffle_order]

np.random.seed(0)
shuffle_order = np.random.permutation(combined_goals_navend.shape[0])
combined_spikerates_navend = combined_spikerates_navend[shuffle_order]
combined_goals_navend = combined_goals_navend[shuffle_order]
'''

IndexError: arrays used as indices must be of integer (or boolean) type

In [None]:
# Clean up large memory variables
del all_spikerates_cue
del all_spikerates_hints
del all_spikerates_navend
del all_goals_cue
del all_goals_hints
del all_goals_navend
del all_trajectories_cue
del all_trajectories_hints
del all_trajectories_navend

In [23]:
# Build decoder for 6+1 classes
k_fold = 10
ppop_probmins_goal = np.zeros(k_fold)
ppop_accuracy_goal_cue = np.zeros(k_fold)
ppop_confusion_goal_cue = np.zeros((num_goals, num_goals))
ppop_models_goal = list()
for i in range(k_fold):
    X_train, X_test, y_train, y_test = train_test_split(combined_spikerates_cue, combined_goals_cue, (i, k_fold))
    lda = LinearDiscriminantAnalysis()
    lda.fit(X_train, y_train)
    y_pred = lda.predict(X_test)
    y_pred_proba = lda.predict_proba(X_train)
    y_pred_proba = np.max(y_pred_proba, axis=1)
    ppop_probmins_goal[i] = np.min(y_pred_proba)
    ppop_accuracy_goal_cue[i] = prediction_accuracy(y_pred, y_test)
    ppop_confusion_goal_cue += confusion_matrix(y_pred, y_test, num_goals)
    ppop_models_goal.append(lda)

# Fit models using each fold of data
ppop_confusion_goal_hints, ppop_confusion_goal_navend = np.zeros((num_goals+1, num_goals+1)), np.zeros((num_goals+1, num_goals+1))
ppop_accuracy_goal_hints, ppop_accuracy_goal_navend = np.zeros(k_fold), np.zeros(k_fold)

for i, model in enumerate(ppop_models_goal):
    ppop_models_goal[i] = ResponseClassifier(model)
    ppop_models_goal[i].set_thresh(ppop_probmins_goal[i])

    # Prediction on hint views
    hints_preds = model.predict(combined_spikerates_hints)
    ppop_confusion_goal_hints += confusion_matrix(hints_preds, combined_goals_hints, num_goals+1)
    ppop_accuracy_goal_hints[i] = prediction_accuracy(hints_preds, combined_goals_hints)

    # Prediction on navend phases
    navend_preds = model.predict(combined_spikerates_navend)
    ppop_confusion_goal_navend += confusion_matrix(navend_preds, combined_goals_navend, num_goals+1)
    ppop_accuracy_goal_navend[i] = prediction_accuracy(navend_preds, combined_goals_navend)

NameError: name 'combined_spikerates_cue' is not defined

In [24]:
class DiscreteFrechet(object):
    """
    Calculates the discrete Fréchet distance between two poly-lines using the
    original recursive algorithm
    """

    def __init__(self, dist_func):
        """
        Initializes the instance with a pairwise distance function.
        :param dist_func: The distance function. It must accept two NumPy
        arrays containing the point coordinates (x, y), (lat, long)
        """
        self.dist_func = dist_func
        self.ca = np.array([0.0])

    def distance(self, p: np.ndarray, q: np.ndarray) -> float:
        """
        Calculates the Fréchet distance between poly-lines p and q
        This function implements the algorithm described by Eiter & Mannila
        :param p: Poly-line p
        :param q: Poly-line q
        :return: Distance value
        """

        def calculate(i: int, j: int) -> float:
            """
            Calculates the distance between p[i] and q[i]
            :param i: Index into poly-line p
            :param j: Index into poly-line q
            :return: Distance value
            """
            if self.ca[i, j] > -1.0:
                return self.ca[i, j]

            d = self.dist_func(p[i], q[j])
            if i == 0 and j == 0:
                self.ca[i, j] = d
            elif i > 0 and j == 0:
                self.ca[i, j] = max(calculate(i-1, 0), d)
            elif i == 0 and j > 0:
                self.ca[i, j] = max(calculate(0, j-1), d)
            elif i > 0 and j > 0:
                self.ca[i, j] = max(min(calculate(i-1, j),
                                        calculate(i-1, j-1),
                                        calculate(i, j-1)), d)
            else:
                self.ca[i, j] = np.infty
            return self.ca[i, j]

        n_p = p.shape[0]
        n_q = q.shape[0]
        self.ca = np.zeros((n_p, n_q))
        self.ca.fill(-1.0)
        return calculate(n_p - 1, n_q - 1)
    
def mean_euclidean_distance(arr1, arr2):
    distances = np.zeros_like(arr1)
    for i in range(arr1.shape[0]):
        distances[i] = euclidean_distance(arr1[i,:], arr2[i,:])
    return np.mean(distances)

def euclidean_distance(point1, point2):
    return np.linalg.norm(point1 - point2)


def bootstrapping(sess_data, num_trials):
    # Bootstraps the session data in sess_data to meet the minimum number of required trials given by num_trials
    # sess_data is given as a list, so output a list as well
    # Set random seed for repeatable randomization
    np.random.seed(0)
    idxs = np.random.choice(np.arange(len(sess_data)), num_trials)
    res = list()
    for i in idxs:
        res.append(sess_data[i])
    return res

In [25]:
all_spikerates_cue = list()
all_spikerates_iti_cue = list()
# all_goals_cue = list()
all_trajectories_cue = list()
all_cell_labels = list()

'''
for day in day_list:
    with open(f'data/pcaview/{day}_data.pkl', 'rb') as file:
        data = pickle.load(file)
        # Drop the first trial from each day's dataset
        all_spikerates_cue.append(data['spikerates_cue'][1:])
        all_goals_cue.append(data['goals_cue'][1:])
        all_trajectories_cue.append(data['trajectories_cue'][1:])
'''

for day in day_list:
    with open(f'data/combined/{day}_data.pkl', 'rb') as file:
        data = pickle.load(file)
        # Drop the first trial from each day's dataset
        cue_data = data['cue_100ms']['spikerates_cue'][1:]
        all_spikerates_cue.append(cue_data)
        # Drop the last iti, tag itis to the next trial's cue phase
        iti_data = data['iti_100ms']['spikerates_iti'][:-1]
        # Drop the last navend, tag navends to the next trial's cue phase
        navend_data = data['navend_100ms']['spikerates_navend'][:-1]
        # Drop the first navst
        navst_data = data['navst_100ms']['spikerates_navst'][1:]
        for i in range(len(iti_data)):
            iti_data[i] = np.concatenate([navend_data[i], iti_data[i], cue_data[i], navst_data[i]], axis=0)
        all_spikerates_iti_cue.append(iti_data)
        # all_goals_cue.append(data['cue_100ms']['goals_cue'][1:])
        all_trajectories_cue.append(data['cue_100ms']['trajectories_cue'][1:])
        all_cell_labels.extend(data['cell_labels'])

# num_all_cells = len(all_cell_labels)
num_all_cells = len(good_cell_labels)

In [26]:
# Group trials within each session according to trajectory
all_spikerates_cue_per_traj = list()
for i in range(num_sess):
    all_spikerates_cue_per_traj.append(group_by_trajectory(all_spikerates_cue[i], all_trajectories_cue[i]))

all_spikerates_iti_cue_per_traj = list()
for i in range(num_sess):
    all_spikerates_iti_cue_per_traj.append(group_by_trajectory(all_spikerates_iti_cue[i], all_trajectories_cue[i]))

trajectories = list(all_spikerates_cue_per_traj[0].keys())

# Map out the number of trials per trajectory for each session
num_trials_per_trajectory = {traj: [] for traj in trajectories}
for sess in all_spikerates_cue_per_traj:
    for traj in trajectories:
        num_trials_per_trajectory[traj].append(len(sess[traj]))

num_trials_per_trajectory_per_sess = np.zeros((num_sess, len(trajectories)))
for i, traj in enumerate(trajectories):
    num_trials_per_trajectory_per_sess[:,i] = num_trials_per_trajectory[traj]

# Get target and lower bound (to drop sessions) of number of trials per trajectory
target_num_trials = np.percentile(num_trials_per_trajectory_per_sess, 25, axis=0).astype(int)
min_num_trials = (0.67 * target_num_trials).astype(int)

In [27]:
sum(map(min, num_trials_per_trajectory.values()))

311

In [28]:
### Assemble pseudopopulation responses, grouped by trials of the same trajectory

## Drop all trials beyond the least number of trials out of any session per trajectory
# # Append cells from each session to form a pseudopopulation, append in chronological order on responses grouped by (start, end) goals
# trajectories = list(all_spikerates_cue_per_traj[0].keys())
# combined_spikerates_cue_per_traj = {key: list() for key in trajectories}
# for traj in trajectories:
#     trials_data = [sess[traj] for sess in all_spikerates_cue_per_traj]
#     num_trials = min(map(len, trials_data))
#     for trial in range(num_trials):
#         for i, sess in enumerate(trials_data):
#             if sess[trial].shape[0] > 10:
#                 # Trim the length of the observation to 10x time bins (0 - 1000 ms)
#                 trials_data[i][trial] = sess[trial][:10,:]
#         combined_spikerates_cue_per_traj[traj].append(np.hstack([sess[trial] for sess in trials_data]))


## Bootstrap sessions to some minimum number of trials per trajectory

# Mark out sessions with at least 1 trajectory that falls below required min number of trials
sessions_to_keep = set()
for s in range(num_sess):
    if np.all(num_trials_per_trajectory_per_sess[s,:] >= min_num_trials):
        sessions_to_keep.add(s)

# Update all_cell_labels after dropping sessions
days_to_keep = set([day for d, day in enumerate(day_list) if d in sessions_to_keep])
new_cell_labels = list()
for cell in all_cell_labels:
    if cell[:8] in days_to_keep:
        new_cell_labels.append(cell)
all_cell_labels = new_cell_labels

# Append cells from each session to form a pseudopopulation, and append in chronological order on responses grouped by trajectories
# But rather than dropping trials to the minimum number of observations for the given trajectory type, try to bootstrap up to the same average number of trials per trajectory
combined_spikerates_cue_per_traj = {key: list() for key in trajectories}
num_timebins = 10
for t, traj in enumerate(trajectories):
    # Set number of average trials per trajectory to meet
    trials_data = [sess[traj] for s, sess in enumerate(all_spikerates_cue_per_traj) if s in sessions_to_keep]
    for i, sess in enumerate(trials_data):
        # num_trials_in_sess = len(sess)
        if len(sess) < target_num_trials[t]:
            # Bootstrap trials of this session if the number of obs is less than the required amount
            trials_data[i] = bootstrapping(sess, target_num_trials[t])
    for trial in range(target_num_trials[t]):
        for i, sess in enumerate(trials_data):
            if sess[trial].shape[0] > num_timebins:
                # Trim the length of the observation to 10x time bins (0 - 1000 ms)
                trials_data[i][trial] = sess[trial][:num_timebins,:]
        combined_spikerates_cue_per_traj[traj].append(np.hstack([sess[trial] for sess in trials_data]))

combined_spikerates_iti_cue_per_traj = {key: list() for key in trajectories}
num_timebins_ext = 50
for t, traj in enumerate(trajectories):
    # Set number of average trials per trajectory to meet
    trials_data = [sess[traj] for s, sess in enumerate(all_spikerates_iti_cue_per_traj) if s in sessions_to_keep]
    for i, sess in enumerate(trials_data):
        # num_trials_in_sess = len(sess)
        if len(sess) < target_num_trials[t]:
            # Bootstrap trials of this session if the number of obs is less than the required amount
            trials_data[i] = bootstrapping(sess, target_num_trials[t])
    for trial in range(target_num_trials[t]):
        for i, sess in enumerate(trials_data):
            if sess[trial].shape[0] > num_timebins_ext:
                # Trim the length of the observation to __ time bins
                trials_data[i][trial] = sess[trial][:num_timebins_ext,:]
        combined_spikerates_iti_cue_per_traj[traj].append(np.hstack([sess[trial] for sess in trials_data]))

In [29]:
sum(map(len, combined_spikerates_cue_per_traj.values()))

338

In [30]:
# Filter out cells not in good_cell_labels
cell_filter = np.array([idx for idx, cell in enumerate(all_cell_labels) if cell in set(good_cell_labels)])
for traj, data in combined_spikerates_cue_per_traj.items():
    for n, trial in enumerate(data):
        data[n] = trial[:,cell_filter]
for traj, data in combined_spikerates_iti_cue_per_traj.items():
    for n, trial in enumerate(data):
        data[n] = trial[:,cell_filter]

# Update num_all_cells and all_cell_labels to reflect number of cells in pseudopopulation
num_all_cells = cell_filter.shape[0]
all_cell_labels = [all_cell_labels[i] for i in cell_filter]

In [32]:
# Combine trials into a pseudosession and fit PCA to all trials
combined_pcspikerates_cue = list()
combined_trajectories_cue = list()
for traj in trajectories:
    combined_pcspikerates_cue.extend(combined_spikerates_cue_per_traj[traj])
    combined_trajectories_cue.extend(len(combined_spikerates_cue_per_traj[traj]) * [traj])

n_dims = 3
pca = PCA(n_components=n_dims)
combined_pcspikerates_cue = pca.fit_transform(np.vstack(combined_pcspikerates_cue))

# Regroup into trajectories for plotting
num_timebins = 10
combined_pcspikerates_cue_per_traj = [[np.empty((0, num_timebins, n_dims)) for i in range(num_goals)] for j in range(num_goals)]
for k, (i, j) in enumerate(combined_trajectories_cue):
    i -= 1
    j -= 1
    combined_pcspikerates_cue_per_traj[i][j] = np.concatenate([combined_pcspikerates_cue_per_traj[i][j], combined_pcspikerates_cue[num_timebins*k:num_timebins*(k+1),:].reshape(-1, num_timebins, n_dims)])

'''
# Group cue phase responses according to trajectory, then group by same start goal and same end goal
combined_spikerates_cue_start_goals, combined_spikerates_cue_end_goals = [list() for _ in range(num_goals)], [list() for _ in range(num_goals)]
combined_start_goals_labels, combined_end_goals_labels = [list() for _ in range(num_goals)], [list() for _ in range(num_goals)]
for traj, responses in combined_spikerates_cue_per_traj.items():
    start_goal, end_goal = traj[0] - 1, traj[1] - 1
    combined_spikerates_cue_start_goals[start_goal].extend(responses)
    combined_start_goals_labels[start_goal].extend(len(responses) * [end_goal + 1])
    combined_spikerates_cue_end_goals[end_goal].extend(responses)
    combined_end_goals_labels[end_goal].extend(len(responses) * [start_goal + 1])

# Fit PCA within each group of same starting/same ending goal
pca = PCA(n_components=3)
combined_pcspikerates_cue_start_goals, combined_pcspikerates_cue_end_goals = list(), list()
for i in range(num_goals):
    pc_start_goals = pca.fit_transform(np.vstack(combined_spikerates_cue_start_goals[i]))
    pc_end_goals = pca.fit_transform(np.vstack(combined_spikerates_cue_end_goals[i]))
    combined_pcspikerates_cue_start_goals.append(np.array([pc_start_goals[10*j:10*(j+1),:] for j in range(pc_start_goals.shape[0]//10)]))
    combined_pcspikerates_cue_end_goals.append(np.array([pc_end_goals[10*j:10*(j+1),:] for j in range(pc_end_goals.shape[0]//10)]))

# Regroup into trajectories for plotting
combined_pcspikerates_cue_per_traj = [[np.empty((0, 10, 3)) for i in range(num_goals)] for j in range(num_goals)]
for i in range(num_goals):
    for j, data in enumerate(combined_pcspikerates_cue_start_goals[i]):
        k = combined_start_goals_labels[i][j] - 1
        combined_pcspikerates_cue_per_traj[i][k] = np.concatenate([combined_pcspikerates_cue_per_traj[i][k], data.reshape(-1, 10, 3)], axis=0)
'''

# Plot the average response per trajectory
combined_pcspikerates_cue_per_traj_avg = [[np.empty((0, num_timebins, n_dims)) for i in range(num_goals)] for j in range(num_goals)]
for i in range(num_goals):
    for j, data in enumerate(combined_pcspikerates_cue_per_traj[i]):
        if i == j:
            continue
        combined_pcspikerates_cue_per_traj_avg[i][j] = np.concatenate([combined_pcspikerates_cue_per_traj_avg[i][j], np.mean(data, axis=0).reshape(1, num_timebins, n_dims)], axis=0)

'''
# Subtract the first frame from all subsequent frames to get the displacement from start
combined_pcspikerates_cue_per_traj_ref = [[combined_pcspikerates_cue_per_traj_avg[i][j][:,0,:] for j in range(num_goals)] for i in range(num_goals)]
combined_pcspikerates_cue_per_traj_start = [[combined_pcspikerates_cue_per_traj_avg[i][j].copy() for j in range(num_goals)] for i in range(num_goals)]
for i in range(num_goals):
    for j in range(num_goals):
        if i == j:
            continue
        combined_pcspikerates_cue_per_traj_start[i][j] -= combined_pcspikerates_cue_per_traj_ref[i][j]

# Subtract the last frame from all previous frames to get the displacement from end
combined_pcspikerates_cue_per_traj_ref = [[combined_pcspikerates_cue_per_traj_avg[i][j][:,-1,:] for j in range(num_goals)] for i in range(num_goals)]
combined_pcspikerates_cue_per_traj_end = [[combined_pcspikerates_cue_per_traj_avg[i][j].copy() for j in range(num_goals)] for i in range(num_goals)]
for i in range(num_goals):
    for j in range(num_goals):
        if i == j:
            continue
        combined_pcspikerates_cue_per_traj_end[i][j] -= combined_pcspikerates_cue_per_traj_ref[i][j]
'''

# Calculate Frechet distances between each group, Euclidean distances between start points across groups,
# between end points across groups, and between start and end within group
frechet = DiscreteFrechet(euclidean_distance)
combined_pcspikerates_cue_per_traj_frechdist = [[dict(), dict()] for _ in range(num_goals)]
combined_pcspikerates_cue_per_traj_meaneucdist = [[dict(), dict()] for _ in range(num_goals)]
combined_pcspikerates_cue_per_traj_startdist = [[dict(), dict()] for _ in range(num_goals)]
combined_pcspikerates_cue_per_traj_enddist = [[dict(), dict()] for _ in range(num_goals)]
combined_pcspikerates_cue_per_traj_startenddist = [[dict(), dict()] for _ in range(num_goals)]
for i in range(num_goals):
    traj_pairs = list(range(num_goals))
    traj_pairs.remove(i)
    for j in traj_pairs:
        combined_pcspikerates_cue_per_traj_startenddist[i][0][j] = euclidean_distance(combined_pcspikerates_cue_per_traj_avg[i][j][0,0,:], combined_pcspikerates_cue_per_traj_avg[i][j][0,-1,:])
        combined_pcspikerates_cue_per_traj_startenddist[i][1][j] = euclidean_distance(combined_pcspikerates_cue_per_traj_avg[j][i][0,0,:], combined_pcspikerates_cue_per_traj_avg[j][i][0,-1,:])

    traj_pairs = list(combinations(traj_pairs, 2))
    for (j1, j2) in traj_pairs:
        combined_pcspikerates_cue_per_traj_frechdist[i][0][(j1, j2)] = frechet.distance(combined_pcspikerates_cue_per_traj_avg[i][j1][0,:,:], combined_pcspikerates_cue_per_traj_avg[i][j2][0,:,:])
        combined_pcspikerates_cue_per_traj_frechdist[i][1][(j1, j2)] = frechet.distance(combined_pcspikerates_cue_per_traj_avg[j1][i][0,:,:], combined_pcspikerates_cue_per_traj_avg[j2][i][0,:,:])
        combined_pcspikerates_cue_per_traj_meaneucdist[i][0][(j1, j2)] = mean_euclidean_distance(combined_pcspikerates_cue_per_traj_avg[i][j1][0,:,:], combined_pcspikerates_cue_per_traj_avg[i][j2][0,:,:])
        combined_pcspikerates_cue_per_traj_meaneucdist[i][1][(j1, j2)] = mean_euclidean_distance(combined_pcspikerates_cue_per_traj_avg[j1][i][0,:,:], combined_pcspikerates_cue_per_traj_avg[j2][i][0,:,:])

        combined_pcspikerates_cue_per_traj_startdist[i][0][(j1, j2)] = euclidean_distance(combined_pcspikerates_cue_per_traj_avg[i][j1][0,0,:], combined_pcspikerates_cue_per_traj_avg[i][j2][0,0,:])
        combined_pcspikerates_cue_per_traj_startdist[i][1][(j1, j2)] = euclidean_distance(combined_pcspikerates_cue_per_traj_avg[j1][i][0,0,:], combined_pcspikerates_cue_per_traj_avg[j2][i][0,0,:])
        combined_pcspikerates_cue_per_traj_enddist[i][0][(j1, j2)] = euclidean_distance(combined_pcspikerates_cue_per_traj_avg[i][j1][0,-1,:], combined_pcspikerates_cue_per_traj_avg[i][j2][0,-1,:])
        combined_pcspikerates_cue_per_traj_enddist[i][1][(j1, j2)] = euclidean_distance(combined_pcspikerates_cue_per_traj_avg[j1][i][0,-1,:], combined_pcspikerates_cue_per_traj_avg[j2][i][0,-1,:])


# Calculate separation between clusters per (100 ms) timestep in cue phase
combined_pcspikerates_cue_per_traj_vrc = [[np.zeros(num_timebins), np.zeros(num_timebins)] for _ in range(num_goals)]
combined_pcspikerates_cue_per_traj_dbi = [[np.zeros(num_timebins), np.zeros(num_timebins)] for _ in range(num_goals)]
for i in range(num_goals):
    pcspikerates_start, pcspikerates_end = np.zeros((0, num_timebins, n_dims)), np.zeros((0, num_timebins, n_dims))
    goals_start, goals_end = list(), list()
    for j in range(num_goals):
        if i == j:
            continue
        pcspikerates_start = np.concatenate([pcspikerates_start, combined_pcspikerates_cue_per_traj[i][j]], axis=0)
        pcspikerates_end = np.concatenate([pcspikerates_end, combined_pcspikerates_cue_per_traj[j][i]], axis=0)
        goals_start.extend(combined_pcspikerates_cue_per_traj[i][j].shape[0] * [j])
        goals_end.extend(combined_pcspikerates_cue_per_traj[j][i].shape[0] * [j])
    for t in range(10):
        combined_pcspikerates_cue_per_traj_vrc[i][0][t] = calinski_harabasz_score(pcspikerates_start[:,t,:], np.array(goals_start))
        combined_pcspikerates_cue_per_traj_vrc[i][1][t] = calinski_harabasz_score(pcspikerates_end[:,t,:], np.array(goals_end))
        combined_pcspikerates_cue_per_traj_dbi[i][0][t] = davies_bouldin_score(pcspikerates_start[:,t,:], np.array(goals_start))
        combined_pcspikerates_cue_per_traj_dbi[i][1][t] = davies_bouldin_score(pcspikerates_end[:,t,:], np.array(goals_end))

In [None]:
# Clean up large memory variables
del all_spikerates_cue
del all_spikerates_iti_cue
# del all_goals_cue
del all_trajectories_cue
del all_spikerates_cue_per_traj
del all_spikerates_iti_cue_per_traj

In [33]:
# Array axes: observation, timebin, cell
combined_spikerates_iti_cue = np.empty((0, num_timebins_ext, num_all_cells))
for (s, g) in trajectories:
    combined_spikerates_iti_cue = np.concatenate([combined_spikerates_iti_cue, np.stack(combined_spikerates_iti_cue_per_traj[(s, g)])], axis=0)
combined_trajectories_iti_cue = np.array(combined_trajectories_cue)

# Compute correlation between cells
cell_correlation = list()
for timebin in range(num_timebins_ext):
    cell_correlation.append(np.corrcoef(combined_spikerates_iti_cue[:,timebin,:], rowvar=False))
cell_correlation = np.array(cell_correlation)

# Set correlation matrix diagonal to zero
for t in range(num_timebins_ext):
    for c in range(num_all_cells):
        cell_correlation[t,c,c] = 0

  c /= stddev[:, None]
  c /= stddev[None, :]


In [36]:
print(cell_correlation)

[[[ 0.00000000e+00 -4.03948063e-02 -3.37814319e-04 ...  6.17163919e-02
    1.50695094e-01 -1.27340582e-02]
  [-4.03948063e-02  0.00000000e+00 -9.72349319e-03 ... -3.38628307e-02
    6.23137418e-02 -9.41301605e-03]
  [-3.37814319e-04 -9.72349319e-03  0.00000000e+00 ...  1.88603616e-02
    4.13119170e-03 -3.23873767e-02]
  ...
  [ 6.17163919e-02 -3.38628307e-02  1.88603616e-02 ...  0.00000000e+00
   -1.35810901e-02 -1.06749183e-02]
  [ 1.50695094e-01  6.23137418e-02  4.13119170e-03 ... -1.35810901e-02
    0.00000000e+00  6.55416977e-04]
  [-1.27340582e-02 -9.41301605e-03 -3.23873767e-02 ... -1.06749183e-02
    6.55416977e-04  0.00000000e+00]]

 [[ 0.00000000e+00 -4.09911594e-02  5.70814813e-02 ...  2.50769501e-03
    2.99906583e-03 -2.22159734e-02]
  [-4.09911594e-02  0.00000000e+00  1.91888041e-03 ... -3.31579860e-02
   -6.97752140e-02  1.69640977e-01]
  [ 5.70814813e-02  1.91888041e-03  0.00000000e+00 ...  2.08693968e-03
   -3.54188555e-02 -5.75453040e-02]
  ...
  [ 2.50769501e-03 -3.3

In [35]:
data_DBSCAN = cell_correlation[timebin]
data_standardized = StandardScaler().fit_transform(data_DBSCAN)
eps = 0.55
min_samples = 5
dbscan = DBSCAN(eps=eps, min_samples=min_samples)
clusters_dbscan = dbscan.fit_predict(data_standardized)
print("DBSCAN Cluster Assignments:", clusters_dbscan)
num_clusters_dbscan = len(set(clusters_dbscan)) - (1 if -1 in clusters_dbscan else 0)
print("Number of clusters found:", num_clusters_dbscan)


DBSCAN Cluster Assignments: [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
 -1 -1 -1 -1 -1 -1 -1 -1]
Number of clusters found: 0
