In [1]:
%matplotlib qt
import os
import mne
import scipy.io
import numpy as np
import os.path as op
import matplotlib.pyplot as plt
from mne.preprocessing import ICA, create_eog_epochs, create_ecg_epochs

# Clean cardiac and occular artifacts using ICA

In [None]:
# Given
fname_raw = 'run_tsss.fif'
data_path = op.expanduser('~/data/meg/')
fpath_raw = op.join(data_path, fname_raw)
print(fpath_raw)

In [None]:
# Read data
raw = mne.io.read_raw_fif(fpath_raw)
raw.info

In [None]:
# pick some channels that clearly show heartbeats and blinks
regexp = r'(MEG[12][45][123]1)'
artifact_picks = mne.pick_channels_regexp(raw.ch_names, regexp=regexp)

In [None]:
# Find ICA compenents
ica = ICA(n_components=15, max_iter='auto', random_state=97)
ica.fit(raw)
ica

In [None]:
# Plot found ICA components
ica.plot_sources(raw, show_scrollbars=False)

In [None]:
# Mark ICA components carrying ECG and EOG activity for exclusion
ica.exclude = [0, 1] # indices chosen based on various plot above

In [None]:
# Load raw data and remove marked ICA components for exclusion
raw.load_data()
orig_raw = raw.copy()
ica.apply(raw)

In [None]:
# Plot to see if the results of ICA cleaning
orig_raw.plot(order=artifact_picks, n_channels=len(artifact_picks), show_scrollbars=False)
raw.plot(order=artifact_picks, n_channels=len(artifact_picks), show_scrollbars=False)

In [None]:
fname_raw_clean = 'clean_raw_tsss.fif'
fpath_raw_clean = op.join(data_path, fname_raw_clean)
raw.save(fpath_raw_clean)

# Data segmentation and calculate SSVEF

In [None]:
# Given
fs = 1000
fname_raw = 'clean_raw_tsss.fif'
data_path = op.expanduser('~/data/meg/')
fpath_raw = op.join(data_path, fname_raw)

# Read data
raw = mne.io.read_raw_fif(fpath_raw)
raw.pick(['mag'])

# Show sensor locations
# fig = plt.figure()
# ax = fig.add_subplot(111, projection='3d')
# raw.plot_sensors(axes=ax, kind='3d', show_names=True)
# fig.tight_layout()
# plt.show()

# Make events object
events = np.array([[558793, 0, 1]])
event_dict = {'F': 1}

# fig = mne.viz.plot_events(events, event_id=event_dict, sfreq=raw.info['sfreq'],
#                          first_samp=raw.first_samp)

reject_criteria = dict(mag=4000e-15)    # 4000 fT
epochs = mne.Epochs(raw, events, event_id=event_dict, tmin=-0.2, tmax=120,
                    reject=reject_criteria, preload=True)

trial_data = np.squeeze(epochs['F'].get_data())[:, 201:]

ch_vis = ['MEG1931', 'MEG2141', 'MEG1741', 'MEG1731', 'MEG1941', 'MEG1921', 'MEG2111',
          'MEG2121', 'MEG2331', 'MEG2131', 'MEG2541', 'MEG2511', 'MEG2321', 'MEG2341']
ch_vis_inds = [raw.ch_names.index(ch) for ch in ch_vis]

x = np.mean(trial_data[ch_vis_inds, :], axis=0)
np.save(op.join(data_path, 'ssvef.npy'), x)

# Calculate Fourier and Hilbert Spectra

In [2]:
import pandas as pd
from PyEMD import EMD
from scipy.signal import hilbert

plt.rcParams['font.size'] = 14

In [3]:
from scipy import signal

def __detrend(data, detrend_type='linear', bp=0):
    """
    Remove linear trend along second axis from 2D array-like object `data`.
                            [Adapted from MNE-Python]
    """
    axis = -1
    if detrend_type not in ['linear', 'l', 'constant', 'c']:
        raise ValueError("Trend type must be 'linear' or 'constant'.")
    data = np.asarray(data)
    if detrend_type in ['constant', 'c']:
        data_detrended = data - np.mean(data, axis, keepdims=True)
        return data_detrended
    else:
        dshape = data.shape
        N = dshape[axis]
        bp = np.sort(np.unique(np.r_[0, bp, N]))
        if np.any(bp > N):
            raise ValueError("Breakpoints must be less than length "
                             "of data along given axis.")
        Nreg = len(bp) - 1
        # Find leastsq fit and remove it for each piece
        newdata = data.T
        for m in range(Nreg):
            Npts = bp[m + 1] - bp[m]
            A = np.ones((Npts, 2))
            A[:, 0] = np.arange(1, Npts+1) / Npts
            sl = slice(bp[m], bp[m+1])
            coef, _, _, _ = np.linalg.lstsq(A, newdata[sl], rcond=None)
            newdata[sl] = newdata[sl] - np.dot(A, coef)
        data_detrended = newdata.T

        return data_detrended
    
def computeFFTPSD(X, Fs, detrend=None, window=None, plot_psd=None):
    '''
    Compute the Fourier power spectral density of 1-D signal 'X' sampled at frequency `Fs`.
    '''

    # Detrend data
    if detrend and detrend is not None:
        X = __detrend(X)

    # Window size
    n = np.size(X, -1)

    # Windowing
    if window and window is not None:
        if window == 'hamming':
            w = np.hamming(n)
        elif window == 'hanning':
            w = np.hanning(n)
        elif window == 'cosine':
            w = signal.tukey(n)
        X = w * X

    # FFT using Numpy
    Xhat = np.fft.rfft(X)
    psd = np.real(Xhat * np.conj(Xhat)) / n # FFT PSD
    freq = Fs / n * np.arange(n)
    k = np.arange(1, np.floor(n / 2), dtype='int') # Only use the first half of frequencies
    freq = freq[k]; psd = psd[k];


    # Plot power spectrum
    if plot_psd and plot_psd is not None:
        plt.figure()
        plt.plot(freq, psd)
        plt.xlim([0, 50])
        plt.xlabel('Frequency [Hz]')
        plt.ylabel('Power Spectral Density [V2/Hz]')
        plt.show()

    return psd, freq

def calculate_hilbert_spectrum(imfs, t, fs, n=5, plot_hilbert_spec=None, plot_inst_freq=None):
    '''
    Calculate hilbert amplitude spectrum from a given set of intrinsic mode functions.
    '''

    ## Create Hilbert spectrum
    T = t[-1] - t[0]; delta_t = 1 / fs
    fmin = fres = 1 / T; fmax = 1 / (n * delta_t)
    N = int(T / (n * delta_t))
    bin_centres = np.arange(N) * fres + fmin
    bin_edges = np.arange(N + 1) * fres + (fmin - fres / 2)

    hht = np.zeros((len(imfs), N, (len(t) - 2)))

    for j, imf in enumerate(imfs):
        Z = hilbert(imf)
        A = np.abs(Z)
        theta_inst = np.unwrap(np.angle(Z))
        f_inst = np.r_[np.nan,
                       0.5 * (np.angle(-Z[2:] * np.conj(Z[:-2])) + np.pi) / (2 * np.pi) * fs,
                       np.nan]
        t_spec = t[1:-1]; A_spec = A[1:-1]; f_spec = f_inst[1:-1]

        # Plot instantaneous frequency curves
        if plot_inst_freq and plot_inst_freq is not None:
            fig, (ax0, ax1) = plt.subplots(nrows=2)
            ax0.plot(t, imf, label='signal')
            ax0.plot(t, A, label='envelope')
            ax0.set_xlabel("time (s)")
            ax0.set_ylabel("signal (units)")
            ax0.legend()
            ax1.plot(t_spec, f_spec)
            ax1.set_xlabel("time (s)")
            ax1.set_ylabel("frequency (Hz)")
            fig.tight_layout()
            plt.show()

        # Binning of frequency values
        binned_freq = pd.cut(f_spec, bin_edges)
        bin_inds = binned_freq.codes

        # Populate Hilbert spectrum matrix
        for i, bin_ind in enumerate(bin_inds):
            if bin_ind > 0:
                hht[j][bin_ind][i] = A_spec[i]
            else:
                pass

    hht_sum = np.sum(hht, axis=0)
    
    # Plot Hilbert spectrum for all IMFs
    if plot_hilbert_spec and plot_hilbert_spec is not None:
        plt.figure()
        plt.pcolormesh(t_spec, bin_centres, hht_sum)
        plt.xlabel('Time (s)')
        plt.ylabel('Frequency (Hz)')
        plt.show()
        
    return hht_sum

In [11]:
# Given
data_path = op.expanduser('~/data/meg/')
fs = 1000
x = np.load(op.join(data_path, 'ssvef.npy'))
x = x / np.mean(np.abs(x))
# t = np.arange(0, len(x)/fs, 1/fs)
t = np.arange(0, 2001/fs, 1/fs); x = x[:len(t)]

In [12]:
## Fourier Spectrum
computeFFTPSD(x, fs, detrend=True, window="hamming", plot_psd=True);

In [13]:
## Hilbert Spectrum
# EMD
emd_decomp = EMD()
imfs = emd_decomp(x)
print(imfs.shape)
C = imfs[:-1]

(8, 2001)


In [14]:
# Visualize EMD
plt.figure(figsize=(12, 12))
for i in range(len(imfs)-1):
    plt.subplot(len(imfs)+1, 1, i+1)
    plt.plot(t, x, color='0.8')
    plt.plot(t, imfs[i], 'k')
    plt.xlim([np.min(t), np.max(t)])
    plt.ylabel('IMF ' + str(i + 1))
plt.subplot(len(imfs)+1, 1, i+2)
plt.plot(t, x, color='0.8')
plt.plot(t, imfs[-1], 'k')
plt.xlim([np.min(t), np.max(t)])
plt.ylabel('Residual')
plt.xlabel('Time (s)')
plt.tight_layout()
plt.show()

  plt.tight_layout()


In [15]:
hht = calculate_hilbert_spectrum(C, t, fs, plot_hilbert_spec=True)