In [6]:
import glob 
import h5py
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

torch.manual_seed(7)

DATAPATH = '/Users/wonsuk/projects/data/ecg/raw/2019-11-19'
DATA_LENGTH = 1000
TRAIN_RATIO = 0.8
ecg_key_string_list = [ 
    "strip_I", 
    "strip_II",
    "strip_III",
    "strip_V1",
    "strip_V2",
    "strip_V3",
    "strip_V4",
    "strip_V5",
    "strip_V6",
    "strip_aVF",
    "strip_aVL",
    "strip_aVR"
]

hdf5_files = []
count = 0
for f in glob.glob("{}/*.hd5".format(DATAPATH)):
    count += 1
    if count > DATA_LENGTH:
        break
    hdf5_files.append(f)
        
print('Data Loading finished (row:{})'.format(len(hdf5_files)))

class ECGDataset(Dataset):
    def __init__(self, data, target, transform=None):
        self.data = torch.from_numpy(data).int()
        self.target = torch.from_numpy(target).float()
        self.transform = transform
        
    def __getitem__(self, index):
            
        x = self.data[index]
        y = self.target[index]
        
        if self.transform:
            x = self.transform(x)
            
        return x, y
    
    def __len__(self):
        return len(self.data)
    
print('Converting to TorchDataset...')
    
x_all = []
y_all = []
for hdf_file in hdf5_files:
    f = h5py.File(hdf_file, 'r')
    y_all.append(f['continuous']['VentricularRate'][0])
    x = np.zeros(shape=(12, 5000))
    for (i, key) in enumerate(ecg_key_string_list):
        x[i][:] = f['ecg_rest'][key][:]
    x_all.append(x)
    
data = ECGDataset(np.asarray(x_all), np.asarray(y_all))

train_size = int(TRAIN_RATIO * len(data))
test_size = len(data) - train_size

train_dataset, test_dataset = torch.utils.data.random_split(data, [train_size, test_size])
dataloader_train=torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

for batch_idx, (data, target) in enumerate(dataloader_train):
    print('Batch idx {}, data shape {}, target shape {}'.format(batch_idx, data.shape, target.shape))

Data Loading finished (row:1000)
Converting to TorchDataset...
Batch idx 0, data shape torch.Size([64, 12, 5000]), target shape torch.Size([64])
Batch idx 1, data shape torch.Size([64, 12, 5000]), target shape torch.Size([64])
Batch idx 2, data shape torch.Size([64, 12, 5000]), target shape torch.Size([64])
Batch idx 3, data shape torch.Size([64, 12, 5000]), target shape torch.Size([64])
Batch idx 4, data shape torch.Size([64, 12, 5000]), target shape torch.Size([64])
Batch idx 5, data shape torch.Size([64, 12, 5000]), target shape torch.Size([64])
Batch idx 6, data shape torch.Size([64, 12, 5000]), target shape torch.Size([64])
Batch idx 7, data shape torch.Size([64, 12, 5000]), target shape torch.Size([64])
Batch idx 8, data shape torch.Size([64, 12, 5000]), target shape torch.Size([64])
Batch idx 9, data shape torch.Size([64, 12, 5000]), target shape torch.Size([64])
Batch idx 10, data shape torch.Size([64, 12, 5000]), target shape torch.Size([64])
Batch idx 11, data shape torch.Siz