In [3]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal
import wfdb
import pandas as pd
import wignerdpy
from wignerdpy.toolkits import signal_toolkit
from torch.utils.data import Dataset, DataLoader, random_split

In [4]:
excel = "/home/abhishek/rashad_internship/Physionet/ptb-xl-1.0.3/ptbxl_database.csv"
path = '/home/abhishek/rashad_internship/Physionet/ptb-xl-1.0.3/'

R Peak detection

In [7]:
import pandas as pd
import torch
from torch.utils.data import Dataset
import wfdb
import wignerdpy
from wignerdpy.toolkits import signal_toolkit
from torchvision import transforms
import numpy as np
from scipy.ndimage import zoom
import ast
import random
from scipy.signal import ShortTimeFFT,get_window
import cv2


class SingleToThreeChannel:
    def __call__(self, image):
        return image.repeat(3, 1, 1)
    
class onedimTotwodim:
    def __call__(self, data):
        # Assuming this is your existing transformation
        # This can be replaced with STFT transformation
        stft_transformer = ShortTimeFFT(get_window('hann', 300), hop=100, fs=500)
        Zxx = stft_transformer.stft(data)
        spectrogram = np.abs(Zxx)  # Take absolute values of STFT
        plt.figure(figsize=(6, 6))  # Adjust figure size as needed
        plt.pcolormesh(spectrogram, shading='gouraud')
        plt.axis('off')  # Remove axes for a cleaner image
        plt.savefig('temp_spectrogram.png', bbox_inches='tight', pad_inches=0)
        plt.close()

        # Load the saved image
        image = cv2.imread('temp_spectrogram.png')
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB

        # Resize the image to 224x224
        resized_image = cv2.resize(image, (224, 224))

        # Convert to tensor
        resized_image_tensor = transforms.ToTensor()(resized_image)

        return resized_image_tensor
# Correcting the transforms.Compose
transform = transforms.Compose([
    onedimTotwodim(),  # Apply WVD transformation
    # transforms.ToTensor(), # Convert single-channel to three-channel
])

class Custom_class(Dataset):
    def __init__(self, excelfile, path, num_data, transform=None, data_split='train', fold=None):
        self.dat = pd.read_csv(excelfile)
        self.col = self.dat['filename_hr']
        self.label = self.dat['scp_codes']
        self.strat_fold = self.dat['strat_fold']
        self.path = path
        self.transform = transform
        self.num_data = num_data
        self.data_split = data_split
        self.fold = fold
        

        if self.data_split == 'train':
            self.indices = [idx for idx in range(self.num_data) if (self.strat_fold[idx] != fold)]
        elif self.data_split == 'test':
            self.indices = [idx for idx in range(self.num_data) if (self.strat_fold[idx] == fold)]
        elif self.data_split == 'val':
            self.indices = [idx for idx in range(self.num_data) if (self.strat_fold[idx] == fold)]

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        # Randomly select a channel between 0 and 11
        channel = random.randint(0, 11)
        idx = self.indices[idx]
        
        # Read the signal for the randomly selected channel
        y, _ = wfdb.rdsamp(self.path + self.col[idx], channels=[channel])
        y = y.flatten()  # Ensure y is a 1D array
        y = y[:1500]
        
        scp_code_dict = ast.literal_eval(self.label[idx])
        first_key = max(scp_code_dict, key=scp_code_dict.get)  # one key in scp_code dictionary with highest value is considered as label
        label = 0 if first_key == 'NORM' else 1
        
        if self.transform:
            y = self.transform(y)
        
            
        return y, label

# Example usage
train_dataset = Custom_class(excel,path, 1000, transform,data_split='train',fold = 10)
validation_dataset = Custom_class(excel,path, 1000, transform,data_split='test',fold = 10)

In [None]:
a,label = train_dataset[0]
print(label)

In [8]:

# Create DataLoaders for train and test sets
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(validation_dataset, batch_size=32, shuffle=False)

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50
from torch.nn import MultiheadAttention

class CustomResNet50(nn.Module):
    def __init__(self, num_classes=2):
        super(CustomResNet50, self).__init__()
        # Load pre-trained ResNet-50
        self.resnet = resnet50(pretrained=True)

        # Modify the first convolutional layer to accept single-channel input
        self.resnet.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)

        # Extract layers up to the penultimate layer
        self.features = nn.Sequential(*list(self.resnet.children())[:-2])

        # Define Multi-Head Attention parameters
        self.attention = MultiheadAttention(embed_dim=2048, num_heads=1, batch_first=True)

        # Fully connected layers
        self.fc1 = nn.Linear(2048, 512)
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, num_classes)

    def forward(self, x):
        # Extract features
        x = self.features(x)

        # Global average pooling
        x = F.adaptive_avg_pool2d(x, (1, 1)).view(x.size(0), -1)

        # Add dimension for attention (batch_size, seq_len, embed_dim)
        x = x.unsqueeze(1)

        # Apply multi-head attention
        attn_output, _ = self.attention(x, x, x)

        # Remove the extra dimension
        attn_output = attn_output.squeeze(1)

        # Fully connected layers
        x = F.relu(self.fc1(attn_output))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        x = torch.sigmoid(x)

        return x

# Example usage
model = CustomResNet50(num_classes=1)

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [9]:
import torch.optim as optim
from torch.optim import lr_scheduler
from tqdm import tqdm
import copy

criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
# scheduler = LambdaLR(optimizer, lr_lambda)
scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

num_epochs = 25
# best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0

def validate_model(model, dataloader, criterion):
    model.eval()  # Set model to evaluation mode
    running_loss = 0.0
    running_corrects = 0

    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            labels = labels.float()

            outputs = model(inputs).squeeze(1)
            preds = torch.round(outputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = running_corrects.double() / len(dataloader.dataset)
    
    return epoch_loss, epoch_acc

for epoch in range(num_epochs):
    running_loss = 0.0
    running_corrects = 0

    print(f'Epoch {epoch+1}/{num_epochs}')
    
    # Wrap the train_loader with tqdm
    for inputs, labels in tqdm(train_loader, desc='Training'):
        inputs = inputs.to(device)
        
        labels = labels.to(device)
        labels = labels.float()

        optimizer.zero_grad()

        outputs = model(inputs).squeeze(1)
        preds = torch.round(outputs)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)

    epoch_loss = running_loss / len(train_loader.dataset)
    epoch_acc = running_corrects.double() / len(train_loader.dataset)
    scheduler.step()

    print(f'Training - Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}')

    # Validate the model
    val_loss, val_acc = validate_model(model, val_loader, criterion)
    print(f'Validation - Epoch {epoch+1}/{num_epochs}, Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}')

    # Deep copy the model if the current validation accuracy is the best so far
    if val_acc > best_acc:
        best_acc = val_acc
        best_model_wts = copy.deepcopy(model.state_dict())
        # Save the best model
        torch.save(model.state_dict(), "best_model_resnet.pth")

# Load best model weights
model.load_state_dict(best_model_wts)

print(f"Training complete. Best validation accuracy: {best_acc:.4f}")


Epoch 1/25


Training: 100%|██████████| 28/28 [01:45<00:00,  3.78s/it]


Training - Epoch 1/25, Loss: 0.6709, Accuracy: 0.5890
Validation - Epoch 1/25, Loss: 0.6939, Accuracy: 0.5349
Epoch 2/25


Training: 100%|██████████| 28/28 [01:31<00:00,  3.28s/it]


Training - Epoch 2/25, Loss: 0.6981, Accuracy: 0.4902
Validation - Epoch 2/25, Loss: 0.7026, Accuracy: 0.4884
Epoch 3/25


Training: 100%|██████████| 28/28 [01:32<00:00,  3.32s/it]


Training - Epoch 3/25, Loss: 0.6932, Accuracy: 0.5258
Validation - Epoch 3/25, Loss: 0.6942, Accuracy: 0.4884
Epoch 4/25


Training: 100%|██████████| 28/28 [01:31<00:00,  3.28s/it]


Training - Epoch 4/25, Loss: 0.6923, Accuracy: 0.5258
Validation - Epoch 4/25, Loss: 0.6958, Accuracy: 0.4884
Epoch 5/25


Training: 100%|██████████| 28/28 [01:32<00:00,  3.31s/it]


Training - Epoch 5/25, Loss: 0.6939, Accuracy: 0.5040
Validation - Epoch 5/25, Loss: 0.6931, Accuracy: 0.5116
Epoch 6/25


Training: 100%|██████████| 28/28 [01:34<00:00,  3.36s/it]


Training - Epoch 6/25, Loss: 0.6932, Accuracy: 0.5166
Validation - Epoch 6/25, Loss: 0.6940, Accuracy: 0.4884
Epoch 7/25


Training: 100%|██████████| 28/28 [01:35<00:00,  3.43s/it]


Training - Epoch 7/25, Loss: 0.6923, Accuracy: 0.5258
Validation - Epoch 7/25, Loss: 0.6946, Accuracy: 0.4884
Epoch 8/25


Training: 100%|██████████| 28/28 [01:32<00:00,  3.30s/it]


Training - Epoch 8/25, Loss: 0.6913, Accuracy: 0.5258
Validation - Epoch 8/25, Loss: 0.6949, Accuracy: 0.4884
Epoch 9/25


Training: 100%|██████████| 28/28 [01:34<00:00,  3.36s/it]


Training - Epoch 9/25, Loss: 0.6911, Accuracy: 0.5258
Validation - Epoch 9/25, Loss: 0.6942, Accuracy: 0.4884
Epoch 10/25


Training: 100%|██████████| 28/28 [01:30<00:00,  3.23s/it]


Training - Epoch 10/25, Loss: 0.6905, Accuracy: 0.5258
Validation - Epoch 10/25, Loss: 0.6937, Accuracy: 0.4884
Epoch 11/25


Training: 100%|██████████| 28/28 [01:35<00:00,  3.40s/it]


Training - Epoch 11/25, Loss: 0.6892, Accuracy: 0.5258
Validation - Epoch 11/25, Loss: 0.6925, Accuracy: 0.4884
Epoch 12/25


Training: 100%|██████████| 28/28 [01:31<00:00,  3.25s/it]


Training - Epoch 12/25, Loss: 0.6873, Accuracy: 0.5258
Validation - Epoch 12/25, Loss: 0.6946, Accuracy: 0.4884
Epoch 13/25


Training: 100%|██████████| 28/28 [01:30<00:00,  3.23s/it]


Training - Epoch 13/25, Loss: 0.6752, Accuracy: 0.5258
Validation - Epoch 13/25, Loss: 0.6997, Accuracy: 0.4884
Epoch 14/25


Training: 100%|██████████| 28/28 [01:30<00:00,  3.23s/it]


Training - Epoch 14/25, Loss: 0.6667, Accuracy: 0.6119
Validation - Epoch 14/25, Loss: 0.6602, Accuracy: 0.6202
Epoch 15/25


Training: 100%|██████████| 28/28 [01:30<00:00,  3.23s/it]


Training - Epoch 15/25, Loss: 0.6446, Accuracy: 0.6659
Validation - Epoch 15/25, Loss: 0.6715, Accuracy: 0.6124
Epoch 16/25


Training: 100%|██████████| 28/28 [01:30<00:00,  3.23s/it]


Training - Epoch 16/25, Loss: 0.6336, Accuracy: 0.6935
Validation - Epoch 16/25, Loss: 0.6601, Accuracy: 0.6279
Epoch 17/25


Training: 100%|██████████| 28/28 [01:37<00:00,  3.50s/it]


Training - Epoch 17/25, Loss: 0.6376, Accuracy: 0.6487
Validation - Epoch 17/25, Loss: 0.6624, Accuracy: 0.6124
Epoch 18/25


Training: 100%|██████████| 28/28 [01:34<00:00,  3.37s/it]


Training - Epoch 18/25, Loss: 0.6339, Accuracy: 0.6590
Validation - Epoch 18/25, Loss: 0.6443, Accuracy: 0.6434
Epoch 19/25


Training: 100%|██████████| 28/28 [01:41<00:00,  3.62s/it]


Training - Epoch 19/25, Loss: 0.6194, Accuracy: 0.6728
Validation - Epoch 19/25, Loss: 0.6452, Accuracy: 0.6357
Epoch 20/25


Training: 100%|██████████| 28/28 [01:33<00:00,  3.35s/it]


Training - Epoch 20/25, Loss: 0.6155, Accuracy: 0.6567
Validation - Epoch 20/25, Loss: 0.7005, Accuracy: 0.5504
Epoch 21/25


Training: 100%|██████████| 28/28 [01:40<00:00,  3.59s/it]


Training - Epoch 21/25, Loss: 0.6067, Accuracy: 0.6762
Validation - Epoch 21/25, Loss: 0.6358, Accuracy: 0.6279
Epoch 22/25


Training: 100%|██████████| 28/28 [01:32<00:00,  3.32s/it]


Training - Epoch 22/25, Loss: 0.5902, Accuracy: 0.6946
Validation - Epoch 22/25, Loss: 0.6466, Accuracy: 0.6202
Epoch 23/25


Training: 100%|██████████| 28/28 [01:36<00:00,  3.44s/it]


Training - Epoch 23/25, Loss: 0.5988, Accuracy: 0.6889
Validation - Epoch 23/25, Loss: 0.5881, Accuracy: 0.7287
Epoch 24/25


Training: 100%|██████████| 28/28 [01:41<00:00,  3.61s/it]


Training - Epoch 24/25, Loss: 0.6016, Accuracy: 0.6797
Validation - Epoch 24/25, Loss: 0.6229, Accuracy: 0.6589
Epoch 25/25


Training: 100%|██████████| 28/28 [01:47<00:00,  3.83s/it]


Training - Epoch 25/25, Loss: 0.6113, Accuracy: 0.6762
Validation - Epoch 25/25, Loss: 0.6241, Accuracy: 0.6512
Training complete. Best validation accuracy: 0.7287
