In [1]:
import numpy as np
from loader import load_oneIC
import h5py
import xarray as xr

In [2]:
from sklearn.decomposition import PCA
from sklearn.preprocessing import scale
from hmmlearn import hmm
import scipy.signal as signal
import pickle

In [3]:
import matplotlib.pyplot as plt
from wavelet_transform import wavelet_transform2

In [4]:
import mne

In [5]:
import time as tm

In [6]:
directory = "/home/INT/malfait.n/Documents/NIC_250819"
file = "FCK_LOCKED_IC_JYOTIKA_250819.mat"

path = f"{directory}/{file}"

mat_file = h5py.File(path, "r")
cells_refs = mat_file['FCK_LOCKED_IC_JYOTIKA']

n_IC = 4
n_subj = 23

In [7]:
figures_dir = "tde-hmm2/Multivariate/figures/"
data_dir = "tde-hmm2/Multivariate/data/"

In [8]:
# The embedx function copies the `x` array len(lags) times into `xe`
# with lags (i.e. time delays) between lags[0] and lags[-1] (we implement the time-delay array for the HMM).

def embedx(x, lags):
    
    Xe = np.zeros((x.shape[1], x.shape[0],  len(lags)))

    for l in range(len(lags)):
        Xe[:, :, l] = np.roll(x, lags[l], axis=0).swapaxes(0, 1)

    # Remove edges
    valid = np.ones((x.shape[0], 1), dtype=np.int8)
    valid[:np.abs(np.min(lags)), :] = 0
    valid[-np.abs(np.max(lags)):, :] = 0

    Xe = Xe[:, valid[:, 0] == 1, :]

    return Xe, valid

In [9]:
def statesPSD(gamma, n_states, xe, fs=256/3):

    psd_all = []
    for i in range(n_states):

        # Compute PSD separately for each lag
        tot = []
        for seg in xe[gamma[:, i]>(2/3), :].T:
            freqs, psd = signal.welch(x=seg, fs=fs, nfft=1000)
            tot.append(psd)
        psd = np.mean(np.asarray(tot), 0)
        psd_all.append(psd)
    
    psd_all = np.asarray(psd_all)
    
    return freqs, psd_all

## Parameters

In [10]:
### Imput data parameters
# subj_list = [2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 14, 16, 17, 18, 19, 20, 21, 22, 23] # All except 1, 9, 13, 15
subj_list = [2, 3, 4, 6, 7, 8, 10, 11, 14, 16, 18, 19, 22] # same && all IC exist
# subj_list = [2, 3]
downsamp_rate = 3
# downsamp_rate = 1
lags = np.arange(-5, 5)
# lags = np.arange(-29, 29)
# lags = np.arange(-11, 12)
n_lags = lags.shape[0]
apply_PCA = False  # Do we apply a PCA before inferring the HMM?
n_components = 0     # Number of principal components in case of PCA
# n_components = 40

### HMM parameters
model_type = 'GaussianHMM'  # Can't use anything else in this script (and would be useless)
covariance_type = 'full'
# covariance_type = 'diag'  # ONLY IN CASE OF PCA
n_iter = 100
tol = 0.01

### Output data parameters
n_states_list = [3, 6]    # Number of hidden Markov states. Must be a list.


## Script

Works only when the subjects have the same number of Independent Components (here 4IC)

In [None]:
print("Computing the imput matrix for the model")
subj_lengths = []
tde_imput = []
for subj in subj_list:
    # Loading all data for subject{subj}
    xeall = []
    IC_list = []
    if apply_PCA:
        pcaall = []
    for IC in range(1, n_IC+1):
        data, n_trials = load_oneIC(mat_file, cells_refs, subj, IC, comp=False)
        big_timecourse = np.concatenate([data[f'raw_timecourse_256Hz'][i][::downsamp_rate] for i in range(100)])
            # Downsampled to 256/3 = 85,33333Hz
        big_timecourse = scale(big_timecourse)
        x = big_timecourse.reshape(-1, 1)
        xe, valid = embedx(x, lags)
        xeall.append(xe[0, :, :])
        if apply_PCA:
            pca = PCA(n_components=n_components)
            pcaall.append(pca.fit_transform(xe[0, :, :]))
        IC_list.append(IC)
        print(f"IC{IC} loaded")
    if apply_PCA:
        y = np.concatenate(pcaall, axis=1)
    else:
        y = np.concatenate(xeall, axis=1)
    tde_imput.append(y)
    subj_lengths.append(xe.shape[1])
tde_imput = np.concatenate(tde_imput)

In [None]:
for n_states in n_states_list:
    start_time = tm.time()
    print(f"Computing and saving the general model for {n_states} states")
    model = hmm.GaussianHMM(n_components=n_states, n_iter=n_iter,
                            covariance_type=covariance_type, tol=tol)
    model.fit(tde_imput, subj_lengths)

    with open(data_dir + f"ALLSUBJECTS_lg{n_lags}co{n_components}st{n_states}"
        + "Multivariate" + model_type + ".pkl", "wb") as file: pickle.dump(model, file)
    print("%s seconds" % (tm.time() - start_time))