# Final: saving all HMM data, of which the states timecourses.

## The libraries and methods we need:

In [1]:
import numpy as np
from loader import load_oneIC

In [2]:
from sklearn.decomposition import PCA
from hmmlearn import hmm
import scipy.signal as signal
import pickle

In [3]:
import matplotlib.pyplot as plt

In [4]:
import h5py

In [5]:
import xarray as xr

In [6]:
directory = "E:/timot/Documents/1 - Centrale Marseille/0.5 - Semestre S8/Stage/NIC_250819"
file = "FCK_LOCKED_IC_JYOTIKA_250819.mat"

path = f"{directory}/{file}"

mat_file = h5py.File(path, "r")
cells_refs = mat_file['FCK_LOCKED_IC_JYOTIKA']

n_IC = 4
n_subj = 23

## Functions:

### Cumputing functions:

In [7]:
# The embedx function copies the `x` array len(lags) times into `xe`
# with lags (i.e. time delays) between lags[0] and lags[-1] (we implement the time-delay array for the HMM).

def embedx(x, lags):
    
    Xe = np.zeros((x.shape[1], x.shape[0],  len(lags)))

    for l in range(len(lags)):
        Xe[:, :, l] = np.roll(x, lags[l], axis=0).swapaxes(0, 1)

    # Remove edges
    valid = np.ones((x.shape[0], 1), dtype=np.int8)
    valid[:np.abs(np.min(lags)), :] = 0
    valid[-np.abs(np.max(lags)):, :] = 0

    Xe = Xe[:, valid[:, 0] == 1, :]

    return Xe, valid


# The hmm_tde function finds parameters for the HMM,
# then uses them to determine the probability of presence of each found state over time.

def hmm_tde(data: np.array, lags, subj, IC, n_lags, n_states=3, n_iter=100, n_components=8, 
            covariance_type='full', model_type='GMMHMM', tol=0.01, n_mix=1, **kwargs):
    
    # Embed time serie
    xe, valid = embedx(data, lags)

    pca = PCA(n_components=n_components)
    y = pca.fit_transform(xe[0, :, :])
    
    if model_type=='GMMHMM':
        model = hmm.GMMHMM(n_components=n_states, n_iter=n_iter,
                            covariance_type=covariance_type, tol=tol, n_mix=n_mix, **kwargs)
        
    elif model_type=='GaussianHMM':
        model = hmm.GaussianHMM(n_components=n_states, n_iter=n_iter,
                            covariance_type=covariance_type, tol=tol, **kwargs)
        
    elif model_type=='MultinomialHMM':
        model = hmm.MultinomialHMM(n_components=n_states, n_iter=n_iter, tol=tol, **kwargs)
    
    else: 
        return "Non-exixting model_type. Please choose 'GMMHMM' or 'GaussianHMM' or 'MultinomialHMM'. default='GMMHMM'"
        
    model.fit(y)
    gamma = model.predict_proba(y)

    return gamma, model, xe

### Plotting and saving functions:

In [8]:
def show_bigstates(
    gamma, n_states, xe, # the data we need for the plot
    
    subj, IC, # which IC of which subject is of interest here
    
    n_components, n_lags, covariance_type, model_type, n_mix, temp, # infos we put in the .png name if we want to save it
):

    fig = plt.figure(figsize=(6, 6))
    plt.title('State Power Spectrum')
    max_power = np.zeros(n_states)
    psd_all = np.zeros((n_states, 196))
    for i in range(n_states):

        # Compute PSD separately for each lag
        tot = []
        for seg in xe[0, gamma[:, i]> .6, :].T:
            freqs, psd = signal.welch(x=seg, fs=256, nfft=1000)
            tot.append(psd)
        psd = np.mean(np.asarray(tot), 0)
        
        max_power[i] = np.amax(psd[80:])
        
        psd_all[i] = psd[:196]
        plt.plot(freqs[:196], psd_all[i])
    
    plt.ylabel('PSD')
    plt.xlabel('Frequencies (Hz)')
    plt.legend([f'state {i+1}' for i in range(n_states)], loc='upper right')
    plt.tight_layout()   

    plt.savefig(f'tde-hmm2/png_files/su{subj}IC{IC}All'+temp+f'_lg{n_lags}co{n_components}st{n_states}'
                    +f'{n_mix}'+model_type+'_states-info.png', dpi=600)
    
    plt.close(fig)
    
    return psd_all, freqs[:196], max_power
    

In [9]:
def plot_hmm_over_bigtfr(   
    bigtime, bigtfr, gamma, lags, n_states, max_power, # the data we need for the plot
    
    subj, IC, # which IC of which subject is of interest here
    
    n_components, n_lags, covariance_type, model_type, n_mix, # infos we put in the .png name if we want to save it
):
    
    fig = plt.figure(figsize=(16*3, 5))

    # HMM states probability plot
    plt.subplot(211)
    plt.title('HMM States probability')
    time = bigtime[np.abs(np.min(lags)):-np.abs(np.max(lags))]

    burst = np.argmax(max_power) # this is the burst state index
    
    labels = ['']
    states = np.where([i not in [burst] for i in range(n_states)])[0] # the other indexes
    for i in states:
        plt.fill_between(x=time[:1793*3-60], y1=gamma[:1793*3-60, i], alpha=0.2)
        plt.xlim(0, 7*3)
        labels.append(f'state {i+1}')
    plt.plot(time[:1793*3-60], gamma[:1793*3-60, burst]>0.6, 'red')
    labels[0]=(f'burst state (state {burst+1})')
    plt.fill_between(x=time[:1793*3-60], y1=gamma[:1793*3-60, burst], alpha=0.2, color='red')
    plt.ylabel('State probability')
    plt.legend(labels, loc='upper left')

    # Time-frequency plot
    plt.subplot(212)

    plt.title('Wavelet transform')
    plt.imshow(bigtfr[0, :, :1793*3],
               aspect='auto', origin='lower', extent=[0, 7*3, 2, 50], cmap='RdBu_r')
    plt.xlabel('Time (s)')
    plt.ylabel('Frequencies (Hz)')
    plt.tight_layout()
    
    plt.savefig(f'tde-hmm2/png_files/su{subj}IC{IC}All_lg{n_lags}co{n_components}st{n_states}'
                    +f'{n_mix}'+model_type+'_hmm-tfr.png', dpi=600)
        
    plt.close(fig)

## The routine:

#### Model and parameters:
|_ 58delays x (7s x 256Hz x 10trials)   -----PCA-----> |_ 40components x (7s x 256Hz x 10trials) -----> |_ TDE-HMM
                                                                                                        (
                                                                                                        1 Gaussian/state,
                                                                                                        3 states
                                                                                                        )

In [10]:
# The parameters we change to hope for some results
lags = np.arange(-29, 29)
n_lags = lags.shape[0]
n_iter=100
n_states=3    # for the Hidden Markov Model
n_components=40     # For the principal component analysis
covariance_type='diag'
model_type='GMMHMM'
tol=0.01
n_mix=1

In [11]:
import time as tm

In [12]:
subj_list = [i for i in range(2,9)] + [i for i in range(10,13)] + [14] + [i for i in range(16, n_subj+1)] 
                                                                            # All except subjects 1, 9, 13, 15
complete_time = tm.time()
for subj in subj_list:  
    for IC in range(1, n_IC+1): 
# for subj in [14]:
#     for IC in [1]:
        start_time = tm.time()
        try:
            print(f"---- SUBJECT{subj}, IC{IC} ----")
            # Loading all data for subject{subj}, IC{IC}
            data, n_trials = load_oneIC(mat_file, cells_refs, subj, IC)
            
            time = data['time_axis']
            ind = 1024
#             while time[ind]<0:
#                 ind+=1
            
            data_before = np.concatenate([data[f'raw_timecourse_256Hz'][i, :ind] for i in range(n_trials)])
            data_after = np.concatenate([data[f'raw_timecourse_256Hz'][i, ind:] for i in range(n_trials)])
                        
            # Finding and saving the models
            x_before = data_before.reshape(-1, 1)
            x_after = data_after.reshape(-1, 1)            
            print("Computing the two models")
            gamma_before, model_before, xe_before = hmm_tde(x_before, lags, subj, IC, n_lags, n_iter=n_iter, n_states=n_states, n_components=n_components, 
                                       covariance_type=covariance_type, model_type=model_type, tol=tol, n_mix=n_mix)
            gamma_after, model_after, xe_after = hmm_tde(x_after, lags, subj, IC, n_lags, n_iter=n_iter, n_states=n_states, n_components=n_components, 
                                       covariance_type=covariance_type, model_type=model_type, tol=tol, n_mix=n_mix)
            print("Saving the model")
            with open(f"tde-hmm2/pkl_files/su{subj}IC{IC}AllBefore_lg{n_lags}co{n_components}st{n_states}"
                +f"{n_mix}"+model_type+"_model.pkl", "wb") as file: pickle.dump(model_before, file)
            with open(f"tde-hmm2/pkl_files/su{subj}IC{IC}AllAfter_lg{n_lags}co{n_components}st{n_states}"
                +f"{n_mix}"+model_type+"_model.pkl", "wb") as file: pickle.dump(model_after, file)
            print("Saving the states timecourse")
#             gamma = np.concatenate([gamma_before[np.newaxis, :, :], gamma_after[np.newaxis, :, :]], axis = 0)
            gammaxr_before = xr.DataArray(
                gamma_before,
                dims=['time', 'states'],
                coords={
                    "subject": subj,
                    "IC": IC,
                    "temp": "before"
                },
            )
            gammaxr_after = xr.DataArray(
                gamma_after,
                dims=['time', 'states'],
                coords={
                    "subject": subj,
                    "IC": IC,
                    "temp": "after"
                },
            )
            ds_before = xr.Dataset({"states_timecourse_256Hz": gammaxr_before},)
            ds_before.to_netcdf(f"tde-hmm2/nc_files/su{subj}IC{IC}Before-states_timecourse_256Hz.nc")
            ds_after = xr.Dataset({"states_timecourse_256Hz": gammaxr_after},)
            ds_after.to_netcdf(f"tde-hmm2/nc_files/su{subj}IC{IC}After-states_timecourse_256Hz.nc")
            print("%s seconds" % (tm.time() - start_time))
            start_time = tm.time()
            
            # Saving the Power Spectral Density of each state
            print("Computing the PSD of each states and saving a plot")
            psd_all_before, freqs, max_power = show_bigstates(
                gamma_before, n_states, xe_before, # the data we need for the plot
                subj, IC, # which IC of which subject is of interest here
                n_components, n_lags, covariance_type, model_type, n_mix, "Before", # infos we put in the .png name if we want to save it
            )
            psd_all_after, freqs, max_power = show_bigstates(
                gamma_after, n_states, xe_after, # the data we need for the plot
                subj, IC, # which IC of which subject is of interest here
                n_components, n_lags, covariance_type, model_type, n_mix, "After", # infos we put in the .png name if we want to save it
            )
            print("Saving the PSD of each state")
            psdxr_before = xr.DataArray(
                psd_all_before,
                dims=['states', 'power'],
                coords={
                    "subject": subj,
                    "IC": IC,
                    "temp": "before"
                },
            )
            psdxr_after = xr.DataArray(
                psd_all_after,
                dims=['states', 'power'],
                coords={
                    "subject": subj,
                    "IC": IC,
                    "temp": "after"
                },
            )
            # !!! ICI ajouter le freqs array dedans !!!
            ds_before = xr.Dataset({"states_psd": psdxr_before},)
            ds_before.to_netcdf(f"tde-hmm2/nc_files/su{subj}IC{IC}Before-states_psd.nc")
            ds_after = xr.Dataset({"states_psd": psdxr_after},)
            ds_after.to_netcdf(f"tde-hmm2/nc_files/su{subj}IC{IC}After-states_psd.nc")

            print("%s seconds" % (tm.time() - start_time))
            print(f"subj{subj}, IC{IC}: OK")
        except:
            print(f"subj{subj}, IC{IC}: NOT POSSIBLE")

sec_time = tm.time() - complete_time
hr_time = int(sec_time/3600)
sec_time = sec_time - (hr_time*3600)
mn_time = int(sec_time/60)
sec_time = int(sec_time - (mn_time*60))
print(f"{hr_time}h{mn_time}mn{sec_time}s")

---- SUBJECT2, IC1 ----
Loading the raw timecourse
Computing and loading the time-frequency wavelet transformation for 3 trials
Computing the two models
Saving the model
Saving the states timecourse
557.5257723331451 seconds
Computing the PSD of each states and saving a plot
Saving the PSD of each state
21.115009784698486 seconds
subj2, IC1: OK
---- SUBJECT2, IC2 ----
Loading the raw timecourse
Computing and loading the time-frequency wavelet transformation for 3 trials
Computing the two models
Saving the model
Saving the states timecourse
37570.98857188225 seconds
Computing the PSD of each states and saving a plot
Saving the PSD of each state
23.54384207725525 seconds
subj2, IC2: OK
---- SUBJECT2, IC3 ----
Loading the raw timecourse
Computing and loading the time-frequency wavelet transformation for 3 trials
Computing the two models
Saving the model
Saving the states timecourse
375.0128138065338 seconds
Computing the PSD of each states and saving a plot
Saving the PSD of each state
20

Saving the model
Saving the states timecourse
462.2281348705292 seconds
Computing the PSD of each states and saving a plot
Saving the PSD of each state
25.790138483047485 seconds
subj8, IC1: OK
---- SUBJECT8, IC2 ----
Loading the raw timecourse
Computing and loading the time-frequency wavelet transformation for 3 trials
Computing the two models
Saving the model
Saving the states timecourse
502.0229878425598 seconds
Computing the PSD of each states and saving a plot
Saving the PSD of each state
24.876614809036255 seconds
subj8, IC2: OK
---- SUBJECT8, IC3 ----
Loading the raw timecourse
Computing and loading the time-frequency wavelet transformation for 3 trials
Computing the two models
Saving the model
Saving the states timecourse
454.5971050262451 seconds
Computing the PSD of each states and saving a plot
Saving the PSD of each state
23.750486373901367 seconds
subj8, IC3: OK
---- SUBJECT8, IC4 ----
Loading the raw timecourse
Computing and loading the time-frequency wavelet transformati

1012.6228334903717 seconds
Computing the PSD of each states and saving a plot
Saving the PSD of each state
27.954464435577393 seconds
subj17, IC1: OK
---- SUBJECT17, IC2 ----
Loading the raw timecourse
Computing and loading the time-frequency wavelet transformation for 3 trials
Computing the two models
Saving the model
Saving the states timecourse
759.2678833007812 seconds
Computing the PSD of each states and saving a plot
Saving the PSD of each state
39.98350405693054 seconds
subj17, IC2: OK
---- SUBJECT17, IC3 ----
Loading the raw timecourse
The independent component IC3 of the subject 17 is not in the .mat file.
subj17, IC3: NOT POSSIBLE
---- SUBJECT17, IC4 ----
Loading the raw timecourse
Computing and loading the time-frequency wavelet transformation for 3 trials
Computing the two models
Saving the model
Saving the states timecourse
988.2769455909729 seconds
Computing the PSD of each states and saving a plot
Saving the PSD of each state
31.235021352767944 seconds
subj17, IC4: OK
--

Saving the PSD of each state
22.61465883255005 seconds
subj23, IC2: OK
---- SUBJECT23, IC3 ----
Loading the raw timecourse
The independent component IC3 of the subject 23 is not in the .mat file.
subj23, IC3: NOT POSSIBLE
---- SUBJECT23, IC4 ----
Loading the raw timecourse
Computing and loading the time-frequency wavelet transformation for 3 trials
Computing the two models
Saving the model
Saving the states timecourse
423.1374125480652 seconds
Computing the PSD of each states and saving a plot
Saving the PSD of each state
22.442832946777344 seconds
subj23, IC4: OK
23h28mn23s
