In [28]:
import os
import torch
import torchvision
import numpy as np
import pickle
import mne
import scipy
from scipy.signal import welch
import matplotlib.pyplot as plt
from scipy.io import loadmat
from sklearn.decomposition import PCA

from osl_dynamics.data import Data
from osl_dynamics.inference import tf_ops
from osl_dynamics.models.hmm import Config, Model






In [26]:
# load data
data_dir = os.path.join('/', 'well', 'woolrich', 'projects', 'cichy118_cont', 'preproc_data_osl', 'subj1', 'hmm', '100hz')
paths = [os.path.join('/', 'well', 'woolrich', 'projects', 'cichy118_cont', 'preproc_data_osl', 'subj1', 'osl_dynamics100hz.npy')]

data = Data(paths, time_axis_first=False)
data.prepare(n_embeddings=15, n_pca_components=80)
data.save(data_dir)

HBox(children=(HTML(value='Loading files'), FloatProgress(value=0.0, max=1.0), HTML(value='')))

Calculating PCA components:   0%|          | 0/1 [00:00<?, ?it/s]




Calculating PCA components: 100%|██████████| 1/1 [00:07<00:00,  7.20s/it]
2023-02-28 15:07:06 INFO osl-dynamics: Explained variance: 66.3%


HBox(children=(HTML(value='Preparing data'), FloatProgress(value=0.0, max=1.0), HTML(value='')))

Saving data: 100%|██████████| 1/1 [00:00<00:00,  5.12it/s]







In [36]:
# train HMM
# Directory to save the model to
model_dir = os.path.join('..', 'results', 'cichy_epoched', 'subj1', 'data', '100hz')

# GPU settings
tf_ops.gpu_growth()

# Settings
# - If you run out of memory you can reduce the sequence_length
#   and/or batch_size.
# - You might want to play around with the learning rate.
# - Pick the parameters what give you the best final training loss.
# - You also want to show your results are robust to the choice
#   for n_states.
config = Config(
    n_states=12,
    n_channels=80,
    sequence_length=2000,
    learn_means=False,
    learn_covariances=True,
    learn_trans_prob=True,
    batch_size=32,
    learning_rate=1e-2,
    n_epochs=20,
)

# Load the prepared data
# - pass the path to the directory created by Data.save() in prepare_data.py
training_data = Data(data_dir)

# Build the model
model = Model(config)
model.summary()

# Train the model
print("Training model")
history = model.fit(training_data)

# Save the trained model
model.save(model_dir)

# Save the training history (contains the loss as a function of epochs)
pickle.dump(history, open(model_dir + "/history.pkl", "wb"))

HBox(children=(HTML(value='Loading files'), FloatProgress(value=0.0, max=1.0), HTML(value='')))


Model: "HMM-Obs"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
inputs (InputLayer)             [(None, 2000, 92)]   0                                            
__________________________________________________________________________________________________
tf_op_layer_split_1 (TensorFlow [(None, 2000, 80), ( 0           inputs[0][0]                     
__________________________________________________________________________________________________
means (VectorsLayer)            (12, 80)             960         tf_op_layer_split_1[0][0]        
__________________________________________________________________________________________________
covs (CovarianceMatricesLayer)  (12, 80, 80)         38880       tf_op_layer_split_1[0][0]        
___________________________________________________________________________________________

In [None]:
# compute stats