In [1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal
import os
from pathlib import Path
import json
from tqdm import tqdm

def get_json_labels(json_file_path):
    """Extract pre-existing labels from JSON file"""
    try:
        with open(json_file_path, 'r') as f:
            data = json.load(f)
        
        # Check if file contains ZM data with onset/offset info
        if 'ZM' in data and isinstance(data['ZM'], dict):
            zm_data = data['ZM']
            
            # Extract onset and offset
            if 'onset' in zm_data and 'offset' in zm_data:
                onsets = zm_data['onset']
                offsets = zm_data['offset']
                
                # Convert to lists if not already
                if not isinstance(onsets, list):
                    onsets = [onsets]
                if not isinstance(offsets, list):
                    offsets = [offsets]
                
                # Create mapping of all valid onset/offset pairs for each epoch
                epoch_labels = {}
                
                for i in range(len(onsets)):
                    if i < len(offsets) and onsets[i] is not None and offsets[i] is not None:
                        epoch_idx = i
                        # Store as a tuple (onset, offset)
                        epoch_labels[epoch_idx] = (onsets[i], offsets[i])
                
                return epoch_labels
        
        return {}
    except Exception as e:
        print(f"Error reading JSON labels: {str(e)}")
        return {}

def load_emg_data(data_dir):
    """Load EMG data from .fdt files with channel information from JSON"""
    # Find .fdt files
    fdt_files = list(Path(data_dir).glob('*.fdt'))
    
    if not fdt_files:
        print("No .fdt files found.")
        return [], [], []
    
    all_epochs = []
    epoch_names = []
    pre_labels = []
    
    for fdt_file in fdt_files:
        print(f"Processing: {fdt_file}")
        
        # Find corresponding .json file
        json_file = fdt_file.with_suffix('.json')
        
        if json_file.exists():
            # Get pre-existing labels if available
            epoch_labels = get_json_labels(json_file)
        else:
            print(f"No metadata file found for {fdt_file}")
            epoch_labels = {}
        
        # Load binary data
        try:
            raw_data = np.fromfile(str(fdt_file), dtype=np.float32)
            print(f"Loaded {len(raw_data)} samples from {fdt_file}")
            
            # Reshape data - needs to account for number of channels
            n_channels = 4  # Default to 4 channels (DI, OO, ZM, CS)
            
            # Check if data can be divided evenly by number of channels
            total_samples = len(raw_data)
            if total_samples % n_channels != 0:
                print(f"Warning: Data size {total_samples} not divisible by {n_channels} channels")
                # Try to adjust to the nearest multiple
                new_size = (total_samples // n_channels) * n_channels
                raw_data = raw_data[:new_size]
                print(f"Adjusted to {len(raw_data)} samples")
            
            # Reshape to (channels, samples)
            samples_per_channel = len(raw_data) // n_channels
            data_reshaped = raw_data.reshape(samples_per_channel, n_channels).T
            print(f"Reshaped to {data_reshaped.shape} (channels, samples)")
            
            # Extract ZM channel (index 2)
            zm_channel_idx = 2
            zm_data = data_reshaped[zm_channel_idx]
            print(f"Extracted ZM channel with {len(zm_data)} samples")
            
            # Split into epochs of 435 samples
            epoch_length = 435
            n_epochs = len(zm_data) // epoch_length
            print(f"Creating {n_epochs} epochs of length {epoch_length}")
            
            for i in range(n_epochs):
                start = i * epoch_length
                end = start + epoch_length
                if end <= len(zm_data):
                    epoch = zm_data[start:end]
                    all_epochs.append(epoch)
                    epoch_names.append(f"{fdt_file.stem}_epoch{i}")
                    
                    # Get label for this specific epoch if available
                    if i in epoch_labels:
                        pre_labels.append(epoch_labels[i])
                    else:
                        pre_labels.append(None)
            
        except Exception as e:
            print(f"Error loading .fdt file: {str(e)}")
            import traceback
            traceback.print_exc()
    
    print(f"Loaded {len(all_epochs)} total epochs")
    return all_epochs, epoch_names, pre_labels

def threshold_data(signal_data, fs=2000):
    """Apply threshold-based detection to EMG signal"""
    # Preprocess signal
    b, a = signal.butter(4, [10/(fs/2), 500/(fs/2)], btype='bandpass')
    filtered = signal.filtfilt(b, a, signal_data)
    rectified = np.abs(filtered)
    
    # Generate envelope
    b_env, a_env = signal.butter(4, 10/(fs/2), btype='lowpass')
    envelope = signal.filtfilt(b_env, a_env, rectified)
    
    # Calculate threshold
    sorted_env = np.sort(envelope)
    noise_floor = np.mean(sorted_env[:int(len(sorted_env)*0.2)])
    noise_std = np.std(sorted_env[:int(len(sorted_env)*0.2)])
    
    # Try multiple threshold multipliers
    thresholds = [3.0, 2.0, 4.0, 5.0]
    results = {}
    
    for mult in thresholds:
        threshold = noise_floor + mult * noise_std
        
        # Apply threshold
        above_threshold = envelope > threshold
        
        # Find onsets and offsets
        changes = np.diff(np.concatenate([[0], above_threshold.astype(int), [0]]))
        onsets = np.where(changes == 1)[0]
        offsets = np.where(changes == -1)[0] - 1
        
        # Filter short activations (< 30ms)
        min_samples = int(0.03 * fs)
        valid_onsets = []
        valid_offsets = []
        
        for j in range(len(onsets)):
            if j < len(offsets):
                duration = offsets[j] - onsets[j]
                if duration >= min_samples:
                    valid_onsets.append(onsets[j])
                    valid_offsets.append(offsets[j])
        
        results[mult] = {
            'threshold': threshold,
            'onsets': valid_onsets,
            'offsets': valid_offsets
        }
    
    # Select best threshold
    best_mult = None
    for mult in thresholds:
        if 1 <= len(results[mult]['onsets']) <= 3:
            best_mult = mult
            break
    
    if best_mult is None:
        best_mult = 3.0  # Default
    
    return {
        'signal': signal_data,
        'envelope': envelope,
        'best_threshold': best_mult,
        'threshold_value': results[best_mult]['threshold'],
        'onsets': results[best_mult]['onsets'],
        'offsets': results[best_mult]['offsets'],
        'all_results': results
    }

def visualize_separate_bursts(result, filename, pre_label, output_dir, fs=2000):
    """
    Create two separate plots in the same image showing pre-labels and threshold labels
    
    Args:
        result: Dictionary with detection results
        filename: Filename for output
        pre_label: Pre-existing label (onset, offset) pair
        output_dir: Output directory
        fs: Sampling frequency
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # Extract data
    signal = result['signal']
    onsets = result['onsets']
    offsets = result['offsets']
    
    # Create time vector
    time = np.arange(len(signal)) / fs
    
    # Create figure with two subplots
    fig, axs = plt.subplots(2, 1, figsize=(12, 8), sharex=True)
    
    # Plot 1: Original pre-labels
    axs[0].plot(time, signal, 'b-', label='EMG Signal')
    axs[0].set_title('Original Detection (Pre-labels)')
    axs[0].set_ylabel('Amplitude')
    axs[0].grid(True)
    
    # Mark pre-existing activations if available
    has_prelabels = False
    if pre_label is not None:
        onset, offset = pre_label
        # Make sure both values are valid
        if onset is not None and offset is not None:
            # Make sure values are within valid range
            if 0 <= onset < len(time) and 0 <= offset < len(time):
                axs[0].axvspan(time[onset], time[offset], 
                              color='c', alpha=0.4, label='Detected Burst')
                has_prelabels = True
    
    if not has_prelabels:
        axs[0].text(0.5, 0.5, 'No pre-labels found', 
                   horizontalalignment='center', verticalalignment='center',
                   transform=axs[0].transAxes, fontsize=14)
    
    axs[0].legend(loc='upper right')
    
    # Plot 2: New thresholding labels
    axs[1].plot(time, signal, 'b-', label='EMG Signal')
    axs[1].set_title('New Detection (Threshold-based)')
    axs[1].set_xlabel('Time (s)')
    axs[1].set_ylabel('Amplitude')
    axs[1].grid(True)
    
    # Mark new activations
    for i in range(len(onsets)):
        if i < len(offsets):
            axs[1].axvspan(time[onsets[i]], time[offsets[i]],
                        color='y', alpha=0.4, label='Detected Burst' if i == 0 else None)
    
    if not onsets:
        axs[1].text(0.5, 0.5, 'No bursts detected', 
                   horizontalalignment='center', verticalalignment='center',
                   transform=axs[1].transAxes, fontsize=14)
    
    axs[1].legend(loc='upper right')
    
    plt.suptitle(f'Burst Detection Comparison - {filename}', fontsize=16)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f'{filename}_separate_burst_comparison.png'))
    plt.close()
    
    # Save results
    json_results = {
        'filename': filename,
        'new_onsets': [int(onset) for onset in onsets],
        'new_offsets': [int(offset) for offset in offsets],
        'new_onsets_time': [float(onset/fs) for onset in onsets],
        'new_offsets_time': [float(offset/fs) for offset in offsets]
    }
    
    # Add original labels if available
    if pre_label:
        json_results['original_onset'] = int(pre_label[0]) if pre_label[0] is not None else None
        json_results['original_offset'] = int(pre_label[1]) if pre_label[1] is not None else None
        json_results['original_onset_time'] = float(pre_label[0]/fs) if pre_label[0] is not None else None
        json_results['original_offset_time'] = float(pre_label[1]/fs) if pre_label[1] is not None else None
    
    with open(os.path.join(output_dir, f'{filename}_burst_labels.json'), 'w') as f:
        json.dump(json_results, f, indent=2)
    
    return json_results
def save_emg_data_and_labels(signal, filename, pre_label, output_dir, fs=2000):
    """
    Save EMG signal data in NPZ format and labels in JSON format
    
    Args:
        signal: EMG signal data
        filename: Filename for output
        pre_label: Pre-existing label (onset, offset) pair
        output_dir: Output directory
        fs: Sampling frequency
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # Create time vector
    time = np.arange(len(signal)) / fs
    
    # Save signal data in NPZ format
    npz_path = os.path.join(output_dir, f'{filename}_emg.npz')
    np.savez(npz_path, signal=signal, time=time, fs=fs)
    
    # Create JSON with labels
    json_results = {
        'filename': filename,
        'signal_length': len(signal),
        'sampling_rate': fs
    }
    
    # Add original labels if available
    if pre_label:
        onset, offset = pre_label
        if onset is not None and offset is not None:
            json_results['onset'] = int(onset)
            json_results['offset'] = int(offset)
            json_results['onset_time'] = float(onset/fs)
            json_results['offset_time'] = float(offset/fs)
            json_results['has_valid_label'] = True
        else:
            json_results['has_valid_label'] = False
    else:
        json_results['has_valid_label'] = False
    
    # Save label to JSON file
    json_path = os.path.join(output_dir, f'{filename}_labels.json')
    with open(json_path, 'w') as f:
        json.dump(json_results, f, indent=2)
    
    return {
        'npz_path': npz_path,
        'json_path': json_path,
        'has_valid_label': json_results['has_valid_label']
    }
def main():
    """Main function"""
    data_dir = r"C:\EMG_onset_detection\LOL_project\continuous_EMG_data"
    output_dir = r"C:\EMG_onset_detection\LOL_project\continuous_prediction_data"
    os.makedirs(output_dir, exist_ok=True)
    
    # Load EMG data
    print("Loading EMG data...")
    epochs, epoch_names, pre_labels = load_emg_data(data_dir)
    
    if not epochs:
        print("No epochs loaded. Exiting.")
        return
    
    # Process each epoch
    print("Processing epochs...")
    all_results = []
    
    for i, (epoch, name, pre_label) in enumerate(zip(epochs, epoch_names, pre_labels)):
        print(f"Processing epoch {i+1}/{len(epochs)}: {name}")
        
        # Apply threshold detection
        result = threshold_data(epoch)
        
        # Visualize and save results
        json_result = save_emg_data_and_labels(result, name, pre_label, output_dir)
        all_results.append(json_result)
        
        if (i+1) % 10 == 0:
            print(f"Processed {i+1}/{len(epochs)} epochs")
    
    # Save all results
    with open(os.path.join(output_dir, 'all_burst_labels.json'), 'w') as f:
        json.dump(all_results, f, indent=2)
    
    # Generate statistics
    total_epochs = len(all_results)
    epochs_with_activity = sum(1 for result in all_results if 'new_onsets' in result and result['new_onsets'])
    total_onsets = sum(len(result['new_onsets']) for result in all_results if 'new_onsets' in result)
    
    summary = {
        "total_epochs": total_epochs,
        "epochs_with_activity": epochs_with_activity,
        "percentage_with_activity": (epochs_with_activity / total_epochs) * 100 if total_epochs > 0 else 0,
        "total_onsets_detected": total_onsets,
        "average_onsets_per_epoch": total_onsets / total_epochs if total_epochs > 0 else 0
    }
    
    with open(os.path.join(output_dir, 'burst_detection_summary.json'), 'w') as f:
        json.dump(summary, f, indent=2)
    
    print("\nSummary Statistics:")
    print(f"Total epochs processed: {total_epochs}")
    print(f"Epochs with activity: {epochs_with_activity} ({summary['percentage_with_activity']:.1f}%)")
    print(f"Total onsets detected: {total_onsets}")
    print(f"Average onsets per epoch: {summary['average_onsets_per_epoch']:.2f}")

if __name__ == "__main__":
    main()

Loading EMG data...
Processing: C:\EMG_onset_detection\LOL_project\continuous_EMG_data\EMG_continuous_01.fdt
No metadata file found for C:\EMG_onset_detection\LOL_project\continuous_EMG_data\EMG_continuous_01.fdt
Loaded 3993600 samples from C:\EMG_onset_detection\LOL_project\continuous_EMG_data\EMG_continuous_01.fdt
Reshaped to (4, 998400) (channels, samples)
Extracted ZM channel with 998400 samples
Creating 2295 epochs of length 435
Processing: C:\EMG_onset_detection\LOL_project\continuous_EMG_data\EMG_continuous_10.fdt
No metadata file found for C:\EMG_onset_detection\LOL_project\continuous_EMG_data\EMG_continuous_10.fdt
Loaded 4060160 samples from C:\EMG_onset_detection\LOL_project\continuous_EMG_data\EMG_continuous_10.fdt
Reshaped to (4, 1015040) (channels, samples)
Extracted ZM channel with 1015040 samples
Creating 2333 epochs of length 435
Loaded 4628 total epochs
Processing epochs...
Processing epoch 1/4628: EMG_continuous_01_epoch0
Processing epoch 2/4628: EMG_continuous_01_epo