In [15]:
import torch
from torch import nn
from torch.nn import functional as F
import pandas as pd
import numpy as np
from pytorch_lightning import LightningModule, Trainer, LightningDataModule
from tqdm.auto import tqdm
from joblib import Parallel, delayed
from torch.utils.data import Dataset, DataLoader

In [56]:
def load_ptbxl_split(representation_type, target, split, fs=100):
    data = torch.load(f'../../data/ptbxl/representations_{fs}/data/{representation_type}/{split}_data.pt')
    labels = np.load(f'../../data/ptbxl/representations_{fs}/labels/{target}/{split}_labels.npy')
    classes = np.load(f'../../data/ptbxl/representations_{fs}/labels/{target}/classes.npy', allow_pickle=True)
    classes = {i: classes[i] for i in range(len(classes))}
    return {'data': data, 'labels': labels, 'classes': classes}

def load_ptbxl_dataset(representation_type, target):
    return {split: load_ptbxl_split(representation_type, target, split) for split in ['train', 'val', 'test']}


class PTBXLDataset(Dataset):
    """PTB-XL Dataset class used in DeepLearning models."""

    def __init__(self, representation_type, target, split, transform=None):
        dataset = load_ptbxl_split(representation_type, target, split)
        self.data = dataset['data']
        self.labels = dataset['labels']
        self.classes = dataset['classes']
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        return self.data[idx].float(), self.labels[idx]


class PTBXLDataModule(LightningDataModule):
    """PTB-XL DataModule class used as DeepLearning models DataLoaders provider."""

    def __init__(self, representation_type, target: str = 'diagnostic_class', batch_size: int = 64, num_workers=8):
        super().__init__()
        self.representation_type = representation_type
        self.target = target
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.train = None
        self.val = None
        self.test = None

    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            self.train = PTBXLDataset(self.representation_type, self.target, split="train")
            self.val = PTBXLDataset(self.representation_type, self.target, split="val")
        if stage == "test" or stage is None:
            self.test = PTBXLDataset(self.representation_type, self.target, split="test")

    def train_dataloader(self):
        return DataLoader(self.train, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val, batch_size=10 * self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.test, batch_size=10 * self.batch_size, num_workers=self.num_workers)

In [2]:
import sys         
sys.path.append('./../../src/')
from data.ptbxl import PTBXLDataModule

In [3]:
ptbxl_datamodule = PTBXLDataModule(
    representation_type = 'whole_signal_waveforms',
    target = 'diagnostic_class'
)

In [4]:
ptbxl_datamodule.setup()

In [5]:
ptbxl_datamodule.train_dataloader().dataset[0][0].shape

torch.Size([12, 1000])

In [6]:
ptbxl_datamodule.train_dataloader().dataset.classes

{0: 'CD', 1: 'HYP', 2: 'MI', 3: 'NORM', 4: 'STTC'}