In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MultiLabelBinarizer
from sklearn.metrics import classification_report, confusion_matrix 
import pywt
import wfdb
import ast
import os
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import shap
import warnings
warnings.filterwarnings('ignore')


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
class ECGDataset(Dataset):
    """Custom Dataset for PTB-XL ECG data"""
    
    def __init__(self, signals, labels, transform=None):
        self.signals = signals
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.signals)
    
    def __getitem__(self, idx):
        signal = self.signals[idx]
        label = self.labels[idx]
        
        if self.transform:
            signal = self.transform(signal)
        
        return torch.FloatTensor(signal), torch.FloatTensor(label)

In [5]:
class ResNetBlock(nn.Module):
    """1D ResNet Block for ECG signals"""
    
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1):
        super(ResNetBlock, self).__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding=1)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size, 1, padding=1)
        self.bn2 = nn.BatchNorm1d(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv1d(in_channels, out_channels, 1, stride),
                nn.BatchNorm1d(out_channels)
            )
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

In [6]:
class ResNet1D(nn.Module):
    """1D ResNet for ECG feature extraction"""
    
    def __init__(self, input_channels=12, num_classes=5):
        super(ResNet1D, self).__init__()
        self.conv1 = nn.Conv1d(input_channels, 64, 7, 2, padding=3)
        self.bn1 = nn.BatchNorm1d(64)
        self.maxpool = nn.MaxPool1d(3, 2, padding=1)
        
        self.layer1 = self._make_layer(64, 64, 2, stride=1)
        self.layer2 = self._make_layer(64, 128, 2, stride=2)
        self.layer3 = self._make_layer(128, 256, 2, stride=2)
        self.layer4 = self._make_layer(256, 512, 2, stride=2)
        
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(512, 256)
        
    def _make_layer(self, in_channels, out_channels, blocks, stride):
        layers = []
        layers.append(ResNetBlock(in_channels, out_channels, stride=stride))
        for _ in range(1, blocks):
            layers.append(ResNetBlock(out_channels, out_channels))
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x


In [7]:
class DWTFeatureExtractor:
    """Discrete Wavelet Transform feature extractor"""
    
    def __init__(self, wavelet='db4', levels=4):
        self.wavelet = wavelet
        self.levels = levels
    
    def extract_features(self, signal):
        """Extract DWT features from ECG signal"""
        features = []
        
        for lead in range(signal.shape[0]):  # For each lead
            lead_signal = signal[lead, :]
            
            # Perform DWT decomposition
            coeffs = pywt.wavedec(lead_signal, self.wavelet, level=self.levels)
            
            # Extract statistical features from each level
            for coeff in coeffs:
                features.extend([
                    np.mean(coeff),
                    np.std(coeff),
                    np.var(coeff),
                    np.max(coeff),
                    np.min(coeff),
                    np.median(coeff)
                ])
        
        return np.array(features)

In [8]:
class ECGClassificationModel(nn.Module):
    """Complete ECG Classification Model combining ResNet and DWT features"""
    
    def __init__(self, resnet_features=256, dwt_features=360, num_classes=5):
        super(ECGClassificationModel, self).__init__()
        
        # ResNet branch
        self.resnet = ResNet1D()
        
        # Feature combination
        self.feature_combiner = nn.Sequential(
            nn.Linear(resnet_features + dwt_features, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3)
        )
        
        # Deep Neural Network
        self.classifier = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, num_classes),
            nn.Sigmoid()  # For multi-label classification
        )
    
    def forward(self, ecg_signal, dwt_features):
        # ResNet features
        resnet_features = self.resnet(ecg_signal)
        
        # Combine features
        combined_features = torch.cat([resnet_features, dwt_features], dim=1)
        combined_features = self.feature_combiner(combined_features)
        
        # Classification
        output = self.classifier(combined_features)
        return output



In [30]:
class PTBXLDataProcessor:
    """PTB-XL dataset processor"""
    
    def __init__(self, data_path, sampling_rate=500):
        self.data_path = Path(data_path)
        self.sampling_rate = sampling_rate
        self.dwt_extractor = DWTFeatureExtractor()
        
        # Class mapping
        self.class_mapping = {
            'CD': 0,   # Conduction Disturbance
            'HYP': 1,  # Hypertrophy
            'MI': 2,   # Myocardial Infarction  
            'NORM': 3, # Normal
            'STTC': 4  # ST/T Change
        }
    
    def load_database(self):
        """Load PTB-XL database CSV"""
        # Try different possible file names and locations
        possible_paths = [
            self.data_path / 'ptbxl_database.csv',
            self.data_path / 'ptb-xl_database.csv',
            self.data_path / 'database.csv'
        ]
        
        db_path = None
        for path in possible_paths:
            if path.exists():
                db_path = path
                break
        
        if db_path is None:
            # List all CSV files in the directory
            csv_files = list(self.data_path.glob('*.csv'))
            print(f"Available CSV files in {self.data_path}:")
            for file in csv_files:
                print(f"  - {file.name}")
            
            # Try to find database file by looking for 'database' in name
            database_files = [f for f in csv_files if 'database' in f.name.lower()]
            if database_files:
                db_path = database_files[0]
                print(f"Using database file: {db_path.name}")
            else:
                raise FileNotFoundError(f"Could not find database CSV file in {self.data_path}")
        
        print(f"Loading database from: {db_path}")
        self.df = pd.read_csv(db_path, index_col='ecg_id')
        
        # Convert scp_codes from string to dict
        self.df.scp_codes = self.df.scp_codes.apply(lambda x: ast.literal_eval(x))
        
        print(f"Loaded {len(self.df)} ECG records")
        return self.df
    
    def load_scp_statements(self):
        """Load SCP statements for label mapping"""
        # Try different possible file names
        possible_paths = [
            self.data_path / 'scp_statements.csv',
            self.data_path / 'scp-statements.csv',
            self.data_path / 'statements.csv'
        ]
        
        scp_path = None
        for path in possible_paths:
            if path.exists():
                scp_path = path
                break
        
        if scp_path is None:
            # List all CSV files that might be SCP statements
            csv_files = list(self.data_path.glob('*.csv'))
            scp_files = [f for f in csv_files if 'scp' in f.name.lower() or 'statement' in f.name.lower()]
            if scp_files:
                scp_path = scp_files[0]
                print(f"Using SCP statements file: {scp_path.name}")
            else:
                print("Warning: Could not find scp_statements.csv file")
                print("Available CSV files:")
                for file in csv_files:
                    print(f"  - {file.name}")
                return None
        
        print(f"Loading SCP statements from: {scp_path}")
        self.scp_df = pd.read_csv(scp_path, index_col=0)
        return self.scp_df
    
    def extract_labels(self):
        """Extract labels from scp_codes"""
        # Get diagnostic superclass
        def get_diagnostic_class(scp_codes):
            labels = np.zeros(5)  # CD, HYP, MI, NORM, STTC
            
            for code, _ in scp_codes.items():
                if code in self.scp_df.index:
                    diagnostic_class = self.scp_df.loc[code, 'diagnostic_class']
                    if diagnostic_class in self.class_mapping:
                        labels[self.class_mapping[diagnostic_class]] = 1
            
            return labels
        
        self.labels = np.array([get_diagnostic_class(codes) for codes in self.df.scp_codes])
        print(f"Label distribution:\n{np.sum(self.labels, axis=0)}")
        return self.labels
    
    def load_signals(self, limit=None):
        """Load ECG signals"""
        signals = []
        valid_indices = []
        
        for idx, (ecg_id, row) in enumerate(tqdm(self.df.iterrows(), desc="Loading signals")):
            if limit and idx >= limit:
                break
                
            try:
                # Load signal file
                signal_path = self.data_path / row['filename_lr']
                record = wfdb.rdsamp(str(signal_path.with_suffix('')))
                signal = record[0].T  # Shape: (12, 5000)
                
                # Normalize per lead
                signal = self.normalize_signal(signal)
                signals.append(signal)
                valid_indices.append(idx)
                
            except Exception as e:
                print(f"Error loading {ecg_id}: {e}")
                continue
        
        self.signals = np.array(signals)
        self.labels = self.labels[valid_indices]
        
        print(f"Loaded {len(signals)} signals with shape {self.signals.shape}")
        return self.signals, self.labels
    
    def normalize_signal(self, signal):
        """Normalize ECG signal per lead"""
        normalized = np.zeros_like(signal)
        for lead in range(signal.shape[0]):
            lead_signal = signal[lead, :]
            normalized[lead, :] = (lead_signal - np.mean(lead_signal)) / (np.std(lead_signal) + 1e-8)
        return normalized
    
    def extract_dwt_features(self):
        """Extract DWT features from all signals"""
        print("Extracting DWT features...")
        dwt_features = []
        
        for signal in tqdm(self.signals):
            features = self.dwt_extractor.extract_features(signal)
            dwt_features.append(features)
        
        self.dwt_features = np.array(dwt_features)
        print(f"DWT features shape: {self.dwt_features.shape}")
        return self.dwt_features

In [31]:
def train_model(model, train_loader, val_loader, num_epochs=50, device='cuda'):
    """Train the ECG classification model"""
    
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)
    
    train_losses = []
    val_losses = []
    
    model.to(device)
    
    for epoch in range(num_epochs):
        # Training
        model.train()
        train_loss = 0.0
        
        for ecg_batch, dwt_batch, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            ecg_batch = ecg_batch.to(device)
            dwt_batch = dwt_batch.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(ecg_batch, dwt_batch)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        
        # Validation
        model.eval()
        val_loss = 0.0
        
        with torch.no_grad():
            for ecg_batch, dwt_batch, labels in val_loader:
                ecg_batch = ecg_batch.to(device)
                dwt_batch = dwt_batch.to(device)
                labels = labels.to(device)
                
                outputs = model(ecg_batch, dwt_batch)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
        
        train_loss /= len(train_loader)
        val_loss /= len(val_loader)
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        
        scheduler.step(val_loss)
        
        print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    
    return train_losses, val_losses


In [32]:
def evaluate_model(model, test_loader, device='cuda'):
    """Evaluate the model and return predictions"""
    model.eval()
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for ecg_batch, dwt_batch, labels in test_loader:
            ecg_batch = ecg_batch.to(device)
            dwt_batch = dwt_batch.to(device)
            
            outputs = model(ecg_batch, dwt_batch)
            predictions = (outputs > 0.5).float()
            
            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.numpy())
    
    return np.array(all_predictions), np.array(all_labels)


In [33]:
def explain_with_shap(model, sample_data, device='cuda'):
    """Use SHAP for model interpretability"""
    model.eval()
    
    # Create SHAP explainer
    def model_wrapper(x):
        ecg_signals = torch.FloatTensor(x[:, :, :5000]).to(device)  # ECG signals
        dwt_features = torch.FloatTensor(x[:, 0, 5000:]).to(device)  # DWT features
        return model(ecg_signals, dwt_features).cpu().detach().numpy()
    
    # Sample explanation (simplified for demonstration)
    print("SHAP analysis would be performed here for model interpretability")
    print("This would show which ECG features contribute most to each prediction")


In [34]:
# Main execution pipeline
def main():
    """Main pipeline execution"""
    # Configuration
    DATA_PATH = r"C:\Users\JANPALLY SUSHEEL\Desktop\new del pxlb\ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3"
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    BATCH_SIZE = 32
    NUM_EPOCHS = 50
    
    print(f"Using device: {DEVICE}")
    
    # Initialize data processor
    processor = PTBXLDataProcessor(DATA_PATH)
    
    # Load and process data
    print("Loading database...")
    processor.load_database()
    processor.load_scp_statements()
    
    print("Extracting labels...")
    processor.extract_labels()
    
    print("Loading signals...")
    signals, labels = processor.load_signals(limit=1000)  # Limit for demo
    
    print("Extracting DWT features...")
    dwt_features = processor.extract_dwt_features()
    
    # Prepare data for training
    X_train_ecg, X_test_ecg, X_train_dwt, X_test_dwt, y_train, y_test = train_test_split(
        signals, dwt_features, labels, test_size=0.2, random_state=42
    )
    
    X_train_ecg, X_val_ecg, X_train_dwt, X_val_dwt, y_train, y_val = train_test_split(
        X_train_ecg, X_train_dwt, y_train, test_size=0.2, random_state=42
    )
    
    # Create datasets and dataloaders
    class ECGDWTDataset(Dataset):
        def __init__(self, ecg_signals, dwt_features, labels):
            self.ecg_signals = ecg_signals
            self.dwt_features = dwt_features
            self.labels = labels
        
        def __len__(self):
            return len(self.ecg_signals)
        
        def __getitem__(self, idx):
            return (torch.FloatTensor(self.ecg_signals[idx]),
                   torch.FloatTensor(self.dwt_features[idx]),
                   torch.FloatTensor(self.labels[idx]))
    
    train_dataset = ECGDWTDataset(X_train_ecg, X_train_dwt, y_train)
    val_dataset = ECGDWTDataset(X_val_ecg, X_val_dwt, y_val)
    test_dataset = ECGDWTDataset(X_test_ecg, X_test_dwt, y_test)
    
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    
    # Initialize model
    model = ECGClassificationModel(dwt_features=dwt_features.shape[1])
    
    print("Starting training...")
    train_losses, val_losses = train_model(model, train_loader, val_loader, NUM_EPOCHS, DEVICE)
    
    # Evaluate model
    print("Evaluating model...")
    predictions, true_labels = evaluate_model(model, test_loader, DEVICE)
    
    # Print results
    class_names = ['CD', 'HYP', 'MI', 'NORM', 'STTC']
    print("\nClassification Report:")
    print(classification_report(true_labels, predictions, target_names=class_names))
    
    # SHAP analysis
    print("\nPerforming SHAP analysis...")
    sample_data = np.concatenate([X_test_ecg[:10], X_test_dwt[:10]], axis=1)
    explain_with_shap(model, sample_data, DEVICE)
    
    print("Pipeline completed successfully!")

if __name__ == "__main__":
    main()

Using device: cpu
Loading database...
Loading database from: C:\Users\JANPALLY SUSHEEL\Desktop\new del pxlb\ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3\ptbxl_database.csv
Loaded 21799 ECG records
Loading SCP statements from: C:\Users\JANPALLY SUSHEEL\Desktop\new del pxlb\ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3\scp_statements.csv
Extracting labels...
Label distribution:
[4898. 2649. 5469. 9514. 5235.]
Loading signals...


Loading signals: 1000it [00:31, 32.16it/s]


Loaded 1000 signals with shape (1000, 12, 1000)
Extracting DWT features...
Extracting DWT features...


100%|██████████| 1000/1000 [00:07<00:00, 125.76it/s]


DWT features shape: (1000, 360)
Starting training...


Epoch 1: 100%|██████████| 20/20 [00:12<00:00,  1.57it/s]


Epoch 1: Train Loss: 0.5262, Val Loss: 0.5009


Epoch 2: 100%|██████████| 20/20 [00:13<00:00,  1.52it/s]


Epoch 2: Train Loss: 0.4356, Val Loss: 0.4132


Epoch 3: 100%|██████████| 20/20 [00:13<00:00,  1.47it/s]


Epoch 3: Train Loss: 0.4087, Val Loss: 0.3782


Epoch 4: 100%|██████████| 20/20 [00:12<00:00,  1.62it/s]


Epoch 4: Train Loss: 0.3780, Val Loss: 0.3604


Epoch 5: 100%|██████████| 20/20 [00:12<00:00,  1.58it/s]


Epoch 5: Train Loss: 0.3831, Val Loss: 0.3831


Epoch 6: 100%|██████████| 20/20 [00:12<00:00,  1.66it/s]


Epoch 6: Train Loss: 0.3550, Val Loss: 0.4693


Epoch 7: 100%|██████████| 20/20 [00:16<00:00,  1.23it/s]


Epoch 7: Train Loss: 0.3682, Val Loss: 0.4851


Epoch 8: 100%|██████████| 20/20 [00:11<00:00,  1.68it/s]


Epoch 8: Train Loss: 0.3347, Val Loss: 0.4108


Epoch 9: 100%|██████████| 20/20 [00:14<00:00,  1.35it/s]


Epoch 9: Train Loss: 0.3109, Val Loss: 0.4048


Epoch 10: 100%|██████████| 20/20 [00:10<00:00,  1.87it/s]


Epoch 10: Train Loss: 0.3120, Val Loss: 0.4118


Epoch 11: 100%|██████████| 20/20 [00:11<00:00,  1.68it/s]


Epoch 11: Train Loss: 0.2921, Val Loss: 0.4464


Epoch 12: 100%|██████████| 20/20 [00:12<00:00,  1.54it/s]


Epoch 12: Train Loss: 0.2971, Val Loss: 0.4017


Epoch 13: 100%|██████████| 20/20 [00:12<00:00,  1.66it/s]


Epoch 13: Train Loss: 0.2831, Val Loss: 0.4029


Epoch 14: 100%|██████████| 20/20 [00:13<00:00,  1.51it/s]


Epoch 14: Train Loss: 0.2589, Val Loss: 0.4310


Epoch 15: 100%|██████████| 20/20 [00:10<00:00,  1.91it/s]


Epoch 15: Train Loss: 0.2434, Val Loss: 0.4993


Epoch 16: 100%|██████████| 20/20 [00:12<00:00,  1.54it/s]


Epoch 16: Train Loss: 0.2489, Val Loss: 0.4925


Epoch 17: 100%|██████████| 20/20 [00:11<00:00,  1.67it/s]


Epoch 17: Train Loss: 0.2563, Val Loss: 0.5697


Epoch 18: 100%|██████████| 20/20 [00:11<00:00,  1.74it/s]


Epoch 18: Train Loss: 0.2390, Val Loss: 0.4840


Epoch 19: 100%|██████████| 20/20 [00:10<00:00,  1.87it/s]


Epoch 19: Train Loss: 0.2204, Val Loss: 0.7258


Epoch 20: 100%|██████████| 20/20 [00:10<00:00,  1.84it/s]


Epoch 20: Train Loss: 0.2195, Val Loss: 0.5306


Epoch 21: 100%|██████████| 20/20 [00:15<00:00,  1.32it/s]


Epoch 21: Train Loss: 0.2159, Val Loss: 0.5360


Epoch 22: 100%|██████████| 20/20 [00:13<00:00,  1.48it/s]


Epoch 22: Train Loss: 0.2061, Val Loss: 0.5623


Epoch 23: 100%|██████████| 20/20 [00:10<00:00,  1.88it/s]


Epoch 23: Train Loss: 0.2102, Val Loss: 0.6649


Epoch 24: 100%|██████████| 20/20 [00:10<00:00,  1.98it/s]


Epoch 24: Train Loss: 0.1945, Val Loss: 0.6460


Epoch 25: 100%|██████████| 20/20 [00:10<00:00,  1.93it/s]


Epoch 25: Train Loss: 0.1970, Val Loss: 0.6458


Epoch 26: 100%|██████████| 20/20 [00:11<00:00,  1.80it/s]


Epoch 26: Train Loss: 0.1903, Val Loss: 0.7052


Epoch 27: 100%|██████████| 20/20 [00:10<00:00,  1.99it/s]


Epoch 27: Train Loss: 0.1942, Val Loss: 0.6736


Epoch 28: 100%|██████████| 20/20 [00:10<00:00,  1.92it/s]


Epoch 28: Train Loss: 0.1940, Val Loss: 0.5816


Epoch 29: 100%|██████████| 20/20 [00:10<00:00,  1.82it/s]


Epoch 29: Train Loss: 0.1886, Val Loss: 0.7073


Epoch 30: 100%|██████████| 20/20 [00:10<00:00,  1.89it/s]


Epoch 30: Train Loss: 0.1819, Val Loss: 0.7128


Epoch 31: 100%|██████████| 20/20 [00:10<00:00,  1.91it/s]


Epoch 31: Train Loss: 0.1803, Val Loss: 0.7034


Epoch 32: 100%|██████████| 20/20 [00:10<00:00,  1.91it/s]


Epoch 32: Train Loss: 0.1774, Val Loss: 0.7115


Epoch 33: 100%|██████████| 20/20 [00:10<00:00,  1.95it/s]


Epoch 33: Train Loss: 0.1739, Val Loss: 0.7294


Epoch 34: 100%|██████████| 20/20 [00:10<00:00,  1.95it/s]


Epoch 34: Train Loss: 0.1755, Val Loss: 0.7247


Epoch 35: 100%|██████████| 20/20 [00:10<00:00,  1.92it/s]


Epoch 35: Train Loss: 0.1760, Val Loss: 0.7190


Epoch 36: 100%|██████████| 20/20 [00:10<00:00,  1.92it/s]


Epoch 36: Train Loss: 0.1689, Val Loss: 0.7160


Epoch 37: 100%|██████████| 20/20 [00:13<00:00,  1.52it/s]


Epoch 37: Train Loss: 0.1731, Val Loss: 0.7194


Epoch 38: 100%|██████████| 20/20 [00:11<00:00,  1.71it/s]


Epoch 38: Train Loss: 0.1717, Val Loss: 0.7255


Epoch 39: 100%|██████████| 20/20 [00:09<00:00,  2.01it/s]


Epoch 39: Train Loss: 0.1628, Val Loss: 0.7289


Epoch 40: 100%|██████████| 20/20 [00:10<00:00,  1.93it/s]


Epoch 40: Train Loss: 0.1723, Val Loss: 0.7456


Epoch 41: 100%|██████████| 20/20 [00:10<00:00,  1.88it/s]


Epoch 41: Train Loss: 0.1667, Val Loss: 0.7351


Epoch 42: 100%|██████████| 20/20 [00:12<00:00,  1.56it/s]


Epoch 42: Train Loss: 0.1667, Val Loss: 0.7377


Epoch 43: 100%|██████████| 20/20 [00:10<00:00,  1.83it/s]


Epoch 43: Train Loss: 0.1642, Val Loss: 0.7517


Epoch 44: 100%|██████████| 20/20 [00:09<00:00,  2.01it/s]


Epoch 44: Train Loss: 0.1674, Val Loss: 0.7605


Epoch 45: 100%|██████████| 20/20 [00:10<00:00,  1.91it/s]


Epoch 45: Train Loss: 0.1678, Val Loss: 0.7633


Epoch 46: 100%|██████████| 20/20 [00:10<00:00,  1.88it/s]


Epoch 46: Train Loss: 0.1680, Val Loss: 0.7659


Epoch 47: 100%|██████████| 20/20 [00:10<00:00,  1.97it/s]


Epoch 47: Train Loss: 0.1687, Val Loss: 0.7567


Epoch 48: 100%|██████████| 20/20 [00:10<00:00,  1.92it/s]


Epoch 48: Train Loss: 0.1643, Val Loss: 0.7630


Epoch 49: 100%|██████████| 20/20 [00:10<00:00,  1.93it/s]


Epoch 49: Train Loss: 0.1612, Val Loss: 0.7639


Epoch 50: 100%|██████████| 20/20 [00:10<00:00,  1.82it/s]


Epoch 50: Train Loss: 0.1649, Val Loss: 0.7692
Evaluating model...

Classification Report:
              precision    recall  f1-score   support

          CD       0.39      0.28      0.33        39
         HYP       0.00      0.00      0.00        16
          MI       0.43      0.34      0.38        35
        NORM       0.84      0.79      0.82       115
        STTC       0.56      0.71      0.62        42

   micro avg       0.66      0.58      0.62       247
   macro avg       0.44      0.43      0.43       247
weighted avg       0.61      0.58      0.59       247
 samples avg       0.66      0.61      0.63       247


Performing SHAP analysis...


ValueError: all the input arrays must have same number of dimensions, but the array at index 0 has 3 dimension(s) and the array at index 1 has 2 dimension(s)