# Final: saving all HMM data, of which the states timecourses.

## The libraries and methods we need:

In [1]:
import numpy as np
from loader import load_oneIC

In [2]:
from sklearn.decomposition import PCA
from hmmlearn import hmm
import scipy.signal as signal
import pickle

In [3]:
import matplotlib.pyplot as plt

In [4]:
import h5py

In [5]:
directory = "E:/timot/Documents/1 - Centrale Marseille/0.5 - Semestre S8/Stage/NIC_250819"
file = "FCK_LOCKED_IC_JYOTIKA_250819.mat"

path = f"{directory}/{file}"

mat_file = h5py.File(path, "r")
cells_refs = mat_file['FCK_LOCKED_IC_JYOTIKA']

n_IC = 4
n_subj = 23

## Functions:

### Cumputing functions:

In [6]:
# The embedx function copies the `x` array len(lags) times into `xe`
# with lags (i.e. time delays) between lags[0] and lags[-1] (we implement the time-delay array for the HMM).

def embedx(x, lags):
    
    Xe = np.zeros((x.shape[1], x.shape[0],  len(lags)))

    for l in range(len(lags)):
        Xe[:, :, l] = np.roll(x, lags[l], axis=0).swapaxes(0, 1)

    # Remove edges
    valid = np.ones((x.shape[0], 1), dtype=np.int8)
    valid[:np.abs(np.min(lags)), :] = 0
    valid[-np.abs(np.max(lags)):, :] = 0

    Xe = Xe[:, valid[:, 0] == 1, :]

    return Xe, valid


# The hmm_tde function finds parameters for the HMM,
# then uses them to determine the probability of presence of each found state over time.

def hmm_tde(data: np.array, lags, subj, IC, n_lags, n_states=3, n_iter=100, n_components=8, 
            covariance_type='full', model_type='GMMHMM', tol=0.01, n_mix=1, **kwargs):
    
    # Embed time serie
    xe, valid = embedx(data, lags)

    pca = PCA(n_components=n_components)
    y = pca.fit_transform(xe[0, :, :])
    
    if model_type=='GMMHMM':
        model = hmm.GMMHMM(n_components=n_states, n_iter=n_iter,
                            covariance_type=covariance_type, tol=tol, n_mix=n_mix, **kwargs)
        
    elif model_type=='GaussianHMM':
        model = hmm.GaussianHMM(n_components=n_states, n_iter=n_iter,
                            covariance_type=covariance_type, tol=tol, **kwargs)
        
    elif model_type=='MultinomialHMM':
        model = hmm.MultinomialHMM(n_components=n_states, n_iter=n_iter, tol=tol, **kwargs)
    
    else: 
        return "Non-exixting model_type. Please choose 'GMMHMM' or 'GaussianHMM' or 'MultinomialHMM'. default='GMMHMM'"
        
    model.fit(y)
    gamma = model.predict_proba(y)
    
    with open(f"tde-hmm2/su{subj}IC{IC}All_lg{n_lags}co{n_components}st{n_states}"
                +f"{n_mix}"+model_type+"_model.pkl", "wb") as file: pickle.dump(model, file)
        
#     with open(f"tde-hmm2/su{subj}IC{IC}All_lg{n_lags}co{n_components}st{n_states}"
#                 +f"{n_mix}"+model_type+"_gamma.pkl", "wb") as file: pickle.dump(gamma, file)

    return gamma, model

### Plotting and saving functions:

In [7]:
def show_bigstates(
    gamma, n_states, # the data we need for the plot
    
    subj, IC, # which IC of which subject is of interest here
    
    n_components, n_lags, covariance_type, model_type, n_mix, # infos we put in the .png name if we want to save it
    
    save=False # do we really want to save the figure?
):

    fig = plt.figure(figsize=(6, 6))
    plt.title('State Power Spectrum')
    max_power = np.zeros(n_states)
    max_freq = np.zeros(n_states)
    for i in range(n_states):

        # Compute PSD separately for each lag
        tot = []
        for seg in xe[0, gamma[:, i]> .6, :].T:
            freqs, psd = signal.welch(x=seg, fs=256, nfft=1000)
            tot.append(psd)
        psd = np.mean(np.asarray(tot), 0)
        
        max_freq[i] = freqs[np.argmax(psd[50:])]
        max_power[i] = np.amax(psd[50:])

        plt.plot(freqs, psd)
        
    '''
    Ici il va falloir enregistrer freqs[:50Hz], psd[:50Hz] for i in range n_states dans un fichier json ou pickle
    Enregistrer ces petits plots c'est quand même utile pour mieux visualiser la cohérence de nos données
    '''
    
    plt.xlim(0, 50)
    plt.ylabel('PSD')
    plt.xlabel('Frequencies (Hz)')
    plt.legend([f'state {i+1}' for i in range(n_states)], loc='upper right')
    plt.tight_layout()   

    if save==True:
        plt.savefig(f'tde-hmm2/su{subj}IC{IC}All_lg{n_lags}co{n_components}st{n_states}'
                    +f'{n_mix}'+model_type+'_states-info.png', dpi=600)
    
    plt.close(fig)
    
    return max_freq, max_power
    

In [8]:
def plot_hmm_over_bigtfr(   
    bigtime, bigtfr, gamma, lags, n_states, max_power, # the data we need for the plot
    
    subj, IC, # which IC of which subject is of interest here
    
    n_components, n_lags, covariance_type, model_type, n_mix, # infos we put in the .png name if we want to save it
    
    save=False # do we really want to save the figure?
):
    
    fig = plt.figure(figsize=(16*3, 5))

    # HMM states probability plot
    plt.subplot(211)
    plt.title('HMM States probability')
    time = bigtime[np.abs(np.min(lags)):-np.abs(np.max(lags))]

    burst = np.argmax(max_power) # this is the burst state index
    
    labels = ['']
    states = np.where([i not in [burst] for i in range(n_states)])[0] # the other indexes
    for i in states:
        plt.fill_between(x=time[:1793*3-48], y1=gamma[:1793*3-48, i], alpha=0.2)
        plt.xlim(0, 7*3)
        labels.append(f'state {i+1}')
    plt.plot(time[:1793*3-48], gamma[:1793*3-48, burst]>0.6, 'red')
    labels[0]=(f'burst state (state {burst+1})')
    plt.fill_between(x=time[:1793*3-48], y1=gamma[:1793*3-48, burst], alpha=0.2, color='red')
    plt.ylabel('State probability')
    plt.legend(labels, loc='upper left')
    
    '''
    Ici il va falloir enregistrer `gamma>0.6` dans un fichier json ou pickle
    Pourquoi ne pas écrire quelques lignes directement en dehors de la fonction ?
    Enregistrer ces petits plots c'est quand même utile pour mieux visualiser la cohérence de nos données
    '''

    # Time-frequency plot
    plt.subplot(212)

    plt.title('Wavelet transform')
    plt.imshow(bigtfr[0, :, :1793*3],
               aspect='auto', origin='lower', extent=[0, 7*3, 2, 50], cmap='RdBu_r')
    plt.xlabel('Time (s)')
    plt.ylabel('Frequencies (Hz)')
    plt.tight_layout()
    
    if save == True:
        plt.savefig(f'tde-hmm2/su{subj}IC{IC}All_lg{n_lags}co{n_components}st{n_states}'
                    +f'{n_mix}'+model_type+'_hmm-tfr.png', dpi=600)
        
    plt.close(fig)

## The routine:

#### Model and parameters:
|_ 58delays x (7s x 256Hz x 10trials)   -----PCA-----> |_ 40components x (7s x 256Hz x 10trials) -----> |_ TDE-HMM
                                                                                                        (
                                                                                                        1 Gaussian/state,
                                                                                                        3 states
                                                                                                        )

In [9]:
# The parameters we change to hope for some results
trial = 0 # We load all trials
lags = np.arange(-29, 29)
n_lags = lags.shape[0]
n_iter=100
n_states=3    # for the Hidden Markov Model
n_components=40     # For the principal component analysis
covariance_type='diag'
model_type='GMMHMM'
tol=0.01
n_mix=1
save=True

In [13]:
import xarray as xr

In [15]:
# Loading all data for subject{subj}, IC{IC}
data, n_trials = load_oneIC(mat_file, cells_refs, 2, 1)
# time = data['time_axis']
# bigtime = np.concatenate([time+4+(7*i) for i in range(3)])

# Finding and saving the model
big_timecourse = np.concatenate([data[f'raw_timecourse_256Hz'][i] for i in range(10)])
x = big_timecourse.reshape(-1, 1)
gamma, model = hmm_tde(x, lags, subj, IC, n_lags, n_iter=n_iter, n_states=n_states, n_components=n_components, 
                           covariance_type=covariance_type, model_type=model_type, tol=tol, n_mix=n_mix)

'''
Dans l'idéal je coupe gamma en tranches de 7 secondes avant de sauvegarder.
'''

# # Saving the Power Spectral Density of each state
# max_freq, max_power = show_bigstates(
#     gamma, n_states, # the data we need for the plot
#     subj, IC # which IC of which subject is of interest here
#     n_components, n_lags, covariance_type, model_type, n_mix, # infos we put in the .png name if we want to save it
#     save # do we really want to save the figure?
# )

# # Saving probability timecourses of the states over the tfr 
# bigtfr = np.concatenate([data[f'tfr_256Hz trial{i+1}'] for i in range(3)], axis=2)
# plot_hmm_over_bigtfr(   
#     bigtime, bigtfr, gamma, lags, n_states, max_power, # the data we need for the plot
#     subj, IC, # which IC of which subject is of interest here, how many trials
#     n_components, n_lags, covariance_type, model_type, n_mix, # infos we put in the .png name if we want to save it
#     save # do we really want to save the figure?
# )
print(f"subj{subj}, IC{IC}: OK")

loading the raw timecourse
subj2, IC1: OK


In [17]:
gammaxr = xr.DataArray(
    gamma,
    dims=['time', 'states'],
    coords={
        "subject": 2,
        "IC": 1,
    },
)

In [19]:
gammaxr

In [18]:
ds = xr.Dataset(
    {"states_timecourse_256Hz": gammaxr},
)

In [20]:
ds

In [21]:
ds.to_netcdf("tde-hmm2/saved_on_disk.nc")

In [17]:
try:
    try:
        blabla
    except:
        print("aïe")
        raise
    print("Coucou")
except:
    print("Bravo !")

aïe
Bravo !


In [10]:
subj_list = [i for i in range(2,9)] + [i for i in range(10,13)] + [14] + [i for i in range(16, n_subj+1)] 
                                                                            # All except subjects 1, 9, 13, 15
for subj in subj_list:  
    for IC in range(1, n_IC+1): 
        try:
            print(f"---- SUBJECT{subj}, IC{IC} ----")
            # Loading all data for subject{subj}, IC{IC}
            data, n_trials = load_oneIC(mat_file, cells_refs, subj, IC)
            
            # !!! Il faut retourner une erreur si l'info n'est pas dans la file pour aller direct à l'Except !!!
            
            # Finding and saving the model
            big_timecourse = np.concatenate([data[f'raw_timecourse_256Hz'][i] for i in range(n_trials)])
            x = big_timecourse.reshape(-1, 1)
            print("Computing the model")
            gamma, model = hmm_tde(x, lags, subj, IC, n_lags, n_iter=n_iter, n_states=n_states, n_components=n_components, 
                                       covariance_type=covariance_type, model_type=model_type, tol=tol, n_mix=n_mix)
            print("Saving the states timecourse")
            gammaxr = xr.DataArray(
                gamma,
                dims=['time', 'states'],
                coords={
                    "subject": subj,
                    "IC": IC,
                },
            )
            ds = xr.Dataset( {"states_timecourse_256Hz": gammaxr}, )
            ds.to_netcdf(f"tde-hmm2/su{subj}IC{IC}-states_timecourse_256Hz.nc")
            
            # Saving the Power Spectral Density of each state
            max_freq, max_power = show_bigstates(
                gamma, n_states, # the data we need for the plot
                subj, IC # which IC of which subject is of interest here
                n_components, n_lags, covariance_type, model_type, n_mix, # infos we put in the .png name if we want to save it
                save # do we really want to save the figure?
            )

            # Saving probability timecourses of the states over the tfr 
            time = data['time_axis']
            bigtime = np.concatenate([time+4+(7*i) for i in range(3)])
            bigtfr = np.concatenate([data[f'tfr_256Hz trial{i+1}'] for i in range(3)], axis=2)
            plot_hmm_over_bigtfr(   
                bigtime, bigtfr, gamma, lags, n_states, max_power, # the data we need for the plot
                subj, IC, # which IC of which subject is of interest here, how many trials
                n_components, n_lags, covariance_type, model_type, n_mix, # infos we put in the .png name if we want to save it
                save # do we really want to save the figure?
            )
            print(f"subj{subj}, IC{IC}: OK")
        except:
            print(f"subj{subj}, IC{IC}: NOT POSSIBLE")

loading the raw timecourse
subj2, IC1: OK


In [11]:
# for subj in range(2, n_subj+1):
#     for IC in range(1, n_IC+1): 
#         try:
#             # Loading all data for subject{subj}, IC{IC}
#             data, n_trials = load_oneIC(mat_file, cells_refs, subj, IC)
# #             time = data['time_axis']
# #             bigtime = np.concatenate([time+4+(7*i) for i in range(3)])
            
#             # Finding and saving the model
#             big_timecourse = np.concatenate([data[f'raw_timecourse_256Hz'][i] for i in range(n_trials)])
#             x = big_timecourse.reshape(-1, 1)
#             gamma, model = hmm_tde(x, lags, subj, IC, n_lags, n_iter=n_iter, n_states=n_states, n_components=n_components, 
#                                        covariance_type=covariance_type, model_type=model_type, tol=tol, n_mix=n_mix)
#             '''
#             Dans l'idéal je coupe gamma en tranches de 7 secondes avant de sauvegarder.
#             '''
            
# #             # Saving the Power Spectral Density of each state
# #             max_freq, max_power = show_bigstates(
# #                 gamma, n_states, # the data we need for the plot
# #                 subj, IC # which IC of which subject is of interest here
# #                 n_components, n_lags, covariance_type, model_type, n_mix, # infos we put in the .png name if we want to save it
# #                 save # do we really want to save the figure?
# #             )

# #             # Saving probability timecourses of the states over the tfr 
# #             bigtfr = np.concatenate([data[f'tfr_256Hz trial{i+1}'] for i in range(3)], axis=2)
# #             plot_hmm_over_bigtfr(   
# #                 bigtime, bigtfr, gamma, lags, n_states, max_power, # the data we need for the plot
# #                 subj, IC, # which IC of which subject is of interest here, how many trials
# #                 n_components, n_lags, covariance_type, model_type, n_mix, # infos we put in the .png name if we want to save it
# #                 save # do we really want to save the figure?
# #             )
#             print(f"subj{subj}, IC{IC}: OK")
#         except:
#             print(f"subj{subj}, IC{IC}: NOT POSSIBLE")