### basic neural analysis for LFP

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import scipy
import scipy.stats as st
import sklearn
from sklearn.neighbors import KernelDensity
from sklearn.decomposition import PCA
import string
import warnings
import pickle
import json

from scipy.ndimage import gaussian_filter1d
import scipy.io

from scipy.signal import welch
import scipy.signal as signal


import os
import glob
import random
from time import time


### define LFP analysis function

In [None]:
def compute_csd(lfp_data, spacing=0.1, smooth_sigma=1):
    """
    Compute CSD using the second spatial derivative of LFP.
    
    Parameters:
    - lfp_data: 2D array (channels x time) of LFP signals.
    - spacing: Distance (mm) between electrodes (default: 0.1 mm).
    - smooth_sigma: Smoothing factor for Gaussian filtering.

    Returns:
    - csd: 2D array (CSD estimate, same shape as LFP).
    """
    # Compute second spatial derivative (Finite Difference Method)
    csd = -np.diff(lfp_data, n=2, axis=0) / spacing**2  

    # Optional: Apply Gaussian smoothing along the spatial (channel) axis
    csd_smoothed = gaussian_filter1d(csd, sigma=smooth_sigma, axis=0)

    return csd_smoothed

In [None]:
def compute_csd_for_windows(lfp_data, window_size, spacing=0.1, smooth_sigma=1):
    """
    Compute CSD for each 5-second window of LFP data.
    
    Parameters:
    - lfp_data: 2D array (channels x time) of LFP signals.
    - window_size: Number of samples per window (for 5 seconds, 5000 samples).
    - spacing: Distance between electrodes in mm.
    - smooth_sigma: Gaussian smoothing factor for CSD.

    Returns:
    - csd_windows: List of CSD arrays, one for each window.
    """
    num_windows = lfp_data.shape[1] // window_size  # Number of windows in the data
    csd_windows = []

    for i in range(num_windows):
        start_idx = i * window_size
        end_idx = start_idx + window_size
        
        # Extract LFP for the current window
        lfp_window = lfp_data[:, start_idx:end_idx]
        
        # Compute CSD for this window
        csd_window = compute_csd(lfp_window, spacing=spacing, smooth_sigma=smooth_sigma)
        csd_windows.append(csd_window)
    
    return csd_windows


## Analyze each session

### define the analysis dates

In [None]:
# define analysis dates and conditions
neural_data_folder = '/gpfs/radev/pi/nandy/jadi_gibbs_data/Marmoset_neural_recording/'
neural_record_conditions = [
                            # '20240508_Kanga_SR',
                            
                            # '20240513_Kanga_chair',
                        
                            # '20240524_Kanga_SR',
    
                            # '20240528_Kanga_chaired',

                            #'20240606_Kanga_MC',
       
                            # '20250204_Dodson_MC_withKoala', 
                            # '20250206_Dodson_MC_withKoala', 
                            # '20250206_Dodson_chaired', # good LFP
                            # '20250207_Dodson_chaired', 
                            '20250210_Dodson_SR_withKoala', # good LFP
                            # '20250210_Dodson_chaired', # good LFP
                            # '20250211_Dodson_MC_withKoala', 
                            # '20250212_Dodson_SR_withKoala',
                            # '20250213_Dodson_MC_withKoala',
                            # '20250214_Dodson_MC_withKoala', # good LFP
                            # '20250217_Dodson_SR_withKoala',
                            # '20250217_Dodson_chaired',
                            # '20250218_Dodson_MC_withKoala',
                            # '20250218_Dodson_chaired',

                           ]
n_record_conditions = np.shape(neural_record_conditions)[0]

# define the information for bhv analysis
dates_list = [
              # '20240508',
            
              # '20240513',
    
              # '20240524',
    
              # '20240528',

              # '20240606',
    
              # '20250204',
              # '20250206',
              # '20250206',
              # '20250207',
              '20250210',
              # '20250210',
              # '20250211', 
              # '20250212',
              # '20250213',
              # '20250214',
              # '20250217',
              # '20250217',
              # '20250218',
              # '20250218',
             ]
session_start_times_camera = [ # need to update - 12042023
                               0, 
                             ]
animal1_fixedorder = ['dodson']
animal2_fixedorder = ['koala']
# animal1_fixedorder = ['dannon']
# animal2_fixedorder = ['kanga']


animal1_filename = "Dodson"
animal2_filename = "Koala"
# animal1_filename = "Dannon"
# animal2_filename = "Kanga"


fs_spikes = 20000
fs_lfp = 1000

kilosortver = 4

# the total time of the analyzed sessions 
total_session_time = 600 # in the unit of s

### get the basic neural data and behavioral data
### LFP analysis

In [None]:
# plot types
plot_allpulljuice = 0
plot_succ_fail_pull = 0


chanMapFile = '/home/ws523/kilisort_spikesorting/Channel-Maps/Neuronexus_whitematter_2x32_kilosort4_new.mat'
# chanMapFile = '/home/ws523/kilisort_spikesorting/Channel-Maps/Neuronexus_whitematter_2x32_kilosort4_mirroredMap_2.mat'


# Load .mat file
chan_map = scipy.io.loadmat(chanMapFile)

# load the neural activity for each condition

for icondition in np.arange(0,n_record_conditions,1):
    
    neural_record_condition = neural_record_conditions[icondition]
    date_tgt = dates_list[icondition]
    
    # not use it for now, pre-process in matlab code first
    neural_record_filename = glob.glob(neural_data_folder+neural_record_condition+'/*.bin')[0]
    
    # # load filtered lfp
    lfp_filt_filename = neural_data_folder+neural_record_condition+'/lfp_filt.txt'
    lfp_filt_data = np.loadtxt(lfp_filt_filename, delimiter=',')
    
    lfp_filt_data[:,0:200*fs_lfp]=np.nan # arbitrually remove some big noise
    # lfp_filt_data[:,500*fs_lfp:]=np.nan # arbitrually remove some big noise
    
    
    #
    ntimewins = np.shape(lfp_filt_data)[1]
    timewins = (np.arange(0,ntimewins,1))/fs_lfp # in the unit of second
    
    
    if 1:
        # load the behavioral data
        try:
            bhv_data_path = "/gpfs/radev/pi/nandy/jadi_gibbs_data/VideoTracker_SocialInter/marmoset_tracking_bhv_data_from_task_code/"+date_tgt+"_"+animal1_filename+"_"+animal2_filename+"/"
            trial_record_json = glob.glob(bhv_data_path +date_tgt+"_"+animal2_filename+"_"+animal1_filename+"_TrialRecord_" + "*.json")
            bhv_data_json = glob.glob(bhv_data_path + date_tgt+"_"+animal2_filename+"_"+animal1_filename+"_bhv_data_" + "*.json")
            session_info_json = glob.glob(bhv_data_path + date_tgt+"_"+animal2_filename+"_"+animal1_filename+"_session_info_" + "*.json")
            ni_data_json = glob.glob(bhv_data_path + date_tgt+"_"+animal2_filename+"_"+animal1_filename+"_ni_data_" + "*.json")
            #
            trial_record = pd.read_json(trial_record_json[0])
            bhv_data = pd.read_json(bhv_data_json[0])
            session_info = pd.read_json(session_info_json[0])
            #
            with open(ni_data_json[0]) as f:
                for line in f:
                    ni_data=json.loads(line)           
        except:
            bhv_data_path = "/gpfs/radev/pi/nandy/jadi_gibbs_data/VideoTracker_SocialInter/marmoset_tracking_bhv_data_from_task_code/"+date_tgt+"_"+animal1_filename+"_"+animal2_filename+"/"
            trial_record_json = glob.glob(bhv_data_path + date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_TrialRecord_" + "*.json")
            bhv_data_json = glob.glob(bhv_data_path + date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_bhv_data_" + "*.json")
            session_info_json = glob.glob(bhv_data_path + date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_session_info_" + "*.json")
            ni_data_json = glob.glob(bhv_data_path + date_tgt+"_"+animal1_filename+"_"+animal2_filename+"_ni_data_" + "*.json")
            #
            trial_record = pd.read_json(trial_record_json[0])
            bhv_data = pd.read_json(bhv_data_json[0])
            session_info = pd.read_json(session_info_json[0])
            #
            with open(ni_data_json[0]) as f:
                for line in f:
                    ni_data=json.loads(line)

        # get animal info from the session information
        animal1_frombhv = session_info['lever1_animal'][0].lower()
        animal2_frombhv = session_info['lever2_animal'][0].lower()

        # get task type and cooperation threshold
        try:
            coop_thres = session_info["pulltime_thres"][0]
            tasktype = session_info["task_type"][0]
        except:
            coop_thres = 0
            tasktype = 1

            
        # clean up the trial_record
        warnings.filterwarnings('ignore')
        trial_record_clean = pd.DataFrame(columns=trial_record.columns)
        for itrial in np.arange(0,np.max(trial_record['trial_number']),1):
            # trial_record_clean.loc[itrial] = trial_record[trial_record['trial_number']==itrial+1].iloc[[0]]
            trial_record_clean = trial_record_clean.append(trial_record[trial_record['trial_number']==itrial+1].iloc[[0]])
        trial_record_clean = trial_record_clean.reset_index(drop = True)

        # change bhv_data time to the absolute time
        time_points_new = pd.DataFrame(np.zeros(np.shape(bhv_data)[0]),columns=["time_points_new"])
        for itrial in np.arange(0,np.max(trial_record_clean['trial_number']),1):
            ind = bhv_data["trial_number"]==itrial+1
            new_time_itrial = bhv_data[ind]["time_points"] + trial_record_clean["trial_starttime"].iloc[itrial]
            time_points_new["time_points_new"][ind] = new_time_itrial
        bhv_data["time_points"] = time_points_new["time_points_new"]
        bhv_data = bhv_data[bhv_data["time_points"] != 0]
        #
        # fix misslabeled successful trial
        for itrial in np.arange(0,np.max(trial_record_clean['trial_number'])-1,1):
            #
            # initialize  
            ind = bhv_data["trial_number"]==itrial+2
            bhv_data_itrial = bhv_data[ind]
            #
            if itrial == 0:
                ind = bhv_data["trial_number"]==itrial+1
                bhv_data_pre_itrial = bhv_data[ind]
                #
                bhv_data_fixed = bhv_data_pre_itrial
            #
            # use the bhv_data_fixed for the bhv_data_pre_itrial
            #
            ind_pre_itrial = bhv_data_fixed["trial_number"]==itrial+1
            bhv_data_pre_itrial = bhv_data_fixed[ind_pre_itrial]

            # examine the itrial
            # the misslabeled successful trial, but miss one pull, need to fix 
            if np.sum((bhv_data_itrial['behavior_events']==1)|(bhv_data_itrial['behavior_events']==2))==1:
                # 
                # modify when the trial starts
                bhv_data_pre_itrial['trial_number'][bhv_data_pre_itrial['behavior_events']==9] = itrial+2
                bhv_data_pre_itrial['behavior_events'][bhv_data_pre_itrial['behavior_events']==9] = 0
                #
                # if animal 2 pull is missing
                if np.sum((bhv_data_itrial['behavior_events']==1))==1:
                    # modify the animal 2 pull
                    bhv_data_itrial['behavior_events'][bhv_data_itrial['behavior_events']==0] = 2
                    # modify the end of previous trial
                    bhv_data_pre_itrial['behavior_events'].iloc[np.where(bhv_data_pre_itrial['behavior_events']==2)[0][-1]]=9

                # if animal 1 pull is missing
                elif np.sum((bhv_data_itrial['behavior_events']==2))==1:
                    # modify the animal 1 pull
                    bhv_data_itrial['behavior_events'][bhv_data_itrial['behavior_events']==0] = 1
                    # modify the end of previous trial
                    bhv_data_pre_itrial['behavior_events'].iloc[np.where(bhv_data_pre_itrial['behavior_events']==1)[0][-1]]=9
                #
                bhv_data_fixed[ind_pre_itrial] = bhv_data_pre_itrial
                bhv_data_fixed = pd.concat([bhv_data_fixed,bhv_data_itrial])
                #
            else:

                bhv_data_fixed = pd.concat([bhv_data_fixed,bhv_data_itrial])
                bhv_data_old = bhv_data.copy()
                
        bhv_data = bhv_data_fixed

        # analyze behavior results
        pullid = np.array(bhv_data[(bhv_data['behavior_events']==1)|(bhv_data['behavior_events']==2)]["behavior_events"])
        pulltime = np.array(bhv_data[(bhv_data['behavior_events']==1)|(bhv_data['behavior_events']==2)]["time_points"])
        #
        juiceid = np.array(bhv_data[(bhv_data['behavior_events']==3)|(bhv_data['behavior_events']==4)]["behavior_events"])
        juicetime = np.array(bhv_data[(bhv_data['behavior_events']==3)|(bhv_data['behavior_events']==4)]["time_points"])
        #
        # successful trial
        trial_num_succ = np.array(trial_record_clean[trial_record_clean['rewarded']>0]['trial_number'])
        bhv_data_succ = bhv_data[np.isin(bhv_data['trial_number'],trial_num_succ)]
        #
        pullid_succ = np.array(bhv_data_succ[(bhv_data_succ['behavior_events']==1)|(bhv_data_succ['behavior_events']==2)]["behavior_events"])
        pulltime_succ = np.array(bhv_data_succ[(bhv_data_succ['behavior_events']==1)|(bhv_data_succ['behavior_events']==2)]["time_points"])
        #
        juiceid_succ = np.array(bhv_data_succ[(bhv_data_succ['behavior_events']==3)|(bhv_data_succ['behavior_events']==4)]["behavior_events"])
        juicetime_succ = np.array(bhv_data_succ[(bhv_data_succ['behavior_events']==3)|(bhv_data_succ['behavior_events']==4)]["time_points"])
        # 
        #failed trial
        trial_num_fail = np.array(trial_record_clean[trial_record_clean['rewarded']==0]['trial_number'])
        bhv_data_fail = bhv_data[np.isin(bhv_data['trial_number'],trial_num_fail)]
        #
        pullid_fail = np.array(bhv_data_fail[(bhv_data_fail['behavior_events']==1)|(bhv_data_fail['behavior_events']==2)]["behavior_events"])
        pulltime_fail = np.array(bhv_data_fail[(bhv_data_fail['behavior_events']==1)|(bhv_data_fail['behavior_events']==2)]["time_points"])
        #
        juiceid_fail = np.array(bhv_data_fail[(bhv_data_fail['behavior_events']==3)|(bhv_data_fail['behavior_events']==4)]["behavior_events"])
        juicetime_fail = np.array(bhv_data_fail[(bhv_data_fail['behavior_events']==3)|(bhv_data_fail['behavior_events']==4)]["time_points"])

        # 
        if animal1_frombhv == animal1_fixedorder[0]:
            # all pulls 
            npulls1 = sum(pullid==1)
            pulltimes1 = pulltime[pullid==1]
            npulls2 = sum(pullid==2)
            pulltimes2 = pulltime[pullid==2]
            # all juices
            njuice1 = sum(juiceid==3)
            juicetimes1 = juicetime[juiceid==3]
            njuice2 = sum(juiceid==4)
            juicetimes2 = juicetime[juiceid==4]
            # all succ pulls 
            npulls1_succ = sum(pullid_succ==1)
            pulltimes1_succ = pulltime_succ[pullid_succ==1]
            npulls2_succ = sum(pullid_succ==2)
            pulltimes2_succ = pulltime_succ[pullid_succ==2]
            # all fail pulls 
            npulls1_fail = sum(pullid_fail==1)
            pulltimes1_fail = pulltime_fail[pullid_fail==1]
            npulls2_fail = sum(pullid_fail==2)
            pulltimes2_fail = pulltime_fail[pullid_fail==2]
            #
        elif animal1_frombhv == animal2_fixedorder[0]:
            # all pulls 
            npulls1 = sum(pullid==2)
            pulltimes1 = pulltime[pullid==2]
            npulls2 = sum(pullid==1)
            pulltimes2 = pulltime[pullid==1]
            # all juices
            njuice1 = sum(juiceid==4)
            juicetimes1 = juicetime[juiceid==4]
            njuice2 = sum(juiceid==3)
            juicetimes2 = juicetime[juiceid==3]
            # all succ pulls 
            npulls1_succ = sum(pullid_succ==2)
            pulltimes1_succ = pulltime_succ[pullid_succ==2]
            npulls2_succ = sum(pullid_succ==1)
            pulltimes2_succ = pulltime_succ[pullid_succ==1]
            # all fail pulls 
            npulls1_fail = sum(pullid_fail==2)
            pulltimes1_fail = pulltime_fail[pullid_fail==2]
            npulls2_fail = sum(pullid_fail==1)
            pulltimes2_fail = pulltime_fail[pullid_fail==1]


        # session starting time compared with the neural recording
        session_start_time_niboard_offset = ni_data['session_t0_offset'] # in the unit of second
        neural_start_time_niboard_offset = ni_data['trigger_ts'][0]['elapsed_time'] # in the unit of second
        neural_start_time_session_start_offset = neural_start_time_niboard_offset-session_start_time_niboard_offset


    
    else:
        neural_start_time_session_start_offset = 0
    
    # align the LFP recording time stamps
    LFP_timewins_aligned = timewins+neural_start_time_session_start_offset # in the unit of second
    
    
    
 
    # # load spike sorting results
    print('load spike data for '+neural_record_condition)
    if kilosortver == 2:
        spike_time_file = neural_data_folder+neural_record_condition+'/Kilosort/spike_times.npy'
        spike_time_data = np.load(spike_time_file)
    elif kilosortver == 4:
        spike_time_file = neural_data_folder+neural_record_condition+'/kilosort4_6500HzNotch/spike_times.npy'
        # spike_time_file = neural_data_folder+neural_record_condition+'/kilosort4_6500HzNotch_mirroredMap/spike_times.npy'
        spike_time_data = np.load(spike_time_file)
    # 
    # align the FR recording time stamps
    spike_time_data = spike_time_data + fs_spikes*neural_start_time_session_start_offset
    # down-sample the spike recording resolution to the same as the lfp
    spike_time_data = spike_time_data/fs_spikes*fs_lfp
    spike_time_data = np.round(spike_time_data)/fs_lfp
    #
    if kilosortver == 2:
        spike_clusters_file = neural_data_folder+neural_record_condition+'/Kilosort/spike_clusters.npy'
        spike_clusters_data = np.load(spike_clusters_file)
        spike_channels_data = np.copy(spike_clusters_data)
    elif kilosortver == 4:
        spike_clusters_file = neural_data_folder+neural_record_condition+'/kilosort4_6500HzNotch/spike_clusters.npy'
        # spike_clusters_file = neural_data_folder+neural_record_condition+'/kilosort4_6500HzNotch_mirroredMap/spike_clusters.npy'
        spike_clusters_data = np.load(spike_clusters_file)
        spike_channels_data = np.copy(spike_clusters_data)
    #
    if kilosortver == 2:
        channel_maps_file = neural_data_folder+neural_record_condition+'/Kilosort/channel_map.npy'
        channel_maps_data = np.load(channel_maps_file)
    elif kilosortver == 4:
        channel_maps_file = neural_data_folder+neural_record_condition+'/kilosort4_6500HzNotch/channel_map.npy'
        # channel_maps_file = neural_data_folder+neural_record_condition+'/kilosort4_6500HzNotch_mirroredMap/channel_map.npy'
        channel_maps_data = np.load(channel_maps_file)
    #
    if kilosortver == 2:
        channel_pos_file = neural_data_folder+neural_record_condition+'/Kilosort/channel_positions.npy'
        channel_pos_data = np.load(channel_pos_file)
    elif kilosortver == 4:
        channel_pos_file = neural_data_folder+neural_record_condition+'/kilosort4_6500HzNotch/channel_positions.npy'
        # channel_pos_file = neural_data_folder+neural_record_condition+'/kilosort4_6500HzNotch_mirroredMap/channel_positions.npy'
        channel_pos_data = np.load(channel_pos_file)
    #
    if kilosortver == 2:
        clusters_info_file = neural_data_folder+neural_record_condition+'/Kilosort/cluster_info.tsv'
        clusters_info_data = pd.read_csv(clusters_info_file,sep="\t")
    elif kilosortver == 4:
        clusters_info_file = neural_data_folder+neural_record_condition+'/kilosort4_6500HzNotch/cluster_info.tsv'
        # clusters_info_file = neural_data_folder+neural_record_condition+'/kilosort4_6500HzNotch_mirroredMap/cluster_info.tsv'
        clusters_info_data = pd.read_csv(clusters_info_file,sep="\t")
    #
    # only get the spikes that are manually checked
    try:
        good_clusters = clusters_info_data[(clusters_info_data.group=='good')|(clusters_info_data.group=='mua')]['cluster_id'].values
    except:
        good_clusters = clusters_info_data[(clusters_info_data.group=='good')|(clusters_info_data.group=='mua')]['id'].values
    #
    clusters_info_data = clusters_info_data[~pd.isnull(clusters_info_data.group)]
    #
    spike_time_data = spike_time_data[np.isin(spike_clusters_data,good_clusters)]
    spike_channels_data = spike_channels_data[np.isin(spike_clusters_data,good_clusters)]
    spike_clusters_data = spike_clusters_data[np.isin(spike_clusters_data,good_clusters)]
    #
    nclusters = np.shape(clusters_info_data)[0]
    #
    for icluster in np.arange(0,nclusters,1):
        try:
            cluster_id = clusters_info_data['id'].iloc[icluster]
        except:
            cluster_id = clusters_info_data['cluster_id'].iloc[icluster]
        spike_channels_data[np.isin(spike_clusters_data,cluster_id)] = clusters_info_data['ch'].iloc[icluster]   
    #
    # 
    # get the channel to depth information, change 2 shanks to 1 shank 
    try:
        channel_depth=np.hstack([channel_pos_data[channel_pos_data[:,0]==0,1]*2,channel_pos_data[channel_pos_data[:,0]==1,1]*2+1])
        # channel_depth=np.hstack([channel_pos_data[channel_pos_data[:,0]==0,1],channel_pos_data[channel_pos_data[:,0]==31.2,1]])            
        # channel_to_depth = np.vstack([channel_maps_data.T[0],channel_depth])
        channel_to_depth = np.vstack([channel_maps_data.T,channel_depth])
    except:
        channel_depth=np.hstack([channel_pos_data[channel_pos_data[:,0]==0,1],channel_pos_data[channel_pos_data[:,0]==31.2,1]])            
        # channel_to_depth = np.vstack([channel_maps_data.T[0],channel_depth])
        channel_to_depth = np.vstack([channel_maps_data.T,channel_depth])
        channel_to_depth[1] = channel_to_depth[1]/30-64 # make the y axis consistent
    
    

#### plot the mean LFP across all channels to check the noise

In [None]:
# Compute the mean LFP across channels (axis=0 means averaging across rows, i.e., channels)
mean_lfp = np.nanmean(lfp_filt_data, axis=0)

# Create a time axis based on the sampling rate
time_axis = np.arange(mean_lfp.shape[0]) / fs_lfp  # fs_lfp is the sampling rate (1000 Hz)

# Plot the mean LFP trace
plt.figure(figsize=(10, 4))
plt.plot(time_axis, mean_lfp, color='black', linewidth=1)

# Labels and title
plt.xlabel("Time (s)")
plt.ylabel("Mean LFP (µV)")
plt.title("Mean LFP Across Channels Over Entire Recording")

# Improve visualization
plt.grid(True, linestyle="--", alpha=0.6)
plt.tight_layout()

plt.show()

#### define and remove bad channels

In [None]:

# Compute variance of each channel
channel_variances = np.nanvar(lfp_filt_data, axis=1)

# Define thresholds for bad channels
low_threshold = np.percentile(channel_variances, 5)  # Bottom 5% (flat signals)
high_threshold = np.percentile(channel_variances, 93)  # Top 5% (too noisy)

# Identify bad channels
# bad_channels = np.where((channel_variances < low_threshold) | (channel_variances > high_threshold))[0]
bad_channels = np.where((channel_variances > high_threshold))[0]

# Replace bad channels with NaN
# lfp_filt_data[bad_channels, :] = np.nan

# Print bad channels
print("Identified bad channels:", bad_channels)

# Plot channel variances for inspection
plt.figure(figsize=(8, 4))
plt.plot(channel_variances, marker='o', linestyle='-', color='b')
plt.axhline(low_threshold, color='r', linestyle='--', label='Low Variance Threshold')
plt.axhline(high_threshold, color='r', linestyle='--', label='High Variance Threshold')
plt.xlabel("Channel ID")
plt.ylabel("Variance")
plt.title("LFP Channel Variance Analysis; bad channels"+str(bad_channels))
plt.legend()

fig_name = neural_data_folder+neural_record_condition+'/'+neural_record_condition+'_channels_variance_for_badones.pdf'
plt.savefig(fig_name, format="pdf", bbox_inches="tight")

plt.show()


#### match the LFP signals to the maps

In [None]:

# Sort indices based on the descending order of channel_to_depth[1]
sorted_indices = np.argsort(channel_to_depth[1])[::-1]

# Reorder channel_to_depth[0] based on sorted indices
channel_to_depth_sorted = channel_to_depth[:, sorted_indices]

lfp_data_mapmatch_1shank = lfp_filt_data[channel_to_depth_sorted[0].astype(int), :]

# lfp match to the channel maps
lfp_data_mapmatch = lfp_filt_data[chan_map['chanMap'].flatten()-1,:]


### plot the correlation confusion table

In [None]:
# Identify bad channels (fully NaN)
bad_channels = np.all(np.isnan(lfp_filt_data), axis=1)

# Keep only good channels
good_channels = ~bad_channels
lfp_good = lfp_filt_data[good_channels, :]

# Find time points where all good channels have valid data (no NaNs)
valid_timepoints = np.all(~np.isnan(lfp_good), axis=0)

# Keep only valid time points
lfp_clean = lfp_good[:, valid_timepoints]

# Compute correlation matrix on the cleaned LFP data
corr_matrix_good = np.corrcoef(lfp_clean)

# Create a full correlation matrix with NaNs for bad channels
num_channels = lfp_filt_data.shape[0]
corr_matrix_full = np.full((num_channels, num_channels), np.nan)
corr_matrix_full[np.ix_(good_channels, good_channels)] = corr_matrix_good

# Mask NaNs for better visualization
masked_corr = np.ma.masked_invalid(corr_matrix_full)

# Plot heatmap
plt.figure(figsize=(8, 6))
im = plt.imshow(masked_corr, cmap="jet", aspect="auto")

# Set colorbar limits based on valid values
valid_corr_values = corr_matrix_good[~np.isnan(corr_matrix_good)]
im.set_clim(np.min(valid_corr_values), np.max(valid_corr_values))
# im.set_clim(0.8, 1)

plt.colorbar(label="Correlation Coefficient")
plt.xlabel("Channel ID")
plt.ylabel("Channel ID")
plt.title("Channel-to-Channel LFP Correlation Matrix (Filtered for Valid Periods)")

fig_name = neural_data_folder+neural_record_condition+'/'+neural_record_condition+'_ch_to_ch_LFP_corr_noMapMatch.pdf'
plt.savefig(fig_name, format="pdf", bbox_inches="tight")

plt.show()


In [None]:
# Identify bad channels (fully NaN)
bad_channels = np.all(np.isnan(lfp_data_mapmatch), axis=1)

# Keep only good channels
good_channels = ~bad_channels
lfp_good = lfp_data_mapmatch[good_channels, :]

# Find time points where all good channels have valid data (no NaNs)
valid_timepoints = np.all(~np.isnan(lfp_good), axis=0)

# Keep only valid time points
lfp_clean = lfp_good[:, valid_timepoints]

# Compute correlation matrix on the cleaned LFP data
corr_matrix_good = np.corrcoef(lfp_clean)

# Create a full correlation matrix with NaNs for bad channels
num_channels = lfp_data_mapmatch.shape[0]
corr_matrix_full = np.full((num_channels, num_channels), np.nan)
corr_matrix_full[np.ix_(good_channels, good_channels)] = corr_matrix_good

# Mask NaNs for better visualization
masked_corr = np.ma.masked_invalid(corr_matrix_full)

# Plot heatmap
plt.figure(figsize=(8, 6))
im = plt.imshow(masked_corr, cmap="jet", aspect="auto")

# Set colorbar limits based on valid values
valid_corr_values = corr_matrix_good[~np.isnan(corr_matrix_good)]
im.set_clim(np.min(valid_corr_values), np.max(valid_corr_values))
# im.set_clim(0.8, 1)

plt.colorbar(label="Correlation Coefficient")
plt.xlabel("Channel ID")
plt.ylabel("Channel ID")
plt.title("Channel-to-Channel LFP Correlation Matrix (Bad Channels Removed); match the channel map")


fig_name = neural_data_folder+neural_record_condition+'/'+neural_record_condition+'_ch_to_ch_LFP_corr_withMapMatch_mirroredMap.pdf'
plt.savefig(fig_name, format="pdf", bbox_inches="tight")

plt.show()


In [None]:
# Identify bad channels (fully NaN)
bad_channels = np.all(np.isnan(lfp_data_mapmatch_1shank), axis=1)

# Keep only good channels
good_channels = ~bad_channels
lfp_good = lfp_data_mapmatch_1shank[good_channels, :]

# Find time points where all good channels have valid data (no NaNs)
valid_timepoints = np.all(~np.isnan(lfp_good), axis=0)

# Keep only valid time points
lfp_clean = lfp_good[:, valid_timepoints]

# Compute correlation matrix on the cleaned LFP data
corr_matrix_good = np.corrcoef(lfp_clean)

# Create a full correlation matrix with NaNs for bad channels
num_channels = lfp_data_mapmatch_1shank.shape[0]
corr_matrix_full = np.full((num_channels, num_channels), np.nan)
corr_matrix_full[np.ix_(good_channels, good_channels)] = corr_matrix_good

# Mask NaNs for better visualization
masked_corr = np.ma.masked_invalid(corr_matrix_full)

# Plot heatmap
plt.figure(figsize=(8, 6))
im = plt.imshow(masked_corr, cmap="jet", aspect="auto")

# Set colorbar limits based on valid values
valid_corr_values = corr_matrix_good[~np.isnan(corr_matrix_good)]
im.set_clim(np.min(valid_corr_values), np.max(valid_corr_values))
# im.set_clim(0.6, 0.7)

plt.colorbar(label="Correlation Coefficient")
plt.xlabel("Channel ID")
plt.ylabel("Channel ID")
plt.title("Channel-to-Channel LFP Correlation Matrix (Bad Channels Removed); match the channel map and combine to 1 shank")

fig_name = neural_data_folder+neural_record_condition+'/'+neural_record_condition+'_ch_to_ch_LFP_corr_withMapMatch1shank_mirroredMap.pdf'
plt.savefig(fig_name, format="pdf", bbox_inches="tight")

plt.show()

### plot the beta and gamma power 

In [None]:
# Define parameters
fs = 1000  # Sampling rate in Hz
nperseg = 500  # Segment length for Welch’s method

# lfp_tgt = lfp_filt_data
lfp_tgt = lfp_data_mapmatch_1shank
# lfp_tgt = lfp_data_mapmatch

lfp_tgt = lfp_tgt[:,~np.isnan(np.sum(lfp_tgt,axis=0))]

# Compute PSD using Welch’s method
freqs, psd = signal.welch(lfp_tgt, fs=fs, nperseg=nperseg, axis=1)  # Shape (num_channels, freq_bins)

# Normalize by max power across channels for each frequency
relative_power = psd / np.nanmax(psd, axis=0, keepdims=True)  # Shape (num_channels, freq_bins)

# Plot heatmap
plt.figure(figsize=(10, 6))
plt.imshow(relative_power, aspect='auto', cmap='jet', extent=[freqs[0], freqs[-1], 0, lfp_tgt.shape[0]], origin='lower')

# Set axis labels
plt.colorbar(label="Relative Power (0-1)")
plt.xlabel("Frequency (Hz)")
plt.ylabel("Channel ID")
plt.title("Relative Power Across Channels and Frequencies")

# Save figure as PDF
fig_name = neural_data_folder+neural_record_condition+'/'+neural_record_condition+'_lfp_power_spectrum_heatmap_withMapMatch1shank_mirroredMap.pdf'
plt.savefig(fig_name, format="pdf", bbox_inches="tight")

plt.show()

In [None]:
bad_channels

### plot the behavioral events aligned mean LFP

In [None]:
if 1:
    # Define parameters
    time_window = 3  # Time range before and after pull event (3s each)
    samples_window = time_window * fs_lfp  # 3000 samples per side (total 6000)
    time_axis = np.linspace(-3, 3, 2 * samples_window)  # Time from -3s to +3s

    # Initialize array to store LFP segments (shape: channels × time)
    lfp_avg_matrix = np.zeros((64, 2 * samples_window))

    # Loop over channels
    for ch in range(64):
        lfp_segments = []  # Store LFP trials for averaging

        for pull_time in pulltimes2:
            # Find the index of the pull event in LFP_timewins_aligned
            pull_index = np.argmin(np.abs(LFP_timewins_aligned - pull_time))

            # Ensure window is within the data range
            start_idx = pull_index - samples_window
            end_idx = pull_index + samples_window

            if start_idx >= 0 and end_idx < lfp_data_mapmatch_1shank.shape[1]:
                # Extract LFP segment and store
                lfp_segments.append(lfp_data_mapmatch_1shank[ch, start_idx:end_idx])

        # Compute the mean LFP across all pull1 events for this channel
        if lfp_segments:
            lfp_avg_matrix[ch, :] = np.nanmean(np.array(lfp_segments), axis=0)
        else:
            lfp_avg_matrix[ch, :] = np.nan  # Assign NaN if no valid data

    # Create figure
    plt.figure(figsize=(12, 8))

    # Plot each channel's mean LFP trace with an offset for visibility
    offset = 20  # Increase the offset value to separate traces more
    for ch in range(64):
        plt.plot(time_axis, lfp_avg_matrix[ch, :] + ch * offset, label=f"Channel {ch+1}")  # Add an offset for separation

    # Add vertical line for the pull event
    plt.axvline(0, color='red', linestyle='--', label="Pull Event")

    # Labels and Title
    plt.xlabel("Time (s)")
    plt.ylabel("LFP Signal (with offset per channel)")
    plt.title("Mean LFP Trace for Each Channel (with offsets)")

    # Remove legend if not needed
    # plt.legend(loc='upper right')

    plt.tight_layout()
    plt.show()

### plot the behavioral events aligned mean LFP CSD

In [None]:
if 1:
    # Define parameters
    time_window = 3  # Time range before and after pull event (3s each)
    samples_window = time_window * fs_lfp  # 3000 samples per side (total 6000)
    time_axis = np.linspace(-3, 3, 2 * samples_window)  # Time from -3s to +3s

    # Initialize array to store LFP segments (shape: channels × time)
    lfp_avg_matrix = np.zeros((64, 2 * samples_window))

    # Loop over channels to compute LFP segments
    for ch in range(64):
        lfp_segments = []  # Store LFP trials for averaging

        for pull_time in pulltimes1:
            # Find the index of the pull event in LFP_timewins_aligned
            pull_index = np.argmin(np.abs(LFP_timewins_aligned - pull_time))

            # Ensure window is within the data range
            start_idx = pull_index - samples_window
            end_idx = pull_index + samples_window

            if start_idx >= 0 and end_idx < lfp_data_mapmatch_1shank.shape[1]:
                # Extract LFP segment
                lfp_segment = lfp_data_mapmatch_1shank[ch, start_idx:end_idx]
                lfp_segments.append(lfp_segment)

        # Compute the mean LFP across all pull1 events for this channel
        if lfp_segments:
            lfp_avg_matrix[ch, :] = np.nanmean(np.array(lfp_segments), axis=0)
        else:
            lfp_avg_matrix[ch, :] = np.nan  # Assign NaN if no valid data

    # Calculate the CSD by taking the second spatial derivative of the LFP
    # Use finite differences: CSD = LFP(i+1) - 2*LFP(i) + LFP(i-1)
    # Skip the first and last channel to avoid boundary issues
    csd_matrix = np.zeros_like(lfp_avg_matrix)

    # Compute the CSD for each time point across all channels
    for t in range(lfp_avg_matrix.shape[1]):  # Loop over time points
        for ch in range(1, 63):  # Loop over channels (skip first and last)
            csd_matrix[ch, t] = lfp_avg_matrix[ch + 1, t] - 2 * lfp_avg_matrix[ch, t] + lfp_avg_matrix[ch - 1, t]

    # Plot the CSD as a heatmap
    plt.figure(figsize=(10, 8))

    # Create a heatmap of the CSD
    plt.imshow(csd_matrix[1:63, :], aspect='auto', cmap='jet', extent=[time_axis[0], time_axis[-1], 64, 1])
    plt.colorbar(label="CSD (µA/cm²)")
    plt.axvline(0, color='red', linestyle='--')  # Mark pull event at t=0

    # Add labels and title
    plt.xlabel("Time (s)")
    plt.ylabel("Channel")
    plt.title("CSD Heatmap (-3s to +3s around Pull1)")

    # Show the plot
    plt.tight_layout()
    plt.show()

### plot the behavioral events aligned mean LFP aligned at a chosen spike

In [None]:
if 1:
    
    spike_time_tgt = spike_time_data[spike_channels_data==37]
    
    # Define parameters
    time_window = 0.1  # Time range before and after spike (0.1s each)
    samples_window = int(time_window * fs_lfp)  # 100 samples per side (total 200)
    time_axis = np.linspace(-0.1, 0.1, 2 * samples_window)  # Time from -0.1s to +0.1s

    # Select only 200 random spike timestamps
    num_spikes = min(200, len(spike_time_tgt))
    selected_spikes = np.random.choice(spike_time_tgt, num_spikes, replace=False)

    # Initialize array to store LFP segments (shape: channels × time)
    lfp_avg_matrix = np.zeros((64, 2 * samples_window))

    # Extract LFP segments aligned to selected spikes
    for ch in range(64):
        lfp_segments = []  

        for spike_time in selected_spikes:
            # Find the index of the spike event in LFP_timewins_aligned
            spike_index = np.argmin(np.abs(LFP_timewins_aligned - spike_time))

            # Ensure window is within data range
            start_idx = spike_index - samples_window
            end_idx = spike_index + samples_window

            if start_idx >= 0 and end_idx < lfp_data_mapmatch_1shank.shape[1]:
                # Extract LFP segment and store
                lfp_segments.append(lfp_data_mapmatch_1shank[ch, start_idx:end_idx])

        # Compute the mean LFP across selected spike events for this channel
        if lfp_segments:
            lfp_avg_matrix[ch, :] = np.nanmean(np.array(lfp_segments), axis=0)
        else:
            lfp_avg_matrix[ch, :] = np.nan  # Assign NaN if no valid data

    # Compute CSD (Second Spatial Derivative)
    csd_matrix = np.zeros_like(lfp_avg_matrix)
    for ch in range(1, 63):  # Exclude first and last channels
        csd_matrix[ch, :] = lfp_avg_matrix[ch - 1, :] - 2 * lfp_avg_matrix[ch, :] + lfp_avg_matrix[ch + 1, :]

    # --- PLOTTING ---

    fig, axes = plt.subplots(2, 1, figsize=(10, 12), gridspec_kw={'height_ratios': [2, 1]})

    # Plot LFP traces with offset
    offset = 0.8  
    for ch in range(64):
        axes[0].plot(time_axis, lfp_avg_matrix[ch, :] + ch * offset)  

    axes[0].axvline(0, color='red', linestyle='--')  # Spike event marker
    axes[0].set_xlabel("Time (s)")
    axes[0].set_ylabel("LFP Signal (with offset per channel)")
    axes[0].set_title(f"LFP Aligned to {num_spikes} Randomly Selected Spike Events")

    # Plot CSD as a heatmap
    im = axes[1].imshow(csd_matrix, aspect='auto', extent=[-0.1, 0.1, 64, 1], cmap='jet', interpolation='none')
    axes[1].axvline(0, color='red', linestyle='--')
    axes[1].set_xlabel("Time (s)")
    axes[1].set_ylabel("Channel")
    axes[1].set_title("CSD Aligned to Spikes")

    # Add colorbar for CSD
    fig.colorbar(im, ax=axes[1], label="CSD (µV/mm²)")

    plt.tight_layout()
    plt.show()