### **INSTALLING NECESSARY LIBRARIES**

### **Note:**
Run This Notebook With GPU For Better Performance Or Change The Code Structure Accordingly To Run With CPU.

In [None]:
!pip install -q wfdb

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!cp "path of the dataset in gdrive" .

### **IMPORT LIBRARIES**

In [None]:
from __future__ import print_function
import numpy as np
import pandas as pd
import torch
import torchvision as tv
import torch.nn as nn
import torch.nn.functional as F
import os
import matplotlib.pyplot as plt
%matplotlib inline
import wfdb
import time
import random
from sklearn.preprocessing import minmax_scale
import sys
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter

### **CHANNEL CREATION AND REPRODUCEABILITY**

In [None]:
seed_num = 49
torch.manual_seed(seed_num)
run_num = 7
channel_1 = 'v6'
channel_2 = 'vz'
channel_3 = 'ii'
print(seed_num, run_num, channel_1, channel_2, channel_3)

### **LOAD REAL DATA (PTBDB)**

In [None]:
# Loading The Real Data (PTBDB)
with open('ptbdb_data/RECORDS') as fp:  
    lines = fp.readlines()

files_unhealthy, files_healthy = [], []

for file in lines:
    file_path = "ptbdb_data/" + file[:-1] + ".hea"
    
    # Reading The Header To Determine The Class
    if 'Myocardial infarction' in open(file_path).read():
        files_unhealthy.append(file)
        
    if 'Healthy control' in open(file_path).read():
        files_healthy.append(file)

###**SHUFFLING DATA**

In [None]:
# Shuffling Data (Cross-Validation)
np.random.seed(seed_num)
np.random.shuffle(files_unhealthy)
np.random.shuffle(files_healthy)

healthy_train = files_healthy[:int(0.8*len(files_healthy))]
healthy_val = files_healthy[int(0.8*len(files_healthy)):]
unhealthy_train = files_unhealthy[:int(0.8*len(files_unhealthy))]
unhealthy_val = files_unhealthy[int(0.8*len(files_unhealthy)):]

### **LIST OF INTERSECTIONS**

In [None]:
def intersection(lst1, lst2): 
    return list(set(lst1) & set(lst2)) 

patient_ids_unhealthy_train = [element[:10] for element in unhealthy_train]
patient_ids_unhealthy_val = [element[:10] for element in unhealthy_val]
patient_ids_healthy_train = [element[:10] for element in healthy_train]
patient_ids_healthy_val = [element[:10] for element in healthy_val]

intersection_unhealthy = intersection(patient_ids_unhealthy_train, patient_ids_unhealthy_val)
intersection_healthy = intersection(patient_ids_healthy_train, patient_ids_healthy_val)

### **INTERSECTION (UNHEALTHY)**

In [None]:
# UnHealthy
move_to_train = intersection_unhealthy[:int(0.5*len(intersection_unhealthy))]
move_to_val = intersection_unhealthy[int(0.5*len(intersection_unhealthy)):]

for patient_id in move_to_train:
  in_val = []
    
  # Finding And Removing All Files In Val
  for file_ in unhealthy_val:
    if file_[:10] == patient_id:
      in_val.append(file_)
      unhealthy_val.remove(file_)
            
    # Adding To Train
  for file_ in in_val:
    unhealthy_train.append(file_)
       
    
for patient_id in move_to_val:
  in_train = []
    
    # Finding And Removing All Files In Val
  for file_ in unhealthy_train:
    if file_[:10] == patient_id:
      in_train.append(file_)
      unhealthy_train.remove(file_)
            
    # Adding To Train
  for file_ in in_train:
    unhealthy_val.append(file_)

### **INTERSECTION (HEALTHY)**

In [None]:
# Healthy
move_to_train = intersection_healthy[:int(0.5*len(intersection_healthy))]
move_to_val = intersection_healthy[int(0.5*len(intersection_healthy)):]

for patient_id in move_to_train:
  in_val = []
    
  # Finding And Removing All Files In Val
  for file_ in healthy_val:
    if file_[:10] == patient_id:
      in_val.append(file_)
      healthy_val.remove(file_)
            
  # Adding To Train
  for file_ in in_val:
    healthy_train.append(file_)
        

for patient_id in move_to_val:
  in_train = []
    
  # Finding And Removing All Files In Val
  for file_ in healthy_train:
    if file_[:10] == patient_id:
      in_train.append(file_)
      healthy_train.remove(file_)
            
  # Adding To Train
  for file_ in in_train:
    healthy_val.append(file_)

### **DATA SEPARATION FOR TRAINING AND VALIDATION**

In [None]:
data_healthy_train = []
for file in healthy_train:
  data_v4, _ = wfdb.rdsamp("ptbdb_data/" + file[:-1], channel_names=[str(channel_1)])
  data_v5, _ = wfdb.rdsamp("ptbdb_data/" + file[:-1], channel_names=[str(channel_2)])
  data_v6, _ = wfdb.rdsamp("ptbdb_data/" + file[:-1], channel_names=[str(channel_3)])
  data = [data_v4.flatten(), data_v5.flatten(), data_v6.flatten()]
  data_healthy_train.append(data)
data_healthy_val = []
for file in healthy_val:
  data_v4, _ = wfdb.rdsamp("ptbdb_data/" + file[:-1], channel_names=[str(channel_1)])
  data_v5, _ = wfdb.rdsamp("ptbdb_data/" + file[:-1], channel_names=[str(channel_2)])
  data_v6, _ = wfdb.rdsamp("ptbdb_data/" + file[:-1], channel_names=[str(channel_3)])
  data = [data_v4.flatten(), data_v5.flatten(), data_v6.flatten()]
  data_healthy_val.append(data)
data_unhealthy_train = []
for file in unhealthy_train:
  data_v4, _ = wfdb.rdsamp("ptbdb_data/" + file[:-1], channel_names=[str(channel_1)])
  data_v5, _ = wfdb.rdsamp("ptbdb_data/" + file[:-1], channel_names=[str(channel_2)])
  data_v6, _ = wfdb.rdsamp("ptbdb_data/" + file[:-1], channel_names=[str(channel_3)])
  data = [data_v4.flatten(), data_v5.flatten(), data_v6.flatten()]
  data_unhealthy_train.append(data)
data_unhealthy_val = []
for file in unhealthy_val:
  data_v4, _ = wfdb.rdsamp("ptbdb_data/" + file[:-1], channel_names=[str(channel_1)])
  data_v5, _ = wfdb.rdsamp("ptbdb_data/" + file[:-1], channel_names=[str(channel_2)])
  data_v6, _ = wfdb.rdsamp("ptbdb_data/" + file[:-1], channel_names=[str(channel_3)])
  data = [data_v4.flatten(), data_v5.flatten(), data_v6.flatten()]
  data_unhealthy_val.append(data)

In [None]:
data_healthy_train = np.asarray(data_healthy_train, object)
data_healthy_val = np.asarray(data_healthy_val, object)
data_unhealthy_train = np.asarray(data_unhealthy_train, object)
data_unhealthy_val = np.asarray(data_unhealthy_val, object)

window_size = 10000

In [None]:
#Training Sets
data_unhealthy_train_np = np.arange(len(data_unhealthy_train))
data_healthy_train_np = np.arange(len(data_healthy_train))

#Validation Sets
data_unhealthy_val_np = np.arange(len(data_unhealthy_val))
data_healthy_val_np = np.arange(len(data_healthy_val))

### **DEFINING 'GET BATCHES' FUNCTION**

In [None]:
def get_batch(batch_size, split='train'):
  if split == 'train':
    unhealthy_indices = random.sample(tuple(data_unhealthy_train_np), k=int(batch_size / 2))
    healthy_indices = random.sample(tuple(data_healthy_train_np), k=int(batch_size / 2))
    unhealthy_batch = data_unhealthy_train[unhealthy_indices]
    healthy_batch = data_healthy_train[healthy_indices]
  elif split == 'val': 
    unhealthy_indices = random.sample(tuple(data_unhealthy_val_np), k=int(batch_size / 2))
    healthy_indices = random.sample(tuple(data_healthy_val_np), k=int(batch_size / 2))
    unhealthy_batch = data_unhealthy_val[unhealthy_indices]
    healthy_batch = data_healthy_val[healthy_indices]
    
  batch_x = []
  for sample in unhealthy_batch:
    
    start = random.choice(np.arange(len(sample[0]) - window_size))

    # Normalize
    normalized_1 = minmax_scale(sample[0][start:start+window_size])
    normalized_2 = minmax_scale(sample[1][start:start+window_size])
    normalized_3 = minmax_scale(sample[2][start:start+window_size])
    normalized = np.array((normalized_1, normalized_2, normalized_3))
        
    batch_x.append(normalized)
        
  for sample in healthy_batch:
    start = random.choice(np.arange(len(sample[0]) - window_size))
        
    # Normalize
    normalized_1 = minmax_scale(sample[0][start:start+window_size])
    normalized_2 = minmax_scale(sample[1][start:start+window_size])
    normalized_3 = minmax_scale(sample[2][start:start+window_size])
    normalized = np.array((normalized_1, normalized_2, normalized_3))
        
    batch_x.append(normalized)
    
  batch_y = [0.1 for _ in range(int(batch_size / 2))]
  for _ in range(int(batch_size / 2)):
    batch_y.append(0.9)
        
  indices = np.arange(len(batch_y))
  np.random.shuffle(indices)
    
  batch_x = np.array(batch_x)
  batch_y = np.array(batch_y)
    
  batch_x = batch_x[indices]
  batch_y = batch_y[indices]
    
  batch_x = np.reshape(batch_x, (-1, 3, window_size))
  batch_x = torch.from_numpy(batch_x)
  batch_x = batch_x.float().cuda()
  batch_x = batch_x.float()
    
  batch_y = np.reshape(batch_y, (-1, 1))
  batch_y = torch.from_numpy(batch_y)
  batch_y = batch_y.float().cuda()
  batch_y = batch_y.float()
  
  return batch_x, batch_y

### **CNN ARCHITECTURE**

In [None]:
class ConvNetQuake(nn.Module):
  def __init__(self):
    super(ConvNetQuake, self).__init__()
    
    self.conv1 = nn.Conv1d(in_channels=3, out_channels=32, kernel_size=3, stride=2, padding=1)
    self.conv2 = nn.Conv1d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1)
    self.conv3 = nn.Conv1d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1)
    self.conv4 = nn.Conv1d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1)
    self.conv5 = nn.Conv1d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1)
    self.conv6 = nn.Conv1d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1)
    self.conv7 = nn.Conv1d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1)
    self.conv8 = nn.Conv1d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1)
    self.linear1 = nn.Linear(1280, 128)
    self.linear2 = nn.Linear(128, 1)
    self.sigmoid = nn.Sigmoid()
    self.bn1 = nn.BatchNorm1d(32)
    self.bn2 = nn.BatchNorm1d(32)
    self.bn3 = nn.BatchNorm1d(32)
    self.bn4 = nn.BatchNorm1d(32)
    self.bn5 = nn.BatchNorm1d(32)
    self.bn6 = nn.BatchNorm1d(32)
    self.bn7 = nn.BatchNorm1d(32)
    self.bn8 = nn.BatchNorm1d(32)

  def forward(self, x):
    x = self.bn1(F.relu((self.conv1(x))))
    x = self.bn2(F.relu((self.conv2(x))))
    x = self.bn3(F.relu((self.conv3(x))))
    x = self.bn4(F.relu((self.conv4(x))))
    x = self.bn5(F.relu((self.conv5(x))))
    x = self.bn6(F.relu((self.conv6(x))))
    x = self.bn7(F.relu((self.conv7(x))))
    x = self.bn8(F.relu((self.conv8(x))))
    x = x.view(x.size(0), -1) 
    x = self.linear1(x)
    x = self.linear2(x)
    x = self.sigmoid(x)

    return x

### **DEFINING MODEL AND PARAMETERS**
1. MODEL = CONVNETQUAKE
2. OPTIMISER = ADAM
3. REGULARIZER = L2 REGULARIZER(RIDGE REGULARIZER)
4. LOSS FUNCTION = BINARY CROSS-ENTROPY

In [None]:
# Defining The Model
model = ConvNetQuake()
model.cuda()

model = nn.DataParallel(model, device_ids=[0])

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.0001)
criterion = nn.BCELoss()

### **TRAINING AND VALIDATION OF THE MODEL**

In [None]:
# Training loop
writer = SummaryWriter('/content/Runs')

# num_iters = 30000
# num_iters = 35000
num_iters = 5000
batch_size = 10

acc_values = []
acc_values_train = []

for iters in range(num_iters):
  batch_x, batch_y = get_batch(batch_size, split='train')
  y_pred = model(batch_x)
  loss = criterion(y_pred, batch_y)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
 
  # Validation
  if iters%100 == 0 and iters != 0:
    writer.add_scalar('Loss/train', loss, iters)
    with torch.no_grad():
      # Test_set
      iterations = 100
      avg_acc = 0
      for _ in range(iterations):
        batch_x, batch_y = get_batch(batch_size, split='val')
        cleaned = model(batch_x)
        count = 0
        acc = 0
        for num in cleaned:
          if int(torch.round(num)) == int(torch.round(batch_y[count])):
            acc += 10
            count += 1
        avg_acc += acc
        
      acc_values.append((avg_acc / iterations))
      writer.add_scalar('Accuracy/val', (avg_acc / iterations), iters)
      
      # Train_set
      iterations = 100
      avg_acc_train = 0
      
      for _ in range(iterations):
        batch_x, batch_y = get_batch(batch_size, split='train')
        cleaned = model(batch_x)
        
        count = 0
        acc = 0
        for num in cleaned:
          if int(torch.round(num)) == int(torch.round(batch_y[count])):
            acc += 10
            count += 1
        avg_acc_train += acc
        
      acc_values_train.append((avg_acc_train / iterations))
      writer.add_scalar('Accuracy/train', (avg_acc_train / iterations), iters)
        
        # Printing The Values Of Iters, Loss, And Accuracy
    print(f"Iteration {iters}: Loss = {loss:.4f}, Validation Accuracy = {(avg_acc / iterations):.4f}, Train Accuracy = {(avg_acc_train / iterations):.4f}")

### **LOADING THE TRAINED MODEL AND PREDICTING CUSTOM ECG SIGNALS**

In [None]:
model = nn.DataParallel(model)
model.load_state_dict(torch.load("MID_Model_Dicts.pth", map_location=torch.device('cpu')))

In [None]:
def predict(input_path:str,model):
  record = input_path

  #Read The Input ECG Signals And Convert Into NP Array
  data, _ = wfdb.rdsamp(record)
  data_np = np.asarray(data)
  start = random.choice(np.arange(len(data_np) - 10000))
  
  #Normalize The Array With MinMax Scale
  normalized = minmax_scale(data_np[start:start + 10000])
  
  #Reshape According To The Trained Model Size
  np_reshape = np.reshape(normalized, (-1, 3, 10000))
  
  #Convert Arrays To Tensors
  pred_torch = torch.from_numpy(np_reshape)
  pred_float = pred_torch.float()
  
  #Evaluate Using The Model
  with torch.no_grad():
    model.eval()
    ecg_pred = model(pred_float)
    ecg_pred_probs = torch.sigmoid(ecg_pred)*100
    pred = torch.mean(ecg_pred_probs)
    prediction = round(float(pred.item()))

  #Limit The Predictions Accordingly
    if prediction <= 50:
      print(f"This Patient Has Myocardial Infarction")
    else:
      print(f"This Patient Is Healthy")

In [None]:
record_path = "/content/s0010_re"

In [None]:
predict(record_path, model = model)

### **PLOTTING THE RESPECTIVE ECG SIGNAL**

In [None]:
# Loading The ECG Signal
record_ecg = wfdb.rdrecord(record_path, sampfrom=0, sampto=1000, channels=[0,1])
signal = record_ecg.p_signal.flatten()

# Create A Time Array Based On The Sampling Frequency
fs = record_ecg.fs
time = [i / fs for i in range(len(signal))]

# Plot The ECG Signal
plt.figure(figsize=(10,4))
plt.plot(time, signal)
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
plt.title('ECG Signal')
plt.show()
plt.savefig('ecg_signal.png')