In [1]:
%load_ext autoreload
%autoreload 2
# Get paths of training data
import glob
from pathlib import Path
import pyedflib
from hmpai.pytorch.generators import MultiNumpyDataset, worker_init_fn
from hmpai.pytorch.pretraining import random_masking
import pandas as pd
import numpy as np
import multiprocessing
import os
from tqdm.notebook import tqdm
import pickle
from hmpai.pytorch.models import *
from hmpai.pytorch.training import train_and_test, pretrain
from sklearn.model_selection import train_test_split
import itertools
import h5py
DATA_PATH = Path(os.getenv("DATA_PATH"))

### Create split_edf_files

In [2]:
# Optional, if split_edf_files.txt does not exist yet
path = Path("../../data/tueg_split")

files = path.glob("*.h5")

with open('split_edf_files.txt', 'w') as file:
    for f_name in files:
        file.write(str(f_name) + "\n")

### Load split_edf_files

In [2]:
# Assumes files are already gathered in split_edf_files.txt
with open('split_edf_files.txt', 'r') as file:
    files = file.readlines()
files = [file.rstrip('\n') for file in files]
# files = files[:100]

### Create index map

In [4]:
gen = MultiNumpyDataset(data_paths=files)

In [5]:
# Save gen.index_map to file
gen.index_map.to_csv('index_map.csv', index=False)

In [6]:
index_map = pd.read_csv('index_map.csv')

### Load index map

In [3]:
index_map = pd.read_csv('index_map.csv')
# index_map = gen.index_map
idx_train, idx_val = train_test_split(index_map, test_size=0.2, random_state=42)
idx_train = idx_train.reset_index(drop=True)
idx_val = idx_val.reset_index(drop=True)
train_data = MultiNumpyDataset(data_paths=files, index_map=idx_train)
val_data = MultiNumpyDataset(data_paths=files, index_map=idx_val)

### Pre-train

In [9]:
model = Seq2SeqTransformer(d_model=19, ff_dim=512, num_heads=8, num_layers=6, num_classes=0, emb_dim=128)



In [13]:
%load_ext line_profiler

In [15]:
import random
def random_samples():
    for idx in random.sample(range(len(train_data)), 32):
        train_data.__getitem__(idx)
%lprun -f train_data.__getitem__ random_samples()
# 30.7494s
# 5.20681s after chunking

Timer unit: 1e-09 s

Total time: 5.20681 s
File: /workspace/hmp-ai/src/hmpai/pytorch/generators.py
Function: __getitem__ at line 209

Line #      Hits         Time  Per Hit   % Time  Line Contents
   209                                               def __getitem__(self, idx):
   210        32    7033860.0 219808.1      0.1          info_idx = self._find_file_idx(idx)
   211        32   13840614.0 432519.2      0.3          info_row = self.index_map.iloc[info_idx]
   212        32    1274335.0  39823.0      0.0          file_path = info_row['path']
   213                                           
   214        32  525046617.0    2e+07     10.1          file = self._get_dataset(file_path)
   215                                           
   216        32    1835132.0  57347.9      0.0          sample_idx = idx if info_idx == 0 else idx - self.cumulative_sizes[info_idx]
   217        32 4657367460.0    1e+08     89.4          data = file[f'participants/{info_row["participant"]}/sessions

In [17]:
def subsequent_samples():
    for idx in range(1000000, 1000000+32):
        train_data.__getitem__(idx)
%lprun -f train_data.__getitem__ subsequent_samples()
# 0.419811s
# 1.13611s after chunking

Timer unit: 1e-09 s

Total time: 1.13611 s
File: /workspace/hmp-ai/src/hmpai/pytorch/generators.py
Function: __getitem__ at line 209

Line #      Hits         Time  Per Hit   % Time  Line Contents
   209                                               def __getitem__(self, idx):
   210        32    3911967.0 122249.0      0.3          info_idx = self._find_file_idx(idx)
   211        32    8029890.0 250934.1      0.7          info_row = self.index_map.iloc[info_idx]
   212        32     793383.0  24793.2      0.1          file_path = info_row['path']
   213                                           
   214        32      67999.0   2125.0      0.0          file = self._get_dataset(file_path)
   215                                           
   216        32     360236.0  11257.4      0.0          sample_idx = idx if info_idx == 0 else idx - self.cumulative_sizes[info_idx]
   217        32 1122751157.0    4e+07     98.8          data = file[f'participants/{info_row["participant"]}/sessions

In [13]:
pretrain(model, train_data, val_data, batch_size=128, workers=14, pretrain_fn=random_masking)
# ~2 batch/s with 8 workers, converges to same with 14

  0%|          | 0/937050 [00:00<?, ? batch/s]

tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(F

KeyboardInterrupt: 