In [1]:
import numpy as np
from loader import load_oneIC
import h5py
import xarray as xr

In [2]:
from sklearn.decomposition import PCA
from sklearn.preprocessing import scale
from hmmlearn import hmm
import scipy.signal as signal
import pickle

In [3]:
import matplotlib.pyplot as plt
from wavelet_transform import wavelet_transform2

In [4]:
import mne

In [5]:
import time as tm

In [6]:
directory = "/home/INT/malfait.n/Documents/NIC_250819"
file = "FCK_LOCKED_IC_JYOTIKA_250819.mat"

path = f"{directory}/{file}"

mat_file = h5py.File(path, "r")
cells_refs = mat_file['FCK_LOCKED_IC_JYOTIKA']

n_IC = 4
n_subj = 23

In [7]:
figures_dir = "tde-hmm2/Univariate/figures/"
data_dir = "tde-hmm2/Univariate/data/"

In [8]:
# The embedx function copies the `x` array len(lags) times into `xe`
# with lags (i.e. time delays) between lags[0] and lags[-1] (we implement the time-delay array for the HMM).

def embedx(x, lags):
    
    Xe = np.zeros((x.shape[1], x.shape[0],  len(lags)))

    for l in range(len(lags)):
        Xe[:, :, l] = np.roll(x, lags[l], axis=0).swapaxes(0, 1)

    # Remove edges
    valid = np.ones((x.shape[0], 1), dtype=np.int8)
    valid[:np.abs(np.min(lags)), :] = 0
    valid[-np.abs(np.max(lags)):, :] = 0

    Xe = Xe[:, valid[:, 0] == 1, :]

    return Xe, valid

In [9]:
def statesPSD(gamma, n_states, xe, fs=256/3):

    psd_all = []
    for i in range(n_states):

        # Compute PSD separately for each lag
        tot = []
        for seg in xe[gamma[:, i]>(2/3), :].T:
            freqs, psd = signal.welch(x=seg, fs=fs, nfft=1000)
            tot.append(psd)
        psd = np.mean(np.asarray(tot), 0)
        psd_all.append(psd)
    
    psd_all = np.asarray(psd_all)
    
    return freqs, psd_all

## Parameters

In [10]:
### Imput data parameters
# subj_list = [2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 14, 16, 17, 18, 19, 20, 21, 22, 23] # All except 1, 9, 13, 15
subj_list = [2, 3, 4, 6, 7, 8, 10, 11, 14, 16, 18, 19, 22] # same && all IC exist
# subj_list = [2, 3]
IC_list = [1, 2, 3, 4]
downsamp_rate = 3
# downsamp_rate = 1
lags = np.arange(-5, 5)
# lags = np.arange(-29, 29)
# lags = np.arange(-11, 12)
n_lags = lags.shape[0]
apply_PCA = False  # Do we apply a PCA before inferring the HMM?
n_components = 0     # Number of principal components in case of PCA
# n_components = 40

### HMM parameters
model_type = 'GaussianHMM'
covariance_type = 'full'
# covariance_type = 'diag'  # ONLY IN CASE OF PCA
n_iter = 100
tol = 0.01

### Output data parameters
n_states_list = [3, 6]    # Number of hidden Markov states. Must be a list.

## Script

In [None]:
complete_time = tm.time()
for n_states in n_states_list:
    for subj in subj_list:
        tcourses = []
        psds = []
        IC_new_list = []
        for IC in IC_list: 
            start_time = tm.time()
            try:
                print(f"---- SUBJECT{subj}, IC{IC} ----")
                # Loading all data for subject{subj}, IC{IC}
                data, n_trials = load_oneIC(mat_file, cells_refs, subj, IC)
                t_len = data["time_axis"][::downsamp_rate].shape[0]            
                # Finding and saving the model
                big_timecourse = np.concatenate([data[f'raw_timecourse_256Hz'][i][::downsamp_rate] for i in range(n_trials)])
                x = big_timecourse.reshape(-1, 1)
                print("Computing and saving the model")
                xe, valid = embedx(data, lags)
                if apply_PCA:
                    pca = PCA(n_components=n_components)
                    y = pca.fit_transform(xe[0, :, :])
                else:
                    y = xe[0, :, :]
                model = hmm.GaussianHMM(n_components=n_states, n_iter=n_iter,
                                covariance_type=covariance_type, tol=tol)
                model.fit(y)
                with open(data_dir + f"su{subj}IC{IC}_lg{n_lags}co{n_components}st{n_states}"
                    +"Univariate"+model_type+".pkl", "wb") as file: pickle.dump(model, file)

                # Computing and saving the states' Probability time courses and PSDs
                print("Computing the probability time course and the PSD of each states")
                gamma = model.predict_proba(y)
                tcourse = np.concatenate(
                    (np.zeros((abs(lags[0]),n_states)), gamma, np.zeros((lags[-1],n_states)))
                )
                tcourse_trials = np.zeros((n_trials, t_len, n_states))
                for tr in range(n_trials):
                    tcourse_trials[tr] = tcourse[tr*t_len:(tr+1)*t_len]
                tcourses.append(tcourse_trials)
                freqs, psd = statesPSD(gamma, n_states, xe, fs=256/downsamp_rate)
                psd = psd[np.newaxis,]
                psds.append(psd)
                IC_new_list.append(IC)
                print("%s seconds" % (tm.time() - start_time))
                print(f"subj{subj}, IC{IC}: OK")
            except:
                print(f"subj{subj}, IC{IC}: NOT POSSIBLE")
                    print("Saving the states timecourses and PSDs")
        print(f"Saving the probability time course and the PSD of each states for subj{subj}")
        ds = xr.Dataset(
            {
                "states_timecourse": (("IC", "trials","time", "states"), np.concatenate((psds))),
                "states_psd": (("IC", "states", "freq"), np.concatenate((psds))),
            },
            {
                "IC":IC_new_list,
                "time":data["time_axis"][::downsamp_rate],
                "states":np.arange(1, n_states+1),
                "freq": freqs,
            }
        )
        frac_occ = 0 ######### FINIR
        ds = ds.assign(frac_occ = (ds["states_timecourse"].sum("trials")/ds.sizes["trials"]))
        ds.to_netcdf(f"tde-hmm2/test_files/su{subj}-{n_states}states_woPCA_preModel_data_Multi.nc")


sec_time = tm.time() - complete_time
hr_time = int(sec_time/3600)
sec_time = sec_time - (hr_time*3600)
mn_time = int(sec_time/60)
sec_time = sec_time - (mn_time*60)
print(f"{hr_time}h{mn_time}mn{sec_time}s")

## Tests 