In [None]:
import numpy as np
from scipy.signal import hilbert, butter, filtfilt
import matplotlib.pyplot as plt

In [None]:
def signal_filt(data, fs, lowcut, highcut):
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = butter(2, [low, high], btype='band')
    filtered_data = filtfilt(b, a, data)
    return filtered_data

In [None]:
def cal_mi(lfp_amp, lfp_pha, PhaseFreqVector, AmpFreqVector, PhaseFreq_BandWidth, AmpFreq_BandWidth, fs):
    # Define phase bins
    data_length = len(lfp_amp)
    nbin = 18  # number of phase bins
    position = np.zeros(nbin)  # this variable will get the beginning (not the center) of each phase bin (in rads)
    winsize = 2 * np.pi / nbin
    for j in range(nbin):
        position[j] = -np.pi + j * winsize

    # Filtering and Hilbert transform
    Comodulogram = np.zeros((len(PhaseFreqVector), len(AmpFreqVector)), dtype=np.float32)
    AmpFreqTransformed = np.zeros((len(AmpFreqVector), data_length))
    PhaseFreqTransformed = np.zeros((len(PhaseFreqVector), data_length))

    # Amplitude filtering and Hilbert transform
    for ii, Af1 in enumerate(AmpFreqVector):
        Af2 = Af1 + AmpFreq_BandWidth
        AmpFreq = signal_filt(lfp_amp, fs, Af1, Af2)
        AmpFreqTransformed[ii, :] = np.abs(hilbert(AmpFreq))

    # Phase filtering and Hilbert transform
    for jj, Pf1 in enumerate(PhaseFreqVector):
        Pf2 = Pf1 + PhaseFreq_BandWidth
        PhaseFreq = signal_filt(lfp_pha, fs, Pf1, Pf2)
        PhaseFreqTransformed[jj, :] = np.angle(hilbert(PhaseFreq))