In [1]:

import numpy as np
import os
import usleep
import typing

from IPython.display import clear_output


In [2]:

# In[82]:


# ARGS

class ARGS(object):
    def __init__(self, file, data_per_prediction: int = 128):
        self.f = os.path.abspath(file) if not os.path.isabs(file) else file
        self.o = self.f.replace(".edf",".npy")
        self.logging_out_path = self.f.replace(".edf",".log")
        
        self.auto_channel_grouping =  ['EOG', 'EEG']
        self.auto_reference_types  =  None
        self.channels              =  ['O1-M2==EEG', 'O2-M1==EEG', 'E1-M1==EOG', 'E2-M1==EOG']
        self.data_per_prediction   =  128
        self.force_gpus            =  ''
        self.header_file_name      =  None
        self.model                 =  'u-sleep:1.0'
        self.no_argmax             =  True
        self.num_gpus              =  0
        self.overwrite             =  True
        self.project_dir           =  usleep.get_model_path(model_name=self.model.split(":")[0], model_version=self.model.split(":")[-1])
        self.strip_func            =  'trim_psg_trailing'
        self.weights_file_name     =  None


# In[83]:


from utime import Defaults
from utime.hyperparameters import YAMLHParams

# Load arguments and hyperparamets
args = ARGS(file="../edf_data/9JQY.edf", data_per_prediction=128)
hparams = YAMLHParams(Defaults.get_hparams_path(args.project_dir), no_version_control=True)


# In[84]:


from psg_utils.dataset.sleep_study import SleepStudy
from utime.bin.predict_one import get_sleep_study

def get_and_load_study(file, args: ARGS, hparams: YAMLHParams) -> SleepStudy:

    # Get the sleep study
    print(f"Loading and pre-processing PSG file {file}...")
    hparams['prediction_params']['channels'] = args.channels
    hparams['prediction_params']['strip_func']['strip_func_str'] = args.strip_func

    study, channel_groups = get_sleep_study(psg_path=file,
                                            header_file_name=args.header_file_name,
                                            auto_channel_grouping=args.auto_channel_grouping,
                                            auto_reference_types=args.auto_reference_types,
                                            **hparams['prediction_params'])
    
    study.channel_groups = channel_groups

    return study


# In[85]:


from utime.bin.evaluate import get_and_load_model, get_and_load_one_shot_model
from keras import Model
from keras.layers import Input

    
base = get_and_load_model(
            project_dir=args.project_dir,
            hparams=hparams,
            weights_file_name=hparams.get('weights_file_name')
        )
clear_output(wait=False)    # Removing glorot intitialization warning...



# In[ ]:

def predict_on_study(mdl, study: SleepStudy):

    prob = np.empty([len(study.channel_groups), study.n_periods*mdl.output_shape[-2], mdl.output_shape[-1]])
    for i, channel_group in enumerate(study.channel_groups):
        # Get PSG for particular group
        psg = np.expand_dims(study.get_all_periods(),0)
        psg_subset = psg[..., tuple(channel_group.channel_indices)]
        prob_i = mdl.predict_on_batch(psg_subset)
        prob[i,...] = prob_i.reshape(-1, prob.shape[-1])
    
    return prob

In [3]:
import glob

files = glob.glob("../edf_data/*.edf")

studies = [get_and_load_study(x, args, hparams) for x in files]
clear_output(wait=False)

In [42]:
from utime.bin.evaluate import get_and_load_model, get_and_load_one_shot_model

ids = []
probs = []
hparams["build"]["data_per_prediction"] = args.data_per_prediction


for s in studies:
    print(s.psg_file_path)
    mdl = get_and_load_one_shot_model(
                n_periods=s.n_periods,
                project_dir=args.project_dir,
                hparams=hparams,
                weights_file_name=hparams['weights_file_name'])

    prob = predict_on_study(mdl, s)

    ids.append(s.psg_file_path.split("\\")[-1].replace(".edf",""))
    probs.append(prob)

c:\code\mU-Sleep\edf_data\0ncr.edf
c:\code\mU-Sleep\edf_data\0pai.edf
c:\code\mU-Sleep\edf_data\3J4W.edf
c:\code\mU-Sleep\edf_data\3P0D.edf
c:\code\mU-Sleep\edf_data\40kO.edf
c:\code\mU-Sleep\edf_data\5bSg.edf
c:\code\mU-Sleep\edf_data\6iwd.edf
c:\code\mU-Sleep\edf_data\6JVj.edf
c:\code\mU-Sleep\edf_data\9098.edf
c:\code\mU-Sleep\edf_data\9JQY.edf
c:\code\mU-Sleep\edf_data\Af8a.edf
c:\code\mU-Sleep\edf_data\AsLD.edf
c:\code\mU-Sleep\edf_data\AXbm.edf
c:\code\mU-Sleep\edf_data\bkx9.edf
c:\code\mU-Sleep\edf_data\BSvO.edf
c:\code\mU-Sleep\edf_data\C1Wu.edf
c:\code\mU-Sleep\edf_data\cblr.edf
c:\code\mU-Sleep\edf_data\csxQ.edf
c:\code\mU-Sleep\edf_data\d3ET.edf
c:\code\mU-Sleep\edf_data\ddTG.edf
c:\code\mU-Sleep\edf_data\DjrT.edf
c:\code\mU-Sleep\edf_data\Dr51.edf
c:\code\mU-Sleep\edf_data\DSfb.edf
c:\code\mU-Sleep\edf_data\DYYI.edf
c:\code\mU-Sleep\edf_data\EHED.edf
c:\code\mU-Sleep\edf_data\EMcQ.edf
c:\code\mU-Sleep\edf_data\EyTS.edf
c:\code\mU-Sleep\edf_data\f8H5.edf
c:\code\mU-Sleep\edf

In [49]:
from scipy.io import savemat

for i, p in zip(ids, probs):
    savemat(f"Matlab/probs/{i}.mat", mdict={"probs": p})