# Final: saving all HMM data, of which the states timecourses.

## The libraries and methods we need:

In [3]:
!pip install numpy

Collecting numpy
[33m  Cache entry deserialization failed, entry ignored[0m
  Downloading https://files.pythonhosted.org/packages/3a/5f/47e578b3ae79e2624e205445ab77a1848acdaa2929a00eeef6b16eaaeb20/numpy-1.16.6-cp27-cp27mu-manylinux1_x86_64.whl (17.0MB)
[K    100% |████████████████████████████████| 17.0MB 60kB/s  eta 0:00:01
[?25hInstalling collected packages: numpy
Successfully installed numpy-1.16.6


In [4]:
import numpy as np
from loader import load_oneIC

ModuleNotFoundError: No module named 'numpy'

In [2]:
from sklearn.decomposition import PCA
from hmmlearn import hmm
import scipy.signal as signal
import pickle

In [3]:
import matplotlib.pyplot as plt

In [4]:
import h5py

In [5]:
import xarray as xr

In [6]:
directory = "E:/timot/Documents/1 - Centrale Marseille/0.5 - Semestre S8/Stage/NIC_250819"
file = "FCK_LOCKED_IC_JYOTIKA_250819.mat"

path = f"{directory}/{file}"

mat_file = h5py.File(path, "r")
cells_refs = mat_file['FCK_LOCKED_IC_JYOTIKA']

n_IC = 4
n_subj = 23

## Functions:

### Cumputing functions:

In [20]:
# The embedx function copies the `x` array len(lags) times into `xe`
# with lags (i.e. time delays) between lags[0] and lags[-1] (we implement the time-delay array for the HMM).

def embedx(x, lags):
    
    Xe = np.zeros((x.shape[1], x.shape[0],  len(lags)))

    for l in range(len(lags)):
        Xe[:, :, l] = np.roll(x, lags[l], axis=0).swapaxes(0, 1)

    # Remove edges
    valid = np.ones((x.shape[0], 1), dtype=np.int8)
    valid[:np.abs(np.min(lags)), :] = 0
    valid[-np.abs(np.max(lags)):, :] = 0

    Xe = Xe[:, valid[:, 0] == 1, :]

    return Xe, valid


# The hmm_tde function finds parameters for the HMM,
# then uses them to determine the probability of presence of each found state over time.

def hmm_tde(y: np.array, lags, subj, IC, n_lags, n_states=3, n_iter=100, n_components=8, 
            covariance_type='full', model_type='GMMHMM', tol=0.01, n_mix=1, **kwargs):
    
    if model_type=='GMMHMM':
        model = hmm.GMMHMM(n_components=n_states, n_iter=n_iter,
                            covariance_type=covariance_type, tol=tol, n_mix=n_mix, **kwargs)
        
    elif model_type=='GaussianHMM':
        model = hmm.GaussianHMM(n_components=n_states, n_iter=n_iter,
                            covariance_type=covariance_type, tol=tol, **kwargs)
        
    elif model_type=='MultinomialHMM':
        model = hmm.MultinomialHMM(n_components=n_states, n_iter=n_iter, tol=tol, **kwargs)
    
    else: 
        return "Non-exixting model_type. Please choose 'GMMHMM' or 'GaussianHMM' or 'MultinomialHMM'. default='GMMHMM'"
        
    model.fit(y)
    gamma = model.predict_proba(y)

    return gamma, model

## The routine:

#### Model and parameters:
|_ 58delays x (7s x 256Hz x 10trials)   -----PCA-----> |_ 40components x (7s x 256Hz x 10trials) -----> |_ TDE-HMM
                                                                                                        (
                                                                                                        1 Gaussian/state,
                                                                                                        3 states
                                                                                                        )

In [8]:
# The parameters we change to hope for some results
lags = np.arange(-29, 29)
n_lags = lags.shape[0]
n_iter=100
n_states=3    # for the Hidden Markov Model
n_components=40     # For the principal component analysis
covariance_type='diag'
model_type='GMMHMM'
tol=0.01
n_mix=1

In [None]:
subj_list = [i for i in range(2,9)] + [i for i in range(10,13)] + [14] + [i for i in range(16, n_subj+1)] 
                                                                            # All except subjects 1, 9, 13, 15
complete_time = tm.time()
for subj in subj_list:
    print(f"---- SUBJECT{subj} ----")
    start_time = tm.time()
    datall = []
    IC_list = []
    for IC in range(1, n_IC+1):
        try:
            data, n_trials = load_oneIC(mat_file, cells_refs, subj, IC, comp=False)
            big_timecourse = np.concatenate([data[f'raw_timecourse_256Hz'][i] for i in range(n_trials)])
            x = big_timecourse.reshape(-1, 1)
            xe, valid = embedx(x, lags)
            pca = PCA(n_components=n_components)
            y = pca.fit_transform(xe[0, :, :])
            datall.append(y)
            IC_list.append(IC)
            print(f"IC{IC} loaded")
        except:
            pass
    y = np.concatenate(datall, axis=1)
    gamma, model = hmm_tde(y, lags, subj, IC, n_lags, n_iter=n_iter, n_states=n_states, n_components=n_components, 
                           covariance_type=covariance_type, model_type=model_type, tol=tol, n_mix=n_mix)
    print("Saving the model")
    with open(f"tde-hmm2/pkl_files/su{subj}All_lg{n_lags}co{n_components}st{n_states}"
        +f"{n_mix}"+model_type+"_model.pkl", "wb") as file: pickle.dump(model, file)
    print("Saving the states timecourse")
    for IC in IC_list:
        
    print("%s seconds" % (tm.time() - start_time))
    start_time = tm.time()

    # Saving the Power Spectral Density of each state
    print("Computing the PSD of each states and saving a plot")
    psd_all, freqs, max_power = show_bigstates(
        gamma, n_states, xe, # the data we need for the plot
        subj, IC, # which IC of which subject is of interest here
        n_components, n_lags, covariance_type, model_type, n_mix, # infos we put in the .png name if we want to save it
    )
    print("Saving the PSD of each state")
    psdxr = xr.DataArray(
        psd_all,
        dims=['states', 'power'],
        coords={
            "subject": subj,
            "IC": IC,
        },
    )
    # !!! ICI ajouter le freqs array dedans !!!
    ds = xr.Dataset({"states_psd": psdxr},)
    ds.to_netcdf(f"tde-hmm2/nc_files/su{subj}IC{IC}-states_psd.nc")

    # Saving a plot of probability timecourses of the states over the tfr 
    print("Saving a plot of the states timecourses over a plot of the tfr")
    time = data['time_axis']
    bigtime = np.concatenate([time+4+(7*i) for i in range(3)])
    bigtfr = np.concatenate([data[f'tfr_256Hz trial{i+1}'] for i in range(3)], axis=2)
    plot_hmm_over_bigtfr(   
        bigtime, bigtfr, gamma, lags, n_states, max_power, # the data we need for the plot
        subj, IC, # which IC of which subject is of interest here, how many trials
        n_components, n_lags, covariance_type, model_type, n_mix, # infos we put in the .png name if we want to save it
    )
    print("%s seconds" % (tm.time() - start_time))
    print(f"subj{subj}, IC{IC}: OK")

sec_time = tm.time() - complete_time
hr_time = int(sec_time/3600)
sec_time = sec_time - (hr_time*3600)
mn_time = int(sec_time/60)
sec_time = sec_time - (mn_time*60)
print(f"{hr_time}h{mn_time}mn{sec_time}s")

In [None]:
ds = xr.Dataset(
    {
        "states_timecourse": (("IC","trials","time", "states"), np.concatenate((gamma[]))),
        "states_psd": (("IC", "states", "freq"), np.concatenate((psds))),
    },
    {
        "IC":IC_list,
        "time":data["time_axis"],
        "states":np.arange(n_states),
        "freq": freqs,
    }
)
ds.to_netcdf(f"tde-hmm2/nc_files/su{subj}-{n_states}states_data.nc")

In [24]:
subj = 2

# Loading all data for subject{subj}, IC{IC}
datall = []
for IC in range(1, n_IC+1):
    data, n_trials = load_oneIC(mat_file, cells_refs, subj, IC, comp=False)
    big_timecourse = np.concatenate([data[f'raw_timecourse_256Hz'][i] for i in range(n_trials)])
    x = big_timecourse.reshape(-1, 1)
    xe, valid = embedx(x, lags)
    pca = PCA(n_components=n_components)
    y = pca.fit_transform(xe[0, :, :])
    datall.append(y)
y = np.concatenate(datall, axis=1)
gamma, model = hmm_tde(y, lags, subj, IC, n_lags, n_iter=n_iter, n_states=n_states, n_components=n_components, 
                           covariance_type=covariance_type, model_type=model_type, tol=tol, n_mix=n_mix)

Loading the raw timecourse
Loading the raw timecourse
Loading the raw timecourse
Loading the raw timecourse


In [27]:
y = np.concatenate(datall, axis=1)
gamma, model = hmm_tde(y, lags, subj, IC, n_lags, n_iter=n_iter, n_states=n_states, n_components=n_components, 
                           covariance_type=covariance_type, model_type=model_type, tol=tol, n_mix=n_mix)


(1210218, 160)

In [9]:
subj = 2

# Loading all data for subject{subj}, IC{IC}
datall = []
for IC in range(1, n_IC+1):
    data, n_trials = load_oneIC(mat_file, cells_refs, subj, IC, comp=False)
    big_timecourse = np.concatenate([data[f'raw_timecourse_256Hz'][i] for i in range(n_trials)])
    x = big_timecourse.reshape(-1, 1)
    xe, valid = embedx(x, lags)
    datall.append(xe)


Loading the raw timecourse
Loading the raw timecourse
Loading the raw timecourse
Loading the raw timecourse


In [10]:
xe = np.concatenate(datall, axis=2)
xe.shape

(1, 1210218, 232)

In [12]:
pca = PCA(n_components=n_components)
y = pca.fit_transform(xe[0, :, :])


In [13]:
y.shape

(1210218, 40)

In [14]:
gamma, model = hmm_tde(y, lags, subj, IC, n_lags, n_iter=n_iter, n_states=n_states, n_components=n_components, 
                           covariance_type=covariance_type, model_type=model_type, tol=tol, n_mix=n_mix)

In [26]:
def show_bigstates(
    gamma, n_states, xe, # the data we need for the plot
    
    subj, IC, # which IC of which subject is of interest here
    
    n_components, n_lags, covariance_type, model_type, n_mix, # infos we put in the .png name if we want to save it
):

    fig = plt.figure(figsize=(6, 6))
    plt.title('Power Spectrum Density')
    max_power = np.zeros(n_states)
    psd_all = np.zeros((n_states, 196))
    for i in range(n_states):

        # Compute PSD separately for each lag
        tot = []
        for seg in xe[0, gamma[:, i]> .6, :].T:
            freqs, psd = signal.welch(x=seg, fs=256, nfft=1000)
            tot.append(psd)
        psd = np.mean(np.asarray(tot), 0)
        
        max_power[i] = np.amax(psd[80:])
        
        psd_all[i] = psd[:196]
        plt.plot(freqs[:196], psd_all[i])
    
    plt.ylabel('PSD')
    plt.xlabel('Frequencies (Hz)')
    plt.legend([f'state {i+1}' for i in range(n_states)], loc='upper right')
    plt.tight_layout()   

    plt.savefig(f'tde-hmm2/png_files/test{ts}su{subj}IC{IC}All_lg{n_lags}co{n_components}st{n_states}'
                    +f'{n_mix}'+model_type+'_states-info.png', dpi=600)
    
    plt.close(fig)
    
    return psd_all, freqs[:196], max_power