In [1]:
import numpy as np
from loader import load_oneIC
import h5py
import xarray as xr

In [2]:
from sklearn.decomposition import PCA
from sklearn.preprocessing import scale
from hmmlearn import hmm
import scipy.signal as signal
import pickle

In [17]:
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from wavelet_transform import wavelet_transform2

In [4]:
import mne

In [5]:
import time as tm

In [6]:
directory = "/home/INT/malfait.n/Documents/NIC_250819"
file = "FCK_LOCKED_IC_JYOTIKA_250819.mat"

path = f"{directory}/{file}"

mat_file = h5py.File(path, "r")
cells_refs = mat_file['FCK_LOCKED_IC_JYOTIKA']

n_IC = 4
n_subj = 23

In [7]:
figures_dir = "tde-hmm2/Multivariate/figures/"
data_dir = "tde-hmm2/Multivariate/data/"

In [8]:
# The embedx function copies the `x` array len(lags) times into `xe`
# with lags (i.e. time delays) between lags[0] and lags[-1] (we implement the time-delay array for the HMM).

def embedx(x, lags):
    
    Xe = np.zeros((x.shape[1], x.shape[0],  len(lags)))

    for l in range(len(lags)):
        Xe[:, :, l] = np.roll(x, lags[l], axis=0).swapaxes(0, 1)

    # Remove edges
    valid = np.ones((x.shape[0], 1), dtype=np.int8)
    valid[:np.abs(np.min(lags)), :] = 0
    valid[-np.abs(np.max(lags)):, :] = 0

    Xe = Xe[:, valid[:, 0] == 1, :]

    return Xe, valid

In [9]:
def statesPSD(gamma, n_states, xe, fs=256/3):

    psd_all = []
    for i in range(n_states):

        # Compute PSD separately for each lag
        tot = []
        for seg in xe[gamma[:, i]>(2/3), :].T:
            freqs, psd = signal.welch(x=seg, fs=fs, nfft=1000)
            tot.append(psd)
        psd = np.mean(np.asarray(tot), 0)
        psd_all.append(psd)
    
    psd_all = np.asarray(psd_all)
    
    return freqs, psd_all

## Parameters

In [43]:
### Imput data parameters
# subj_list = [2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 14, 16, 17, 18, 19, 20, 21, 22, 23] # All except 1, 9, 13, 15
# subj_list = [2, 3, 4, 6, 7, 8, 10, 11, 14, 16, 18, 19, 22] # same && all IC exist
subj_list = [2]
IC_list = [1, 2, 3, 4]
# IC_list = [1, 2]
downsamp_rate = 3
# downsamp_rate = 1
lags = np.arange(-5, 5)
# lags = np.arange(-29, 29)
# lags = np.arange(-11, 12)
n_lags = lags.shape[0]
apply_PCA = False  # Do we apply a PCA before inferring the HMM?
n_components = 0     # Number of principal components in case of PCA
# n_components = 40

### HMM parameters
all_subj_first = False    # Do we compute the model on all subjects before refining it for each subject?
model_type = 'GaussianHMM'
covariance_type = 'full'
# covariance_type = 'diag'  # ONLY IN CASE OF PCA
n_iter = 100
tol = 0.01

### Output data parameters
# n_states_list = [3, 4, 5, 6]    # Number of hidden Markov states. Must be a list.
n_states_list = [3]

## Script

The model on all subjects can only be computed on subjects that have the same number of Independent Components (here 4IC). But then, the refining using the parameters found can be done on all subjects, no matter the number of ICs.

In [44]:
if all_subj_first:
    print("Computing the imput matrix for the model")
    subj_lengths = []
    tde_imput = []
    for subj in subj_list:
        # Loading all data for subject{subj}
        xeall = []
        if apply_PCA:
            pcaall = []
        for IC in IC_list:
            data, n_trials = load_oneIC(mat_file, cells_refs, subj, IC, comp=False)
            big_timecourse = np.concatenate([data[f'raw_timecourse_256Hz'][i][::downsamp_rate] for i in range(100)])
                # Downsampled to 256/3 = 85,33333Hz
            big_timecourse = scale(big_timecourse)
            x = big_timecourse.reshape(-1, 1)
            xe, valid = embedx(x, lags)
            xeall.append(xe[0, :, :])
            if apply_PCA:
                pca = PCA(n_components=n_components)
                pcaall.append(pca.fit_transform(xe[0, :, :]))
            print(f"IC{IC} loaded")
        if apply_PCA:
            y = np.concatenate(pcaall, axis=1)
        else:
            y = np.concatenate(xeall, axis=1)
        tde_imput.append(y)
        subj_lengths.append(xe.shape[1])
    tde_imput = np.concatenate(tde_imput)

    for n_states in n_states_list:
        start_time = tm.time()
        print(f"Computing and saving the general model for {n_states} states")
        model = hmm.GaussianHMM(n_components=n_states, n_iter=n_iter,
                                covariance_type=covariance_type, tol=tol)
        model.fit(tde_imput, subj_lengths)

        with open(data_dir + f"ALLSUBJECTS_st{n_states}_lg{n_lags}co{n_components}"
            + "Multivariate" + model_type + ".pkl", "wb") as file: pickle.dump(model, file)
        print("%s seconds" % (tm.time() - start_time))

In [45]:
for n_states in n_states_list:
    if all_subj_first:
        file = open(data_dir + f"ALLSUBJECTS_st{n_states}_lg{n_lags}co{n_components}"
                + "Multivariate" + model_type + ".pkl", "rb")
        model = pickle.load(file)
        model.init_params = 'st'

    for subj in subj_list:
        start_time = tm.time()
        # Loading all data for subject{subj}
        print("Computing the imput matrix for the model")
        xeall = []
        if apply_PCA:
            pcaall = []
        for IC in IC_list:
            try:
                data, n_trials = load_oneIC(mat_file, cells_refs, subj, IC, comp=False)
                big_timecourse = np.concatenate([data[f'raw_timecourse_256Hz'][i][::downsamp_rate] 
                                                 for i in range(n_trials)])
                    # Downsampled to 256/3 = 85,33333Hz
                big_timecourse = scale(big_timecourse)
                x = big_timecourse.reshape(-1, 1)
                xe, valid = embedx(x, lags)
                xeshape = xe[0].shape
                xeall.append(xe[0, :, :])
                if apply_PCA:
                    pca = PCA(n_components=n_components)
                    pcaall.append(pca.fit_transform(xe[0, :, :]))
                print(f"IC{IC} loaded")
            except:
                xeall.append(np.zeros(1))
                if apply_PCA:
                    pcaall.append(np.zeros(1))
        for i in range(len(IC_list)):
            if xeall[i].shape[0]==1:
                xeall[i] = np.zeros(xeshape)
                if apply_PCA:
                    pcaall[i] = np.zeros(xeshape)
        if apply_PCA:
            y = np.concatenate(pcaall, axis=1)
        else:
            y = np.concatenate(xeall, axis=1)
        print("Computing and saving the model")
        if all_subj_first==False:
            model = hmm.GaussianHMM(n_components=n_states, n_iter=n_iter,
                                covariance_type=covariance_type, tol=tol)
        model.fit(y)
        gamma = model.predict_proba(y)
        with open(data_dir + f"su{subj}_st{n_states}_lg{n_lags}co{n_components}"
                    + "Multivariate" + model_type + ".pkl", "wb") as file: pickle.dump(model, file)
        print("Computing the PSD of each state")
        psds = []
        for i in range(len(IC_list)):
            freqs, psd = statesPSD(gamma, n_states, xeall[i])
            psd = psd[np.newaxis,]
            psds.append(psd)
        # Save the states timecourses and PSDs thanks to xarray and netCDF
        print("Saving the states timecourses and PSDs")
        tcourse = np.concatenate(
            (np.zeros((abs(lags[0]),n_states)), gamma, np.zeros((lags[-1],n_states)))
        )
        time_axis = data["time_axis"][::downsamp_rate]
        t_len = time_axis.shape[0]
        tcourse_trials = np.zeros((n_trials, t_len, n_states))
        for tr in range(n_trials):
            tcourse_trials[tr] = tcourse[tr*t_len:(tr+1)*t_len]
        ds = xr.Dataset(
            {
                "states_timecourse": (("trials","time", "states"), tcourse_trials),
                "states_psd": (("IC", "states", "freq"), np.concatenate((psds))),
            },
            {
                "IC":IC_list,
                "time":time_axis,
                "states":np.arange(1, n_states+1),
                "freq": freqs,
            }
        )
        ds = ds.assign(frac_occ = (ds["states_timecourse"].sum("trials")/ds.sizes["trials"]))
        ds.to_netcdf(data_dir + f"su{subj}_st{n_states}_lg{n_lags}co{n_components}_data.nc", mode="w")
        print("%s seconds" % (tm.time() - start_time))
        print(f"su{subj}, {n_states} states done")

Computing the imput matrix for the model
Loading the raw timecourse
IC1 loaded
Loading the raw timecourse
IC2 loaded
Computing and saving the model
Computing the PSD of each state
Saving the states timecourses and PSDs
222.4885814189911 seconds
su3, 3 states done


In [46]:
info = mne.create_info(ch_names=['signal'], sfreq=256, ch_types=['eeg'])

for n_states in n_states_list:
    widths = [14, 1, 5]
    heights = [1 for i in range(n_states)] + [4,4,4,4]
    gs_kw = dict(width_ratios=widths, height_ratios=heights, wspace=0.0, hspace=0.0)

    for subj in subj_list:
        fig, f_axes = plt.subplots(figsize=(sum(widths),sum(heights)), ncols=3, nrows=(n_states+4), constrained_layout=True,
                                 gridspec_kw=gs_kw)
        ds = xr.open_dataset(data_dir + f"su{subj}_st{n_states}_lg{n_lags}co{n_components}_data.nc")
        for state in ds["states"].values:_
            ax = f_axes[state-1,0]
            ax.plot(ds["time"].values, ds["frac_occ"].values[:, state-1], color=f"C{state-1}")
            ax.set_xlim([-4, 3])
            ax.set_ylabel(f"state {state}")
        for IC in ds["IC"].values:
            try:
                ax = f_axes[IC+n_states-1,0]
                data, n_trials = load_oneIC(mat_file, cells_refs, subj, IC, comp=False)
                tfr = wavelet_transform2(data, info, trial=np.arange(1,n_trials+1))
                mappable = ax.imshow(tfr[0], aspect='auto', origin='lower', extent=[-4, 3, 2, 50], 
                                       norm=colors.PowerNorm(gamma=0.5), cmap='RdYlBu_r')
                ax.set_ylabel(f'Frequencies IC{IC} (Hz)')
                fig.colorbar(mappable, cax=f_axes[IC+n_states-1,1])
            except:
                pass
        for i in range(len(ds["IC"].values)):
            IC = ds["IC"].values[i]
            ax = f_axes[IC+n_states-1,2]
            ax.plot(ds["freq"].values, ds["states_psd"].values[i,].T)
            ax.set_ylabel(f'PSD IC{IC}')
        ax = f_axes[n_states,2]
        lines = ax.get_lines()
        labels = [f"State {i}" for i in range (1, n_states+1)]
        ax = f_axes[0,2]
        ax.legend(lines, labels, loc='upper left')
        ax = f_axes[n_states+3,0]
        ax.set_xlabel(f'Time (s)') 
        ax = f_axes[n_states+3,2]
        ax.set_xlabel(f'Frequency (Hz)')   

        fig.savefig(figures_dir + f'grid-{n_states}states-subj{subj}.png', dpi=300)
        plt.close(fig)

Loading the raw timecourse
Loading the raw timecourse


In [47]:
n_trials_to_plot = 5

for n_states in n_states_list:
    for subj in subj_list:
        fig, f_axes = plt.subplots(figsize=(20,n_states*n_trials_to_plot), ncols=1, nrows=(n_states*n_trials_to_plot), constrained_layout=True)
        ds = xr.open_dataset(data_dir + f"su{subj}_st{n_states}_lg{n_lags}co{n_components}_data.nc")
        for state in range(1, n_states+1):
            color = f"C{state-1}"
            for trial in range(1, n_trials_to_plot+1):
                ax = f_axes[n_trials_to_plot*(state-1)+trial-1]
                ax.fill_between(ds["time"], ds["states_timecourse"].values[trial, :, state-1], color=color)
                ax.set_xlim([-4, 3])
                ax.set_ylim([0, 1])
                ax.set_ylabel(f"Prob. st{state} tr{trial}")
        ax.set_xlabel("Time (s)")
        fig.savefig(figures_dir + f'tcourses-{n_states}states-subj{subj}.png', dpi=300)
        plt.close(fig)