In [1]:
import os
import pandas as pd
from torchvision.io import read_image
import re
import wfdb
import wfdb.processing
import scipy
from torch.utils.data import Dataset
import numpy as np
import json
import torch.nn as nn
import torch
from tqdm import tqdm

In [2]:
def extract_segment_with_padding(z, k, N):
    # Rozmiar segmentu to 2N + 1
    start_idx = k - N
    end_idx = k + N + 1  # Indeks końcowy +1, ponieważ Python używa wykluczającego indeksu
    
    # Upewnij się, że start_idx i end_idx mieszczą się w granicach tablicy
    if start_idx < 0:
        # Jeśli start_idx jest poza zakresem, dopełnij na początku
        padding_left = np.median(z[:end_idx])  # Wypełniamy medianą
        segment = np.concatenate([np.full(-start_idx, padding_left), z[:end_idx]])
    elif end_idx > len(z):
        # Jeśli end_idx jest poza zakresem, dopełnij na końcu
        padding_right = np.median(z[start_idx:])  # Wypełniamy medianą
        segment = np.concatenate([z[start_idx:], np.full(end_idx - len(z), padding_right)])
    else:
        # Normalny przypadek, kiedy zakres mieści się w tablicy
        segment = z[start_idx:end_idx]
    
    return segment

class MIT_BIH_Arythmia(Dataset):
    def __init__(self,N, M, dataset_dir = 'Datasets/files/', fs = 10, filename = "MIT-BIH_Arrythmia.json"):
        """
        n - number of samples of orginal signal resampled to fs, interval [-n,n]
        m - qrs times, interval [-m,m]
        """
        ecg_list = []
        exclusion_lst = ["00735", "03665", "04043", "04936", "05091", "06453", "08378", "08405", "08434", "08455"]
        for file in os.listdir(dataset_dir):
            name = re.match(r'^(.*\d\d+)\.atr$', file)
            if name:
                if name.group(1) in exclusion_lst:
                    continue
            if name:
                record = wfdb.rdsamp(f"{dataset_dir}{name.group(1)}") 
                annotation = wfdb.rdann(f"{dataset_dir}{name.group(1)}", 'atr')
                signal = record[0][:,0]
                fs_original = record[1]["fs"]
                num_samples_target = int(signal.shape[0] * fs / fs_original)
                resampled_signal = scipy.signal.resample(signal, num_samples_target)
                annotation_times_resampled = (annotation.sample * fs) / fs_original
                resampled_annotation = wfdb.Annotation('atr',annotation.symbol,annotation_times_resampled.astype(int),aux_note=annotation.aux_note)
                ecg_list.append({"name": name.group(1),"rec" : resampled_signal, "ann" : resampled_annotation})
        self.samples_list = []
        self.label_list = []
        self.qrs_samples = []
        for dic in ecg_list:
            print(dic["name"])
            # xqrs = wfdb.processing.XQRS(sig=dic["rec"], fs=fs)
            # xqrs.detect()
            # qrs_inds = xqrs.qrs_inds
            for n,i in enumerate(dic["ann"].sample):
                self.label_list.append(1 if dic["ann"].aux_note[n] == '(AFIB' else 0)
                self.samples_list.append(list(extract_segment_with_padding(dic["rec"], dic["ann"].sample[n],N)))
                # nearest_qrs_idx = find_nearest_qrs_index(dic["ann"].sample[n], qrs_inds)
                # self.qrs_samples.append(list(extract_segment_with_padding(qrs_inds,nearest_qrs_idx,M)))
        data = {
            'samples_list': self.samples_list,  # This would work if the segments are simple numeric lists
            'label_list': self.label_list,
            'qrs_samples': self.qrs_samples
        }
        with open(filename, 'w') as f:
            json.dump(data, f)
                
    def __len__(self):
        return len(self.label_list)

    def __getitem__(self, idx):
        data = torch.Tensor(self.samples_list[idx])
        label = torch.Tensor(self.label_list[idx])
        return data, label

In [3]:
ds = MIT_BIH_Arythmia(100,5,fs=100)

04015
04048
04126
04746
04908
05121
05261
06426
06995
07162
07859
07879
07910
08215
08219


In [9]:


class SimpleConv(nn.Module):
    def __init__(self, input = 201, num_classes = 2):
        super(SimpleConv, self).__init__()
        self.model = nn.Sequential(
            nn.Conv1d(input, 64, kernel_size=7, padding='same'),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(64, 64, kernel_size=3, padding='same'),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(64, 128, kernel_size=3, padding='same'),      # out 1 x 128 x n
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.MaxPool1d(2),                        # out 1 x 128 x n//2
            nn.Conv1d(128, 128, kernel_size=3, padding='same'),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Conv1d(128, 256, kernel_size=3, padding='same'),     # out 1 x 256 x n//2
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.MaxPool1d(2),                        # out 1 x 256 x n//4
            nn.Conv1d(256, 256, kernel_size=3, padding='same'),     
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Conv1d(256, 512, kernel_size=3, padding='same'),     
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.MaxPool1d(2),                        # out 1 x 512 x n//8
            nn.Flatten(),
            nn.Linear(512*(input//8), 256),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes),
        )
        self.model.to('cuda:0')

    def forward(self, x):

        return self.model(x)
    
    def train_model(self, train_loader, valid_loader, num_epochs = 5, learning_rate=0.001, save_best = False, save_thr = 0.94):
        best_accuracy = 0.0
        total_step = len(train_loader)
        # Loss and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.RMSprop(self.parameters(), lr=learning_rate, weight_decay = 0.005, momentum = 0.9)  

        for epoch in range(num_epochs):
            # self.train()
            correct = 0
            total = 0
            for i, (images, labels) in enumerate(tqdm(train_loader)):
                # Move tensors to the configured device
                images = images.float().to("cuda")
                labels = labels.type(torch.LongTensor)
                labels = labels.to("cuda")


                optimizer.zero_grad()

                # Forward pass
                outputs = self.forward(images)
                loss = criterion(outputs, labels)
                # Backward and optimize
                loss.backward()
                
                optimizer.step()

                # accuracy
                _, predicted = torch.max(outputs.data, 1)
                correct += (torch.eq(predicted, labels)).sum().item()
                total += labels.size(0)

                del images, labels, outputs

            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Accuracy: {:.4f}'
                            .format(epoch+1, num_epochs, i+1, total_step, loss.item(), (float(correct))/total))


            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            # Validation
            with torch.no_grad():
                correct = 0
                total = 0
                for images, labels in valid_loader:
                    images = images.float().to("cuda")
                    labels = labels.to("cuda")
                    outputs = self.forward(images)
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (torch.eq(predicted, labels)).sum().item()
                    del images, labels, outputs
                if(((100 * correct / total) > best_accuracy) and save_best and ((100 * correct / total) > save_thr)):
                    torch.save(self.state_dict(), "best_resnet50_MINST-DVS2.pt")

                print('Accuracy of the network: {} %'.format( 100 * correct / total))

In [10]:
model = SimpleConv()

In [11]:
from torch.utils.data import DataLoader, random_split
train_set, val_set = random_split(ds, [0.8, 0.2])
train = DataLoader(train_set, batch_size=32, shuffle=True)
val = DataLoader(val_set, batch_size=32, shuffle=True)

In [8]:
model.train_model(train,val)

  0%|          | 0/8 [00:00<?, ?it/s]


RuntimeError: stack expects each tensor to be equal size, but got [0] at entry 0 and [1] at entry 3