In [173]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader



In [223]:
HEALTHY = 0
SEVERE = 1


class HomogenicDataset(Dataset):
    def __init__(self, tensor, state):
        self.x = tensor
        values = torch.Tensor([state] * len(tensor))
        self.y = torch.stack([1- values,values],dim=1)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]


class CombinedDataset(Dataset):
    def __init__(self, severe_ds, healthy_ds):
        self.x = torch.concat([severe_ds[0],healthy_ds[0]])
        self.y = torch.concat([severe_ds[1],healthy_ds[1]])

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

In [164]:
def normalize(tensor):
    return torch.nn.BatchNorm1d((12))(tensor)


def prep_severe(severe_tensor):
    severe_tensor = severe_tensor.permute((0,2,1))
    severe_tensor = severe_tensor[:,:,::2].clone()
    return severe_tensor

In [224]:
SEVERE_PATH = "C:\\Users\\rbenjos\\Desktop\\ekg_proj\\severe_ds.pt"
HEALTHY_PATH = "C:\\Users\\rbenjos\\Desktop\\ekg_proj\\healthy_ds.pt"

severe_tensor = torch.load(SEVERE_PATH)
severe_tensor = normalize(prep_severe(severe_tensor))
severe_ds = HomogenicDataset(severe_tensor, SEVERE)

healthy_tensor = torch.load(HEALTHY_PATH)
healthy_tensor = normalize(healthy_tensor)
healthy_ds = HomogenicDataset(healthy_tensor, HEALTHY)

In [178]:
severe_ds.x.shape, healthy_ds.x.shape

(torch.Size([41, 12, 1250]), torch.Size([3000, 12, 1250]))

In [152]:
severe_loader = DataLoader(severe_ds,batch_size=4,shuffle=True)
healthy_loader = DataLoader(healthy_ds,batch_size=4,shuffle=True)

In [231]:
SEVERE_TRAIN_SIZE = 30
HEALTHY_TRAIN_SIZE = 2800

training_set = CombinedDataset(healthy_ds[:HEALTHY_TRAIN_SIZE], severe_ds[:SEVERE_TRAIN_SIZE])
test_set = CombinedDataset(healthy_ds[HEALTHY_TRAIN_SIZE:], severe_ds[SEVERE_TRAIN_SIZE:])


In [232]:
TRAIN_SET_PATH =  "C:\\Users\\rbenjos\\Desktop\\ekg_proj\\train_ds.pt"
TEST_SET_PATH =  "C:\\Users\\rbenjos\\Desktop\\ekg_proj\\test_ds.pt"

torch.save(training_set,TRAIN_SET_PATH)
torch.save(test_set,TEST_SET_PATH)