# Final: saving all HMM data, of which the states timecourses.

## The libraries and methods we need:

In [None]:
import numpy as np
from loader import load_oneIC

In [None]:
from sklearn.decomposition import PCA
from hmmlearn import hmm
import scipy.signal as signal
import pickle

In [None]:
import matplotlib.pyplot as plt

In [None]:
import h5py

In [None]:
import xarray as xr

In [None]:
import time as tm

In [None]:
directory = "/home/INT/malfait.n/Documents/NIC_250819"
file = "FCK_LOCKED_IC_JYOTIKA_250819.mat"

path = f"{directory}/{file}"

mat_file = h5py.File(path, "r")
cells_refs = mat_file['FCK_LOCKED_IC_JYOTIKA']

n_IC = 4
n_subj = 23

## Functions:

### Cumputing functions:

In [None]:
# The embedx function copies the `x` array len(lags) times into `xe`
# with lags (i.e. time delays) between lags[0] and lags[-1] (we implement the time-delay array for the HMM).

def embedx(x, lags):
    
    Xe = np.zeros((x.shape[1], x.shape[0],  len(lags)))

    for l in range(len(lags)):
        Xe[:, :, l] = np.roll(x, lags[l], axis=0).swapaxes(0, 1)

    # Remove edges
    valid = np.ones((x.shape[0], 1), dtype=np.int8)
    valid[:np.abs(np.min(lags)), :] = 0
    valid[-np.abs(np.max(lags)):, :] = 0

    Xe = Xe[:, valid[:, 0] == 1, :]

    return Xe, valid


# The hmm_tde function finds parameters for the HMM,
# then uses them to determine the probability of presence of each found state over time.

def hmm_tde(y: np.array, n_states=3, n_iter=100, n_components=8, 
            covariance_type='full', model_type='GMMHMM', tol=0.01, n_mix=1, **kwargs):
    
    if model_type=='GMMHMM':
        model = hmm.GMMHMM(n_components=n_states, n_iter=n_iter,
                            covariance_type=covariance_type, tol=tol, n_mix=n_mix, **kwargs)
        
    elif model_type=='GaussianHMM':
        model = hmm.GaussianHMM(n_components=n_states, n_iter=n_iter,
                            covariance_type=covariance_type, tol=tol, **kwargs)
        
    elif model_type=='MultinomialHMM':
        model = hmm.MultinomialHMM(n_components=n_states, n_iter=n_iter, tol=tol, **kwargs)
    
    else: 
        return "Non-exixting model_type. Please choose 'GMMHMM' or 'GaussianHMM' or 'MultinomialHMM'. default='GMMHMM'"
        
    model.fit(y)
    gamma = model.predict_proba(y)

    return gamma, model

In [None]:
def show_bigstates(
    gamma, n_states, xe, # the data we need for the plot
    
    subj, IC, # which IC of which subject is of interest here
    
    n_components, n_lags, covariance_type, model_type, n_mix, # infos we put in the .png name if we want to save it
):

#     fig = plt.figure(figsize=(6, 6))
#     plt.title('Power Spectrum Density')
    max_power = np.zeros(n_states)
    psd_all = np.zeros((n_states, 196))
    for i in range(n_states):

        # Compute PSD separately for each lag
        tot = []
        for seg in xe[0, gamma[:, i]> .6, :].T:
            freqs, psd = signal.welch(x=seg, fs=256, nfft=1000)
            tot.append(psd)
        psd = np.mean(np.asarray(tot), 0)
        
        max_power[i] = np.amax(psd[80:])
        
        psd_all[i] = psd[:196]
#         plt.plot(freqs[:196], psd_all[i])
    
#     plt.ylabel('PSD')
#     plt.xlabel('Frequencies (Hz)')
#     plt.legend([f'state {i+1}' for i in range(n_states)], loc='upper right')
#     plt.tight_layout()   

#     plt.savefig(f'tde-hmm2/png_files/test{ts}su{subj}IC{IC}All_lg{n_lags}co{n_components}st{n_states}'
#                     +f'{n_mix}'+model_type+'_states-info.png', dpi=600)
    
#     plt.close(fig)
    
    return psd_all

## The routine:

#### Model and parameters:
|_ (58delays x 4IC)  x (7s x 256Hz x Ntrials)   -----PCA-----> |_ (40components x 4IC) x (7s x 256Hz x Ntrials) -----> |_ TDE-HMM ( 1 Gaussian/state, 4 to 6 states )

In [None]:
# The parameters we change to hope for some results
lags = np.arange(-29, 29)
n_lags = lags.shape[0]
n_iter=100
n_states_max=6    # for the Hidden Markov Model
n_components=40     # For the principal component analysis
covariance_type='diag'
model_type='GaussianHMM'
tol=0.01
n_mix=1

##### Complete script

In [None]:
subj_list = [i for i in range(2,9)] + [i for i in range(10,13)] + [14] + [i for i in range(16, n_subj+1)] 
                                                                            # All except subjects 1, 9, 13, 15
complete_time = tm.time()
freqs = np.arange(0, 50, 50/196)
for subj in subj_list:
    print(f"---- SUBJECT{subj} ----")
    # Create the input matrix for the TDE-HMM:
    print("Computing the imput matrix for the model")
    datall = []
    xeall = []
    IC_list = []
    for IC in range(1, n_IC+1):
        try:
            data, n_trials = load_oneIC(mat_file, cells_refs, subj, IC, comp=False)
            big_timecourse = np.concatenate([data[f'raw_timecourse_256Hz'][i] for i in range(n_trials)])
            x = big_timecourse.reshape(-1, 1)
            xe, valid = embedx(x, lags)
            xeall.append(xe)
            pca = PCA(n_components=n_components)
            y = pca.fit_transform(xe[0, :, :])
            datall.append(y)
            IC_list.append(IC)
            print(f"IC{IC} loaded")
        except:
            pass
    y = np.concatenate(datall, axis=1)
    # Compute the model and the states timecourses
    for n_states in range(n_states_max, 3, -1):
        start_time = tm.time()
        try:
            print(f"Computing and saving the model with {n_states} states")
            gamma, model = hmm_tde(y, n_iter=n_iter, n_states=n_states, n_components=n_components, 
                                    covariance_type=covariance_type, model_type=model_type, tol=tol, n_mix=n_mix)
            print("Score:", model.score(y))
            with open(f"tde-hmm2/pkl_files/su{subj}All_lg{n_lags}co{n_components}st{n_states}"
                +f"{n_mix}"+model_type+"_model.pkl", "wb") as file: pickle.dump(model, file)
            # Compute the Power Spectral Density of each state
            print(f"Computing the PSD of each of the {n_states} states")
            psds = []
            for i in range(len(IC_list)):
                psd = show_bigstates(
                    gamma, n_states, xeall[i], # the data we need for the plot
                    subj, IC_list[i], # which IC of which subject is of interest here
                    n_components, n_lags, covariance_type, model_type, n_mix, # infos we put in the .png name if we want to save it
                )
                psd = psd[np.newaxis,]
                psds.append(psd)
            # Save the states timecourses and PSDs thanks to xarray and netCDF
            print("Saving the states timecourses and PSDs")
            tcourse = np.concatenate(
                (np.zeros((29,n_states)), gamma, np.zeros((28,n_states)))
            )
            time_axis = data["time_axis"]
            t_len = time_axis.shape[0]
            tcourse_trials = np.zeros((n_trials, t_len, n_states))
            for tr in range(n_trials):
                tcourse_trials[tr] = tcourse[tr*t_len:(tr+1)*t_len]
            ds = xr.Dataset(
                {
                    "states_timecourse": (("trials","time", "states"), tcourse_trials),
                    "states_psd": (("IC", "states", "freq"), np.concatenate((psds))),
                },
                {
                    "IC":IC_list,
                    "time":time_axis,
                    "states":np.arange(1, n_states+1),
                    "freq": freqs,
                }
            )
            ds.to_netcdf(f"tde-hmm2/nc_files/su{subj}-{n_states}states_data.nc")
            print(f"subj{subj}, IC{IC}: OK")
        except:
            print("Oops!")
            pass
        print("%s seconds" % (tm.time() - start_time))

sec_time = tm.time() - complete_time
hr_time = int(sec_time/3600)
sec_time = sec_time - (hr_time*3600)
mn_time = int(sec_time/60)
sec_time = sec_time - (mn_time*60)
print(f"{hr_time}h{mn_time}mn{sec_time}s")

##### Save the fractional occupancy, and plot it with the PSD in each IC

In [None]:
subj_list = [i for i in range(2,9)] + [i for i in range(10,13)] + [14] + [i for i in range(16, n_subj+1)] 

for subj in subj_list:
    try:
        print(f"----- Subject{subj} -----")
        for n_states in range(n_states_max, 3, -1):
            ds = xr.open_dataset(f"tde-hmm2/nc_files/su{subj}-{n_states}states_data.nc")
            ds = ds.assign(frac_occ = (ds["states_timecourse"].sum("trials")/ds.sizes["trials"]))
            ds.to_netcdf(f"tde-hmm2/nc_files/su{subj}-{n_states}states_data_Multi.nc", mode="w")
            labels = [f"state {state}" for state in ds["states"].values]
            fig = plt.figure(figsize=(6*(len(ds["IC"].values)+1),6))
            plt.subplot(1, len(ds["IC"].values)+1, 1)
            plt.title(f'Frac. Occupancy')
            plt.plot(ds["time"], ds["frac_occ"].values)
            plt.xlabel('Time (s)')
            plt.ylabel('Frac. occupancy')
            plt.legend(labels, loc='upper left')
            for i in range(len(ds["IC"].values)):
                plt.subplot(1, len(ds["IC"].values)+1, i+2)
                plt.title(f"Power Spectrum Density of IC{ds['IC'].values[i]}")
                plt.plot(ds["freq"], ds["states_psd"].values[i,].T)
                plt.xlabel('Frequency (Hz)')
                plt.ylabel('PSD')
                plt.legend(labels, loc='upper right')
            plt.savefig(f"tde-hmm2/png_files/su{subj}-frac_occ_{n_states}st_Multi.png", dpi=300)
            plt.close(fig)
            print(f"su{subj}, {n_states} states done")
    except:
        print("Ouch!")
        pass

In [None]:
subj_list = [i for i in range(2,9)] + [i for i in range(10,13)] + [14] + [i for i in range(16, n_subj+1)] 

for subj in subj_list:
    try:
        print(f"----- Subject{subj} -----")
        for n_states in [3]:
            ds = xr.open_dataset(f"tde-hmm2/nc_files/su{subj}-{n_states}states_data.nc")
            ds = ds.assign(frac_occ = (ds["states_timecourse"].sum("trials")/ds.sizes["trials"]))
            ds.to_netcdf(f"tde-hmm2/nc_files/su{subj}-{n_states}states_data_Multi.nc", mode="w")
            labels = [f"state {state}" for state in ds["states"].values]
            fig = plt.figure(figsize=(6*(len(ds["IC"].values)+1),6))
            plt.subplot(1, len(ds["IC"].values)+1, 1)
            plt.title(f'Frac. Occupancy')
            plt.plot(ds["time"], ds["frac_occ"].values)
            plt.xlabel('Time (s)')
            plt.ylabel('Frac. occupancy')
            plt.legend(labels, loc='upper left')
            for i in range(len(ds["IC"].values)):
                plt.subplot(1, len(ds["IC"].values)+1, i+2)
                plt.title(f"Power Spectrum Density of IC{ds['IC'].values[i]}")
                plt.plot(ds["freq"], ds["states_psd"].values[i,].T)
                plt.xlabel('Frequency (Hz)')
                plt.ylabel('PSD')
                plt.legend(labels, loc='upper right')
            plt.savefig(f"tde-hmm2/png_files/su{subj}-frac_occ_{n_states}st_Multi.png", dpi=300)
            plt.close(fig)
            print(f"su{subj}, {n_states} states done")
    except:
        print("Ouch!")
        pass

##### Tests

In [None]:
subj = 2
freqs = np.arange(0, 50, 50/196)

# Loading all data for subject{subj}, IC{IC}
print("Computing the imput matrix for the model")
datall = []
IC_list = []
for IC in range(1, n_IC+1):
    try:
        data, n_trials = load_oneIC(mat_file, cells_refs, subj, IC, comp=False)
        big_timecourse = np.concatenate([data[f'raw_timecourse_256Hz'][i] for i in range(n_trials)])
        x = big_timecourse.reshape(-1, 1)
        xe, valid = embedx(x, lags)
        pca = PCA(n_components=n_components)
        y = pca.fit_transform(xe[0, :, :])
        datall.append(y)
        IC_list.append(IC)
        print(f"IC{IC} loaded")
    except:
        pass
y = np.concatenate(datall, axis=1)

In [None]:
xeall = []
IC_list = []
for IC in range(1, n_IC+1):
    try:
        data, n_trials = load_oneIC(mat_file, cells_refs, subj, IC, comp=False)
        big_timecourse = np.concatenate([data[f'raw_timecourse_256Hz'][i] for i in range(n_trials)])
        x = big_timecourse.reshape(-1, 1)
        xe, valid = embedx(x, lags)
        xeall.append(xe)
        
        IC_list.append(IC)
        print(f"IC{IC} loaded")
    except:
        pass
print(len(xeall), xeall[0].shape)

In [None]:
print("Computing and saving the model")
gamma, model = hmm_tde(y, n_iter=n_iter, n_states=n_states, n_components=n_components, 
                        covariance_type=covariance_type, model_type=model_type, tol=tol, n_mix=n_mix)
with open(f"tde-hmm2/pkl_files/su{subj}All_lg{n_lags}co{n_components}st{n_states}"
    +f"{n_mix}"+model_type+"_model.pkl", "wb") as file: pickle.dump(model, file)

In [None]:
print("Computing the PSD of each state")
psds = []
for i in range(len(IC_list)):
    psd = show_bigstates(
        gamma, n_states, xeall[i], # the data we need for the plot
        subj, IC_list[i], # which IC of which subject is of interest here
        n_components, n_lags, covariance_type, model_type, n_mix, # infos we put in the .png name if we want to save it
    )
    psd = psd[np.newaxis,]
    psds.append(psd)

In [None]:
gamma.shape

In [None]:
len(psds)

In [None]:
psds[0].shape

In [None]:
print("Saving the states timecourses and PSDs")
tcourse = np.concatenate(
    (np.zeros((29,n_states)), gamma, np.zeros((28,n_states)))
)
time_axis = data["time_axis"]
t_len = time_axis.shape[0]
tcourse_trials = np.zeros((n_trials, t_len, n_states))
for tr in range(n_trials):
    tcourse_trials[tr] = tcourse[tr*t_len:(tr+1)*t_len]

In [None]:
tcourse_trials.shape

In [None]:
IC_list

In [None]:
time_axis.shape

In [None]:
freqs.shape

In [None]:
ds = xr.Dataset(
    {
        "states_timecourse": (("trials","time", "states"), tcourse_trials),
        "states_psd": (("IC", "states", "freq"), np.concatenate((psds))),
    },
    {
        "IC":IC_list,
        "time":time_axis,
        "states":np.arange(1, n_states+1),
        "freq": freqs,
    }
)
ds.to_netcdf(f"tde-hmm2/nc_files/su{subj}-{n_states}states_data.nc")

In [None]:
ds

In [None]:
ds = xr.open_dataset(f"tde-hmm2/nc_files/su{subj}-{n_states}states_data.nc")
ds = ds.assign(frac_occ = (ds["states_timecourse"].sum("trials")/ds.sizes["trials"]))
ds.to_netcdf(f"tde-hmm2/nc_files/su{subj}-{n_states}states_data_Multi.nc", mode="w")
labels = [f"state {state}" for state in ds["states"].values]
fig = plt.figure(figsize=(6*(len(ds["IC"].values)+1),6))
plt.subplot(1, len(ds["IC"].values)+1, 1)
plt.title(f'Frac. Occupancy')
plt.plot(ds["time"], ds["frac_occ"].values)
plt.xlabel('Time (s)')
plt.ylabel('Frac. occupancy')
plt.legend(labels, loc='upper left')
for i in range(len(ds["IC"].values)):
    plt.subplot(1, len(ds["IC"].values)+1, i+2)
    plt.title(f"Power Spectrum Density of IC{ds['IC'].values[i]}")
    plt.plot(ds["freq"], ds["states_psd"].values[i,].T)
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('PSD')
    plt.legend(labels, loc='upper right')
plt.savefig(f"tde-hmm2/png_files/su{subj}-frac_occ_Multi.png", dpi=300)
plt.close(fig)
print(f"su{subj}")
ds.close()