In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import stats
from sklearn.metrics import mean_squared_error, mean_absolute_error
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from models import *
import random

# Evaluation on the Test Set

In [2]:
WINDOW_SIZE = 200

def calculate_waveform_length_segment(segment):
    # Calculate the waveform length as the sum of the absolute differences between adjacent samples in the segment.
    return np.sum(np.abs(np.diff(segment)))

def calculate_waveform_length(signal, window_size=WINDOW_SIZE, overlap=0):
    # Calculate the number of segments based on segment size and overlap.
    signal_length = len(signal)
    step_size = int(window_size * (1 - overlap))
    num_segments = (signal_length - window_size) // step_size + 1

    segment_waveform_lengths = []
    for i in range(num_segments):
        start = i * step_size
        end = start + window_size
        segment = signal[start:end]
        
        # Calculate the waveform length for the current segment and store it in the result array.
        segment_length = calculate_waveform_length_segment(segment)
        segment_waveform_lengths.append(segment_length)
    
    return np.array(segment_waveform_lengths)

def mean_variance_normalize(time_series):
    # Calculate the mean and standard deviation of the time series.
    mean = np.mean(time_series)
    std_dev = np.std(time_series)
    
    # Ensure there's no division by zero.
    if std_dev == 0:
        raise ValueError("Cannot normalize: Standard deviation is zero.")
    
    # Normalize the time series to have mean 0 and standard deviation 1.
    normalized_series = (time_series - mean) / std_dev
    
    return normalized_series

def normalize_to_unit_energy(signal):
    # Calculate the energy of the signal as the sum of the squares of its samples.
    energy = np.sum(np.abs(signal)**2)
    
    # Ensure there's no division by zero.
    if energy == 0:
        raise ValueError("Cannot normalize: Energy is zero.")
    
    # Normalize the signal to have unit energy by dividing by the square root of the energy.
    normalized_signal = signal / np.sqrt(energy)
    
    return normalized_signal

def normalize_dataset(data, mean, std):
    # Normalize the dataset by subtracting the mean and dividing by the standard deviation.
    normalized_data = (data - mean) / std
    return normalized_data

# Example usage:
test_sig = np.random.rand(10000,)
print(test_sig.shape)
test_wl = calculate_waveform_length(test_sig)
print(test_wl.shape)
input_length = test_wl.shape[0]

(10000,)
(50,)


In [3]:
# Path to annotation CSV files
test_csv_path = "./test_annotations_E1.csv"

# Folder containing .npy files
data_folder = "./mixed_signals_E1"

# Read the CSV files into pandas DataFrames
test_df = pd.read_csv(test_csv_path)

# Initialize lists to store data and labels
test_data = []
test_labels = []

# Read .npy files and their corresponding SNRs
for _, row in test_df.iterrows():
    file_name = row['mixed_name']
    snr = np.array([row['snr']])
    npy_file_path = os.path.join(data_folder, file_name)
    data = np.load(npy_file_path)
    data = mean_variance_normalize(data)
    data = normalize_to_unit_energy(data)
    data = calculate_waveform_length(data)
    test_data.append(data)
    test_labels.append(snr)

# Convert lists to NumPy arrays
test_data = np.array(test_data)
test_labels = np.array(test_labels)

# Normalization
train_mean = 0.4347099520974417
train_std = 0.3951968331310125
normalize_dataset(test_data, train_mean, train_std)

# Print the shapes of the loaded data
print("Test data shape:", test_data.shape)

Test data shape: (43648, 50)


In [None]:
device = 'cuda:5'
test_data = torch.Tensor(test_data).to(device)
test_labels = torch.Tensor(test_labels).to(device)
test_dataset = TensorDataset(test_data, test_labels)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=False)

In [None]:
model = WLDNN(input_length)
model = model.to(device)

print(model)

model.load_state_dict(torch.load(f"./checkpoints/wldnn_window{WINDOW_SIZE}_30eps_seed2023.pth"))

In [None]:
# Set model to evaluation mode
model.eval()

# Initialize lists to store predictions and labels
test_predictions = []
test_labels = []

# Iterate through test data
with torch.no_grad():
    for batch_data, batch_labels in test_dataloader:
        predictions = model(batch_data)
        test_predictions.append(predictions.cpu())
        test_labels.append(batch_labels.cpu().numpy())

In [None]:
# Concatenate predictions and labels
test_predictions_np = np.concatenate(test_predictions, axis=0)
test_labels_np = np.concatenate(test_labels, axis=0)

# Reshape predictions and labels
test_predictions_flat = test_predictions_np.reshape(-1)
test_labels_flat = test_labels_np.reshape(-1)

# Save results
np.save(f'./test_results/y_pred_{model.__class__.__name__.lower()}.npy', test_predictions_flat)
np.save(f'./test_results/y_true_{model.__class__.__name__.lower()}.npy', test_labels_flat)

# Calculate Mean Squared Error (MSE)
mse = mean_squared_error(test_labels_flat, test_predictions_flat)

# Calculate Mean Absolute Error (MAE)
mae = mean_absolute_error(test_labels_flat, test_predictions_flat)

# Calculate correlation coefficients
correlation_matrix = np.corrcoef(test_labels_flat, test_predictions_flat)
correlation_coefficient = correlation_matrix[0, 1]
spearmanr_cc, _ = stats.spearmanr(test_labels_flat, test_predictions_flat)

# Plot predictions against true labels
plt.scatter(test_labels_flat, test_predictions_flat, alpha=0.05, s=1, color='black')
plt.xlabel('True SNR(dB)')
plt.ylabel('Estimated SNR(dB)')
plt.title(f'CC = {correlation_coefficient:.4f}')

# Set x-axis and y-axis limits
plt.xlim(-17.5, 2.5)
plt.ylim(-17.5, 2.5)

# Perform linear regression
m, b = np.polyfit(test_labels_flat, test_predictions_flat, 1)
plt.plot(test_labels_flat, m * test_labels_flat + b, label=f"y = {m:.2f}x + {b:.2f}")
plt.plot(test_labels_flat, test_labels_flat, label="y = x")
plt.legend(loc='upper left')
plt.show()

# Print the results
print(f'Correlation coefficient (CC): {correlation_coefficient:.4f}')
print(f"Spearman's rank correlation coefficient (SRCC): {spearmanr_cc:.4f}")
print(f'Mean Squared Error (MSE): {mse:.4f}')
print(f'Mean Absolute Error (MAE): {mae:.4f}')

In [None]:
# Calculate the square error for each element
error = (test_labels_flat - test_predictions_flat)**2

# Calculate mean and standard deviation of the error
mean_error = np.mean(error)
std_deviation_error = np.std(error)

# Plot the error distribution as a histogram
hist, bins, _ = plt.hist(error, bins=10, alpha=0.7, color='blue', edgecolor='black', range=(0, 5))
plt.xlabel('Squared Error')
plt.ylabel('Frequency')
plt.title('Squared Error Distribution')
plt.grid(True)

# Annotate each bar with the frequency
for i in range(len(hist)):
    plt.text(bins[i] + (bins[i+1] - bins[i]) / 2, hist[i], f'{int(hist[i])}', ha='center', va='bottom')

plt.show()

print("Mean Error:", mean_error)
print("Standard Deviation of Error:", std_deviation_error)

In [None]:
# Calculate the absolute error for each element
error = np.abs(test_labels_flat - test_predictions_flat)

# Calculate mean and standard deviation of the error
mean_error = np.mean(error)
std_deviation_error = np.std(error)

# Plot the error distribution as a histogram
hist, bins, _ = plt.hist(error, bins=10, alpha=0.7, color='blue', edgecolor='black', range=(0, 5))
plt.xlabel('Absolute Error')
plt.ylabel('Frequency')
plt.title('Absolute Error Distribution')
plt.grid(True)

# Annotate each bar with the frequency
for i in range(len(hist)):
    plt.text(bins[i] + (bins[i+1] - bins[i]) / 2, hist[i], f'{int(hist[i])}', ha='center', va='bottom')

plt.show()

print("Mean Absolute Error (MAE):", mean_error)
print("Standard Deviation of Absolute Error:", std_deviation_error)