<a href="https://colab.research.google.com/github/sharath2004-tech/cardiac-disease-detection/blob/main/cardiac%20disease%20detection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install wfdb


import wfdb
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [None]:



wfdb.dl_database(
    'mitdb',
    dl_dir='mitdb',
    keep_subdirs=True
)

In [None]:
import wfdb
import numpy as np
import matplotlib.pyplot as plt

record = wfdb.rdrecord('mitdb/100')
annotation = wfdb.rdann('mitdb/100', 'atr')

signal = record.p_signal[:,0]   # use first channel
r_peaks = annotation.sample
labels = annotation.symbol

print("Signal shape:", signal.shape)
print("Total beats:", len(r_peaks))

In [None]:
beats = []
beat_labels = []

for i in range(len(r_peaks)):
    start = r_peaks[i] - 90
    end = r_peaks[i] + 90

    if start > 0 and end < len(signal):
        beat = signal[start:end]
        beats.append(beat)
        beat_labels.append(labels[i])

beats = np.array(beats)
print("Beats shape:", beats.shape)

In [None]:
AAMI_map = {
    'N':'N','L':'N','R':'N','e':'N','j':'N',
    'A':'S','a':'S','J':'S','S':'S',
    'V':'V','E':'V',
    'F':'F'
}

class_to_int = {'N':0,'S':1,'V':2,'F':3}

In [None]:
# Standard DS1 (Train) and DS2 (Test) split used in research

train_records = [
    '101','106','108','109','112','114','115','116',
    '118','119','122','124','201','203','205','207',
    '208','209','215','220','223','230'
]

test_records = [
    '100','103','105','111','113','117','121','123',
    '200','202','210','212','213','214','219','221',
    '222','228','231','232','233','234'
]

print("Train records:", len(train_records))
print("Test records:", len(test_records))

# Safety check (no leakage)
print("Overlap:", set(train_records) & set(test_records))

In [None]:
def extract_beats(record_list):
    beats = []
    labels = []

    for rec in record_list:
        print(f"Processing record {rec}...")

        record = wfdb.rdrecord(f'mitdb/{rec}')
        annotation = wfdb.rdann(f'mitdb/{rec}', 'atr')

        signal = record.p_signal[:, 0]   # Use first channel
        r_peaks = annotation.sample
        symbols = annotation.symbol

        for i in range(len(r_peaks)):
            start = r_peaks[i] - 90
            end = r_peaks[i] + 90

            if start > 0 and end < len(signal):
                label = symbols[i]

                if label in AAMI_map:
                    beat = signal[start:end]

                    # Per-beat normalization
                    mean = np.mean(beat)
                    std = np.std(beat)
                    beat = (beat - mean) / (std + 1e-8)

                    beats.append(beat)
                    labels.append(class_to_int[AAMI_map[label]])

    return np.array(beats), np.array(labels)

In [None]:
X_train, y_train = extract_beats(train_records)
X_test, y_test = extract_beats(test_records)

print("Train shape:", X_train.shape)
print("Test shape:", X_test.shape)

import numpy as np
print("Unique classes in train:", np.unique(y_train))

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import random

class ECGDataset(Dataset):
    def __init__(self, X, y, augment=False):
        self.X = torch.tensor(X, dtype=torch.float32).unsqueeze(1)
        self.y = torch.tensor(y, dtype=torch.long)
        self.augment = augment

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        beat = self.X[idx].clone()

        if self.augment:
            # Gaussian noise
            if random.random() > 0.5:
                noise = torch.randn_like(beat) * 0.02
                beat = beat + noise
            
            # Amplitude scaling
            if random.random() > 0.5:
                scale = random.uniform(0.9, 1.1)
                beat = beat * scale
            
            # Time shifting (circular shift)
            if random.random() > 0.5:
                shift = random.randint(-10, 10)
                beat = torch.roll(beat, shift, dims=-1)
            
            # Baseline wander
            if random.random() > 0.5:
                baseline = torch.sin(torch.linspace(0, 4*3.14159, beat.size(-1))) * 0.05
                beat = beat + baseline.unsqueeze(0)

        return beat, self.y[idx]

In [None]:
train_dataset = ECGDataset(X_train, y_train, augment=True)
test_dataset = ECGDataset(X_test, y_test, augment=False)

# Smaller batch size for better generalization
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=0)

In [None]:
import numpy as np

unique, counts = np.unique(y_train, return_counts=True)

for u, c in zip(unique, counts):
    print(f"Class {u}: {c}")

In [None]:
import torch
import torch.nn as nn
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.ce = nn.CrossEntropyLoss(reduction='none')

    def forward(self, inputs, targets):
        ce_loss = self.ce(inputs, targets)
        pt = torch.exp(-ce_loss)
        focal_loss = (1 - pt) ** self.gamma * ce_loss

        if self.alpha is not None:
            at = self.alpha[targets]
            focal_loss = at * focal_loss

        return focal_loss.mean()

In [None]:
from sklearn.utils.class_weight import compute_class_weight
import numpy as np

class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(y_train),
    y=y_train
)

alpha = torch.tensor(class_weights, dtype=torch.float32)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
alpha = alpha.to(device)

print("Class weights:", alpha)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=0.3):
        super().__init__()

        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=5, padding=2)
        self.bn1 = nn.BatchNorm1d(out_channels)

        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=5, padding=2)
        self.bn2 = nn.BatchNorm1d(out_channels)
        
        self.dropout = nn.Dropout(dropout)

        self.shortcut = nn.Sequential()
        if in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv1d(in_channels, out_channels, kernel_size=1),
                nn.BatchNorm1d(out_channels)
            )

    def forward(self, x):
        residual = self.shortcut(x)
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.dropout(out)
        out = self.bn2(self.conv2(out))
        out += residual
        return F.relu(out)


class SelfAttention(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.query = nn.Conv1d(channels, channels // 8, 1)
        self.key = nn.Conv1d(channels, channels // 8, 1)
        self.value = nn.Conv1d(channels, channels, 1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch, channels, length = x.size()
        
        q = self.query(x).view(batch, -1, length).permute(0, 2, 1)
        k = self.key(x).view(batch, -1, length)
        v = self.value(x).view(batch, -1, length)
        
        attention = torch.bmm(q, k)
        attention = F.softmax(attention / math.sqrt(channels // 8), dim=-1)
        
        out = torch.bmm(v, attention.permute(0, 2, 1))
        out = out.view(batch, channels, length)
        
        return self.gamma * out + x


class ECGResNet(nn.Module):
    def __init__(self, num_classes=4):
        super().__init__()

        self.layer1 = nn.Sequential(
            nn.Conv1d(1, 32, kernel_size=7, padding=3),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.MaxPool1d(2)
        )

        self.layer2 = ResidualBlock(32, 64, dropout=0.2)
        self.layer3 = ResidualBlock(64, 128, dropout=0.3)
        self.layer4 = ResidualBlock(128, 256, dropout=0.3)
        self.layer5 = ResidualBlock(256, 256, dropout=0.4)
        
        self.attention = SelfAttention(256)

        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self.global_max_pool = nn.AdaptiveMaxPool1d(1)

        self.fc1 = nn.Linear(512, 256)
        self.bn_fc = nn.BatchNorm1d(256)
        self.dropout1 = nn.Dropout(0.5)
        
        self.fc2 = nn.Linear(256, 128)
        self.dropout2 = nn.Dropout(0.4)
        
        self.fc3 = nn.Linear(128, 4)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        
        x = self.attention(x)
        
        # Concatenate avg and max pooling
        avg_pool = self.global_pool(x).squeeze(-1)
        max_pool = self.global_max_pool(x).squeeze(-1)
        x = torch.cat([avg_pool, max_pool], dim=1)
        
        x = F.relu(self.bn_fc(self.fc1(x)))
        x = self.dropout1(x)
        
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        
        x = self.fc3(x)
        return x

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ECGResNet().to(device)

# Use AdamW with weight decay for better regularization
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)

# Cosine annealing with warmup
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, T_0=10, T_mult=2, eta_min=1e-6
)

print("Using device:", device)
print("Model parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))

In [None]:
import torch.nn.functional as F

class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self, smoothing=0.1, weight=None):
        super().__init__()
        self.smoothing = smoothing
        self.weight = weight
    
    def forward(self, pred, target):
        n_class = pred.size(1)
        one_hot = torch.zeros_like(pred).scatter(1, target.view(-1, 1), 1)
        one_hot = one_hot * (1 - self.smoothing) + self.smoothing / n_class
        log_prb = F.log_softmax(pred, dim=1)
        
        if self.weight is not None:
            loss = -(one_hot * log_prb).sum(dim=1)
            loss = (loss * self.weight[target]).mean()
        else:
            loss = -(one_hot * log_prb).sum(dim=1).mean()
        return loss

criterion = LabelSmoothingCrossEntropy(smoothing=0.1, weight=alpha)

def train_model(model, train_loader, test_loader, criterion, optimizer, scheduler, epochs=50):
    best_val_acc = 0.0
    patience_counter = 0
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        correct = 0
        total = 0

        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step(epoch + total / len(train_loader.dataset))

            train_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()

        train_acc = 100 * correct / total

        # Validation
        model.eval()
        val_correct = 0
        val_total = 0
        val_loss = 0

        with torch.no_grad():
            for inputs, targets in test_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()

                _, predicted = torch.max(outputs, 1)
                val_total += targets.size(0)
                val_correct += (predicted == targets).sum().item()

        val_acc = 100 * val_correct / val_total
        
        # Track best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
        else:
            patience_counter += 1

        print(f"Epoch [{epoch+1}/{epochs}] "
              f"Train Acc: {train_acc:.2f}% "
              f"Val Acc: {val_acc:.2f}% "
              f"Best Val: {best_val_acc:.2f}% "
              f"LR: {optimizer.param_groups[0]['lr']:.6f}")
        
        # Early stopping
        if patience_counter >= 15:
            print(f"Early stopping at epoch {epoch+1}")
            break
    
    print(f"\nBest Validation Accuracy: {best_val_acc:.2f}%")

In [None]:
import torch.nn.functional as F

In [None]:
X_train, y_train = extract_beats(train_records)
X_test, y_test = extract_beats(test_records)

In [None]:
import numpy as np
print(np.unique(y_train))

In [None]:
train_model(model, train_loader, test_loader, criterion, optimizer, scheduler, epochs=50)

<a href="https://colab.research.google.com/github/sharath2004-tech/cardiac-disease-detection/blob/main/cardiac%20disease%20detection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install wfdb


import wfdb
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader



In [None]:



wfdb.dl_database(
    'mitdb',
    dl_dir='mitdb',
    keep_subdirs=True
)

Generating record list for: 100
Generating record list for: 101
Generating record list for: 102
Generating record list for: 103
Generating record list for: 104
Generating record list for: 105
Generating record list for: 106
Generating record list for: 107
Generating record list for: 108
Generating record list for: 109
Generating record list for: 111
Generating record list for: 112
Generating record list for: 113
Generating record list for: 114
Generating record list for: 115
Generating record list for: 116
Generating record list for: 117
Generating record list for: 118
Generating record list for: 119
Generating record list for: 121
Generating record list for: 122
Generating record list for: 123
Generating record list for: 124
Generating record list for: 200
Generating record list for: 201
Generating record list for: 202
Generating record list for: 203
Generating record list for: 205
Generating record list for: 207
Generating record list for: 208
Generating record list for: 209
Generati

In [None]:
import wfdb
import numpy as np
import matplotlib.pyplot as plt

record = wfdb.rdrecord('mitdb/100')
annotation = wfdb.rdann('mitdb/100', 'atr')

signal = record.p_signal[:,0]   # use first channel
r_peaks = annotation.sample
labels = annotation.symbol

print("Signal shape:", signal.shape)
print("Total beats:", len(r_peaks))

Signal shape: (650000,)
Total beats: 2274


In [None]:
beats = []
beat_labels = []

for i in range(len(r_peaks)):
    start = r_peaks[i] - 90
    end = r_peaks[i] + 90

    if start > 0 and end < len(signal):
        beat = signal[start:end]
        beats.append(beat)
        beat_labels.append(labels[i])

beats = np.array(beats)
print("Beats shape:", beats.shape)

Beats shape: (2271, 180)


In [None]:
AAMI_map = {
    'N':'N','L':'N','R':'N','e':'N','j':'N',
    'A':'S','a':'S','J':'S','S':'S',
    'V':'V','E':'V',
    'F':'F'
}

class_to_int = {'N':0,'S':1,'V':2,'F':3}

In [None]:
# Standard DS1 (Train) and DS2 (Test) split used in research

train_records = [
    '101','106','108','109','112','114','115','116',
    '118','119','122','124','201','203','205','207',
    '208','209','215','220','223','230'
]

test_records = [
    '100','103','105','111','113','117','121','123',
    '200','202','210','212','213','214','219','221',
    '222','228','231','232','233','234'
]

print("Train records:", len(train_records))
print("Test records:", len(test_records))

# Safety check (no leakage)
print("Overlap:", set(train_records) & set(test_records))

Train records: 22
Test records: 22
Overlap: set()


In [None]:
def extract_beats(record_list):
    beats = []
    labels = []

    for rec in record_list:
        print(f"Processing record {rec}...")

        record = wfdb.rdrecord(f'mitdb/{rec}')
        annotation = wfdb.rdann(f'mitdb/{rec}', 'atr')

        signal = record.p_signal[:, 0]   # Use first channel
        r_peaks = annotation.sample
        symbols = annotation.symbol

        for i in range(len(r_peaks)):
            start = r_peaks[i] - 90
            end = r_peaks[i] + 90

            if start > 0 and end < len(signal):
                label = symbols[i]

                if label in AAMI_map:
                    beat = signal[start:end]

                    # Per-beat normalization
                    mean = np.mean(beat)
                    std = np.std(beat)
                    beat = (beat - mean) / (std + 1e-8)

                    beats.append(beat)
                    labels.append(class_to_int[AAMI_map[label]])

    return np.array(beats), np.array(labels)

In [None]:
X_train, y_train = extract_beats(train_records)
X_test, y_test = extract_beats(test_records)

print("Train shape:", X_train.shape)
print("Test shape:", X_test.shape)

import numpy as np
print("Unique classes in train:", np.unique(y_train))

Processing record 101...
Processing record 106...
Processing record 108...
Processing record 109...
Processing record 112...
Processing record 114...
Processing record 115...
Processing record 116...
Processing record 118...
Processing record 119...
Processing record 122...
Processing record 124...
Processing record 201...
Processing record 203...
Processing record 205...
Processing record 207...
Processing record 208...
Processing record 209...
Processing record 215...
Processing record 220...
Processing record 223...
Processing record 230...
Processing record 100...
Processing record 103...
Processing record 105...
Processing record 111...
Processing record 113...
Processing record 117...
Processing record 121...
Processing record 123...
Processing record 200...
Processing record 202...
Processing record 210...
Processing record 212...
Processing record 213...
Processing record 214...
Processing record 219...
Processing record 221...
Processing record 222...
Processing record 228...


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import random

class ECGDataset(Dataset):
    def __init__(self, X, y, augment=False):
        self.X = torch.tensor(X, dtype=torch.float32).unsqueeze(1)
        self.y = torch.tensor(y, dtype=torch.long)
        self.augment = augment

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        beat = self.X[idx].clone()

        if self.augment:
            # Gaussian noise
            if random.random() > 0.5:

                noise = torch.randn_like(beat) * 0.02        return beat, self.y[idx]

                beat = beat + noise

                            beat = beat + baseline.unsqueeze(0)

            # Amplitude scaling                baseline = torch.sin(torch.linspace(0, 4*3.14159, beat.size(-1))) * 0.05

            if random.random() > 0.5:            if random.random() > 0.5:

                scale = random.uniform(0.9, 1.1)            # Baseline wander

                beat = beat * scale            

                            beat = torch.roll(beat, shift, dims=-1)

            # Time shifting (circular shift)                shift = random.randint(-10, 10)
            if random.random() > 0.5:

In [None]:
train_dataset = ECGDataset(X_train, y_train, augment=True)
test_dataset = ECGDataset(X_test, y_test, augment=False)

# Smaller batch size for better generalization
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=0)

In [None]:
import numpy as np

unique, counts = np.unique(y_train, return_counts=True)

for u, c in zip(unique, counts):
    print(f"Class {u}: {c}")

Class 0: 45856
Class 1: 944
Class 2: 3788
Class 3: 414


In [None]:
import torch
import torch.nn as nn
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.ce = nn.CrossEntropyLoss(reduction='none')

    def forward(self, inputs, targets):
        ce_loss = self.ce(inputs, targets)
        pt = torch.exp(-ce_loss)
        focal_loss = (1 - pt) ** self.gamma * ce_loss

        if self.alpha is not None:
            at = self.alpha[targets]
            focal_loss = at * focal_loss

        return focal_loss.mean()

In [None]:
from sklearn.utils.class_weight import compute_class_weight
import numpy as np

class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(y_train),
    y=y_train
)

alpha = torch.tensor(class_weights, dtype=torch.float32).to(device)

print("Class weights:", alpha)

Class weights: tensor([ 0.2781, 13.5069,  3.3660, 30.7983], device='cuda:0')


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=0.3):
        super().__init__()

        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=5, padding=2)
        self.bn1 = nn.BatchNorm1d(out_channels)

        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=5, padding=2)
        self.bn2 = nn.BatchNorm1d(out_channels)
        
        self.dropout = nn.Dropout(dropout)

        self.shortcut = nn.Sequential()
        if in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv1d(in_channels, out_channels, kernel_size=1),
                nn.BatchNorm1d(out_channels)
            )

    def forward(self, x):
        residual = self.shortcut(x)
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.dropout(out)
        out = self.bn2(self.conv2(out))
        out += residual
        return F.relu(out)


class SelfAttention(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.query = nn.Conv1d(channels, channels // 8, 1)
        self.key = nn.Conv1d(channels, channels // 8, 1)
        self.value = nn.Conv1d(channels, channels, 1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch, channels, length = x.size()
        
        q = self.query(x).view(batch, -1, length).permute(0, 2, 1)
        k = self.key(x).view(batch, -1, length)
        v = self.value(x).view(batch, -1, length)
        
        attention = torch.bmm(q, k)
        attention = F.softmax(attention / math.sqrt(channels // 8), dim=-1)
        
        out = torch.bmm(v, attention.permute(0, 2, 1))
        out = out.view(batch, channels, length)
        
        return self.gamma * out + x


class ECGResNet(nn.Module):
    def __init__(self, num_classes=4):
        super().__init__()


        self.layer1 = nn.Sequential(        return x

            nn.Conv1d(1, 32, kernel_size=7, padding=3),        x = self.fc2(x)

            nn.BatchNorm1d(32),        x = self.dropout(x)

            nn.ReLU(),        x = F.relu(self.fc1(x))

            nn.MaxPool1d(2)        x = self.global_pool(x).squeeze(-1)

        )

        x = self.layer4(x)   # NEW

        self.layer2 = ResidualBlock(32, 64, dropout=0.2)        x = self.layer3(x)

        self.layer3 = ResidualBlock(64, 128, dropout=0.3)        x = self.layer2(x)

        self.layer4 = ResidualBlock(128, 256)   # NEW deeper block        x = self.layer1(x)

    def forward(self, x):

        self.global_pool = nn.AdaptiveAvgPool1d(1)

        self.fc2 = nn.Linear(128, 4)

        self.fc1 = nn.Linear(256, 128)          # updated input size        self.dropout = nn.Dropout(0.5)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ECGResNet().to(device)

# Use AdamW with weight decay for better regularization
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4)

# Cosine annealing with warmup
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, T_0=10, T_mult=2, eta_min=1e-6
)

print("Model parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))
print("Using device:", device)

Using device: cuda


In [None]:
import torch.nn.functional as F

class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self, smoothing=0.1, weight=None):
        super().__init__()
        self.smoothing = smoothing
        self.weight = weight
    
    def forward(self, pred, target):
        n_class = pred.size(1)
        one_hot = torch.zeros_like(pred).scatter(1, target.view(-1, 1), 1)
        one_hot = one_hot * (1 - self.smoothing) + self.smoothing / n_class
        log_prb = F.log_softmax(pred, dim=1)
        
        if self.weight is not None:
            loss = -(one_hot * log_prb).sum(dim=1)
            loss = (loss * self.weight[target]).mean()
        else:
            loss = -(one_hot * log_prb).sum(dim=1).mean()
        return loss

criterion = LabelSmoothingCrossEntropy(smoothing=0.1, weight=alpha)

def train_model(model, train_loader, test_loader, criterion, optimizer, scheduler, epochs=50):
    best_val_acc = 0.0
    patience_counter = 0
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        correct = 0
        total = 0

        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step(epoch + total / len(train_loader.dataset))

            train_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()

        train_acc = 100 * correct / total

    print(f"\nBest Validation Accuracy: {best_val_acc:.2f}%")

        # Validation    

        model.eval()            break

        val_correct = 0            print(f"Early stopping at epoch {epoch+1}")

        val_total = 0        if patience_counter >= 15:

        val_loss = 0        # Early stopping

        

        with torch.no_grad():              f"LR: {optimizer.param_groups[0]['lr']:.6f}")

            for inputs, targets in test_loader:              f"Best Val: {best_val_acc:.2f}% "

                inputs, targets = inputs.to(device), targets.to(device)              f"Val Acc: {val_acc:.2f}% "

                outputs = model(inputs)              f"Train Acc: {train_acc:.2f}% "

                loss = criterion(outputs, targets)        print(f"Epoch [{epoch+1}/{epochs}] "

                val_loss += loss.item()

            patience_counter += 1

                _, predicted = torch.max(outputs, 1)        else:

                val_total += targets.size(0)            patience_counter = 0

                val_correct += (predicted == targets).sum().item()            best_val_acc = val_acc

        if val_acc > best_val_acc:

        val_acc = 100 * val_correct / val_total        # Track best model
        

In [None]:
import torch.nn.functional as F

In [None]:
X_train, y_train = extract_beats(train_records)
X_test, y_test = extract_beats(test_records)

Processing record 101...
Processing record 106...
Processing record 108...
Processing record 109...
Processing record 112...
Processing record 114...
Processing record 115...
Processing record 116...
Processing record 118...
Processing record 119...
Processing record 122...
Processing record 124...
Processing record 201...
Processing record 203...
Processing record 205...
Processing record 207...
Processing record 208...
Processing record 209...
Processing record 215...
Processing record 220...
Processing record 223...
Processing record 230...
Processing record 100...
Processing record 103...
Processing record 105...
Processing record 111...
Processing record 113...
Processing record 117...
Processing record 121...
Processing record 123...
Processing record 200...
Processing record 202...
Processing record 210...
Processing record 212...
Processing record 213...
Processing record 214...
Processing record 219...
Processing record 221...
Processing record 222...
Processing record 228...


In [None]:
import numpy as np
print(np.unique(y_train))

[0 1 2 3]


In [None]:
train_model(model, train_loader, test_loader, criterion, optimizer, scheduler, epochs=50)

Epoch [1/35] Train Acc: 76.55% Val Acc: 48.29%
Epoch [2/35] Train Acc: 88.22% Val Acc: 65.10%
Epoch [3/35] Train Acc: 91.06% Val Acc: 55.56%
Epoch [4/35] Train Acc: 92.85% Val Acc: 70.31%
Epoch [5/35] Train Acc: 93.50% Val Acc: 67.60%


KeyboardInterrupt: 