In [1]:
import os
import numpy as np
from scipy import io
from pathlib import Path
from sklearn.model_selection import train_test_split

def load_emg_data(folder_path):
    """
    Load EMG data from .set files and create binary label sequences marking onset-offset regions
    from corresponding .json files with the same base name.
    
    Args:
        folder_path: Path to the folder containing both .set and .json files
        
    Returns:
        X: shape (n_epochs, time_points) - EMG signal data
        y: shape (n_epochs, time_points) - Binary labels with 1s between onset and offset
    """
    import numpy as np
    import json
    from scipy import io
    from pathlib import Path
    
    # Find all .set files in the folder
    set_files = list(Path(folder_path).glob('*.set'))
    
    if not set_files:
        print(f"No .set files found in {folder_path}")
        return None, None
    
    all_data = []
    all_labels = []
    
    for set_file in set_files:
        try:
            # Look for the corresponding JSON file with the same base name
            json_file = set_file.with_suffix('.json')
            
            if not json_file.exists():
                print(f"Warning: No matching JSON file for {set_file.name}")
                continue
            
            # Load .set file (EMG data)
            mat_data = io.loadmat(str(set_file), struct_as_record=True, squeeze_me=True)
            
            if 'data' not in mat_data or isinstance(mat_data['data'], str):
                print(f"Warning: No valid data in {set_file.name}")
                continue
            
            data = mat_data['data']  # shape: (channels, time_points, epochs)
            n_channels, time_points, n_epochs = data.shape
            
            # Find ZM channel index
            channel_names = []
            zm_idx = 0  # Default to first channel
            if 'chanlocs' in mat_data:
                chanlocs = mat_data['chanlocs']
                if not isinstance(chanlocs, np.ndarray):
                    chanlocs = [chanlocs]
                for ch in chanlocs:
                    if hasattr(ch, 'labels'):
                        channel_names.append(str(ch.labels))
                if 'ZM' in channel_names:
                    zm_idx = channel_names.index('ZM')
            
            # Load JSON file (onset/offset annotations)
            with open(json_file, 'r') as f:
                json_data = json.load(f)
            
            if "ZM" not in json_data:
                print(f"Warning: No ZM data in {json_file.name}")
                continue
            
            zm_json = json_data["ZM"]
            
            # Extract onset and offset points
            onsets = zm_json.get("onset", [])
            offsets = zm_json.get("offset", [])
            
            # If lengths don't match, use only the matching pairs
            min_length = min(len(onsets), len(offsets))
            onsets = onsets[:min_length]
            offsets = offsets[:min_length]
            
            # Extract ZM signal and create labels for each epoch
            for i in range(n_epochs):
                signal = data[zm_idx, :, i]  # (time_points,)
                
                # Create label array (zeros with 1s between onset and offset)
                labels = np.zeros(time_points)
                
                # Mark regions between onset and offset with 1s
                for j in range(len(onsets)):
                    if onsets[j] is not None and offsets[j] is not None:
                        start = max(0, onsets[j])
                        end = min(time_points - 1, offsets[j])
                        if start < time_points and end >= 0:
                            labels[start:end+1] = 1
                
                all_data.append(signal)
                all_labels.append(labels)
            
        except Exception as e:
            print(f"Error loading {set_file.name}: {e}")
    
    X = np.array(all_data)  # (n_epochs, time_points)
    y = np.array(all_labels)  # (n_epochs, time_points)
    
    print(f"Loaded {X.shape[0]} epochs | X shape: {X.shape}, y shape: {y.shape}")
    return X, y

def preprocess_data(X, y, test_size=0.2, random_state=42):
    """
    Preprocess the data for training

    Args:
        X (numpy.ndarray): Input features (epochs, time_points)
        y (numpy.ndarray): Binary mask labels (epochs, time_points)
        test_size (float): Proportion of data to use for testing
        random_state (int): Random seed for reproducibility

    Returns:
        tuple: (X_train, X_test, y_train, y_test)
    """
    if len(X) == 0:
        print("No data to preprocess.")
        return None, None, None, None

    # Standardize the data (zero mean, unit variance)
    mean = np.mean(X, axis=1, keepdims=True)
    std = np.std(X, axis=1, keepdims=True)
    X_normalized = (X - mean) / (std + 1e-8)  # Add small epsilon to prevent division by zero

    # Reshape for CNN input - (samples, time_points, channels)
    X_reshaped = X_normalized[:, :, np.newaxis]

    # Ensure y is in the right format for sequence prediction
    # For binary mask prediction, we keep y as a sequence of the same length as input
    y = y.astype(np.float32)

    # Split the data
    X_train, X_test, y_train, y_test = train_test_split(
        X_reshaped, y, test_size=test_size, random_state=random_state
    )

    print(f"Training set: X={X_train.shape}, y={y_train.shape}")
    print(f"Test set: X={X_test.shape}, y={y_test.shape}")
    
    # Calculate class imbalance metrics
    train_masked_percentage = (np.sum(y_train > 0) / y_train.size) * 100
    test_masked_percentage = (np.sum(y_test > 0) / y_test.size) * 100
    print(f"Training set masked points: {train_masked_percentage:.2f}%")
    print(f"Test set masked points: {test_masked_percentage:.2f}%")

    return X_train, X_test, y_train, y_test

In [3]:
import matplotlib.pyplot as plt
import numpy as np
import os
import argparse
import time
import tensorflow as tf

# Keep your load_emg_data function as previously defined
# from data_loader import load_emg_data, preprocess_data

print("Visualizing training samples with onset and offset markers...")

train_img_dir = os.path.join(r"C:\EMG_onset_detection\LOL_project\results", "train_img")
os.makedirs(train_img_dir, exist_ok=True)

threshold = 0  # Label threshold to determine if an onset exists

saved_count = 0
X_train, y_train = load_emg_data(r"C:\EMG_onset_detection\LOL_project\epoched_EMG_data")

for i in range(len(X_train)):
    signal = X_train[i].squeeze()
    label = y_train[i].squeeze()
    
    # Only save if a valid label (onset/offset region) is present
    if np.max(label) > threshold:
        # Find onset and offset indices
        # Onset is the first point where label becomes 1
        onset_indices = np.where(np.diff(np.concatenate(([0], label))) == 1)[0]
        # Offset is the first point where label becomes 0 after being 1
        offset_indices = np.where(np.diff(np.concatenate((label, [0]))) == -1)[0]
        
        plt.figure(figsize=(10, 3))
        plt.plot(signal, label="EMG Signal", linewidth=1)
        
        # Plot all onset-offset pairs with different colors
        colors = ['r', 'g', 'b', 'm', 'c', 'y']
        for j, (onset, offset) in enumerate(zip(onset_indices, offset_indices)):
            color = colors[j % len(colors)]
            plt.axvline(onset, color=color, linestyle='--', 
                       label=f"Onset {j+1}: {onset}")
            plt.axvline(offset, color=color, linestyle=':',
                       label=f"Offset {j+1}: {offset}")
            
            # Add shaded region between onset and offset
            plt.axvspan(onset, offset, alpha=0.2, color=color)
        
        plt.title(f"Training Epoch {i}")
        plt.xlabel("Time")
        plt.ylabel("Amplitude")
        plt.legend()
        plt.tight_layout()
        plt.savefig(os.path.join(train_img_dir, f"train_epoch_{i}.png"))
        plt.close()
        saved_count += 1

print(f"✅ Saved {saved_count} training plots with valid onsets/offsets to: {train_img_dir}")

Visualizing training samples with onset and offset markers...
Loaded 800 epochs | X shape: (800, 435), y shape: (800, 435)
✅ Saved 800 training plots with valid onsets/offsets to: C:\EMG_onset_detection\LOL_project\results\train_img


In [19]:
y_train.shape,X_train.shape

((800, 435), (800, 435))

In [36]:
import json

# Path to your .json file
json_path = r"C:\EMG_onset_detection\LOL_project\epoched_EMG_data\EMGfast_10.json"

# Load the JSON data
with open(json_path, 'r') as f:
    data = json.load(f)

# If it's a dictionary, get all keys
keys = list(data.keys())

print("Top-level keys in the JSON file:")
for key in keys:
    print(key)


Top-level keys in the JSON file:
sub_id
ZM
