In [18]:
import numpy as np
from scipy.signal import hilbert, butter, filtfilt
import matplotlib.pyplot as plt

In [11]:
with open('test_pd_lfp_data.txt', 'r', encoding='utf-8') as file:
    lines = file.readlines()
split_data = [line.strip().split('\t') for line in lines]
data = list(zip(*split_data))

In [22]:
data_ctx = data[:2]
data_str = data[2:4]

In [15]:
def signal_filt(data, fs, lowcut, highcut):
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = butter(2, [low, high], btype='band')
    filtered_data = filtfilt(b, a, data)
    return filtered_data


def mod_index_v2(phase, amp, position):
    """
    Compute the modulation index and mean amplitude distribution.

    Parameters:
    phase (array-like): Phase time series
    amp (array-like): Amplitude time series
    position (array-like): Phase bins (left boundary)

    Returns:
    MI (float): Modulation index
    mean_amp (numpy.ndarray): Amplitude distribution over phase bins (non-normalized)
    """
    nbin = len(position)
    winsize = 2 * np.pi / nbin

    # Compute the mean amplitude in each phase bin
    mean_amp = np.zeros(nbin)
    for j in range(nbin):
        I = np.where((phase < position[j] + winsize) & (phase >= position[j]))[0]
        mean_amp[j] = np.mean(amp[I])

    # Calculate the modulation index using normalized entropy
    p = mean_amp / np.sum(mean_amp)
    entropy = -np.sum(p * np.log(p))
    MI = (np.log(nbin) - entropy) / np.log(nbin)

    return MI, mean_amp


In [20]:
def cal_mi(lfp_amp, lfp_pha, PhaseFreqVector, AmpFreqVector, PhaseFreq_BandWidth, AmpFreq_BandWidth, fs):
    # Define phase bins
    data_length = len(lfp_amp)
    nbin = 18  # number of phase bins
    position = np.zeros(nbin)  # this variable will get the beginning (not the center) of each phase bin (in rads)
    winsize = 2 * np.pi / nbin
    for j in range(nbin):
        position[j] = -np.pi + j * winsize

    # Filtering and Hilbert transform
    Comodulogram = np.zeros((len(PhaseFreqVector), len(AmpFreqVector)), dtype=np.float32)
    AmpFreqTransformed = np.zeros((len(AmpFreqVector), data_length))
    PhaseFreqTransformed = np.zeros((len(PhaseFreqVector), data_length))

    # Amplitude filtering and Hilbert transform
    for ii, Af1 in enumerate(AmpFreqVector):
        Af2 = Af1 + AmpFreq_BandWidth
        AmpFreq = signal_filt(lfp_amp, fs, Af1, Af2)
        AmpFreqTransformed[ii, :] = np.abs(hilbert(AmpFreq))

    # Phase filtering and Hilbert transform
    for jj, Pf1 in enumerate(PhaseFreqVector):
        Pf2 = Pf1 + PhaseFreq_BandWidth
        PhaseFreq = signal_filt(lfp_pha, fs, Pf1, Pf2)
        PhaseFreqTransformed[jj, :] = np.angle(hilbert(PhaseFreq))
     
    # Initialize Comodulogram matrix
    Comodulogram = np.zeros((len(PhaseFreqVector), len(AmpFreqVector)))

    # Comodulation loop
    for ii in range(len(PhaseFreqVector)):
        for jj in range(len(AmpFreqVector)):
            MI, MeanAmp = mod_index_v2(PhaseFreqTransformed[ii, :], AmpFreqTransformed[jj, :], position)
            Comodulogram[ii, jj] = MI

    # Handle NaN values by replacing them with the minimum value in the Comodulogram
    #fill_value = np.nanmin(Comodulogram)
    #Comodulogram = np.nan_to_num(Comodulogram, nan=fill_value)

    return Comodulogram

        

In [21]:
PhaseFreqVector=np.arange(22,2,50)
AmpFreqVector= np.arange(50,5,200)

PhaseFreq_BandWidth=2
AmpFreq_BandWidth=10

fs = 10000

cal_mi(data_ctx, data_str, PhaseFreqVector, AmpFreqVector, PhaseFreq_BandWidth, AmpFreq_BandWidth, fs)

array([], shape=(0, 0), dtype=float64)