In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import pearsonr

# Data directory (set to empty string)
data_dir = ""

# Load data
z100 = torch.load(f"{data_dir}/test_z100.pth")  # [N, C, L]
z50 = torch.load(f"{data_dir}/test_z50.pth")    # [N, C, L]

# Convert to numpy arrays
z100 = z100.numpy()
z50 = z50.numpy()

# Evaluation metrics functions
def calculate_mse(x, y):
    """Calculate Mean Squared Error"""
    return np.mean((x - y) ** 2)

def calculate_nmse(x, y):
    """Calculate Normalized Mean Squared Error"""
    return np.sum((x - y) ** 2) / np.sum(y ** 2)

def calculate_pcc(original, reconstructed):
    """Calculate Pearson Correlation Coefficient"""
    original_flatten = original.flatten()
    reconstructed_flatten = reconstructed.flatten()
    return np.corrcoef(original_flatten, reconstructed_flatten)[0, 1]

def calculate_snr(original, reconstructed):
    """Calculate Signal-to-Noise Ratio (dB)"""
    signal_power = np.mean(original ** 2)
    noise_power = np.mean((original - reconstructed) ** 2)
    return 10 * np.log10(signal_power / (noise_power + 1e-8))

# Calculate overall metrics
print("Comparison between z50 and z100:")
print("MSE  :", calculate_mse(z50, z100))
print("NMSE :", calculate_nmse(z50, z100))
print("PCC  :", calculate_pcc(z100, z50))
print("SNR  :", calculate_snr(z100, z50), "dB")

In [None]:
# Visualize all channels for the first sample
N, C, L = z100.shape  # Get dimensions: samples, channels, time points

# Create a figure for each channel
for channel_index in range(C):
    plt.figure(figsize=(14, 3))  # Create a new figure
    
    # Plot z100 and z50 values for this channel
    plt.plot(z100[0, channel_index], label='z100 (original)')
    plt.plot(z50[0, channel_index], label='z50 (masked)')
    
    # Add title and labels
    plt.title(f'Channel {channel_index} - First Sample')
    plt.xlabel('Time Points')
    plt.ylabel('Signal Value')
    plt.legend()
    
    plt.tight_layout()  # Improve spacing
    plt.show()  # Display the plot