In [1]:
%load_ext autoreload
%autoreload 2
# Get paths of training data
import glob
from pathlib import Path
import pyedflib
from hmpai.pytorch.generators import MultiNumpyDataset, worker_init_fn
from hmpai.pytorch.pretraining import random_masking
import pandas as pd
import numpy as np
import multiprocessing
import os
from tqdm.notebook import tqdm
import pickle
from hmpai.pytorch.models import *
from hmpai.pytorch.training import train_and_test, pretrain
from sklearn.model_selection import train_test_split
import itertools
import h5py
DATA_PATH = Path(os.getenv("DATA_PATH"))

### Create split_edf_files

In [2]:
# Optional, if split_edf_files.txt does not exist yet
path = Path("../../data/tueg_split")

files = path.glob("*.h5")

with open('split_edf_files.txt', 'w') as file:
    for f_name in files:
        file.write(str(f_name) + "\n")

### Load split_edf_files

In [2]:
# Assumes files are already gathered in split_edf_files.txt
with open('split_edf_files.txt', 'r') as file:
    files = file.readlines()
files = [file.rstrip('\n') for file in files]
# files = files[:100]

### Create index map

In [3]:
gen = MultiNumpyDataset(data_paths=files)

In [4]:
# Save gen.index_map to file
gen.index_map.to_csv('index_map.csv', index=False)

In [5]:
index_map = pd.read_csv('index_map.csv')

### Load index map

In [3]:
index_map = pd.read_csv('index_map.csv')
# index_map = gen.index_map
idx_train, idx_val = train_test_split(index_map, test_size=0.2, random_state=42)
idx_train = idx_train.reset_index(drop=True)
idx_val = idx_val.reset_index(drop=True)
train_data = MultiNumpyDataset(data_paths=files, index_map=idx_train)
val_data = MultiNumpyDataset(data_paths=files, index_map=idx_val)

In [6]:
file = h5py.File(files[0])

In [15]:
index_map = {'path': [], 'participant': [], 'session': [], 'n_samples': []}
for participant in file['participants']:
    for session in file[f'participants/{participant}/sessions']:
        # print(participant, session)
        for dataset in file[f'participants/{participant}/sessions/{session}']:
            ds = file[f'participants/{participant}/sessions/{session}/{dataset}']
            print(ds.shape, ds.chunks)

(6148, 19, 150) (385, 2, 19)
(2208, 19, 150) (276, 3, 19)
(2213, 19, 150) (277, 3, 19)
(2203, 19, 150) (276, 3, 19)
(2210, 19, 150) (277, 3, 19)
(2279, 19, 150) (143, 3, 19)
(2148, 19, 150) (269, 3, 19)
(2214, 19, 150) (277, 3, 19)
(2370, 19, 150) (149, 3, 19)
(2381, 19, 150) (149, 3, 19)
(2300, 19, 150) (144, 3, 19)
(2203, 19, 150) (276, 3, 19)
(2370, 19, 150) (149, 3, 19)
(2510, 19, 150) (157, 3, 19)
(2020, 19, 150) (253, 3, 19)
(2143, 19, 150) (268, 3, 19)
(2182, 19, 150) (273, 3, 19)
(2051, 19, 150) (257, 3, 19)
(2060, 19, 150) (258, 3, 19)
(2016, 19, 150) (252, 3, 19)
(2005, 19, 150) (251, 3, 19)
(2416, 19, 150) (151, 3, 19)
(2016, 19, 150) (252, 3, 19)
(5355, 19, 150) (335, 3, 19)
(9489, 19, 150) (594, 2, 19)
(2046, 19, 150) (256, 3, 19)
(9761, 19, 150) (611, 2, 19)
(6357, 19, 150) (398, 2, 19)
(2206, 19, 150) (276, 3, 19)
(2170, 19, 150) (272, 3, 19)
(2003, 19, 150) (251, 3, 19)
(2263, 19, 150) (142, 3, 19)
(2631, 19, 150) (165, 3, 19)
(2470, 19, 150) (155, 3, 19)
(2205, 19, 150

### Pre-train

In [4]:
model = Seq2SeqTransformer(d_model=19, ff_dim=2048, num_heads=8, num_layers=6, num_classes=0, emb_dim=128)



In [4]:
%load_ext line_profiler

In [5]:
import random
def random_samples():
    for idx in random.sample(range(len(train_data)), 32):
        train_data.__getitem__(idx)
%lprun -f train_data.__getitem__ random_samples()
# 30.7494s


Timer unit: 1e-09 s

Total time: 29.8634 s
File: /workspace/hmp-ai/src/hmpai/pytorch/generators.py
Function: __getitem__ at line 209

Line #      Hits         Time  Per Hit   % Time  Line Contents
   209                                               def __getitem__(self, idx):
   210        32    7448004.0 232750.1      0.0          info_idx = self._find_file_idx(idx)
   211        32   13791608.0 430987.8      0.0          info_row = self.index_map.iloc[info_idx]
   212        32    1200396.0  37512.4      0.0          file_path = info_row['path']
   213                                           
   214        32  726734532.0    2e+07      2.4          file = self._get_dataset(file_path)
   215                                           
   216        32    1906661.0  59583.2      0.0          sample_idx = idx if info_idx == 0 else idx - self.cumulative_sizes[info_idx]
   217        32        3e+10    9e+08     97.5          data = file[f'participants/{info_row["participant"]}/sessions

In [36]:
def subsequent_samples():
    for idx in range(1000000, 1000000+32):
        train_data.__getitem__(idx)
%lprun -f train_data.__getitem__ subsequent_samples()
# 0.419811s

Timer unit: 1e-09 s

Total time: 0.419811 s
File: /workspace/hmp-ai/src/hmpai/pytorch/generators.py
Function: __getitem__ at line 209

Line #      Hits         Time  Per Hit   % Time  Line Contents
   209                                               def __getitem__(self, idx):
   210        32    2463816.0  76994.2      0.6          info_idx = self._find_file_idx(idx)
   211        32    5260257.0 164383.0      1.3          info_row = self.index_map.iloc[info_idx]
   212        32     496066.0  15502.1      0.1          file_path = info_row['path']
   213                                           
   214        32      44235.0   1382.3      0.0          file = self._get_dataset(file_path)
   215                                           
   216        32     294342.0   9198.2      0.1          sample_idx = idx if info_idx == 0 else idx - self.cumulative_sizes[info_idx]
   217        32  411132315.0    1e+07     97.9          data = file[f'participants/{info_row["participant"]}/session

In [11]:
from torch.utils.data import DataLoader, Dataset

train_loader = DataLoader(
    train_data, 16, shuffle=True, num_workers=0, worker_init_fn=worker_init_fn
)

In [13]:
# Just gathering one batch takes 1m
import cProfile
import pstats
profiler = cProfile.Profile()
profiler.enable()
for batch in train_loader:
    print(batch)
    print(batch.shape)
    break
profiler.disable()
stats = pstats.Stats(profiler).sort_stats('cumtime')
stats.print_stats()

tensor([[[-5.3684e+00,  4.2105e+00,  1.4754e+01,  ...,  2.3860e+00,
          -9.4737e-01, -8.2456e-01],
         [-3.6842e+00,  2.4561e+00,  1.4298e+01,  ...,  9.8246e-01,
          -2.6140e+00, -1.2105e+00],
         [-1.6491e+00,  4.1404e+00,  1.2982e+01,  ...,  3.0877e+00,
          -2.5088e+00, -1.6842e+00],
         ...,
         [ 4.7193e+00,  1.2737e+01,  7.0000e+01,  ...,  4.2105e-01,
          -3.5088e-01,  1.9123e+00],
         [ 3.8421e+00,  1.5421e+01,  7.1491e+01,  ...,  3.8596e-01,
           1.4211e+00,  1.2807e+00],
         [ 6.1579e+00,  1.5702e+01,  7.0140e+01,  ..., -6.1403e-01,
          -7.0175e-02,  1.0526e+00]],

        [[-1.5433e+00, -3.9370e-01,  1.3386e-01,  ..., -2.1890e+00,
          -7.7165e-01, -9.8425e-01],
         [-1.3858e+00, -2.9134e-01,  2.6772e-01,  ..., -2.1890e+00,
          -8.3465e-01, -9.2126e-01],
         [-1.0866e+00, -1.1811e-01,  3.5433e-01,  ..., -2.1339e+00,
          -9.4488e-01, -8.5039e-01],
         ...,
         [-8.3465e-01, -7

<pstats.Stats at 0x7facedc6b880>

In [5]:
pretrain(model, train_data, val_data, batch_size=128, workers=8, pretrain_fn=random_masking)

  0%|          | 0/922390 [00:00<?, ? batch/s]

KeyboardInterrupt: 