In [1]:
%load_ext autoreload
%autoreload 2
# Get paths of training data
import glob
from pathlib import Path
import pyedflib
from hmpai.pytorch.generators import MultiH5pyDataset, worker_init_fn, MultiNumpyDataset
from hmpai.pytorch.pretraining import random_masking
import pandas as pd
import numpy as np
import multiprocessing
import os
from tqdm.notebook import tqdm
import pickle
from hmpai.pytorch.models import *
from hmpai.pytorch.training import train_and_test, pretrain
from sklearn.model_selection import train_test_split
from hmpai.training import split_index_map_tueg
import itertools
import h5py
DATA_PATH = Path(os.getenv("DATA_PATH"))

### Create split_edf_files

In [86]:
# Optional, if split_edf_files.txt does not exist yet
path = DATA_PATH / 'tueg'

files = path.glob("*.csv")

with open('split_edf_files.txt', 'w') as file:
    for f_name in files:
        file.write(str(f_name) + "\n")

### Load split_edf_files

In [2]:
# Assumes files are already gathered in split_edf_files.txt
with open('split_edf_files.txt', 'r') as file:
    files = file.readlines()
files = [file.rstrip('\n') for file in files]
# files = files[:100]

### Create index map

In [3]:
files

['/workspace/data_local/tueg/1034509_624.csv',
 '/workspace/data_local/tueg/1034507_597.csv',
 '/workspace/data_local/tueg/1034506_624.csv',
 '/workspace/data_local/tueg/1034509_303.csv',
 '/workspace/data_local/tueg/1034510_483.csv',
 '/workspace/data_local/tueg/1034510_624.csv',
 '/workspace/data_local/tueg/1034507_624.csv',
 '/workspace/data_local/tueg/1034505_624.csv',
 '/workspace/data_local/tueg/1034508_408.csv',
 '/workspace/data_local/tueg/1034508_624.csv',
 '/workspace/data_local/tueg/1034504_624.csv']

In [3]:
gen = MultiNumpyDataset(data_paths=files)

In [4]:
gen.index_map

Unnamed: 0,participant,session,path,offset,length,cumulative
0,aaaaabka,s005,../../data/tueg_split_npy/1034509_624.npy,0,512,512
1,aaaaabka,s005,../../data/tueg_split_npy/1034509_624.npy,512,1025,1537
2,aaaaabka,s005,../../data/tueg_split_npy/1034509_624.npy,1537,1025,2562
3,aaaaabka,s005,../../data/tueg_split_npy/1034509_624.npy,2562,512,3074
4,aaaaabka,s005,../../data/tueg_split_npy/1034509_624.npy,3074,1025,4099
...,...,...,...,...,...,...
491,aaaaaakk,s001,../../data/tueg_split_npy/1034504_624.npy,649131,2203,6611642
492,aaaaaakl,s001,../../data/tueg_split_npy/1034504_624.npy,651334,2011,6613653
493,aaaaaakm,s001,../../data/tueg_split_npy/1034504_624.npy,653345,2148,6615801
494,aaaaaakm,s002,../../data/tueg_split_npy/1034504_624.npy,655493,2145,6617946


In [91]:
# Save gen.index_map to file
gen.index_map.to_csv('index_map.csv', index=False)

In [5]:
index_map = pd.read_csv('index_map.csv')

### Load index map

In [3]:
index_map = pd.read_csv('index_map.csv')
# index_map = index_map[:100]
# Split on participants
idx_train, idx_val = split_index_map_tueg(index_map, train_percentage=80)
idx_train = idx_train.reset_index(drop=True)
idx_val = idx_val.reset_index(drop=True)
train_data = MultiNumpyDataset(data_paths=files, index_map=idx_train)
val_data = MultiNumpyDataset(data_paths=files, index_map=idx_val)

### Pre-train

In [5]:
# model = Seq2SeqTransformer(d_model=19, ff_dim=512, num_heads=8, num_layers=6, num_classes=0, emb_dim=256)
model = MambaModel(256, 19, 0, 5, global_pool=False, dropout=0.1)
model.pretraining = True
# 8 workers, 8.5 b/s

pretrain(
    model,
    train_data,
    val_data,
    batch_size=256,
    workers=8,
    pretrain_fn=random_masking,
    logs_path=Path("../logs/"),
    early_stopping=False,
    epochs=30,
)



  0%|          | 0/19418 [00:00<?, ? batch/s]

KeyboardInterrupt: 

In [None]:
# pretrain(model, train_data, val_data, batch_size=128, workers=8, pretrain_fn=random_masking)
from hmpai.pytorch.pretraining import pretrain_train

%lprun -f train_data.__getitem__ pretrain(model, train_data, val_data, batch_size=128, workers=0, pretrain_fn=random_masking)
# ~2 batch/s with 8 workers, converges to same with 14