### LFP analysis
 - Created December 3, 2024 by Thomas Elston

In [1]:
# imports
from pathlib import Path
import h5_utilities_module as h5u
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
from scipy import signal, optimize
from scipy.optimize import curve_fit
from itertools import product
import matplotlib.pyplot as plt
from scipy.stats import zscore
import h5py
from scipy.signal import detrend, butter, filtfilt, correlate, hilbert


In [2]:
# functions
def get_path_from_dir(base_folder, file_name):
    """
    Search for a file within the specified directory and its subdirectories
    by matching the given file name.

    Args:
    - base_folder (str): The directory path to start the search from.
    - file_name (str): The specific file name or part of the file name to be matched.

    Returns:
    - str or None: The path to the first file found with the given name,
      or None if no file is found.
    """
    base_folder = Path(base_folder)
    
    # Iterate through all files and directories in the base_folder
    for file in base_folder.glob('**/*'):
        if file.is_file() and file_name in file.name:
            return file.resolve()  # Return the path of the first file found with the given name
    
    # Print an error message if no file with the given name is found
    print(f"No '{file_name}' file found in the directory.")
    return None


def coherogram(lfp1, lfp2, n_perseg, n_fft, n_overlap, fq_min, fq_max, fs):
    """Computes a coherogram for an LFP stream of arbitrary length using a Hamming window.

    Inputs: 
        lfp1 (ndarray):  Datastream for one LFP channel
        lfp2 (ndarray):  Datastream for a different LFP channel
        n_perseg(int):   Window over which each piece of the coherogram is evaluated
        n_fft (int):     Number of samples to compute the FFT over (typically n_perseg)
        fq_min (int):    Minimum frequency coherence is computed at
        fq_max (int):    Maximum frequency coherence is evaluated at
        fs (int):        Sampling frequency (in Hz) of the LFP data (typically 1000 Hz)
            
    Returns:
        ts (ndarray):    Timestamps of each element of the coherogram
        coh (ndarray):   A n_frequencies x n_times array where each element describes the coherence between
                         lfp1 and lfp2 at each timestep and frequency
        z_coh (ndarray): Z-scored coherence (zscore applied to each frequency)
        freqs (ndarray): Array detailing which frquencies define the rows of coh     
    """
    
    if n_overlap < 1:
        n_overlap = int(n_overlap * n_perseg) # convert to number of segments if a fraction provided (noverlap <= 1)

    # find out the times to compute coherence at   
    ts = np.arange(0, lfp1.shape[0] + (n_perseg - n_overlap), n_perseg - n_overlap)

    n_freqs = ((fq_max - fq_min) * n_perseg // 1000) + 1

    # define the starts and stops of each window to assess coherence over
    # first columns is starts, second is stops
    win_details = np.zeros(shape=(len(ts), 2))

    # loop over each timestep and find appropriate windows. Truncate windows at the 
    # very beginning and end of of the data stream
    for i_t in range(len(ts)):

        # make centered windows
        win_details[i_t, 0] = ts[i_t] - np.floor((n_perseg/2)) # window starts
        win_details[i_t, 1] = ts[i_t] + np.floor((n_perseg/2)) # window ends

        # is the left border of the window before the start of the session?
        if win_details[i_t, 0] < 0: 
            # then set the window start to zero
            win_details[i_t, 0] = 0

        # is the right border of the window longer than the end of the session?
        if win_details[i_t, 1] > lfp1.shape[0]:
            # then set the window end to the end of the session
            win_details[i_t, 1] = lfp1.shape[0]

    # be sure these are ints
    win_details = win_details.astype(int)
    
    # run the coherence on the first window to see how many frequencies are obtained
    freqs, test_coh = signal.coherence(lfp1[0:n_perseg], lfp1[0:n_perseg], 1000, nfft=1000)
    
    n_freqs = len(freqs)

    # intialize an array to accumulate the coherence data into
    coh = np.zeros(shape=(n_freqs, len(ts)))
    coh[:] = np.nan

    # loop over each timestep
    for t in range(len(ts)):
        _, coh[:, t] = signal.coherence(lfp1[win_details[t, 0]: win_details[t, 1]],
                                           lfp2[win_details[t, 0]: win_details[t, 1]],
                                           fs=fs, nfft = n_fft, window='hamming')

    # select the frequency range defined by fq_min and fq_max
    freqs2keep = (freqs >= fq_min) & (freqs <= fq_max)

    # zscore the coherence
    z_coh = np.zeros_like(coh)
    for f in range(len(freqs)):
        f_mean = np.nanmean(coh[f,:])
        f_std = np.nanstd(coh[f,:])
        z_coh[f,:] = (coh[f,:] - f_mean) / f_std



    return ts, coh[freqs2keep,: ], z_coh[freqs2keep, :], freqs[freqs2keep]


def detrend_psd(freqs, psd):
    """
    Detrend the PSD by fitting and removing the aperiodic component using a second-order polynomial.

    Parameters:
    - freqs (array): Frequencies corresponding to the PSD.
    - psd (array): PSD values from scipy.signal.welch().

    Returns:
    - freqs (array): Frequencies corresponding to the PSD.
    - psd_detrended (array): Detrended PSD values.
    """

    # Ensure positive frequencies and PSD values
    freqs = np.array(freqs)
    psd = np.array(psd)
    valid_indices = (freqs > 0) & (psd > 0)
    freqs = freqs[valid_indices]
    psd = psd[valid_indices]

    # Convert to log-log space
    log_freqs = np.log10(freqs)
    log_psd = np.log10(psd)

    # Fit a second-order polynomial to log-log data
    def poly2(x, a, b, c):
        return a * x**2 + b * x + c

    popt, _ = curve_fit(poly2, log_freqs, log_psd)

    # Compute the aperiodic fit and detrend
    log_fit = poly2(log_freqs, *popt)
    log_psd_detrended = log_psd - log_fit

    # Convert back to linear space
    psd_detrended = 10 ** log_psd_detrended

    return freqs, psd_detrended

class AmplitudeCrossCorr:
    def __init__(self, low_freq, high_freq, fs, max_lags):
        self.low_freq = low_freq
        self.high_freq = high_freq
        self.fs = fs
        self.max_lags = max_lags
        
    def _preprocess(self, signal):
        """Internal preprocessing of each signal."""
        # Detrend
        signal = detrend(signal)
        
        # Z-score
        signal = (signal - np.mean(signal)) / (np.std(signal) + 1e-10)
        
        # Bandpass filter
        nyq = self.fs / 2
        b, a = butter(3, [self.low_freq/nyq, self.high_freq/nyq], btype='band')
        signal = filtfilt(b, a, signal)
        
        return signal
        
    def compute(self, lfp1, lfp2):
        """
        Compute amplitude cross-correlation between two LFP signals.
        
        Parameters
        ----------
        lfp1, lfp2 : np.ndarray
            LFP time series
            
        Returns
        -------
        lags : np.ndarray
            Lag times in milliseconds
        xcorr : np.ndarray
            Cross-correlation values
        """
        # Preprocess both signals
        lfp1_proc = self._preprocess(lfp1)
        lfp2_proc = self._preprocess(lfp2)
        
        # Get amplitude envelopes
        amp1 = np.abs(hilbert(lfp1_proc))
        amp2 = np.abs(hilbert(lfp2_proc))
        
        # Compute cross-correlation
        xcorr = correlate(amp1, amp2, mode='same')
        
        # Normalize
        n = len(xcorr)
        auto1 = correlate(amp1, amp1, mode='same')[int(n/2)]
        auto2 = correlate(amp2, amp2, mode='same')[int(n/2)]
        xcorr = xcorr / np.sqrt(auto1 * auto2)
        
        # Create lag vector in milliseconds
        lags = np.arange(-self.max_lags, self.max_lags + 1) * 1000/self.fs
        
        # Extract the requested lags
        middle = len(xcorr) // 2
        max_lag_samples = int(self.max_lags)
        start = middle - max_lag_samples
        end = middle + max_lag_samples + 1
        xcorr = xcorr[start:end]
        
        return lags, xcorr


In [3]:
# where are the data?
data_dir = 'C:/Users/thome/Documents/PYTHON/OFC-CdN 3 state self control/files_for_decoder/'

# where to save the data?
save_dir = 'C:/Users/thome/Documents/PYTHON/OFC-CdN 3 state self control/lfp_PSDs_Coh/'

# get their relevant paths
data_files = h5u.find_h5_files(data_dir)

In [4]:
for this_file in data_files:

    f_name = Path(this_file).stem
    print(f_name)

    if 'D' in Path(this_file).stem:
        s = 0
    else:
        s = 1

    # load the data
    bhv = pd.read_hdf(this_file, key='bhv')
    ofc_lfp = np.clip(h5u.pull_from_h5(this_file, 'OFC_lfp'), -1e6, 1e6)
    cdn_lfp = np.clip(h5u.pull_from_h5(this_file, 'CdN_lfp'), -1e6, 1e6)
    lfp_ts = h5u.pull_from_h5(this_file, 'lfp_ts')

    # zscore the lfp
    z_ofc_lfp = np.zeros((len(bhv), len(lfp_ts), ofc_lfp.shape[2]), dtype='float16')
    z_cdn_lfp = np.zeros((len(bhv), len(lfp_ts), cdn_lfp.shape[2]), dtype='float16')

    print('\nZ-scoring lfps...')
    for ch in tqdm(range(ofc_lfp.shape[2])):

        ofc_ch_mean = np.nanmean(ofc_lfp[:,:, ch])
        ofc_ch_std = np.nanstd(ofc_lfp[:,:, ch])

        z_ofc_lfp[:,:, ch] = (ofc_lfp[:,:, ch] - ofc_ch_mean) / ofc_ch_std

    for ch in tqdm(range(cdn_lfp.shape[2])):
        cdn_ch_mean = np.nanmean(cdn_lfp[:,:, ch])
        cdn_ch_std = np.nanstd(cdn_lfp[:,:, ch])

        z_cdn_lfp[:,:, ch] = (cdn_lfp[:,:, ch] - cdn_ch_mean) / cdn_ch_std

    # delete the original LFP
    del ofc_lfp
    del cdn_lfp

    # compute PSD across the probe
    window = 'hamming'
    fs = 1000
    n_fft = fs
    n_overlap = np.round(.5*n_fft)

    ofc_pwr = np.zeros((len(bhv), 501, z_ofc_lfp.shape[2]-1))
    cdn_pwr = np.zeros((len(bhv), 501, z_cdn_lfp.shape[2]-1))

    print('\nComputing each channel PSD...')
    for ch in tqdm(range(ofc_pwr.shape[2])):
        for t in range(len(bhv)):
            fq, ofc_pwr[t,:, ch] = signal.welch(z_ofc_lfp[t,:,ch], fs=fs, noverlap=n_overlap, nfft = n_fft, nperseg=n_fft)
            fq, cdn_pwr[t,:, ch] = signal.welch(z_cdn_lfp[t,:,ch], fs=fs, noverlap=n_overlap, nfft = n_fft, nperseg=n_fft)

    # calculate the channel mean PSDs
    ofc_ch_means = np.nanmean(ofc_pwr, axis=0).T
    cdn_ch_means = np.nanmean(cdn_pwr, axis=0).T

    # subtract the aperiodic component
    ofc_fooof_psd = np.zeros_like(ofc_ch_means)
    cdn_fooof_psd = np.zeros_like(cdn_ch_means)

    print('\nRemoving aperiodic (1/f) component of PSD...')
    for ch in tqdm(range(ofc_ch_means.shape[0])):
        freqs, ofc_fooof_psd[ch, :] = detrend_psd(fq+1, ofc_ch_means[ch, :])
        _, cdn_fooof_psd[ch, :] = detrend_psd(fq+1, cdn_ch_means[ch, :])

    # account for the frequency shift imposed earlier (to avoid a divide-by-zero error)
    ofc_fooof_psd = np.roll(ofc_fooof_psd, -1, axis=1)
    cdn_fooof_psd = np.roll(cdn_fooof_psd, -1, axis=1)

    ofc_theta_pwr = np.mean(ofc_fooof_psd[:, 4:9], axis=1)
    ofc_alpha_pwr = np.mean(ofc_fooof_psd[:, 11:20], axis=1)
    cdn_theta_pwr = np.mean(cdn_fooof_psd[:, 4:9], axis=1)
    cdn_alpha_pwr = np.mean(cdn_fooof_psd[:, 11:20], axis=1)

    ##
    # get the mean activity over the probe for the coherence analysis
    ofc_data = np.mean(z_ofc_lfp[:,:,0:384], axis=2)
    cdn_data = np.mean(z_cdn_lfp[:,:,0:384], axis=2)

    # run a single trial to get shape of data
    ts, test_coh, _, freqs2 = coherogram(ofc_data[0,:], cdn_data[0,:], 1000, 1000, .95, 1, 50, 1000)


    # initialize arrays for the trialwise coherence
    coh_trials = np.zeros(shape=(len(freqs2), len(ts), len(bhv)), dtype='float16')
    coh_trials[:] = np.nan
    z_coh_trials = np.zeros(shape=(len(freqs2), len(ts), len(bhv)), dtype='float16')
    z_coh_trials[:] = np.nan

    print('\nComputing trial-by-trial coherence')
    for t in tqdm(range(len(bhv))):

        if not np.isnan(np.sum(ofc_data[t,:])):
            ts, coh_trials[:,:,t], z_coh_trials[:,:,t], freqs = coherogram(ofc_data[t,:], cdn_data[t,:],
                                                                            1000, 1000, .95, 1, 50, 1000)
            
    ts = ts - 1500    
    t_start = np.argmin(np.abs(ts - -1000))
    t_end = np.argmin(np.abs(ts - 1000))

    # chop off the borders because that will blow the zscore
    coh_trials = coh_trials[:,t_start:t_end,:]   
    z_coh_trials = z_coh_trials[:,t_start:t_end,:]    
    ts = ts[t_start:t_end]

    # Run an amplitude cross-correlation analysis
    print('\nComputing trial-by-trial amplitude cross correlation...')
    max_lag=100
    xcorr = AmplitudeCrossCorr(low_freq=10, high_freq=20, fs=1000, max_lags=max_lag)

    # Initialize lists to store results
    xcorr_results = np.zeros((len(bhv), (2*max_lag) + 1))
    xcorr_results[:] = np.nan
    max_lags = np.zeros((len(bhv), ))
    max_lags[:] = np.nan

    # Your trial loop
    for t in range(len(bhv)):
        if np.any(np.isnan(ofc_data[t])) or np.any(np.isnan(cdn_data[t])):
            continue
            
        # Compute for this trial
        lags, corr = xcorr.compute(ofc_data[t,1500:2000 + (500*s)], cdn_data[t,1500:2000 + (500*s)])
        xcorr_results[t, :] = corr
        peak_idx = np.argmax(np.abs(corr))
        max_lags[t] = lags[peak_idx]

    # set to nans trials with no resolved lag
    max_lags[max_lags == 0] = np.nan
    max_lags[max_lags == -100] = np.nan

    # save the data!
    save_name = save_dir + f_name + '_psd_coh.h5'
    
    # Open an HDF5 file in write mode ('w' or 'w-' to create or truncate the file)
    with h5py.File(save_name, 'w') as file:
        # Create datasets within the HDF5 file and write data
        file.create_dataset('ofc_psd', data=ofc_fooof_psd)  
        file.create_dataset('cdn_psd', data=cdn_fooof_psd)  
        file.create_dataset('coh_trials', data=coh_trials)  
        file.create_dataset('z_coh_trials', data=z_coh_trials)  
        file.create_dataset('coh_ts', data = ts)

        file.create_dataset('amp_x_corr_lags', data = max_lags)
        
        # Save behavior to the HDF5 file
        bhv.to_hdf(save_name, key='bhv', mode='a')

    print('File processed. \n')
    
print('All files complete :]')
    


D20231219_Rec05

Z-scoring lfps...


100%|██████████| 385/385 [00:33<00:00, 11.46it/s]
100%|██████████| 385/385 [00:33<00:00, 11.41it/s]



Computing each channel PSD...


100%|██████████| 384/384 [01:40<00:00,  3.84it/s]



Removing aperiodic (1/f) component of PSD...


100%|██████████| 384/384 [00:00<00:00, 2417.16it/s]



Computing trial-by-trial coherence


100%|██████████| 843/843 [00:32<00:00, 26.28it/s]



Computing trial-by-trial amplitude cross correlation...
File processed. 

D20231221_Rec06

Z-scoring lfps...


100%|██████████| 385/385 [00:49<00:00,  7.75it/s]
100%|██████████| 385/385 [00:49<00:00,  7.81it/s]



Computing each channel PSD...


100%|██████████| 384/384 [02:25<00:00,  2.64it/s]



Removing aperiodic (1/f) component of PSD...


100%|██████████| 384/384 [00:00<00:00, 3080.52it/s]



Computing trial-by-trial coherence


100%|██████████| 1246/1246 [00:47<00:00, 26.28it/s]



Computing trial-by-trial amplitude cross correlation...
File processed. 

D20231224_Rec07

Z-scoring lfps...


100%|██████████| 385/385 [00:49<00:00,  7.78it/s]
100%|██████████| 385/385 [00:49<00:00,  7.73it/s]



Computing each channel PSD...


100%|██████████| 384/384 [02:26<00:00,  2.63it/s]



Removing aperiodic (1/f) component of PSD...


100%|██████████| 384/384 [00:00<00:00, 2180.46it/s]



Computing trial-by-trial coherence


100%|██████████| 1242/1242 [00:48<00:00, 25.52it/s]



Computing trial-by-trial amplitude cross correlation...
File processed. 

D20231227_Rec08

Z-scoring lfps...


100%|██████████| 385/385 [00:33<00:00, 11.41it/s]
100%|██████████| 385/385 [00:34<00:00, 11.09it/s]



Computing each channel PSD...


100%|██████████| 384/384 [01:41<00:00,  3.78it/s]



Removing aperiodic (1/f) component of PSD...


100%|██████████| 384/384 [00:00<00:00, 2947.75it/s]



Computing trial-by-trial coherence


100%|██████████| 864/864 [00:25<00:00, 33.90it/s] 



Computing trial-by-trial amplitude cross correlation...
File processed. 

K20240707_Rec06

Z-scoring lfps...


100%|██████████| 385/385 [00:48<00:00,  8.00it/s]
100%|██████████| 385/385 [00:48<00:00,  8.01it/s]



Computing each channel PSD...


100%|██████████| 384/384 [02:23<00:00,  2.67it/s]



Removing aperiodic (1/f) component of PSD...


100%|██████████| 384/384 [00:00<00:00, 3043.31it/s]



Computing trial-by-trial coherence


100%|██████████| 1212/1212 [00:42<00:00, 28.28it/s]



Computing trial-by-trial amplitude cross correlation...
File processed. 

K20240710_Rec07

Z-scoring lfps...


100%|██████████| 385/385 [00:47<00:00,  8.15it/s]
100%|██████████| 385/385 [00:46<00:00,  8.20it/s]



Computing each channel PSD...


100%|██████████| 384/384 [02:21<00:00,  2.71it/s]



Removing aperiodic (1/f) component of PSD...


100%|██████████| 384/384 [00:00<00:00, 3014.20it/s]



Computing trial-by-trial coherence


100%|██████████| 1197/1197 [00:45<00:00, 26.05it/s]



Computing trial-by-trial amplitude cross correlation...
File processed. 

K20240712_Rec08

Z-scoring lfps...


100%|██████████| 385/385 [00:40<00:00,  9.43it/s]
100%|██████████| 385/385 [00:40<00:00,  9.52it/s]



Computing each channel PSD...


100%|██████████| 384/384 [02:02<00:00,  3.15it/s]



Removing aperiodic (1/f) component of PSD...


100%|██████████| 384/384 [00:00<00:00, 2981.55it/s]



Computing trial-by-trial coherence


100%|██████████| 1024/1024 [00:37<00:00, 27.49it/s]



Computing trial-by-trial amplitude cross correlation...
File processed. 

K20240715_Rec09

Z-scoring lfps...


100%|██████████| 385/385 [00:47<00:00,  8.08it/s]
100%|██████████| 385/385 [00:49<00:00,  7.83it/s]



Computing each channel PSD...


100%|██████████| 384/384 [02:16<00:00,  2.81it/s]



Removing aperiodic (1/f) component of PSD...


100%|██████████| 384/384 [00:00<00:00, 3199.38it/s]



Computing trial-by-trial coherence


100%|██████████| 1206/1206 [00:42<00:00, 28.56it/s]



Computing trial-by-trial amplitude cross correlation...
File processed. 

All files complete :]
