In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import stats
from sklearn.metrics import mean_squared_error
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from models import WLDNN
import random
import time

## Set Seed to Ensure Reproducibility 

In [2]:
def set_random_seed(seed):
    # Set random seed for CPU
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    # Set random seed for GPU if available
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    
    print(f"Random seed set to {seed}")

seed = 2023  # You can replace this with your desired seed value
set_random_seed(seed)

Random seed set to 2023


## Data Preparation
<!-- * Training
    * sEMG subjects: 1-24
    * sEMG channels: 1, 2
    * ECG subjects: [16265, 16272, 16273, 16420, 16483, 16539, 16773, 16786, 16795, 17052, 17453, 18177] (9F+3M)
    * ECG channel: 1
    * SNR: [0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10]
* Validation
    * sEMG subjects: 25-32
    * sEMG channels: 5, 6
    * ECG subjects: [18184, 19088, 19090, 19093, 19140, 19830] (4F+2M)
    * ECG channel: 1
    * SNR: [0, -0.5, -1, -1.5, -2, -2.5, -3, -3.5, -4, -4.5, -5, -5.5, -6, -6.5, -7, -7.5, -8, -8.5, -9, -9.5, -10]
* Testing
    * sEMG subjects: 33-40
    * sEMG channels: 9, 10
    * ECG subjects: [18184, 19088, 19090, 19093, 19140, 19830] (4F+2M)
    * ECG channel: 2
    * SNR: [0, -0.5, -1, -1.5, -2, -2.5, -3, -3.5, -4, -4.5, -5, -5.5, -6, -6.5, -7, -7.5, -8, -8.5, -9, -9.5, -10] -->

In [3]:
WINDOW_SIZE = 200

def calculate_waveform_length_segment(segment):
    # Calculate the waveform length as the sum of the absolute differences between adjacent samples in the segment.
    return np.sum(np.abs(np.diff(segment)))

def calculate_waveform_length(signal, window_size=WINDOW_SIZE, overlap=0):
    # Calculate the number of segments based on segment size and overlap.
    signal_length = len(signal)
    step_size = int(window_size * (1 - overlap))
    num_segments = (signal_length - window_size) // step_size + 1

    segment_waveform_lengths = []
    for i in range(num_segments):
        start = i * step_size
        end = start + window_size
        segment = signal[start:end]
        
        # Calculate the waveform length for the current segment and store it in the result array.
        segment_length = calculate_waveform_length_segment(segment)
        segment_waveform_lengths.append(segment_length)
    
    return np.array(segment_waveform_lengths)

def mean_variance_normalize(time_series):
    # Calculate the mean and standard deviation of the time series.
    mean = np.mean(time_series)
    std_dev = np.std(time_series)
    
    # Ensure there's no division by zero.
    if std_dev == 0:
        raise ValueError("Cannot normalize: Standard deviation is zero.")
    
    # Normalize the time series to have mean 0 and standard deviation 1.
    normalized_series = (time_series - mean) / std_dev
    
    return normalized_series

def normalize_to_unit_energy(signal):
    # Calculate the energy of the signal as the sum of the squares of its samples.
    energy = np.sum(np.abs(signal)**2)
    
    # Ensure there's no division by zero.
    if energy == 0:
        raise ValueError("Cannot normalize: Energy is zero.")
    
    # Normalize the signal to have unit energy by dividing by the square root of the energy.
    normalized_signal = signal / np.sqrt(energy)
    
    return normalized_signal

def calculate_overall_mean_std(array):
    # Calculate the mean and standard deviation of the input array.
    overall_mean = np.mean(array)
    overall_std = np.std(array)
    
    # Print the calculated values for reference.
    print(f"train_mean = {overall_mean}")
    print(f"train_std = {overall_std}")
    
    return overall_mean, overall_std

def normalize_dataset(data, mean, std):
    # Normalize the dataset by subtracting the mean and dividing by the standard deviation.
    normalized_data = (data - mean) / std
    return normalized_data

# Example usage:
test_sig = np.random.rand(10000,)
print(test_sig.shape)
test_wl = calculate_waveform_length(test_sig)
print(test_wl.shape)
input_length = test_wl.shape[0]

(10000,)
(50,)


In [4]:
# Paths to annotation CSV files
train_csv_path = "train_annotations_E1.csv"
val_csv_path = "val_annotations_E1.csv"

# Folder containing .npy files
data_folder = "./mixed_signals_E1"

# Read the CSV files into pandas DataFrames
train_df = pd.read_csv(train_csv_path)
val_df = pd.read_csv(val_csv_path)

# Initialize lists to store data and labels
train_data = []
train_labels = []
val_data = []
val_labels = []

# Read .npy files and their corresponding SNRs
for _, row in train_df.iterrows():
    file_name = row['mixed_name']
    snr = np.array([row['snr']])
    npy_file_path = os.path.join(data_folder, file_name)
    data = np.load(npy_file_path)
    data = mean_variance_normalize(data)
    data = normalize_to_unit_energy(data)
    data = calculate_waveform_length(data)
    train_data.append(data)
    train_labels.append(snr)
    
for _, row in val_df.iterrows():
    file_name = row['mixed_name']
    snr = np.array([row['snr']])
    npy_file_path = os.path.join(data_folder, file_name)
    data = np.load(npy_file_path)
    data = mean_variance_normalize(data)
    data = normalize_to_unit_energy(data)
    data = calculate_waveform_length(data)
    val_data.append(data)
    val_labels.append(snr)

# Convert lists to NumPy arrays
train_data = np.array(train_data)
train_labels = np.array(train_labels)
val_data = np.array(val_data)
val_labels = np.array(val_labels)

# Normalize data
train_mean, train_std = calculate_overall_mean_std(train_data)
normalize_dataset(train_data, train_mean, train_std)
normalize_dataset(val_data, train_mean, train_std)

# Print the shapes of the loaded data
print("Train data shape:", train_data.shape)
print("Validation data shape:", val_data.shape)

train_mean = 0.4347099520974417
train_std = 0.3951968331310125
Train data shape: (137792, 50)
Validation data shape: (43710, 50)


In [5]:
device = 'cuda:6'
train_data = torch.Tensor(train_data).to(device)
train_labels = torch.Tensor(train_labels).to(device)
train_dataset = TensorDataset(train_data, train_labels)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)

val_data = torch.Tensor(val_data).to(device)
val_labels = torch.Tensor(val_labels).to(device)
val_dataset = TensorDataset(val_data, val_labels)
val_dataloader = DataLoader(val_dataset, batch_size=128, shuffle=False)

## 3. Model, Optimizer, and Loss Function

In [6]:
criterion = nn.MSELoss()

model = WLDNN(input_length)
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

print(model)
print("Number of trainable parameters: ", sum(p.numel() for p in model.parameters() if p.requires_grad))
model_name = model.__class__.__name__

WLDNN(
  (hidden1): Linear(in_features=50, out_features=20, bias=True)
  (hidden2): Linear(in_features=20, out_features=20, bias=True)
  (output): Linear(in_features=20, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)
Number of trainable parameters:  1461


## 4. Training Loop

In [None]:
num_epochs = 100
save_every = 10  # Save the model every 10 epochs
steps_per_print = 1000  # Print loss every 1000 steps

# Initialize a list to store the losses
train_losses = []
val_losses = []  # Initialize a list to store validation losses

start = time.time()

for epoch in range(num_epochs):
    model.train()
    step_counter = 0  # Initialize step counter for the epoch
    epoch_loss = 0.0  # Initialize the epoch loss

    for batch_data, batch_labels in train_dataloader:
        optimizer.zero_grad()
        predictions = model(batch_data)
        loss = criterion(predictions, batch_labels)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()  # Accumulate the loss for the epoch

        # Increment step counter
        step_counter += 1

        # Print loss every `steps_per_print` steps
        if step_counter % steps_per_print == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{step_counter}/{len(train_dataloader)}], Loss: {loss.item():.4f}")

    # Calculate and record the mean loss for the epoch
    mean_epoch_loss = epoch_loss / len(train_dataloader)
    train_losses.append(mean_epoch_loss)

    # Validation loop
    model.eval()  # Set the model to evaluation mode
    val_loss = 0.0

    with torch.no_grad():
        for val_data, val_labels in val_dataloader:
            val_predictions = model(val_data)
            val_loss += criterion(val_predictions, val_labels).item()

    mean_val_loss = val_loss / len(val_dataloader)
    val_losses.append(mean_val_loss)

    print(f"Epoch [{epoch+1}/{num_epochs}], Mean Epoch Loss: {mean_epoch_loss:.4f}, Mean Validation Loss: {mean_val_loss:.4f}")
    print("==================================================")

    # Save the model every `save_every` epochs
    if (epoch + 1) % save_every == 0:
        model_checkpoint_path = f'./checkpoints/{model_name.lower()}_window{WINDOW_SIZE}_{epoch+1}eps_seed{seed}.pth'
        torch.save(model.state_dict(), model_checkpoint_path)
        print(f"Model saved at epoch {epoch+1}: {model_checkpoint_path}")

print("Training completed")
torch.save(model.state_dict(), f'./checkpoints/{model_name.lower()}_window{WINDOW_SIZE}_{num_epochs}eps_seed{seed}.pth')
end = time.time()

Epoch [1/100], Step [1000/4306], Loss: 18.3796
Epoch [1/100], Step [2000/4306], Loss: 20.7903
Epoch [1/100], Step [3000/4306], Loss: 28.8728
Epoch [1/100], Step [4000/4306], Loss: 24.1759
Epoch [1/100], Mean Epoch Loss: 22.8662, Mean Validation Loss: 20.0790
Epoch [2/100], Step [1000/4306], Loss: 17.3417
Epoch [2/100], Step [2000/4306], Loss: 21.8843
Epoch [2/100], Step [3000/4306], Loss: 21.2093
Epoch [2/100], Step [4000/4306], Loss: 25.1089
Epoch [2/100], Mean Epoch Loss: 21.2593, Mean Validation Loss: 20.0548
Epoch [3/100], Step [1000/4306], Loss: 23.0457
Epoch [3/100], Step [2000/4306], Loss: 23.6928
Epoch [3/100], Step [3000/4306], Loss: 23.1222
Epoch [3/100], Step [4000/4306], Loss: 16.9810
Epoch [3/100], Mean Epoch Loss: 21.2585, Mean Validation Loss: 20.0545
Epoch [4/100], Step [1000/4306], Loss: 10.2155
Epoch [4/100], Step [2000/4306], Loss: 6.3043
Epoch [4/100], Step [3000/4306], Loss: 3.8619
Epoch [4/100], Step [4000/4306], Loss: 6.8168
Epoch [4/100], Mean Epoch Loss: 9.2250

Epoch [27/100], Step [2000/4306], Loss: 1.7433
Epoch [27/100], Step [3000/4306], Loss: 3.1240
Epoch [27/100], Step [4000/4306], Loss: 1.9466
Epoch [27/100], Mean Epoch Loss: 2.6236, Mean Validation Loss: 2.1168
Epoch [28/100], Step [1000/4306], Loss: 3.3463
Epoch [28/100], Step [2000/4306], Loss: 3.2562
Epoch [28/100], Step [3000/4306], Loss: 2.4466
Epoch [28/100], Step [4000/4306], Loss: 2.2001
Epoch [28/100], Mean Epoch Loss: 2.6135, Mean Validation Loss: 1.9067
Epoch [29/100], Step [1000/4306], Loss: 2.3397
Epoch [29/100], Step [2000/4306], Loss: 2.6693
Epoch [29/100], Step [3000/4306], Loss: 2.4804
Epoch [29/100], Step [4000/4306], Loss: 2.5065
Epoch [29/100], Mean Epoch Loss: 2.6054, Mean Validation Loss: 1.8339
Epoch [30/100], Step [1000/4306], Loss: 2.5953
Epoch [30/100], Step [2000/4306], Loss: 2.9441
Epoch [30/100], Step [3000/4306], Loss: 2.3881
Epoch [30/100], Step [4000/4306], Loss: 2.4600
Epoch [30/100], Mean Epoch Loss: 2.5997, Mean Validation Loss: 2.0146
Model saved at 

Epoch [53/100], Step [1000/4306], Loss: 1.5234
Epoch [53/100], Step [2000/4306], Loss: 3.4914
Epoch [53/100], Step [3000/4306], Loss: 2.5504
Epoch [53/100], Step [4000/4306], Loss: 1.9714
Epoch [53/100], Mean Epoch Loss: 2.4906, Mean Validation Loss: 1.8653
Epoch [54/100], Step [1000/4306], Loss: 3.5359
Epoch [54/100], Step [2000/4306], Loss: 1.9585
Epoch [54/100], Step [3000/4306], Loss: 3.0480
Epoch [54/100], Step [4000/4306], Loss: 1.6142
Epoch [54/100], Mean Epoch Loss: 2.4907, Mean Validation Loss: 1.7494
Epoch [55/100], Step [1000/4306], Loss: 2.1549
Epoch [55/100], Step [2000/4306], Loss: 1.7289
Epoch [55/100], Step [3000/4306], Loss: 1.8454
Epoch [55/100], Step [4000/4306], Loss: 5.0534
Epoch [55/100], Mean Epoch Loss: 2.4874, Mean Validation Loss: 2.0478
Epoch [56/100], Step [1000/4306], Loss: 1.2757
Epoch [56/100], Step [2000/4306], Loss: 4.7137
Epoch [56/100], Step [3000/4306], Loss: 2.8795
Epoch [56/100], Step [4000/4306], Loss: 2.0831
Epoch [56/100], Mean Epoch Loss: 2.485

Epoch [79/100], Step [2000/4306], Loss: 3.3088
Epoch [79/100], Step [3000/4306], Loss: 1.7855


In [None]:
# Calculate the training time in seconds
training_time_seconds = end - start
# Convert training time to minutes
training_time_minutes = training_time_seconds // 60

# Print the training time in minutes
print(f"Training time: {training_time_minutes} minutes")

In [None]:
# Plot training and validation loss
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss (MSE)')
plt.title('Loss Curve')
plt.legend()  # Show legend indicating the lines
plt.show()

## Validation Loop

In [None]:
# Set model to evaluation mode
model.eval()

# Initialize lists to store predictions and labels
val_predictions = []
val_labels = []

# Iterate through validation data
with torch.no_grad():
    for batch_data, batch_labels in val_dataloader:
        predictions = model(batch_data)
        val_predictions.append(predictions.cpu())
        val_labels.append(batch_labels.cpu().numpy())

In [None]:
# Concatenate predictions and labels
val_predictions_np = np.concatenate(val_predictions, axis=0)
val_labels_np = np.concatenate(val_labels, axis=0)

# Reshape predictions and labels
val_predictions_flat = val_predictions_np.reshape(-1)
val_labels_flat = val_labels_np.reshape(-1)

# Calculate Mean Squared Error (MSE)
mse = mean_squared_error(val_labels_flat, val_predictions_flat)

# Calculate correlation coefficients
correlation_matrix = np.corrcoef(val_labels_flat, val_predictions_flat)
correlation_coefficient = correlation_matrix[0, 1]
spearmanr_cc, _ = stats.spearmanr(val_labels_flat, val_predictions_flat)

# Plot predictions against true labels
plt.scatter(val_labels_flat, val_predictions_flat, alpha=0.5, s=10, color='none', edgecolors='black')
plt.xlabel('True SNR')
plt.ylabel('Estimated SNR')
plt.title(f'CC = {correlation_coefficient:.4f}')

# Set x-axis and y-axis limits
plt.xlim(-17.5, 2.5)
plt.ylim(-17.5, 2.5)

# Perform linear regression
m, b = np.polyfit(val_labels_flat, val_predictions_flat, 1)
plt.plot(val_labels_flat, m * val_labels_flat + b, label=f"y = {m:.2f}x + {b:.2f}")
plt.plot(val_labels_flat, val_labels_flat, label="y = x")
plt.legend(loc='upper left')
plt.show()

# Print the results
print(f'Correlation coefficient (CC): {correlation_coefficient:.4f}')
print(f'Spearman rank-order correlation coefficient (SRCC): {spearmanr_cc:.4f}')
print(f'Mean Squared Error (MSE): {mse:.4f}')

In [None]:
# Calculate the square error for each element
error = (val_labels_flat - val_predictions_flat)**2

# Calculate mean and standard deviation of the error
mean_error = np.mean(error)
std_deviation_error = np.std(error)

# Plot the error distribution as a histogram
hist, bins, _ = plt.hist(error, bins=10, alpha=0.7, color='blue', edgecolor='black', range=(0, 5))
plt.xlabel('Squared Error')
plt.ylabel('Frequency')
plt.title('Squared Error Distribution')
plt.grid(True)

# Annotate each bar with the frequency
for i in range(len(hist)):
    plt.text(bins[i] + (bins[i+1] - bins[i]) / 2, hist[i], f'{int(hist[i])}', ha='center', va='bottom')

plt.show()

print("Mean Error:", mean_error)
print("Standard Deviation of Error:", std_deviation_error)

In [None]:
# Calculate the absolute error for each element
error = np.abs(val_labels_flat - val_predictions_flat)

# Calculate mean and standard deviation of the error
mean_error = np.mean(error)
std_deviation_error = np.std(error)

# Plot the error distribution as a histogram
hist, bins, _ = plt.hist(error, bins=10, alpha=0.7, color='blue', edgecolor='black', range=(0, 5))
plt.xlabel('Absolute Error')
plt.ylabel('Frequency')
plt.title('Squared Error')
plt.grid(True)

# Annotate each bar with the frequency
for i in range(len(hist)):
    plt.text(bins[i] + (bins[i+1] - bins[i]) / 2, hist[i], f'{int(hist[i])}', ha='center', va='bottom')

plt.show()

print("Mean Absolute Error (MAE):", mean_error)
print("Standard Deviation of Absolute Error:", std_deviation_error)