## bpm normal

In [None]:
import json
import struct
import numpy as np
import pywt
import scipy.signal as sp_signal
import matplotlib.pyplot as plt
import base64   
from Crypto.Cipher import AES
from Crypto.Util.Padding import unpad


def decrypt(input_file):
    """Decrypt encrypted JSON file"""
    private_key = "Msz377xMbcn++vrcDel9vxOuEss8fsWO"
    cipher = AES.new(private_key.encode('utf-8'), AES.MODE_ECB)
    with open(input_file, 'rb') as f:
        encrypted_data = f.read()
    enc = base64.b64decode(encrypted_data[24:])
    data = unpad(cipher.decrypt(enc), 16)
    decoded_string = data.decode('utf-8')
    return json.loads(decoded_string)

    
def baseline_wander(X):
    def get_median_filter_width(sampling_rate, duration):
        res = int(sampling_rate * duration)
        res += (res % 2) - 1
        return res

    ms_flt_array = [0.2, 0.6]
    mfa = np.zeros(len(ms_flt_array), dtype="int")
    for i in range(0, len(ms_flt_array)):
        mfa[i] = get_median_filter_width(500, ms_flt_array[i])
    X0 = X
    for mi in range(0, len(mfa)):
        X0 = sp_signal.medfilt(X0, mfa[mi])
    X0 = np.subtract(X, X0)
    return X0


def normalize(signal, min_val, max_val):
    """Normalize signal to range [0, 1]."""
    if max_val - min_val == 0:
        return np.zeros_like(signal)
    return (signal - min_val) / (max_val - min_val)


def process_signal(signal_data, min_val, max_val):
    data = normalize(signal_data, min_val, max_val)
    return data


def data_process(filename):
    keys = ['dataL2']
    datas = []
    
    # print(f"Type of filename['dataL2']: {type(filename['dataL2'])}")
    # print(f"Content of filename['dataL2']: {filename['dataL2']}")
    
    for key in keys:
        sig = np.array(filename[key])
        datas.append(sig.astype('float32'))
    
    datas_array = np.array(datas)
    print(datas_array.shape)
    min_val = np.min(datas_array)
    max_val = np.max(datas_array)
    
    signal = []
    for i in range(datas_array.shape[0]):
        signal.append(process_signal(datas_array[i, :], min_val, max_val))

    final_data = np.stack(signal)
    final_data = np.expand_dims(final_data, axis=0)
    final_data = final_data.transpose(0, 2, 1)
    print(final_data)
    return final_data


def remove_close_peaks(r_peaks, validation_signal, min_dist_samples):
    """Remove peaks that are too close together, keeping the stronger one"""
    if len(r_peaks) == 0: 
        return np.array([])
    
    sorted_idx = np.argsort(r_peaks)
    r_peaks = r_peaks[sorted_idx]
    validation_abs = np.abs(validation_signal[r_peaks.astype(int)])
    
    keep = []
    last_kept = -min_dist_samples
    
    for i, current in enumerate(r_peaks):
        if current - last_kept >= min_dist_samples:
            keep.append(current)
            last_kept = current
        else:
            if validation_abs[i] > validation_abs[len(keep)-1]:
                keep[-1] = current
                last_kept = current
    
    return np.array(keep)


def amplitude_based_filtering(ecg_signal, peaks, segment_num="Unknown"):
    """Filter out high amplitude outlier peaks using IQR method"""
    if len(peaks) == 0:
        return peaks, np.array([])
    
    peak_amplitudes = np.abs(ecg_signal[peaks.astype(int)])
    
    median_amp = np.median(peak_amplitudes)
    q75, q25 = np.percentile(peak_amplitudes, [75, 25])
    iqr = q75 - q25
    
    if iqr > 0:
        high_amp_threshold = q75 + 1.5 * iqr
        
        high_amp_indices = np.where(peak_amplitudes > high_amp_threshold)[0]
        high_amp_count = len(high_amp_indices)
        
        # If we have 2-3 outlier peaks, remove them
        # if 2 <= high_amp_count <= 3:
        # if 3 <= high_amp_count <= 4 and len(peaks) - high_amp_count > 0:
        # if 5 <= high_amp_count <= 6 and len(peaks) - high_amp_count > 0:
        if len(peaks) - high_amp_count > 0:
            mask = np.ones(len(peaks), dtype=bool)
            # mask[high_amp_indices] = False
            mask[high_amp_indices] = True
            cleaned_peaks = peaks[mask]
            cleaned_amplitudes = peak_amplitudes[mask]
        else:
            cleaned_peaks = peaks
            cleaned_amplitudes = peak_amplitudes
    else:
        cleaned_peaks = peaks
        cleaned_amplitudes = peak_amplitudes
    
    return cleaned_peaks, cleaned_amplitudes


def remove_t_waves(ecg_signal, peaks, sampling_rate):
    """Remove T-wave false positives based on timing and morphology"""
    if len(peaks) < 3:
        return peaks
    
    sorted_peaks = np.sort(peaks)
    cleaned_peaks = []
    
    for i, peak in enumerate(sorted_peaks):
        is_r_peak = True
        
        if i > 0:
            prev_peak = sorted_peaks[i-1]
            interval_ms = (peak - prev_peak) / sampling_rate * 1000
            
            # Check if this could be a T-wave (150-450ms after R-peak)
            if 160 < interval_ms < 450:
                prev_amp = abs(ecg_signal[int(prev_peak)])
                curr_amp = abs(ecg_signal[int(peak)])
                
                # T-waves are typically smaller and wider
                if curr_amp < prev_amp * 0.5:
                    half_max = curr_amp * 0.5
                    
                    # Measure width at half maximum
                    left = peak
                    while left > 0 and left > peak - 100:
                        if abs(ecg_signal[int(left)]) < half_max:
                            break
                        left -= 1
                    
                    right = peak
                    while right < len(ecg_signal) - 1 and right < peak + 100:
                        if abs(ecg_signal[int(right)]) < half_max:
                            break
                        right += 1
                    
                    width_ms = (right - left) / sampling_rate * 1000
                    
                    # T-waves are wider than QRS complexes
                    if width_ms > 40:
                        is_r_peak = False
        
        if is_r_peak:
            cleaned_peaks.append(peak)
    
    return np.array(cleaned_peaks)


def robust_qrs_detect_internal(data_clean, sampling_rate):
    """Multi-strategy robust QRS detection for difficult cases"""
    original_data = data_clean.copy()
    nyquist = 0.5 * sampling_rate
    
    # Calculate sharpness threshold
    low_strict = 10 / nyquist
    high_strict = 40 / nyquist
    b2, a2 = sp_signal.butter(2, [low_strict, high_strict], btype='band')
    filtered_strict = sp_signal.filtfilt(b2, a2, data_clean)
    diff_strict = np.diff(np.abs(filtered_strict))
    diff_strict = np.append(diff_strict, 0)
    strict_score = diff_strict ** 2
    
    if len(strict_score) > 0:
        sharpness_threshold = np.percentile(strict_score, 94)
    else:
        sharpness_threshold = 0
    
    all_candidate_peaks = []
    
    # Strategy 1: Multi-band detection with multiple thresholds
    freq_bands = [(5, 15), (8, 24), (10, 30), (12, 40)]
    
    for low_freq, high_freq in freq_bands:
        low = low_freq / nyquist
        high = high_freq / nyquist
        b, a = sp_signal.butter(2, [low, high], btype='band')
        filtered = sp_signal.filtfilt(b, a, data_clean)
        
        squared = filtered ** 2
        window_size = int(0.15 * sampling_rate)
        integrated = np.convolve(squared, np.ones(window_size) / window_size, mode='same')
        
        mean_val = np.mean(integrated)
        std_val = np.std(integrated)
        
        thresholds = [mean_val + 0.1 * std_val, mean_val + 0.2 * std_val, mean_val + 0.3 * std_val]
        
        for threshold in thresholds:
            candidates, _ = sp_signal.find_peaks(
                integrated, 
                height=threshold,
                distance=int(0.2 * sampling_rate)
            )
            
            search_window = int(0.1 * sampling_rate)
            sharp_window = int(0.18 * sampling_rate)
            
            for peak in candidates:
                start_sharp = max(0, peak - sharp_window)
                end_sharp = min(len(strict_score), peak + sharp_window)
                if start_sharp < end_sharp:
                    local_sharpness = np.max(strict_score[start_sharp:end_sharp])
                    
                    if local_sharpness > sharpness_threshold:
                        start = max(0, peak - search_window)
                        end = min(len(original_data), peak + search_window)
                        if start < end:
                            local_segment = original_data[start:end]
                            local_max_idx = np.argmax(np.abs(local_segment))
                            refined_peak = start + local_max_idx
                            all_candidate_peaks.append(refined_peak)
    
    # Strategy 2: Prominence-based detection
    peaks_prom, properties = sp_signal.find_peaks(
        original_data,
        distance=int(0.2 * sampling_rate),
        prominence=0.02
    )
    
    sharp_window = int(0.18 * sampling_rate)
    for peak in peaks_prom:
        start_sharp = max(0, peak - sharp_window)
        end_sharp = min(len(strict_score), peak + sharp_window)
        if start_sharp < end_sharp:
            local_sharpness = np.max(strict_score[start_sharp:end_sharp])
            if local_sharpness > sharpness_threshold * 0.8:
                all_candidate_peaks.append(peak)
    
    # Strategy 3: Derivative-based detection
    diff_signal = np.diff(original_data)
    diff_squared = diff_signal ** 2
    diff_squared = np.append(diff_squared, 0)
    
    mean_diff = np.mean(diff_squared)
    std_diff = np.std(diff_squared)
    
    diff_peaks, _ = sp_signal.find_peaks(
        diff_squared,
        height=mean_diff + 0.5 * std_diff,
        distance=int(0.15 * sampling_rate)
    )
    
    search_window = int(0.08 * sampling_rate)
    
    for peak in diff_peaks:
        start_sharp = max(0, peak - sharp_window)
        end_sharp = min(len(strict_score), peak + sharp_window)
        if start_sharp < end_sharp:
            local_sharpness = np.max(strict_score[start_sharp:end_sharp])
            
            if local_sharpness > sharpness_threshold * 0.7:
                start = max(0, peak - search_window)
                end = min(len(original_data), peak + search_window)
                if start < end:
                    local_segment = original_data[start:end]
                    local_max_idx = np.argmax(np.abs(local_segment))
                    refined_peak = start + local_max_idx
                    all_candidate_peaks.append(refined_peak)
    
    # Merge and deduplicate peaks
    if len(all_candidate_peaks) > 0:
        all_candidate_peaks = np.unique(all_candidate_peaks)
        
        min_distance = int(0.15 * sampling_rate)
        sorted_peaks = np.sort(all_candidate_peaks)
        
        if len(sorted_peaks) > 0:
            keep_mask = [True]
            for i in range(1, len(sorted_peaks)):
                if sorted_peaks[i] - sorted_peaks[i-1] >= min_distance:
                    keep_mask.append(True)
                else:
                    start1 = max(0, sorted_peaks[i-1] - sharp_window)
                    end1 = min(len(strict_score), sorted_peaks[i-1] + sharp_window)
                    start2 = max(0, sorted_peaks[i] - sharp_window)
                    end2 = min(len(strict_score), sorted_peaks[i] + sharp_window)
                    
                    sharp1 = np.max(strict_score[start1:end1]) if start1 < end1 else 0
                    sharp2 = np.max(strict_score[start2:end2]) if start2 < end2 else 0
                    
                    if sharp2 > sharp1:
                        keep_mask[-1] = False
                        keep_mask.append(True)
                    else:
                        keep_mask.append(False)
            
            sorted_peaks = sorted_peaks[keep_mask]
    
    return sorted_peaks if len(all_candidate_peaks) > 0 else np.array([])


def qrs_detect(data, sampling_rate, segment_duration=None):
    """
    Enhanced QRS detection with Amplitude Guardrails for AV Blocks
    """
    # Apply baseline wander removal
    # data_clean = baseline_wander(data) 
    data_clean = data # Keeping your override

    original_data = data_clean.copy()
    nyquist = 0.5 * sampling_rate
    
    # --- STREAM 1: Standard Detection ---
    low = 8 / nyquist
    high = 24 / nyquist
    b, a = sp_signal.butter(2, [low, high], btype='band')
    filtered_standard = sp_signal.filtfilt(b, a, data_clean)
    
    filtered_abs = np.abs(filtered_standard)
    diff = np.diff(filtered_abs)
    diff = np.append(diff, 0)
    squared = diff ** 2
    
    window_size = int(0.15 * sampling_rate)
    integrated = np.convolve(squared, np.ones(window_size) / window_size, mode='same')
    
    mean_val = np.mean(integrated)
    std_val = np.std(integrated)
    threshold = mean_val + 0.20 * std_val
    
    candidates, _ = sp_signal.find_peaks(
        integrated,
        height=threshold,
        distance=int(0.12 * sampling_rate)
    )
    
    # --- STREAM 2: Sharpness Validator ---
    low_strict = 10 / nyquist
    high_strict = 40 / nyquist
    b2, a2 = sp_signal.butter(2, [low_strict, high_strict], btype='band')
    filtered_strict = sp_signal.filtfilt(b2, a2, data_clean)
    diff_strict = np.diff(np.abs(filtered_strict))
    diff_strict = np.append(diff_strict, 0)
    strict_score = diff_strict ** 2
    
    if len(strict_score) > 0:
        sharpness_threshold = np.percentile(strict_score, 94)
    else:
        sharpness_threshold = 0
    
    confirmed_peaks = []
    search_window = int(0.18 * sampling_rate)
    
    for peak in candidates:
        start_check = max(0, peak - search_window)
        end_check = min(len(strict_score), peak + search_window)
        if start_check >= end_check:
            continue
            
        local_sharpness = np.max(strict_score[start_check:end_check])
        
        if local_sharpness > sharpness_threshold:
            local_segment = original_data[start_check:end_check]
            if len(local_segment) > 0:
                abs_local_segment = np.abs(local_segment)
                local_max_idx = np.argmax(abs_local_segment)
                confirmed_peaks.append(start_check + local_max_idx)
    
    r_peaks = np.array(confirmed_peaks)
    
    # Remove close peaks
    min_dist = int(0.15 * sampling_rate)
    r_peaks = remove_close_peaks(r_peaks, original_data, min_dist)
    
    cleaned_r = np.sort(np.array([x for x in r_peaks if not (isinstance(x, float) and np.isnan(x))]))
    
    # =================================================================
    # CRITICAL FIX: GAP FILLING WITH AMPLITUDE GUARDRAILS
    # =================================================================
    if len(cleaned_r) >= 2:
        # Calculate reference height (Median of existing peaks)
        existing_heights = np.abs(original_data[cleaned_r.astype(int)])
        median_r_height = np.median(existing_heights) if len(existing_heights) > 0 else 0
        
        rr_intervals = np.diff(cleaned_r) / sampling_rate
        median_rr = np.median(rr_intervals) if len(rr_intervals) > 0 else 1.0
        new_peaks = list(cleaned_r)
        
        # Only fill gaps if median_rr suggests a normal rhythm (< 1.5s).
        # If median_rr is already 2.0s (bradycardia), huge gaps are normal.
        if median_rr < 1.5: 
            for i in range(len(rr_intervals)):
                if rr_intervals[i] > 1.4 * median_rr:
                    gap_start = cleaned_r[i]
                    gap_end = cleaned_r[i+1]
                    if gap_start >= gap_end:
                        continue
                        
                    gap_integrated = integrated[gap_start:gap_end]
                    # Lower threshold slightly for gap search
                    low_thresh = mean_val * 0.6 
                    
                    gap_candidates, _ = sp_signal.find_peaks(
                        gap_integrated,
                        height=low_thresh,
                        distance=int(0.10 * sampling_rate)
                    )
                    
                    for gc in gap_candidates:
                        abs_idx = gap_start + gc
                        sw_start = max(0, abs_idx - search_window)
                        sw_end = min(len(strict_score), abs_idx + search_window)
                        if sw_start >= sw_end:
                            continue
                            
                        # 1. Check Sharpness
                        local_sharp_max = np.max(strict_score[sw_start:sw_end])
                        if local_sharp_max > sharpness_threshold * 0.4:
                            
                            # 2. Refine Position
                            local_segment = original_data[sw_start:sw_end]
                            abs_local_segment = np.abs(local_segment)
                            refine_idx = np.argmax(abs_local_segment)
                            candidate_peak = sw_start + refine_idx
                            
                            # 3. AMPLITUDE CHECK (The Fix)
                            # Even if it's sharp, is it tall enough?
                            # AV Block P-waves are sharp but short.
                            candidate_amp = np.abs(original_data[candidate_peak])
                            
                            # Must be at least 40-50% of the median R-peak height
                            if candidate_amp > 0.45 * median_r_height:
                                new_peaks.append(candidate_peak)

        new_peaks = np.sort(np.unique(new_peaks))
        cleaned_r = remove_close_peaks(new_peaks, original_data, min_dist)
    
    # =================================================================

    # Determine expected peak count range
    if segment_duration is None:
        segment_duration = len(data_clean) / sampling_rate
    
    # Relaxed expectations for Bradycardia/AV Block
    min_expected_peaks = int(30/60 * segment_duration) 
    max_expected_peaks = int(180/60 * segment_duration)
    
    # Fallback to robust only if counts are extremely off
    if len(cleaned_r) < min_expected_peaks or len(cleaned_r) > max_expected_peaks:
        initial_peaks = robust_qrs_detect_internal(data_clean, sampling_rate)
        initial_peaks = remove_t_waves(data_clean, initial_peaks, sampling_rate)
        cleaned_r, peak_amplitudes = amplitude_based_filtering(data_clean, initial_peaks, "Segment")
    else:
        cleaned_r = remove_t_waves(data_clean, cleaned_r, sampling_rate)
    
    # Calculate BPM
    if len(cleaned_r) > 1:
        rr_intervals = np.diff(cleaned_r) / sampling_rate
        
        # Valid intervals widened to account for Bradycardia/Pauses
        # valid_rr = rr_intervals[(rr_intervals > 0.2) & (rr_intervals < 3.5)] 
        valid_rr = rr_intervals[(rr_intervals > 0.2) & (rr_intervals < 4.0)] 
        
        if len(valid_rr) > 0:
            mean_rr = np.mean(valid_rr)
            bpm = 60 / mean_rr if mean_rr > 0 else 0
        else:
            bpm = 0
    else:
        bpm = 0
    
    return data, cleaned_r, bpm, cleaned_r



def process_ecg_segments(ecg_data, sampling_rate, num_segments=7, min_segment_length=3500):
    max_len = len(ecg_data)
    
    if num_segments > 1:
        window_step = (max_len - min_segment_length) / (num_segments - 1)
        window_step = round(window_step)
    else:
        window_step = 0
    
    results = []
    
    for i in range(num_segments):
        start_idx = i * window_step
        end_idx = start_idx + min_segment_length
        
        if end_idx > max_len:
            start_idx = max_len - min_segment_length
            end_idx = max_len
            
        if start_idx < 0:
            start_idx = 0
            end_idx = min(min_segment_length, max_len)
        
        segment = ecg_data[start_idx:end_idx]
        
        if len(segment) < 100:
            results.append({
                'segment_num': i + 1,
                'start_idx': start_idx,
                'end_idx': end_idx,
                'ecg_filtered': np.array([]),
                'r_peaks': np.array([]),
                'bpm': 0,
                'cleaned_r': np.array([]),
                'ecg_raw': segment
            })
            continue
        
        segment_duration = len(segment) / sampling_rate
        ecg_filtered, r_peaks, bpm, cleaned_r = qrs_detect(segment, sampling_rate, segment_duration)
        print(f"Segment {i+1}: Detected {len(r_peaks)} R-peaks, BPM: {bpm:.1f}")
        
        adjusted_r_peaks = r_peaks + start_idx if len(r_peaks) > 0 else np.array([])
        adjusted_cleaned_r = np.array(cleaned_r) + start_idx if len(cleaned_r) > 0 else np.array([])
        
        results.append({
            'segment_num': i + 1,
            'start_idx': start_idx,
            'end_idx': end_idx,
            'ecg_filtered': ecg_filtered,
            'r_peaks': adjusted_r_peaks,
            'bpm': bpm,
            'cleaned_r': adjusted_cleaned_r,
            'ecg_raw': segment
        })
    
    return results


def plot_ecg_segments(ecg_data, sampling_rate, results, title="ECG Segments with R-peaks and BPM"):
    num_segments = len(results)
    fig, axes = plt.subplots(num_segments, 1, figsize=(15, 3*num_segments))
    
    if num_segments == 1:
        axes = [axes]
    
    time = np.arange(len(ecg_data)) / sampling_rate
    
    for i, (ax, result) in enumerate(zip(axes, results)):
        segment_num = result['segment_num']
        start_idx = result['start_idx']
        end_idx = result['end_idx']
        bpm = result['bpm']
        r_peaks = result['r_peaks']
        
        segment_time = time[start_idx:end_idx]
        
        ax.plot(segment_time, result['ecg_raw'], 'b-', alpha=0.7, linewidth=1, label='ECG Raw')
        
        if len(r_peaks) > 0:
            r_times = r_peaks / sampling_rate
            r_values = ecg_data[r_peaks.astype(int)]
            ax.plot(r_times, r_values, 'ro', markersize=8, label='R-peaks', alpha=0.7)
        
        segment_duration = (end_idx - start_idx) / sampling_rate
        ax.set_title(f'Segment {segment_num}: {start_idx}-{end_idx} samples '
                    f'({segment_duration:.2f}s), BPM: {bpm:.1f}')
        ax.set_xlabel('Time (s)')
        ax.set_ylabel('Amplitude')
        ax.grid(True, alpha=0.3)
        ax.set_xlim([segment_time[0], segment_time[-1]])
 
    plt.tight_layout()
    plt.show()
    
    print("="*60)
    print("ECG SEGMENT ANALYSIS SUMMARY")
    print("="*60)
    for i, result in enumerate(results):
        print(f"\nSegment {result['segment_num']}:")
        print(f"  Samples: {result['start_idx']}-{result['end_idx']}")
        print(f"  Duration: {(result['end_idx']-result['start_idx'])/sampling_rate:.2f}s")
        print(f"  BPM: {result['bpm']:.1f}")
        print(f"  R-peaks detected: {len(result['r_peaks'])}")
    print("="*60)
    
def plot_full_ecg(ecg_data, sampling_rate, title="Full ECG Signal Analysis"):
    """
    Runs detection on the entire dataset and plots a single continuous view.
    """
    # Run detection on the full unsegmented data
    # Note: We ignore segment_duration to let the function calculate it automatically
    _, r_peaks, global_bpm, _ = qrs_detect(ecg_data, sampling_rate)
    
    plt.figure(figsize=(20, 6)) # Width of 20 makes the 15k samples readable
    
    # Create time axis
    time_axis = np.arange(len(ecg_data)) / sampling_rate
    
    # Plot the signal
    plt.plot(time_axis, ecg_data, 'b-', linewidth=0.8, alpha=0.8, label='Filtered ECG')
    
    # Plot the peaks
    if len(r_peaks) > 0:
        # Filter out peaks that might be out of bounds (safety check)
        valid_peaks = r_peaks[r_peaks < len(ecg_data)].astype(int)
        
        peak_times = valid_peaks / sampling_rate
        peak_values = ecg_data[valid_peaks]
        
        plt.plot(peak_times, peak_values, 'ro', markersize=4, label='R-peaks')
        
        # Optional: Annotate every 5th peak to help navigation
        for i, (t, v) in enumerate(zip(peak_times, peak_values)):
            if i % 5 == 0:
                plt.annotate(f'{t:.1f}s', (t, v), xytext=(0, 10), 
                             textcoords='offset points', ha='center', fontsize=8, color='red')

    plt.title(f"{title} | Global BPM: {global_bpm:.1f} | Total Peaks: {len(r_peaks)}")
    plt.xlabel("Time (seconds)")
    plt.ylabel("Normalized Amplitude")
    plt.legend(loc='upper right')
    plt.grid(True, which='both', alpha=0.5)
    plt.tight_layout()
    plt.show()
    
    print(f"Global Analysis: {len(r_peaks)} peaks detected over {len(ecg_data)/sampling_rate:.2f} seconds.")
    



# ============================================================================

 
# # input_json = r"simulator\contec\bigeminy_1756103016311.json" 
# input_json = r"simulator\contec\trigeminy_1756103085272.json" 
# input_json = r"simulator\contec\asystl_1756103447146.json" 
# # input_json = r"simulator\contec\1d av_1756104504294.json" 
# input_json = r"simulator\contec\3d av_1756104633918.json" 
# input_json = r"simulator\contec\280bpm_1756100716422.json" 
# input_json = r"simulator\contec\av sequence_1756106676125.json"   #####
# input_json = r"simulator\contec\dmnd freq_1756106571373.json"
#    
# input_json = r"simulator\fluke\trigeminy_1754543043205.json"   
# input_json = r"simulator\fluke\3d av_1754545068278.json"   
input_json = r"simulator\fluke\asystole_1754544406847.json"   

with open(input_json, 'r') as file:
    file_data = json.load(file)


# ============================================================================

# input_json = r"v01_prob\teton_ecg.ecgdatas.json"  ####
# with open(input_json, 'r') as file:
#     all_id_data = json.load(file)

# file_data = all_id_data[3]['ecgValue']   



# # input_json = r"bpms\afib_1766471694144.json"
# # input_json = r"bpms\bigeminy_1766467666407.json"
# input_json = r"bpms\pvc 6_1766467718685.json"    ########
# # input_json = r"bpms\tri_1766467618314.json"
# # input_json = r"v01_prob/220_1767858669130.json"
# # input_json = r"v01_prob/240bpm_1767858615562.json"
# # input_json = r"v01_prob\25 contec_1768375918389.json"
# # input_json = r"v01_prob\30bpm contec_1768375716454.json"
# # input_json = r"v01_prob\2d av_1754545008828.json"  #####
# # input_json = r"v01_prob\3rd_davb_1768554217066.json"  #####

# with open(input_json, 'r') as file:
#     file_data = json.load(file)



# # input_json = r"exception\L2_1759207950416.json"  #####
# input_json = r"0_bpm\L2_1760767200872.json"  
# # input_json = r"0_bpm\L2_1760254470484.json"  
# # input_json = r"0_bpm\L2_1760354748658.json"  
# # input_json = r"0_bpm\L2_1760770290911.json"  
# # input_json = r"issues\L2_1757064122874.json"  #####
# # input_json = r"v01_prob\run 5 pvc.json"  #####
# # input_json = r"issues\L2_1757579288752.json"
# # input_json = r"issues\L2_1757737806463.json"  #####
# # input_json = r"v01_prob\L2_1765984517025.json"  #####
# # input_json = r"1st-last-peaks\L2_1759908627949.json"
# # input_json = r"1st-last-peaks\L2_1759908888619.json"

# doubles = []
# with open(input_json, "rb") as f:
#     while chunk := f.read(8):
#         if len(chunk) < 8:
#             break
#         value = struct.unpack("<d", chunk)[0]
#         doubles.append(value)

# file_data = {'dataL2': doubles}   



# def decrypt(input_file):
#     """Decrypt encrypted JSON file (optional - commented out in your version)"""
#     private_key = "Msz377xMbcn++vrcDel9vxOuEss8fsWO"
#     cipher = AES.new(private_key.encode('utf-8'), AES.MODE_ECB)
#     with open(input_file, 'rb') as f:
#         encrypted_data = f.read()
#     enc = base64.b64decode(encrypted_data[24:])
#     data = unpad(cipher.decrypt(enc), 16)
#     decoded_string = data.decode('utf-8')
#     return json.loads(decoded_string)

# # input_json = r"NHF2\DATA_1750689015865.json"
# # # input_json = r"NHF2\DATA_1750689460556.json"
# # # input_json = r"NHF2\DATA_1750851207409.json"
# # # input_json = r"NHF2\DATA_1750858856842.json"
# # # input_json = r"NHF2\DATA_1750862721789.json"
# # # input_json = r"NHF2\DATA_1750996455820.json"
# # file_data = decrypt(input_json)

# # input_json = r"NHF\DATA_1752067426678.json"  #####
# input_json = r"NHF\DATA_1752121970835.json"  ########
# # input_json = r"NHF\DATA_1754709586876.json"  #####
# # input_json = r"NHF\DATA_1754729551054.json"
# file_data = decrypt(input_json)



# def decrypt(input_file):
#     """Decrypt encrypted CSV file using AES ECB mode"""
#     Private_key = "Msz377xMbcn++vrcDel9vxOuEss8fsWO"
    
#     cipher = AES.new(Private_key.encode(), AES.MODE_ECB)
#     with open(input_file, 'rb') as f:
#         encrypted_data = f.read()

#     enc = base64.b64decode(encrypted_data[24:])
#     cipher = AES.new(Private_key.encode('utf-8'), AES.MODE_ECB)
#     data = unpad(cipher.decrypt(enc), 16)

#     decoded_string = data.decode('utf-8')
#     data_list = decoded_string.split(",")
#     float_list = [float(x) for x in data_list]

#     return float_list

# selected_path = "v01_prob\ECG_1735798172211.csv"  ####
# # selected_path = "v01_prob\ECG_L2_1738637533455.csv"
# file_data = decrypt(selected_path)
# file_data = {'dataL2': file_data}


# ====================================================================


def low_pass_filter(data):
    try:
        return sp_signal.filtfilt(b_lp, a_lp, data)
    except:
        return data


def notch_filter(data):
    try:
        return sp_signal.filtfilt(b_notch, a_notch, data)
    except:
        return data

    

# data = data_process(file_data)
data = data_process(low_pass_filter(notch_filter(file_data)))

ecg_full = data[0, :15000, 0]
# ecg_full = data[0, :15000, 0]
sampling_rate = 500

results = process_ecg_segments(
    ecg_data=ecg_full,
    sampling_rate=sampling_rate,
    num_segments=4,
    min_segment_length=4500
)


plot_ecg_segments(ecg_full, sampling_rate, results, "ECG Analysis: 7 Segments with R-peak Detection")

# 2. Run the Full Data Plot (New logic)

print("\n--- Plotting Full Data ---")
plot_full_ecg(ecg_full, sampling_rate, "Final Full Data View")

## bpm normal & death

In [None]:
import json
import struct
import numpy as np
import pywt
import scipy.signal as sp_signal
import matplotlib.pyplot as plt
import base64   
from Crypto.Cipher import AES
from Crypto.Util.Padding import unpad

# Define filter coefficients if not defined
fs = 500  # sampling rate
nyq = 0.5 * fs

# Example low pass filter (cutoff 40 Hz)
low_cutoff = 40 / nyq
b_lp, a_lp = sp_signal.butter(4, low_cutoff, btype='low')

# Example notch filter (50 Hz)
q = 30
w0 = 50 / nyq
b_notch, a_notch = sp_signal.iirnotch(w0, q)

def decrypt(input_file):
    """Decrypt encrypted JSON file"""
    private_key = "Msz377xMbcn++vrcDel9vxOuEss8fsWO"
    cipher = AES.new(private_key.encode('utf-8'), AES.MODE_ECB)
    with open(input_file, 'rb') as f:
        encrypted_data = f.read()
    enc = base64.b64decode(encrypted_data[24:])
    data = unpad(cipher.decrypt(enc), 16)
    decoded_string = data.decode('utf-8')
    return json.loads(decoded_string)

    
def baseline_wander(X):
    def get_median_filter_width(sampling_rate, duration):
        res = int(sampling_rate * duration)
        res += (res % 2) - 1
        return res

    ms_flt_array = [0.2, 0.6]
    mfa = np.zeros(len(ms_flt_array), dtype="int")
    for i in range(0, len(ms_flt_array)):
        mfa[i] = get_median_filter_width(500, ms_flt_array[i])
    X0 = X
    for mi in range(0, len(mfa)):
        X0 = sp_signal.medfilt(X0, mfa[mi])
    X0 = np.subtract(X, X0)
    return X0


def normalize(signal, min_val, max_val):
    """Normalize signal to range [0, 1]."""
    if max_val - min_val == 0:
        return np.zeros_like(signal)
    return (signal - min_val) / (max_val - min_val)


def process_signal(signal_data, min_val, max_val):
    data = normalize(signal_data, min_val, max_val)
    return data


def data_process(filename):
    keys = ['dataL2']
    datas = []
    
    for key in keys:
        sig = np.array(filename[key])
        datas.append(sig.astype('float32'))
    
    datas_array = np.array(datas)               # shape: (1, length) or (channels, length)
    
    # ── Compute real (raw) statistics here ───────────────────────────────
    raw_min   = np.min(datas_array)
    raw_max   = np.max(datas_array)
    raw_mean  = np.mean(datas_array)
    raw_std   = np.std(datas_array)
    raw_var   = np.var(datas_array)
    raw_median = np.median(datas_array)
    
    print("\nRaw (pre-normalized) signal statistics:")
    print(f"  Min    = {raw_min:12.4f}")
    print(f"  Max    = {raw_max:12.4f}")
    print(f"  Mean   = {raw_mean:12.4f}")
    print(f"  Std    = {raw_std:12.4f}")
    print(f"  Var    = {raw_var:14.6f}")
    print(f"  Median = {raw_median:12.4f}")
    print(f"  Range  = {raw_max - raw_min:.4f}\n")
    
    # Now do normalization (your existing code)
    min_val = raw_min
    max_val = raw_max
    signal = []
    for i in range(datas_array.shape[0]):
        signal.append(normalize(datas_array[i, :], min_val, max_val))

    final_data = np.stack(signal)
    final_data = np.expand_dims(final_data, axis=0)
    final_data = final_data.transpose(0, 2, 1)
    
    return final_data, {
        'raw_min': raw_min, 'raw_max': raw_max, 'raw_mean': raw_mean,
        'raw_std': raw_std, 'raw_var': raw_var, 'raw_median': raw_median,
        'raw_range': raw_max - raw_min
    }, datas_array[0]  # return flattened raw for simplicity

def remove_close_peaks(r_peaks, validation_signal, min_dist_samples):
    """Remove peaks that are too close together, keeping the stronger one"""
    if len(r_peaks) == 0: 
        return np.array([])
    
    sorted_idx = np.argsort(r_peaks)
    r_peaks = r_peaks[sorted_idx]
    validation_abs = np.abs(validation_signal[r_peaks.astype(int)])
    
    keep = []
    last_kept = -min_dist_samples
    
    for i, current in enumerate(r_peaks):
        if current - last_kept >= min_dist_samples:
            keep.append(current)
            last_kept = current
        else:
            if validation_abs[i] > validation_abs[len(keep)-1]:
                keep[-1] = current
                last_kept = current
    
    return np.array(keep)


def amplitude_based_filtering(ecg_signal, peaks, segment_num="Unknown"):
    """Filter out high amplitude outlier peaks using IQR method"""
    if len(peaks) == 0:
        return peaks, np.array([])
    
    peak_amplitudes = np.abs(ecg_signal[peaks.astype(int)])
    
    median_amp = np.median(peak_amplitudes)
    q75, q25 = np.percentile(peak_amplitudes, [75, 25])
    iqr = q75 - q25
    
    if iqr > 0:
        high_amp_threshold = q75 + 1.5 * iqr
        
        high_amp_indices = np.where(peak_amplitudes > high_amp_threshold)[0]
        high_amp_count = len(high_amp_indices)
        
        # If we have 2-3 outlier peaks, remove them
        # if 2 <= high_amp_count <= 3:
        # if 3 <= high_amp_count <= 4 and len(peaks) - high_amp_count > 0:
        # if 5 <= high_amp_count <= 6 and len(peaks) - high_amp_count > 0:
        if len(peaks) - high_amp_count > 0:
            mask = np.ones(len(peaks), dtype=bool)
            # mask[high_amp_indices] = False
            mask[high_amp_indices] = True
            cleaned_peaks = peaks[mask]
            cleaned_amplitudes = peak_amplitudes[mask]
        else:
            cleaned_peaks = peaks
            cleaned_amplitudes = peak_amplitudes
    else:
        cleaned_peaks = peaks
        cleaned_amplitudes = peak_amplitudes
    
    return cleaned_peaks, cleaned_amplitudes


def remove_t_waves(ecg_signal, peaks, sampling_rate):
    """Remove T-wave false positives based on timing and morphology"""
    if len(peaks) < 3:
        return peaks
    
    sorted_peaks = np.sort(peaks)
    cleaned_peaks = []
    
    for i, peak in enumerate(sorted_peaks):
        is_r_peak = True
        
        if i > 0:
            prev_peak = sorted_peaks[i-1]
            interval_ms = (peak - prev_peak) / sampling_rate * 1000
            
            # Check if this could be a T-wave (160-450ms after R-peak)
            if 160 < interval_ms < 450:
                prev_amp = abs(ecg_signal[int(prev_peak)])
                curr_amp = abs(ecg_signal[int(peak)])
                
                # T-waves are typically smaller and wider
                if curr_amp < prev_amp * 0.5:
                    half_max = curr_amp * 0.5
                    
                    # Measure width at half maximum
                    left = peak
                    while left > 0 and left > peak - 100:
                        if abs(ecg_signal[int(left)]) < half_max:
                            break
                        left -= 1
                    
                    right = peak
                    while right < len(ecg_signal) - 1 and right < peak + 100:
                        if abs(ecg_signal[int(right)]) < half_max:
                            break
                        right += 1
                    
                    width_ms = (right - left) / sampling_rate * 1000
                    
                    # T-waves are wider than QRS complexes
                    if width_ms > 40:
                        is_r_peak = False
        
        if is_r_peak:
            cleaned_peaks.append(peak)
    
    return np.array(cleaned_peaks)


def robust_qrs_detect_internal(data_clean, sampling_rate):
    """Multi-strategy robust QRS detection for difficult cases"""
    original_data = data_clean.copy()
    nyquist = 0.5 * sampling_rate
    
    # Calculate sharpness threshold
    low_strict = 10 / nyquist
    high_strict = 40 / nyquist
    b2, a2 = sp_signal.butter(2, [low_strict, high_strict], btype='band')
    filtered_strict = sp_signal.filtfilt(b2, a2, data_clean)
    diff_strict = np.diff(np.abs(filtered_strict))
    diff_strict = np.append(diff_strict, 0)
    strict_score = diff_strict ** 2
    
    if len(strict_score) > 0:
        sharpness_threshold = np.percentile(strict_score, 94)
    else:
        sharpness_threshold = 0
    
    all_candidate_peaks = []
    
    # Strategy 1: Multi-band detection with multiple thresholds
    freq_bands = [(5, 15), (8, 24), (10, 30), (12, 40)]
    
    for low_freq, high_freq in freq_bands:
        low = low_freq / nyquist
        high = high_freq / nyquist
        b, a = sp_signal.butter(2, [low, high], btype='band')
        filtered = sp_signal.filtfilt(b, a, data_clean)
        
        squared = filtered ** 2
        window_size = int(0.15 * sampling_rate)
        integrated = np.convolve(squared, np.ones(window_size) / window_size, mode='same')
        
        mean_val = np.mean(integrated)
        std_val = np.std(integrated)
        
        thresholds = [mean_val + 0.1 * std_val, mean_val + 0.2 * std_val, mean_val + 0.3 * std_val]
        
        for threshold in thresholds:
            candidates, _ = sp_signal.find_peaks(
                integrated, 
                height=threshold,
                distance=int(0.2 * sampling_rate)
            )
            
            search_window = int(0.1 * sampling_rate)
            sharp_window = int(0.18 * sampling_rate)
            
            for peak in candidates:
                start_sharp = max(0, peak - sharp_window)
                end_sharp = min(len(strict_score), peak + sharp_window)
                if start_sharp < end_sharp:
                    local_sharpness = np.max(strict_score[start_sharp:end_sharp])
                    
                    if local_sharpness > sharpness_threshold:
                        start = max(0, peak - search_window)
                        end = min(len(original_data), peak + search_window)
                        if start < end:
                            local_segment = original_data[start:end]
                            local_max_idx = np.argmax(np.abs(local_segment))
                            refined_peak = start + local_max_idx
                            all_candidate_peaks.append(refined_peak)
    
    # Strategy 2: Prominence-based detection
    peaks_prom, properties = sp_signal.find_peaks(
        original_data,
        distance=int(0.2 * sampling_rate),
        prominence=0.02
    )
    
    sharp_window = int(0.18 * sampling_rate)
    for peak in peaks_prom:
        start_sharp = max(0, peak - sharp_window)
        end_sharp = min(len(strict_score), peak + sharp_window)
        if start_sharp < end_sharp:
            local_sharpness = np.max(strict_score[start_sharp:end_sharp])
            if local_sharpness > sharpness_threshold * 0.8:
                all_candidate_peaks.append(peak)
    
    # Strategy 3: Derivative-based detection
    diff_signal = np.diff(original_data)
    diff_squared = diff_signal ** 2
    diff_squared = np.append(diff_squared, 0)
    
    mean_diff = np.mean(diff_squared)
    std_diff = np.std(diff_squared)
    
    diff_peaks, _ = sp_signal.find_peaks(
        diff_squared,
        height=mean_diff + 0.5 * std_diff,
        distance=int(0.15 * sampling_rate)
    )
    
    search_window = int(0.08 * sampling_rate)
    
    for peak in diff_peaks:
        start_sharp = max(0, peak - sharp_window)
        end_sharp = min(len(strict_score), peak + sharp_window)
        if start_sharp < end_sharp:
            local_sharpness = np.max(strict_score[start_sharp:end_sharp])
            
            if local_sharpness > sharpness_threshold * 0.7:
                start = max(0, peak - search_window)
                end = min(len(original_data), peak + search_window)
                if start < end:
                    local_segment = original_data[start:end]
                    local_max_idx = np.argmax(np.abs(local_segment))
                    refined_peak = start + local_max_idx
                    all_candidate_peaks.append(refined_peak)
    
    # Merge and deduplicate peaks
    if len(all_candidate_peaks) > 0:
        all_candidate_peaks = np.unique(all_candidate_peaks)
        
        min_distance = int(0.15 * sampling_rate)
        sorted_peaks = np.sort(all_candidate_peaks)
        
        if len(sorted_peaks) > 0:
            keep_mask = [True]
            for i in range(1, len(sorted_peaks)):
                if sorted_peaks[i] - sorted_peaks[i-1] >= min_distance:
                    keep_mask.append(True)
                else:
                    start1 = max(0, sorted_peaks[i-1] - sharp_window)
                    end1 = min(len(strict_score), sorted_peaks[i-1] + sharp_window)
                    start2 = max(0, sorted_peaks[i] - sharp_window)
                    end2 = min(len(strict_score), sorted_peaks[i] + sharp_window)
                    
                    sharp1 = np.max(strict_score[start1:end1]) if start1 < end1 else 0
                    sharp2 = np.max(strict_score[start2:end2]) if start2 < end2 else 0
                    
                    if sharp2 > sharp1:
                        keep_mask[-1] = False
                        keep_mask.append(True)
                    else:
                        keep_mask.append(False)
            
            sorted_peaks = sorted_peaks[keep_mask]
    
    return sorted_peaks if len(all_candidate_peaks) > 0 else np.array([])


def qrs_detect(data, sampling_rate, segment_duration=None, raw_segment=None):
    if raw_segment is not None:
        var_raw = np.var(raw_segment)
        # if var_raw < 0.0095:                  
        if var_raw < 0.005:                  
            print(f"Raw variance {var_raw:.6f} < 0.0095 → treating as asystole / flatline")
            return data, np.array([]), 0.0, np.array([])
    else:
        var = np.var(data)
        if var < 0.00015:                     
            print(f"Normalized variance {var:.6f} too low → possible asystole")
            return data, np.array([]), 0.0, np.array([])

    # data_clean = baseline_wander(data) 

    data_clean = data 

    original_data = data_clean.copy()
    nyquist = 0.5 * sampling_rate
    
    # --- STREAM 1: Standard Detection ---
    low = 8 / nyquist
    high = 24 / nyquist
    b, a = sp_signal.butter(2, [low, high], btype='band')
    filtered_standard = sp_signal.filtfilt(b, a, data_clean)
    
    filtered_abs = np.abs(filtered_standard)
    diff = np.diff(filtered_abs)
    diff = np.append(diff, 0)
    squared = diff ** 2
    
    window_size = int(0.15 * sampling_rate)
    integrated = np.convolve(squared, np.ones(window_size) / window_size, mode='same')
    
    mean_val = np.mean(integrated)
    std_val = np.std(integrated)
    threshold = mean_val + 0.20 * std_val
    
    candidates, _ = sp_signal.find_peaks(
        integrated,
        height=threshold,
        distance=int(0.12 * sampling_rate)
    )
    
    # --- STREAM 2: Sharpness Validator ---
    low_strict = 10 / nyquist
    high_strict = 40 / nyquist
    b2, a2 = sp_signal.butter(2, [low_strict, high_strict], btype='band')
    filtered_strict = sp_signal.filtfilt(b2, a2, data_clean)
    diff_strict = np.diff(np.abs(filtered_strict))
    diff_strict = np.append(diff_strict, 0)
    strict_score = diff_strict ** 2
    
    if len(strict_score) > 0:
        sharpness_threshold = np.percentile(strict_score, 94)
    else:
        sharpness_threshold = 0
    
    confirmed_peaks = []
    search_window = int(0.18 * sampling_rate)
    
    for peak in candidates:
        start_check = max(0, peak - search_window)
        end_check = min(len(strict_score), peak + search_window)
        if start_check >= end_check:
            continue
            
        local_sharpness = np.max(strict_score[start_check:end_check])
        
        if local_sharpness > sharpness_threshold:
            local_segment = original_data[start_check:end_check]
            if len(local_segment) > 0:
                abs_local_segment = np.abs(local_segment)
                local_max_idx = np.argmax(abs_local_segment)
                confirmed_peaks.append(start_check + local_max_idx)
    
    r_peaks = np.array(confirmed_peaks)
    
    # Remove close peaks
    min_dist = int(0.15 * sampling_rate)
    r_peaks = remove_close_peaks(r_peaks, original_data, min_dist)
    
    cleaned_r = np.sort(np.array([x for x in r_peaks if not (isinstance(x, float) and np.isnan(x))]))
    
    # =================================================================
    # CRITICAL FIX: GAP FILLING WITH AMPLITUDE GUARDRAILS
    # =================================================================
    if len(cleaned_r) >= 2:
        # Calculate reference height (Median of existing peaks)
        existing_heights = np.abs(original_data[cleaned_r.astype(int)])
        median_r_height = np.median(existing_heights) if len(existing_heights) > 0 else 0
        
        rr_intervals = np.diff(cleaned_r) / sampling_rate
        median_rr = np.median(rr_intervals) if len(rr_intervals) > 0 else 1.0
        new_peaks = list(cleaned_r)
        
        # Only fill gaps if median_rr suggests a normal rhythm (< 1.5s).
        # If median_rr is already 2.0s (bradycardia), huge gaps are normal.
        if median_rr < 1.5: 
            for i in range(len(rr_intervals)):
                if rr_intervals[i] > 1.4 * median_rr:
                    gap_start = cleaned_r[i]
                    gap_end = cleaned_r[i+1]
                    if gap_start >= gap_end:
                        continue
                        
                    gap_integrated = integrated[gap_start:gap_end]
                    # Lower threshold slightly for gap search
                    low_thresh = mean_val * 0.6 
                    
                    gap_candidates, _ = sp_signal.find_peaks(
                        gap_integrated,
                        height=low_thresh,
                        distance=int(0.10 * sampling_rate)
                    )
                    
                    for gc in gap_candidates:
                        abs_idx = gap_start + gc
                        sw_start = max(0, abs_idx - search_window)
                        sw_end = min(len(strict_score), abs_idx + search_window)
                        if sw_start >= sw_end:
                            continue
                            
                        # 1. Check Sharpness
                        local_sharp_max = np.max(strict_score[sw_start:sw_end])
                        if local_sharp_max > sharpness_threshold * 0.4:
                            
                            # 2. Refine Position
                            local_segment = original_data[sw_start:sw_end]
                            abs_local_segment = np.abs(local_segment)
                            refine_idx = np.argmax(abs_local_segment)
                            candidate_peak = sw_start + refine_idx
                            
                            # 3. AMPLITUDE CHECK (The Fix)
                            # Even if it's sharp, is it tall enough?
                            # AV Block P-waves are sharp but short.
                            candidate_amp = np.abs(original_data[candidate_peak])
                            
                            # Must be at least 40-50% of the median R-peak height
                            if candidate_amp > 0.45 * median_r_height:
                                new_peaks.append(candidate_peak)

        new_peaks = np.sort(np.unique(new_peaks))
        cleaned_r = remove_close_peaks(new_peaks, original_data, min_dist)
    
    # =================================================================

    # Determine expected peak count range
    if segment_duration is None:
        segment_duration = len(data_clean) / sampling_rate
    
    # Relaxed expectations for Bradycardia/AV Block
    min_expected_peaks = int(30/60 * segment_duration) 
    max_expected_peaks = int(180/60 * segment_duration)
    
    # Fallback to robust only if counts are extremely off
    if len(cleaned_r) < min_expected_peaks or len(cleaned_r) > max_expected_peaks:
        initial_peaks = robust_qrs_detect_internal(data_clean, sampling_rate)
        initial_peaks = remove_t_waves(data_clean, initial_peaks, sampling_rate)
        cleaned_r, peak_amplitudes = amplitude_based_filtering(data_clean, initial_peaks, "Segment")
    else:
        cleaned_r = remove_t_waves(data_clean, cleaned_r, sampling_rate)
    
    # Calculate BPM
    if len(cleaned_r) > 1:
        rr_intervals = np.diff(cleaned_r) / sampling_rate
        
        # Valid intervals widened to account for Bradycardia/Pauses
        # valid_rr = rr_intervals[(rr_intervals > 0.2) & (rr_intervals < 3.5)] 
        valid_rr = rr_intervals[(rr_intervals > 0.2) & (rr_intervals < 4.0)] 
        
        if len(valid_rr) > 0:
            mean_rr = np.mean(valid_rr)
            bpm = 60 / mean_rr if mean_rr > 0 else 0
        else:
            bpm = 0
    else:
        bpm = 0
    
    return data, cleaned_r, bpm, cleaned_r



def process_ecg_segments(ecg_data, sampling_rate, num_segments=7, min_segment_length=3500):
    max_len = len(ecg_data)
    
    if num_segments > 1:
        window_step = (max_len - min_segment_length) / (num_segments - 1)
        window_step = round(window_step)
    else:
        window_step = 0
    
    results = []
    
    for i in range(num_segments):
        start_idx = i * window_step
        end_idx = start_idx + min_segment_length
        
        if end_idx > max_len:
            start_idx = max_len - min_segment_length
            end_idx = max_len
            
        if start_idx < 0:
            start_idx = 0
            end_idx = min(min_segment_length, max_len)
        
        segment = ecg_data[start_idx:end_idx]
        
        if len(segment) < 100:
            results.append({
                'segment_num': i + 1,
                'start_idx': start_idx,
                'end_idx': end_idx,
                'ecg_filtered': np.array([]),
                'r_peaks': np.array([]),
                'bpm': 0,
                'cleaned_r': np.array([]),
                'ecg_raw': segment
            })
            continue
        
        segment_duration = len(segment) / sampling_rate
        # ecg_filtered, r_peaks, bpm, cleaned_r = qrs_detect(segment, sampling_rate, segment_duration)

        raw_segment = raw_ecg[start_idx:end_idx]   # ← the real raw amplitudes
        ecg_filtered, r_peaks, bpm, cleaned_r = qrs_detect(
            segment,
            sampling_rate,
            segment_duration,
            raw_segment=raw_segment                # ← pass raw here
        )

        print(f"Segment {i+1}: Detected {len(r_peaks)} R-peaks, BPM: {bpm:.1f}")
        
        adjusted_r_peaks = r_peaks + start_idx if len(r_peaks) > 0 else np.array([])
        adjusted_cleaned_r = np.array(cleaned_r) + start_idx if len(cleaned_r) > 0 else np.array([])
        
        results.append({
            'segment_num': i + 1,
            'start_idx': start_idx,
            'end_idx': end_idx,
            'ecg_filtered': ecg_filtered,
            'r_peaks': adjusted_r_peaks,
            'bpm': bpm,
            'cleaned_r': adjusted_cleaned_r,
            'ecg_raw': segment
        })
    
    return results

import numpy as np

def compute_ecg_stats(signal, fs=500):
    """Compute common statistics for an ECG segment"""
    if len(signal) == 0:
        return {
            'nsamples': 0,
            'mean': np.nan,
            'std': np.nan,
            'var': np.nan,
            'min': np.nan,
            'max': np.nan,
            'median': np.nan,
            'rms': np.nan,
            'duration_s': 0.0
        }
    
    return {
        'nsamples': len(signal),
        'mean': float(np.mean(signal)),
        'std': float(np.std(signal)),
        'var': float(np.var(signal)),
        'min': float(np.min(signal)),
        'max': float(np.max(signal)),
        'median': float(np.median(signal)),
        'rms': float(np.sqrt(np.mean(signal**2))),
        'duration_s': len(signal) / fs
    }


def format_stats_text(stats, prefix=""):
    """Create a compact multi-line stats string for plotting"""
    lines = [
        f"{prefix}Duration: {stats['duration_s']:.2f} s",
        f"Samples:   {stats['nsamples']}",
        f"Mean:      {stats['mean']:.4f}",
        f"Std:       {stats['std']:.4f}",
        f"Var:       {stats['var']:.6f}",
        f"Min / Max: {stats['min']:.4f} / {stats['max']:.4f}",
        f"Median:    {stats['median']:.4f}",
        f"RMS:       {stats['rms']:.4f}",
    ]
    return "\n".join(lines)


def plot_ecg_segments(ecg_data, sampling_rate, results, title="ECG Segments with R-peaks and BPM", raw_ecg=None):
    num_segments = len(results)
    fig, axes = plt.subplots(num_segments, 1, figsize=(15, 3.5 * num_segments), sharex=False)
    
    if num_segments == 1:
        axes = [axes]
    
    time = np.arange(len(ecg_data)) / sampling_rate
    
    global_stats = compute_ecg_stats(ecg_data, sampling_rate)
    fig.suptitle(f"{title}\nFull signal stats: {global_stats['duration_s']:.1f}s | "
                 f"mean={global_stats['mean']:.4f}  std={global_stats['std']:.4f}", 
                 fontsize=13, y=0.98)
    
    for i, (ax, result) in enumerate(zip(axes, results)):
        segment_num = result['segment_num']
        start_idx = result['start_idx']
        end_idx = result['end_idx']
        bpm = result['bpm']
        r_peaks = result['r_peaks']
        
        segment_time = time[start_idx:end_idx]
        segment_data = result['ecg_raw']
        
        ax.plot(segment_time, segment_data, 'b-', alpha=0.8, linewidth=1.1, label='ECG')
        
        if len(r_peaks) > 0:
            r_times = r_peaks / sampling_rate
            r_values = ecg_data[r_peaks.astype(int)]
            ax.plot(r_times, r_values, 'ro', markersize=7, label='R-peaks', alpha=0.85)
        
        # ── Statistics box per segment (use raw if available) ───────────────────────────────
        if raw_ecg is not None:
            raw_segment = raw_ecg[start_idx:end_idx]
            seg_stats = compute_ecg_stats(raw_segment, sampling_rate)
            prefix = "Raw "
        else:
            seg_stats = compute_ecg_stats(segment_data, sampling_rate)
            prefix = ""
        stats_text = format_stats_text(seg_stats, prefix + f"Seg {segment_num}  ")
        stats_text += f"\nBPM:       {bpm:.1f}"
        
        ax.text(0.02, 0.98, stats_text,
                transform=ax.transAxes,
                fontsize=9.5,
                verticalalignment='top',
                bbox=dict(facecolor='white', alpha=0.82, edgecolor='gray', boxstyle='round,pad=0.4'))
        
        segment_duration = (end_idx - start_idx) / sampling_rate
        ax.set_title(f'Segment {segment_num}: {start_idx:,} – {end_idx:,}  |  BPM: {bpm:.1f}')
        ax.set_ylabel('Amplitude (norm)')
        ax.grid(True, alpha=0.35, linestyle='--')
        ax.set_xlim([segment_time[0], segment_time[-1]])
        ax.legend(loc='upper right', fontsize=9)
    
    axes[-1].set_xlabel('Time (seconds)')
    plt.tight_layout(rect=[0, 0, 1, 0.96])   # make room for suptitle
    plt.show()
    
    # ── Console summary ───────────────────────────────────────────────
    print("═" * 70)
    print("ECG SEGMENT STATISTICS SUMMARY")
    print("═" * 70)
    for res in results:
        if raw_ecg is not None:
            s = compute_ecg_stats(raw_ecg[res['start_idx']:res['end_idx']], sampling_rate)
            prefix = "Raw "
        else:
            s = compute_ecg_stats(res['ecg_raw'], sampling_rate)
            prefix = "Norm "
        print(f"Segment {res['segment_num']:2d} | {s['duration_s']:5.2f}s | "
              f"mean={s['mean']:8.4f}  std={s['std']:7.4f}  BPM={res['bpm']:5.1f} ({prefix.strip()})"
            )
    print("═" * 70)
    
    
def plot_full_ecg(ecg_data, sampling_rate, title="Full ECG Signal Analysis", raw_ecg=None):
    """
    Runs detection on the entire dataset and plots a single continuous view.
    """
    # _, r_peaks, global_bpm, _ = qrs_detect(ecg_data, sampling_rate)
    _, r_peaks, global_bpm, _ = qrs_detect(
        ecg_data,
        sampling_rate,
        raw_segment=raw_ecg[:len(ecg_data)]    # pass corresponding raw part
    )
        
    if raw_ecg is not None:
        stats = compute_ecg_stats(raw_ecg[:len(ecg_data)], sampling_rate)
        prefix = "Raw "
    else:
        stats = compute_ecg_stats(ecg_data, sampling_rate)
        prefix = ""
    
    plt.figure(figsize=(20, 6)) # Width of 20 makes the 15k samples readable
    
    # Create time axis
    time_axis = np.arange(len(ecg_data)) / sampling_rate
    
    # Plot the signal
    plt.plot(time_axis, ecg_data, 'b-', linewidth=0.8, alpha=0.8, label='Filtered ECG')
    
    # Plot the peaks
    if len(r_peaks) > 0:
        # Filter out peaks that might be out of bounds (safety check)
        valid_peaks = r_peaks[r_peaks < len(ecg_data)].astype(int)
        
        peak_times = valid_peaks / sampling_rate
        peak_values = ecg_data[valid_peaks]
        
        plt.plot(peak_times, peak_values, 'ro', markersize=4, label='R-peaks')
        
        # Optional: Annotate every 5th peak to help navigation
        for i, (t, v) in enumerate(zip(peak_times, peak_values)):
            if i % 5 == 0:
                plt.annotate(f'{t:.1f}s', (t, v), xytext=(0, 10), 
                             textcoords='offset points', ha='center', fontsize=8, color='red')

    plt.title(f"{title} | Global BPM: {global_bpm:.1f} | Total Peaks: {len(r_peaks)}")
    plt.xlabel("Time (seconds)")
    plt.ylabel("Normalized Amplitude")
    plt.legend(loc='upper right')
    plt.grid(True, which='both', alpha=0.5)
    plt.tight_layout()
    plt.show()
    
    print(f"Global Analysis: {len(r_peaks)} peaks detected over {len(ecg_data)/sampling_rate:.2f} seconds.")
    print(f"{prefix}Full signal stats →  mean={stats['mean']:.4f}  std={stats['std']:.4f}  var={stats['var']:.6f}")



# ============================================================================

 
# input_json = r"simulator\contec\trigeminy_1756103085272.json" 
input_json = r"simulator\contec\asystl_1756103447146.json" 
# input_json = r"simulator\contec\1d av_1756104504294.json"  
# input_json = r"simulator\contec\3d av_1756104633918.json"  
# input_json = r"simulator\contec\280bpm_1756100716422.json" 
# input_json = r"simulator\contec\av sequence_1756106676125.json"  #####
# input_json = r"simulator\contec\dmnd freq_1756106571373.json"  
#    
# input_json = r"simulator\fluke\trigeminy_1754543043205.json"   
# input_json = r"simulator\fluke\3d av_1754545068278.json"   
# input_json = r"simulator\fluke\asystole_1754544406847.json"   

with open(input_json, 'r') as file:
    file_data = json.load(file)


# ============================================================================


# input_json = r"v01_prob\teton_ecg.ecgdatas.json"  ####
# with open(input_json, 'r') as file:
#     all_id_data = json.load(file)

# file_data = all_id_data[3]['ecgValue']   



# # input_json = r"bpms\afib_1766471694144.json"
# # input_json = r"bpms\bigeminy_1766467666407.json"
# # input_json = r"bpms\pvc 6_1766467718685.json"    ########
# # input_json = r"bpms\tri_1766467618314.json"
# # input_json = r"v01_prob/220_1767858669130.json"
# # input_json = r"v01_prob/240bpm_1767858615562.json"
# # input_json = r"v01_prob\25 contec_1768375918389.json"
# # input_json = r"v01_prob\30bpm contec_1768375716454.json"
# # input_json = r"v01_prob\2d av_1754545008828.json"  #####
# input_json = r"v01_prob\3rd_davb_1768554217066.json"  #####

# with open(input_json, 'r') as file:
#     file_data = json.load(file)



# input_json = r"exception\L2_1759207950416.json"  #####
# # input_json = r"0_bpm\L2_1760767200872.json"  
# # input_json = r"0_bpm\L2_1760254470484.json"  
# # input_json = r"0_bpm\L2_1760354748658.json"  
# # input_json = r"0_bpm\L2_1760770290911.json"  
# # input_json = r"issues\L2_1757064122874.json"  #####
# # input_json = r"v01_prob\run 5 pvc.json"  #####
# # input_json = r"issues\L2_1757579288752.json"
# # input_json = r"issues\L2_1757737806463.json"  #####
# # input_json = r"v01_prob\L2_1765984517025.json"  #####
# # input_json = r"1st-last-peaks\L2_1759908627949.json"
# # input_json = r"1st-last-peaks\L2_1759908888619.json"

# doubles = []
# with open(input_json, "rb") as f:
#     while chunk := f.read(8):
#         if len(chunk) < 8:
#             break
#         value = struct.unpack("<d", chunk)[0]
#         doubles.append(value)

# file_data = {'dataL2': doubles}   



# def decrypt(input_file):
#     """Decrypt encrypted JSON file (optional - commented out in your version)"""
#     private_key = "Msz377xMbcn++vrcDel9vxOuEss8fsWO"
#     cipher = AES.new(private_key.encode('utf-8'), AES.MODE_ECB)
#     with open(input_file, 'rb') as f:
#         encrypted_data = f.read()
#     enc = base64.b64decode(encrypted_data[24:])
#     data = unpad(cipher.decrypt(enc), 16)
#     decoded_string = data.decode('utf-8')
#     return json.loads(decoded_string)

# # input_json = r"NHF2\DATA_1750689015865.json"
# # input_json = r"NHF2\DATA_1750689460556.json"
# # input_json = r"NHF2\DATA_1750851207409.json"
# # input_json = r"NHF2\DATA_1750858856842.json"
# # input_json = r"NHF2\DATA_1750862721789.json"
# input_json = r"NHF2\DATA_1750996455820.json"
# file_data = decrypt(input_json)

# input_json = r"NHF\DATA_1752067426678.json"  #####
# # input_json = r"NHF\DATA_1752121970835.json"  ########
# # input_json = r"NHF\DATA_1754709586876.json"  #####
# # input_json = r"NHF\DATA_1754729551054.json"
# file_data = decrypt(input_json)



# def decrypt(input_file):
#     """Decrypt encrypted CSV file using AES ECB mode"""
#     Private_key = "Msz377xMbcn++vrcDel9vxOuEss8fsWO"
    
#     cipher = AES.new(Private_key.encode(), AES.MODE_ECB)
#     with open(input_file, 'rb') as f:
#         encrypted_data = f.read()

#     enc = base64.b64decode(encrypted_data[24:])
#     cipher = AES.new(Private_key.encode('utf-8'), AES.MODE_ECB)
#     data = unpad(cipher.decrypt(enc), 16)

#     decoded_string = data.decode('utf-8')
#     data_list = decoded_string.split(",")
#     float_list = [float(x) for x in data_list]

#     return float_list

# # selected_path = "v01_prob\ECG_1735798172211.csv"  ####
# selected_path = "v01_prob\ECG_L2_1738637533455.csv"
# file_data = decrypt(selected_path)
# file_data = {'dataL2': file_data}


def low_pass_filter(data):
    try:
        return sp_signal.filtfilt(b_lp, a_lp, data)
    except:
        return data


def notch_filter(data):
    try:
        return sp_signal.filtfilt(b_notch, a_notch, data)
    except:
        return data

    

# data = data_process(file_data)
processed_data, raw_global_stats, raw_ecg = data_process(
    low_pass_filter(notch_filter(file_data))
)

ecg_full = processed_data[0, :15000, 0]
sampling_rate = 500

results = process_ecg_segments(
    ecg_data=ecg_full,
    sampling_rate=sampling_rate,
    num_segments=4,
    min_segment_length=4500
)


plot_ecg_segments(ecg_full, sampling_rate, results, "ECG Analysis: 4 Segments with R-peak Detection", raw_ecg=raw_ecg)

print("\n--- Plotting Full Data ---")
plot_full_ecg(ecg_full, sampling_rate, "Final Full Data View", raw_ecg=raw_ecg)

## interval

In [None]:
import json
import struct
import numpy as np
import scipy.signal as sp_signal
import matplotlib.pyplot as plt
import base64
from Crypto.Cipher import AES
from Crypto.Util.Padding import unpad
from scipy.interpolate import CubicSpline

# ==========================================
# 1. UTILITY FUNCTIONS
# ==========================================

def baseline_wander(data, sampling_rate=500, knot_spacing=0.4):
    """
    Cubic spline interpolation - fits smooth curve through evenly spaced points.
    knot_spacing: distance between knots in seconds
    """
    n_samples = len(data)
    knot_interval = int(sampling_rate * knot_spacing)
    
    # Create knot points at regular intervals
    knot_indices = np.arange(0, n_samples, knot_interval)
    if knot_indices[-1] != n_samples - 1:
        knot_indices = np.append(knot_indices, n_samples - 1)
    
    # Use percentile at each knot region to estimate baseline (robust to QRS)
    knot_values = []
    half_window = knot_interval // 2
    for idx in knot_indices:
        start = max(0, idx - half_window)
        end = min(n_samples, idx + half_window)
        knot_values.append(np.percentile(data[start:end], 50))
    
    # Fit cubic spline and subtract
    spline = CubicSpline(knot_indices, knot_values)
    baseline = spline(np.arange(n_samples))
    
    return data - baseline

def normalize(signal, min_val, max_val):
    if max_val - min_val == 0:
        return np.zeros_like(signal)
    return (signal - min_val) / (max_val - min_val)

def process_signal(signal_data, min_val, max_val):
    return normalize(signal_data, min_val, max_val)

def data_process(input_data):
    """
    Robust data processing that handles both dictionary inputs and direct array inputs.
    Fixed the FutureWarning issue.
    """
    keys = ['dataL2']
    datas = []
   
    # Check if input is a dictionary and has the key
    if isinstance(input_data, dict) and 'dataL2' in input_data:
        raw_data = input_data['dataL2']
    else:
        # Assume it's already the data array
        raw_data = input_data
        
    sig = np.array(raw_data)
    datas.append(sig.astype('float32'))
   
    datas_array = np.array(datas)
    min_val = np.min(datas_array)
    max_val = np.max(datas_array)
   
    signal = []
    for i in range(datas_array.shape[0]):
        signal.append(process_signal(datas_array[i, :], min_val, max_val))

    final_data = np.stack(signal)
    final_data = np.expand_dims(final_data, axis=0)
    final_data = final_data.transpose(0, 2, 1)
    return final_data


def remove_close_peaks(r_peaks, validation_signal, min_dist_samples):
    """Remove peaks that are too close together, keeping the stronger one"""
    if len(r_peaks) == 0:
        return np.array([])
    
    sorted_idx = np.argsort(r_peaks)
    r_peaks = r_peaks[sorted_idx]
    validation_abs = np.abs(validation_signal[r_peaks.astype(int)])
    
    keep = []
    last_kept = -min_dist_samples
    
    for i, current in enumerate(r_peaks):
        if current - last_kept >= min_dist_samples:
            keep.append(current)
            last_kept = current
        else:
            if validation_abs[i] > validation_abs[len(keep)-1]:
                keep[-1] = current
                last_kept = current
    
    return np.array(keep)


def amplitude_based_filtering(ecg_signal, peaks, segment_num="Unknown"):
    """Filter out high amplitude outlier peaks using IQR method"""
    if len(peaks) == 0:
        return peaks, np.array([])
    
    peak_amplitudes = np.abs(ecg_signal[peaks.astype(int)])
    
    median_amp = np.median(peak_amplitudes)
    q75, q25 = np.percentile(peak_amplitudes, [75, 25])
    iqr = q75 - q25
    
    if iqr > 0:
        high_amp_threshold = q75 + 1.5 * iqr
        
        high_amp_indices = np.where(peak_amplitudes > high_amp_threshold)[0]
        high_amp_count = len(high_amp_indices)
        
        if len(peaks) - high_amp_count > 0:
            mask = np.ones(len(peaks), dtype=bool)
            mask[high_amp_indices] = True
            cleaned_peaks = peaks[mask]
            cleaned_amplitudes = peak_amplitudes[mask]
        else:
            cleaned_peaks = peaks
            cleaned_amplitudes = peak_amplitudes
    else:
        cleaned_peaks = peaks
        cleaned_amplitudes = peak_amplitudes
    
    return cleaned_peaks, cleaned_amplitudes


def remove_t_waves(ecg_signal, peaks, sampling_rate):
    """Remove T-wave false positives based on timing and morphology"""
    if len(peaks) < 3:
        return peaks
    
    sorted_peaks = np.sort(peaks)
    cleaned_peaks = []
    
    for i, peak in enumerate(sorted_peaks):
        is_r_peak = True
        
        if i > 0:
            prev_peak = sorted_peaks[i-1]
            interval_ms = (peak - prev_peak) / sampling_rate * 1000
            
            if 160 < interval_ms < 450:
                prev_amp = abs(ecg_signal[int(prev_peak)])
                curr_amp = abs(ecg_signal[int(peak)])
                
                if curr_amp < prev_amp * 0.5:
                    half_max = curr_amp * 0.5
                    
                    left = peak
                    while left > 0 and left > peak - 100:
                        if abs(ecg_signal[int(left)]) < half_max:
                            break
                        left -= 1
                    
                    right = peak
                    while right < len(ecg_signal) - 1 and right < peak + 100:
                        if abs(ecg_signal[int(right)]) < half_max:
                            break
                        right += 1
                    
                    width_ms = (right - left) / sampling_rate * 1000
                    
                    if width_ms > 40:
                        is_r_peak = False
        
        if is_r_peak:
            cleaned_peaks.append(peak)
    
    return np.array(cleaned_peaks)


def robust_qrs_detect_internal(data_clean, sampling_rate):
    original_data = data_clean.copy()
    nyquist = 0.5 * sampling_rate
    
    # Calculate sharpness threshold
    low_strict = 10 / nyquist
    high_strict = 40 / nyquist
    b2, a2 = sp_signal.butter(2, [low_strict, high_strict], btype='band')
    filtered_strict = sp_signal.filtfilt(b2, a2, data_clean)
    diff_strict = np.diff(np.abs(filtered_strict))
    diff_strict = np.append(diff_strict, 0)
    strict_score = diff_strict ** 2
    
    if len(strict_score) > 0:
        sharpness_threshold = np.percentile(strict_score, 94)
    else:
        sharpness_threshold = 0
    
    all_candidate_peaks = []
    
    # Strategy 1: Multi-band detection with multiple thresholds
    freq_bands = [(5, 15), (8, 24), (10, 30), (12, 40)]
    
    for low_freq, high_freq in freq_bands:
        low = low_freq / nyquist
        high = high_freq / nyquist
        b, a = sp_signal.butter(2, [low, high], btype='band')
        filtered = sp_signal.filtfilt(b, a, data_clean)
        
        squared = filtered ** 2
        window_size = int(0.15 * sampling_rate)
        integrated = np.convolve(squared, np.ones(window_size) / window_size, mode='same')
        
        mean_val = np.mean(integrated)
        std_val = np.std(integrated)
        
        thresholds = [mean_val + 0.1 * std_val, mean_val + 0.2 * std_val, mean_val + 0.3 * std_val]
        
        for threshold in thresholds:
            candidates, _ = sp_signal.find_peaks(
                integrated,
                height=threshold,
                distance=int(0.2 * sampling_rate)
            )
            
            search_window = int(0.1 * sampling_rate)
            sharp_window = int(0.18 * sampling_rate)
            
            for peak in candidates:
                start_sharp = max(0, peak - sharp_window)
                end_sharp = min(len(strict_score), peak + sharp_window)
                if start_sharp < end_sharp:
                    local_sharpness = np.max(strict_score[start_sharp:end_sharp])
                    
                    if local_sharpness > sharpness_threshold:
                        start = max(0, peak - search_window)
                        end = min(len(original_data), peak + search_window)
                        if start < end:
                            local_segment = original_data[start:end]
                            local_max_idx = np.argmax(np.abs(local_segment))
                            refined_peak = start + local_max_idx
                            all_candidate_peaks.append(refined_peak)
    
    # Strategy 2: Prominence-based detection
    peaks_prom, properties = sp_signal.find_peaks(
        original_data,
        distance=int(0.2 * sampling_rate),
        prominence=0.02
    )
    
    sharp_window = int(0.18 * sampling_rate)
    for peak in peaks_prom:
        start_sharp = max(0, peak - sharp_window)
        end_sharp = min(len(strict_score), peak + sharp_window)
        if start_sharp < end_sharp:
            local_sharpness = np.max(strict_score[start_sharp:end_sharp])
            if local_sharpness > sharpness_threshold * 0.8:
                all_candidate_peaks.append(peak)
    
    # Strategy 3: Derivative-based detection
    diff_signal = np.diff(original_data)
    diff_squared = diff_signal ** 2
    diff_squared = np.append(diff_squared, 0)
    
    mean_diff = np.mean(diff_squared)
    std_diff = np.std(diff_squared)
    
    diff_peaks, _ = sp_signal.find_peaks(
        diff_squared,
        height=mean_diff + 0.5 * std_diff,
        distance=int(0.15 * sampling_rate)
    )
    
    search_window = int(0.08 * sampling_rate)
    
    for peak in diff_peaks:
        start_sharp = max(0, peak - sharp_window)
        end_sharp = min(len(strict_score), peak + sharp_window)
        if start_sharp < end_sharp:
            local_sharpness = np.max(strict_score[start_sharp:end_sharp])
            
            if local_sharpness > sharpness_threshold * 0.7:
                start = max(0, peak - search_window)
                end = min(len(original_data), peak + search_window)
                if start < end:
                    local_segment = original_data[start:end]
                    local_max_idx = np.argmax(np.abs(local_segment))
                    refined_peak = start + local_max_idx
                    all_candidate_peaks.append(refined_peak)
    
    # Merge and deduplicate peaks
    if len(all_candidate_peaks) > 0:
        all_candidate_peaks = np.unique(all_candidate_peaks)
        
        min_distance = int(0.15 * sampling_rate)
        sorted_peaks = np.sort(all_candidate_peaks)
        
        if len(sorted_peaks) > 0:
            keep_mask = [True]
            for i in range(1, len(sorted_peaks)):
                if sorted_peaks[i] - sorted_peaks[i-1] >= min_distance:
                    keep_mask.append(True)
                else:
                    start1 = max(0, sorted_peaks[i-1] - sharp_window)
                    end1 = min(len(strict_score), sorted_peaks[i-1] + sharp_window)
                    start2 = max(0, sorted_peaks[i] - sharp_window)
                    end2 = min(len(strict_score), sorted_peaks[i] + sharp_window)
                    
                    sharp1 = np.max(strict_score[start1:end1]) if start1 < end1 else 0
                    sharp2 = np.max(strict_score[start2:end2]) if start2 < end2 else 0
                    
                    if sharp2 > sharp1:
                        keep_mask[-1] = False
                        keep_mask.append(True)
                    else:
                        keep_mask.append(False)
            
            sorted_peaks = sorted_peaks[keep_mask]
    
    return sorted_peaks if len(all_candidate_peaks) > 0 else np.array([])


def qrs_detect(data, sampling_rate, segment_duration=None):
    """Enhanced QRS detection with Amplitude Guardrails for AV Blocks"""
    data_clean = data
    original_data = data_clean.copy()
    nyquist = 0.5 * sampling_rate
    
    # --- STREAM 1: Standard Detection ---
    low = 8 / nyquist
    high = 24 / nyquist
    b, a = sp_signal.butter(2, [low, high], btype='band')
    filtered_standard = sp_signal.filtfilt(b, a, data_clean)
    
    filtered_abs = np.abs(filtered_standard)
    diff = np.diff(filtered_abs)
    diff = np.append(diff, 0)
    squared = diff ** 2
    
    window_size = int(0.15 * sampling_rate)
    integrated = np.convolve(squared, np.ones(window_size) / window_size, mode='same')
    
    mean_val = np.mean(integrated)
    std_val = np.std(integrated)
    threshold = mean_val + 0.20 * std_val
    
    candidates, _ = sp_signal.find_peaks(
        integrated,
        height=threshold,
        distance=int(0.12 * sampling_rate)
    )
    
    # --- STREAM 2: Sharpness Validator ---
    low_strict = 10 / nyquist
    high_strict = 40 / nyquist
    b2, a2 = sp_signal.butter(2, [low_strict, high_strict], btype='band')
    filtered_strict = sp_signal.filtfilt(b2, a2, data_clean)
    diff_strict = np.diff(np.abs(filtered_strict))
    diff_strict = np.append(diff_strict, 0)
    strict_score = diff_strict ** 2
    
    if len(strict_score) > 0:
        sharpness_threshold = np.percentile(strict_score, 94)
    else:
        sharpness_threshold = 0
    
    confirmed_peaks = []
    search_window = int(0.18 * sampling_rate)
    
    for peak in candidates:
        start_check = max(0, peak - search_window)
        end_check = min(len(strict_score), peak + search_window)
        if start_check >= end_check:
            continue
        
        local_sharpness = np.max(strict_score[start_check:end_check])
        
        if local_sharpness > sharpness_threshold:
            local_segment = original_data[start_check:end_check]
            if len(local_segment) > 0:
                abs_local_segment = np.abs(local_segment)
                local_max_idx = np.argmax(abs_local_segment)
                confirmed_peaks.append(start_check + local_max_idx)
    
    r_peaks = np.array(confirmed_peaks)
    
    # Remove close peaks
    min_dist = int(0.15 * sampling_rate)
    r_peaks = remove_close_peaks(r_peaks, original_data, min_dist)
    
    cleaned_r = np.sort(np.array([x for x in r_peaks if not (isinstance(x, float) and np.isnan(x))]))
    
    # GAP FILLING WITH AMPLITUDE GUARDRAILS
    if len(cleaned_r) >= 2:
        existing_heights = np.abs(original_data[cleaned_r.astype(int)])
        median_r_height = np.median(existing_heights) if len(existing_heights) > 0 else 0
        
        rr_intervals = np.diff(cleaned_r) / sampling_rate
        median_rr = np.median(rr_intervals) if len(rr_intervals) > 0 else 1.0
        new_peaks = list(cleaned_r)
        
        if median_rr < 1.5:
            for i in range(len(rr_intervals)):
                if rr_intervals[i] > 1.4 * median_rr:
                    gap_start = cleaned_r[i]
                    gap_end = cleaned_r[i+1]
                    if gap_start >= gap_end:
                        continue
                    
                    gap_integrated = integrated[gap_start:gap_end]
                    low_thresh = mean_val * 0.6
                    
                    gap_candidates, _ = sp_signal.find_peaks(
                        gap_integrated,
                        height=low_thresh,
                        distance=int(0.10 * sampling_rate)
                    )
                    
                    for gc in gap_candidates:
                        abs_idx = gap_start + gc
                        sw_start = max(0, abs_idx - search_window)
                        sw_end = min(len(strict_score), abs_idx + search_window)
                        if sw_start >= sw_end:
                            continue
                        
                        local_sharp_max = np.max(strict_score[sw_start:sw_end])
                        if local_sharp_max > sharpness_threshold * 0.4:
                            
                            local_segment = original_data[sw_start:sw_end]
                            abs_local_segment = np.abs(local_segment)
                            refine_idx = np.argmax(abs_local_segment)
                            candidate_peak = sw_start + refine_idx
                            
                            candidate_amp = np.abs(original_data[candidate_peak])
                            
                            if candidate_amp > 0.45 * median_r_height:
                                new_peaks.append(candidate_peak)
        
        new_peaks = np.sort(np.unique(new_peaks))
        cleaned_r = remove_close_peaks(new_peaks, original_data, min_dist)
    
    # Determine expected peak count range
    if segment_duration is None:
        segment_duration = len(data_clean) / sampling_rate
    
    min_expected_peaks = int(30/60 * segment_duration)
    max_expected_peaks = int(180/60 * segment_duration)
    
    if len(cleaned_r) < min_expected_peaks or len(cleaned_r) > max_expected_peaks:
        initial_peaks = robust_qrs_detect_internal(data_clean, sampling_rate)
        initial_peaks = remove_t_waves(data_clean, initial_peaks, sampling_rate)
        cleaned_r, peak_amplitudes = amplitude_based_filtering(data_clean, initial_peaks, "Segment")
    else:
        cleaned_r = remove_t_waves(data_clean, cleaned_r, sampling_rate)
    
    # Calculate BPM
    if len(cleaned_r) > 1:
        rr_intervals = np.diff(cleaned_r) / sampling_rate
        
        valid_rr = rr_intervals[(rr_intervals > 0.2) & (rr_intervals < 4.0)]
        
        if len(valid_rr) > 0:
            mean_rr = np.mean(valid_rr)
            bpm = 60 / mean_rr if mean_rr > 0 else 0
        else:
            bpm = 0
    else:
        bpm = 0
    
    return data, cleaned_r, bpm, cleaned_r

# ==========================================
# 2. P-WAVE DETECTION
# ==========================================

def adaptive_noise_filter(segment, sampling_rate):
    """Apply stronger filtering in noisy regions"""
    diff = np.diff(segment)
    noise_std = np.std(diff)
    
    if noise_std > 0.05:
        nyquist = 0.5 * sampling_rate
        low = 1 / nyquist
        high = 25 / nyquist
        b, a = sp_signal.butter(3, [low, high], btype='band')
        return sp_signal.filtfilt(b, a, segment)
    
    return segment

def calculate_signal_quality(segment):
    """Calculate signal quality index (0-1, higher is better)"""
    diff = np.diff(segment)
    noise_level = np.std(diff)
    noise_score = np.exp(-noise_level / 0.1)
    
    baseline_drift = np.std(segment)
    drift_score = np.exp(-baseline_drift / 0.3)
    
    signal_range = np.max(segment) - np.min(segment)
    if 0.2 <= signal_range <= 1.5:
        amplitude_score = 1.0
    else:
        amplitude_score = 0.5
    
    quality = (noise_score * 0.4 + drift_score * 0.4 + amplitude_score * 0.2)
    
    return quality

def enhanced_p_wave_detection(signal, r_peaks, sampling_rate, segment_num):
    """Enhanced P-wave detection with Bradycardia fix"""
    if len(r_peaks) < 3:
        return np.full(len(r_peaks), np.nan)
    
    signal_quality = calculate_signal_quality(signal)
    
    p_peaks = np.full(len(r_peaks), np.nan)
    p_qualities = np.zeros(len(r_peaks))
    
    if len(r_peaks) > 1:
        rr_intervals = np.diff(r_peaks) / sampling_rate
        avg_rr = np.mean(rr_intervals)
        avg_hr = 60 / avg_rr if avg_rr > 0 else 0
    else:
        avg_hr = 0
    
    print(f"Segment {segment_num}: Signal Quality = {signal_quality:.2f}, Avg HR = {avg_hr:.0f} BPM")
    
    if avg_hr > 180:
        # print(f"  High heart rate detected - using adaptive short-cycle parameters")
        use_adaptive_short_cycle = True
        min_quality_threshold = 20  
    elif avg_hr > 120:
        use_adaptive_short_cycle = True
        min_quality_threshold = 30
    else:
        use_adaptive_short_cycle = False
        if signal_quality > 0.7:
            min_quality_threshold = 50
        elif signal_quality > 0.5:
            min_quality_threshold = 35
        else:
            min_quality_threshold = 25
    
    preliminary_pr_intervals = []
    preliminary_p_amps = []
    
    t_wave_ends = []
    for i in range(len(r_peaks) - 1):
        r_curr = int(r_peaks[i])
        r_next = int(r_peaks[i + 1])
        rr_interval = r_next - r_curr
        
        if rr_interval < 0.4 * sampling_rate:  
            estimated_t_end = r_curr + int(0.45 * rr_interval)
        elif rr_interval < 0.6 * sampling_rate: 
            estimated_t_end = r_curr + int(0.55 * rr_interval)
        else:  
            estimated_t_end = r_curr + min(int(0.5 * sampling_rate), int(0.65 * rr_interval))
        
        t_wave_ends.append(estimated_t_end)
    
    if len(r_peaks) > 0:
        last_r = int(r_peaks[-1])
        t_wave_ends.append(last_r + int(0.5 * sampling_rate))
    
    for i, r in enumerate(r_peaks):
        if i == 0: continue
            
        r = int(r)
        rr_prev = r - int(r_peaks[i-1])
        
        if use_adaptive_short_cycle and rr_prev < 0.5 * sampling_rate:  
            if i - 1 < len(t_wave_ends):
                t_end_prev = t_wave_ends[i - 1]
                search_start = t_end_prev + int(0.02 * sampling_rate)
            else:
                search_start = int(r_peaks[i-1] + 0.25 * rr_prev)
            
            search_end = int(r - 0.02 * sampling_rate)
            min_pr_ms = 80
            max_pr_ms = 300
            
        else:
            max_lookback_samples = int(0.40 * sampling_rate)
            earliest_allowed_start = r - max_lookback_samples
            
            if i - 1 < len(t_wave_ends):
                t_end_prev = t_wave_ends[i - 1]
                search_start = max(t_end_prev + int(0.05 * sampling_rate), earliest_allowed_start)
            else:
                search_start = max(int(r_peaks[i-1] + 0.4 * sampling_rate), earliest_allowed_start)
            
            search_end = int(r - 0.03 * sampling_rate)
            min_pr_ms = 80
            max_pr_ms = 400
        
        search_start = max(0, search_start)
        search_end = min(len(signal)-1, search_end)
        
        min_window_size = int(0.05 * sampling_rate)
        if search_end - search_start < min_window_size:
            continue
        
        segment = signal[search_start:search_end]
        segment_filtered = adaptive_noise_filter(segment, sampling_rate)
        
        if use_adaptive_short_cycle:
            min_prominence = 0.002
            min_distance = int(0.05 * sampling_rate)
            max_width = int(0.12 * sampling_rate)
        else:
            min_prominence = 0.003
            min_distance = int(0.08 * sampling_rate)
            max_width = int(0.15 * sampling_rate)
        
        try:
            candidate_peaks, properties = sp_signal.find_peaks(
                segment_filtered,
                distance=min_distance,
                prominence=min_prominence,
                width=(int(0.02*sampling_rate), max_width)
            )
        except:
            candidate_peaks = []
        
        if len(candidate_peaks) == 0:
            continue
        
        candidate_peaks = search_start + candidate_peaks
        
        best_score = -np.inf
        best_peak = None
        
        for cp in candidate_peaks:
            cp = int(cp)
            
            pr_interval = (r - cp) / sampling_rate * 1000
            if pr_interval < min_pr_ms or pr_interval > max_pr_ms:
                continue
            
            score = 0
            
            if use_adaptive_short_cycle:
                ideal_pr = 120
                sigma = 40
            else:
                ideal_pr = 160
                sigma = 60
            
            score += np.exp(-((pr_interval - ideal_pr) ** 2) / (2 * sigma ** 2)) * 150
            
            p_amp = abs(signal[cp])
            if p_amp > 0.5:
                score += 50 
            elif 0.015 <= p_amp <= 0.5:
                ideal_amp = 0.08
                score += np.exp(-((p_amp - ideal_amp) ** 2) / (2 * 0.10 ** 2)) * 400
            else:
                continue
            
            try:
                left_slope = signal[cp] - signal[cp - 5]
                right_slope = signal[cp + 5] - signal[cp]
                symmetry = 1 - abs(left_slope - right_slope) / (abs(left_slope) + abs(right_slope) + 1e-6)
                score += symmetry * 60
            except:
                pass
            
            score += 50 
            
            if score > best_score:
                best_score = score
                best_peak = cp
        
        if best_peak is not None and best_score > min_quality_threshold:
             p_peaks[i] = best_peak
             p_qualities[i] = best_score
             preliminary_pr_intervals.append((r - best_peak) / sampling_rate * 1000)
             preliminary_p_amps.append(abs(signal[best_peak]))

    if len(preliminary_pr_intervals) >= 3:
        median_pr = np.median(preliminary_pr_intervals)
        
        for i in range(1, len(r_peaks)):
            if np.isnan(p_peaks[i]): continue
            
            pr = (r_peaks[i] - p_peaks[i]) / sampling_rate * 1000
            
            tolerance = 100 if use_adaptive_short_cycle else 120
            if abs(pr - median_pr) > tolerance:
                p_peaks[i] = np.nan
                p_qualities[i] = 0

    return p_peaks

# ==========================================
# 4. P-WAVE ONSET DETECTION
# ==========================================

def find_p_onset_constrained(signal, p_peak_idx, sampling_rate, prev_t_offset=None):
    """
    Finds P-onset with strict constraints
    """
    if np.isnan(p_peak_idx): return np.nan
    p_idx = int(p_peak_idx)
    
    min_dist_samples = int(0.012 * sampling_rate)
    start_search = p_idx - min_dist_samples
    
    max_lookback = int(0.12 * sampling_rate)
    default_limit = p_idx - max_lookback
    
    if prev_t_offset is not None and not np.isnan(prev_t_offset):
        t_end_buffer = int(prev_t_offset) + int(0.02 * sampling_rate)
        limit_idx = max(default_limit, t_end_buffer)
    else:
        limit_idx = default_limit

    if limit_idx >= start_search:
        return start_search
        
    limit_idx = max(0, limit_idx)
    
    segment = signal[limit_idx:start_search + 1]
    
    if len(segment) < 3:
        return limit_idx
        
    grads = np.gradient(segment)
    
    max_slope = np.max(np.abs(grads))
    threshold = 0.10 * max_slope
    
    for i in range(len(grads) - 1, -1, -1):
        if np.abs(grads[i]) < threshold:
            found_idx = limit_idx + i
            return found_idx
            
    return limit_idx


# ==========================================
# 5. WAVE DELINEATION
# ==========================================


def find_t_wave_offset_with_stability(signal, t_peak_idx, sampling_rate, 
                                      next_r_peak=None, limit_idx=None):
    if np.isnan(t_peak_idx):
        return np.nan
    
    t_peak_idx = int(t_peak_idx)
    t_peak_value = signal[t_peak_idx]
    
    # ===== STAGE 1: Determine Search Limits =====
    if next_r_peak is not None:
        max_search = int(0.70 * (next_r_peak - t_peak_idx))
        max_search = max(int(0.100 * sampling_rate), max_search)
    else:
        max_search = int(0.200 * sampling_rate)
    
    if limit_idx is not None and not np.isnan(limit_idx):
        limit_idx = int(limit_idx)
        dist_to_limit = limit_idx - t_peak_idx
        if dist_to_limit <= 5:
            return t_peak_idx + 5
        max_search = min(max_search, dist_to_limit - 3)
    
    max_idx = min(len(signal) - 1, t_peak_idx + max_search)
    
    if max_idx <= t_peak_idx + 8:
        return t_peak_idx + 5
    
    # ===== STAGE 2: Setup Parameters =====
    skip_samples = max(3, int(0.008 * sampling_rate))  # Skip 8ms from peak
    
    if max_idx <= t_peak_idx + skip_samples + 5:
        return t_peak_idx + skip_samples
    
    segment = signal[t_peak_idx:max_idx]
    
    # Window size for variance calculation (30 samples ≈ 60ms at 500Hz)
    window_size = min(15, int(0.060 * sampling_rate))                                                      
    
    # Detect T-wave polarity
    if t_peak_value > 0:
        target_descent_sign = -1
    else:
        target_descent_sign = 1
    
    # ===== STAGE 3: Calculate Derivatives (for inflection backup) =====
    derivative_1st = np.gradient(segment)
    
    if len(derivative_1st) > 5:
        kernel_size = 3
        derivative_1st_smooth = np.convolve(derivative_1st, 
                                            np.ones(kernel_size)/kernel_size, 
                                            mode='same')
    else:
        derivative_1st_smooth = derivative_1st
    
    derivative_2nd = np.gradient(derivative_1st_smooth)
    
    # Adaptive thresholds for slope-based methods
    slope_magnitudes = np.abs(derivative_1st_smooth[skip_samples:])
    
    if len(slope_magnitudes) > 0:
        slope_75th = np.percentile(slope_magnitudes, 75)
        slope_median = np.median(slope_magnitudes)
        steep_threshold = max(slope_75th * 0.4, 0.002)
        flat_threshold = max(slope_median * 0.15, 0.0008)
    else:
        steep_threshold = 0.003
        flat_threshold = 0.001
    
    # ===== STAGE 4: NEW METHOD - Local Variance Stability Detection =====
    stability_idx = None
    
    # Start search after skip_samples
    search_start = skip_samples + int(0.010 * sampling_rate)  # At least 10ms from peak
    search_end = len(segment) - window_size
    
    if search_start < search_end:
        variance_ratios = []
        candidate_points = []
        
        for i in range(search_start, search_end):
            # Get windows before and after current point
            before_window = segment[max(0, i-window_size):i]
            after_window = segment[i:min(len(segment), i+window_size)]
            
            if len(before_window) < 10 or len(after_window) < 10:
                continue
            
            # Calculate variance (spread of values)
            var_before = np.var(before_window)
            var_after = np.var(after_window)
            
            # Also calculate standard deviation for robustness
            std_before = np.std(before_window)
            std_after = np.std(after_window)
            
            # Calculate ratio (how much does variance drop?)
            if var_before > 0:
                var_ratio = var_after / var_before
                std_ratio = std_after / std_before
                
                # Store results
                variance_ratios.append(var_ratio)
                candidate_points.append(i)
        
        # Find where variance drops significantly (baseline is more stable)
        if len(variance_ratios) > 0:
            variance_ratios = np.array(variance_ratios)
            candidate_points = np.array(candidate_points)
            
            # Threshold: variance after should be < 40% of variance before
            # This means we've transitioned from T-wave to stable baseline
            stability_mask = variance_ratios < 0.15                                                      
            
            if np.any(stability_mask):
                # Take the FIRST point where stability is achieved
                first_stable = candidate_points[stability_mask][0]
                stability_idx = first_stable
    
    # ===== STAGE 5: METHOD 2 - Inflection Point Detection (Backup) =====
    inflection_idx = None
    in_steep_descent = False
    
    for i in range(skip_samples, len(derivative_1st_smooth) - 2):
        current_slope = derivative_1st_smooth[i]
        
        if target_descent_sign * current_slope < -steep_threshold:
            if not in_steep_descent:
                in_steep_descent = True
        
        elif in_steep_descent:
            next_slopes = derivative_1st_smooth[i:min(i+4, len(derivative_1st_smooth))]
            
            # Option A: Slope becomes flat
            if np.all(np.abs(next_slopes) < flat_threshold * 1.5):
                inflection_idx = i
                break
            
            # Option B: Slope reverses direction
            if target_descent_sign * current_slope > 0:
                if i + 2 < len(derivative_1st_smooth):
                    if target_descent_sign * derivative_1st_smooth[i+1] > 0:
                        inflection_idx = i
                        break
            
            # Option C: Slope magnitude drops below threshold
            if np.abs(current_slope) < flat_threshold:
                if i + 3 < len(derivative_1st_smooth):
                    if np.mean(np.abs(derivative_1st_smooth[i:i+3])) < flat_threshold * 1.3:
                        inflection_idx = i
                        break
                else:
                    inflection_idx = i
                    break
    
    # ===== STAGE 6: METHOD 3 - Second Derivative Zero-Crossing =====
    curvature_inflection_idx = None
    
    if len(derivative_2nd) > skip_samples + 10:
        for i in range(skip_samples + 5, len(derivative_2nd) - 1):
            if derivative_2nd[i] * derivative_2nd[i+1] <= 0:
                if i > skip_samples + int(0.015 * sampling_rate):
                    if i + 3 < len(derivative_1st_smooth):
                        avg_slope_after = np.mean(np.abs(derivative_1st_smooth[i:i+3]))
                        if avg_slope_after < steep_threshold * 0.6:
                            curvature_inflection_idx = i
                            break
    
    # ===== STAGE 7: METHOD 4 - Minimum Slope Magnitude =====
    min_slope_idx = None
    
    if stability_idx is None and inflection_idx is None and curvature_inflection_idx is None:
        search_start_min = skip_samples + int(0.020 * sampling_rate)
        search_end_min = min(len(derivative_1st_smooth), skip_samples + int(0.100 * sampling_rate))
        
        if search_start_min < search_end_min:
            slope_window = np.abs(derivative_1st_smooth[search_start_min:search_end_min])
            if len(slope_window) > 0:
                local_min = np.argmin(slope_window)
                min_slope_idx = search_start_min + local_min
    
    # ===== STAGE 8: Select Best Detection (Prioritized) =====
    candidates = []
    
    # HIGHEST PRIORITY: Variance stability (your method!)
    if stability_idx is not None:
        candidates.append(('variance_stability', t_peak_idx + stability_idx, 120))
    
    # HIGH PRIORITY: Inflection point
    if inflection_idx is not None:
        candidates.append(('inflection', t_peak_idx + inflection_idx, 100))
    
    # MEDIUM PRIORITY: Curvature change
    if curvature_inflection_idx is not None:
        candidates.append(('curvature', t_peak_idx + curvature_inflection_idx, 80))
    
    # LOW PRIORITY: Minimum slope
    if min_slope_idx is not None:
        candidates.append(('min_slope', t_peak_idx + min_slope_idx, 60))
    
    if len(candidates) == 0:
        # Ultimate fallback
        return min(t_peak_idx + int(0.040 * sampling_rate), max_idx)
    
    # Use the highest priority detection
    best_method, best_offset, best_score = max(candidates, key=lambda x: x[2])
    
    # Debug info (optional - can be removed in production)
    # print(f"  T-offset method used: {best_method} (score: {best_score})")
    
    # ===== STAGE 9: Validation =====
    best_offset = max(best_offset, t_peak_idx + skip_samples)
    best_offset = min(best_offset, max_idx)
    
    # Sanity check on duration
    duration_ms = (best_offset - t_peak_idx) / sampling_rate * 1000
    
    if duration_ms < 15:
        best_offset = t_peak_idx + int(0.030 * sampling_rate)
    elif duration_ms > 300:
        best_offset = t_peak_idx + int(0.100 * sampling_rate)
    
    return int(best_offset)

def improved_delineate_ecg_waves(signal, r_peaks, sampling_rate):
    """
    Complete ECG wave delineation with ultra-accurate T-offset detection
    """
    waves = {
        'p_peak': [], 'p_onset': [], 'p_offset': [],
        'q_peak': [], 'q_onset': [],
        's_peak': [], 's_offset': [],
        't_peak': [], 't_onset': [], 't_offset': []
    }
    
    signal_len = len(signal)
    
    # Detect P-peaks
    p_peaks = enhanced_p_wave_detection(signal, r_peaks, sampling_rate, "Segment")
    
    # Helper for generic boundaries
    def find_boundary_local(peak_idx, direction, max_search_samples, thresh_factor=0.05):
        if np.isnan(peak_idx): return np.nan
        peak_idx = int(peak_idx)
        limit = peak_idx + (direction * max_search_samples)
        limit = max(0, min(signal_len, limit))
        if abs(limit - peak_idx) < 3: return peak_idx
        start, end = sorted([peak_idx, limit])
        segment = signal[start:end]
        if direction == -1: segment = segment[::-1] 
        diff = np.diff(segment)
        if len(diff) == 0: return peak_idx
        max_slope = np.max(np.abs(diff))
        thresh = max_slope * thresh_factor
        for i in range(1, len(diff)):
            if np.abs(diff[i]) < thresh:
                return peak_idx + (direction * i)
        return limit

    last_t_offset = None

    for i, r in enumerate(r_peaks):
        r = int(r)
        r_height = abs(signal[r]) if abs(signal[r]) > 0.05 else 1.0
        
        # --- P-WAVE ---
        p_peak_val = p_peaks[i] if i < len(p_peaks) else np.nan
        waves['p_peak'].append(p_peak_val)
        
        if not np.isnan(p_peak_val):
            p_onset = find_p_onset_constrained(signal, p_peak_val, sampling_rate, last_t_offset)
            waves['p_onset'].append(p_onset)
            waves['p_offset'].append(find_boundary_local(int(p_peak_val), 1, int(0.08 * sampling_rate)))
        else:
            waves['p_onset'].append(np.nan)
            waves['p_offset'].append(np.nan)

        # --- Q-WAVE ---
        win_q = int(0.05 * sampling_rate)
        q_search_start = max(0, r - win_q)
        q_window = signal[q_search_start:r]
        q_idx = q_search_start + np.argmin(q_window) if len(q_window) > 0 else np.nan
        waves['q_peak'].append(q_idx)
        anchor_q = q_idx if not np.isnan(q_idx) else r
        waves['q_onset'].append(find_boundary_local(anchor_q, -1, int(0.04 * sampling_rate)))
        
        # --- S-WAVE ---
        win_s = int(0.06 * sampling_rate)
        s_search_end = min(signal_len, r + win_s)
        s_window = signal[r:s_search_end]
        s_idx = r + np.argmin(s_window) if len(s_window) > 0 else np.nan
        waves['s_peak'].append(s_idx)
        anchor_s = s_idx if not np.isnan(s_idx) else r
        waves['s_offset'].append(find_boundary_local(anchor_s, 1, int(0.04 * sampling_rate)))
        
        # --- T-WAVE ---
        rr_next = (int(r_peaks[i+1]) - r) if i < len(r_peaks) - 1 else 1.0 * sampling_rate
        dyn_t_start = int(0.10 * sampling_rate)
        dyn_t_end = int(min(0.600 * sampling_rate, 0.65 * rr_next))
        t_search_start = min(signal_len, r + dyn_t_start)
        t_search_end = min(signal_len, r + dyn_t_end)
        t_idx = np.nan
        
        if t_search_start < t_search_end:
            t_window = signal[t_search_start:t_search_end]
            if len(t_window) > 0:
                local_peaks, _ = sp_signal.find_peaks(t_window, prominence=(0.05 * r_height))
                if len(local_peaks) > 0:
                    best_peak = local_peaks[np.argmax(t_window[local_peaks])]
                    t_idx = t_search_start + best_peak
                else:
                    t_idx = t_search_start + np.argmax(t_window)
        
        waves['t_peak'].append(t_idx)
        waves['t_onset'].append(find_boundary_local(t_idx, -1, int(0.08 * sampling_rate)))
        
        # --- T-OFFSET (ULTRA ACCURATE) ---
        next_r = int(r_peaks[i+1]) if i < len(r_peaks) - 1 else None
        next_p_onset_limit = None
        if i + 1 < len(p_peaks):
            next_p_peak = p_peaks[i+1]
            if not np.isnan(next_p_peak):
                next_p_onset_limit = find_boundary_local(int(next_p_peak), -1, int(0.08 * sampling_rate))
        
        t_offset = find_t_wave_offset_with_stability(
            signal, 
            t_idx, 
            sampling_rate, 
            next_r_peak=next_r,
            limit_idx=next_p_onset_limit
        )

        waves['t_offset'].append(t_offset)
        
        last_t_offset = t_offset

    for k in waves:
        waves[k] = np.array(waves[k])
    
    return waves


def calculate_intervals(waves, sampling_rate):
    """Calculate PR, QRS, and QT intervals"""
    pr_intervals = (waves['q_onset'] - waves['p_onset']) / sampling_rate * 1000
    qrs_durations = (waves['s_offset'] - waves['q_onset']) / sampling_rate * 1000
    qt_intervals = (waves['t_offset'] - waves['q_onset']) / sampling_rate * 1000
   
    pr_intervals = np.where((pr_intervals > 40) & (pr_intervals < 600), pr_intervals, np.nan)
    qrs_durations = np.where((qrs_durations > 30) & (qrs_durations < 200), qrs_durations, np.nan)
    qt_intervals = np.where((qt_intervals > 100) & (qt_intervals < 600), qt_intervals, np.nan)
   
    return pr_intervals, qrs_durations, qt_intervals


def process_ecg_segments(ecg_raw, ecg_filtered, sampling_rate, num_segments=7, min_segment_length=3500):
    """Process ECG into segments for analysis"""
    max_len = len(ecg_raw)
    window_step = round((max_len - min_segment_length) / (num_segments - 1)) if num_segments > 1 else 0
    results = []
   
    for i in range(num_segments):
        start_idx = i * window_step
        end_idx = min(start_idx + min_segment_length, max_len)
        if start_idx < 0: start_idx = 0
        
        segment_raw = ecg_raw[start_idx:end_idx]
        segment_filtered = ecg_filtered[start_idx:end_idx]
        
        if len(segment_raw) < 100: continue
        
        _, r_peaks, bpm, _ = qrs_detect(segment_raw, sampling_rate, len(segment_raw)/sampling_rate)
        
        waves = improved_delineate_ecg_waves(segment_filtered, r_peaks, sampling_rate)
        pr, qrs, qt = calculate_intervals(waves, sampling_rate)
        
        def adj(arr):
            if len(arr) == 0: return np.array([])
            return arr + start_idx
   
        res = {
            'segment_num': i + 1,
            'start_idx': start_idx,
            'end_idx': end_idx,
            'ecg_raw': segment_raw,
            'ecg_filtered': segment_filtered,
            'bpm': bpm,
            'r_peaks': adj(r_peaks),
            'avg_pr': np.nanmean(pr),
            'avg_qrs': np.nanmean(qrs),
            'avg_qt': np.nanmean(qt)
        }
        for k, v in waves.items():
            res[k] = adj(v)
            
        results.append(res)
   
    return results


# ==========================================
# 6. PLOTTING FUNCTIONS
# ==========================================

def plot_ecg_segments(ecg_raw, ecg_filtered, sampling_rate, results, title="ECG Analysis"):
    """Plot ECG segments with detected waves and intervals"""
    num_segments = len(results)
    fig, axes = plt.subplots(num_segments, 1, figsize=(20, 4*num_segments))
    if num_segments == 1: axes = [axes]
   
    time = np.arange(len(ecg_raw)) / sampling_rate
   
    def get_valid(indices):
        if len(indices) == 0: return np.array([], dtype=int)
        valid = indices[~np.isnan(indices)]
        valid = valid[valid < len(ecg_raw)]
        return valid.astype(int)

    t_wave_issues = 0
    
    for i, (ax, res) in enumerate(zip(axes, results)):
        seg_time = time[res['start_idx']:res['end_idx']]
        
        ax.plot(seg_time, res['ecg_filtered'], 'b-', alpha=0.7, linewidth=0.8, label='Filtered')
        ax.plot(seg_time, res['ecg_raw'], 'k-', alpha=0.3, linewidth=0.5, label='Raw')
        
        peaks = [('r_peaks', 'ro', 'R'), ('p_peak', 'g^', 'P'), ('t_peak', 'bD', 'T')]
        for key, style, lbl in peaks:
            valid = get_valid(res[key])
            if len(valid): 
                ax.plot(time[valid], ecg_filtered[valid], style, markersize=6, label=lbl)

        p_onsets = res['p_onset']
        q_onsets = res['q_onset']
        t_offsets = res['t_offset']
        t_peaks_arr = res['t_peak']
        
        for j in range(len(t_peaks_arr)):
            if not np.isnan(t_peaks_arr[j]) and not np.isnan(t_offsets[j]):
                if abs(t_peaks_arr[j] - t_offsets[j]) < 2:
                    t_wave_issues += 1
        
        y_min = np.min(res['ecg_filtered'])
        bar_y_pr = y_min - 0.05
        bar_y_qt = y_min - 0.10
        
        count = 0
        for j in range(len(p_onsets)):
            if j < len(q_onsets) and not np.isnan(p_onsets[j]) and not np.isnan(q_onsets[j]):
                pon = int(p_onsets[j])
                qon = int(q_onsets[j])
                if pon < qon:
                    ax.hlines(y=bar_y_pr, xmin=time[pon], xmax=time[qon], colors='green', linewidth=4, alpha=0.7)
                    if count == 0: ax.text(time[pon], bar_y_pr, 'PR', color='green', fontsize=8, ha='right', va='center')

            if j < len(t_offsets) and not np.isnan(q_onsets[j]) and not np.isnan(t_offsets[j]):
                qon = int(q_onsets[j])
                toff = int(t_offsets[j])
                if qon < toff:
                    ax.hlines(y=bar_y_qt, xmin=time[qon], xmax=time[toff], colors='blue', linewidth=4, alpha=0.7)
                    if count == 0: ax.text(time[toff], bar_y_qt, 'QT', color='blue', fontsize=8, ha='left', va='center')
            count += 1

        valid = get_valid(res['p_onset'])
        if len(valid): ax.plot(time[valid], ecg_filtered[valid], 'gx', markersize=8, label='P-start')
        
        valid = get_valid(res['q_onset'])
        if len(valid): ax.plot(time[valid], ecg_filtered[valid], 'm|', markersize=12, markeredgewidth=2, label='QRS-start')
        
        # ==========================================
        # ADDED: QRS END PLOTTING (S-OFFSET)
        # ==========================================
        valid = get_valid(res['s_offset'])
        if len(valid): ax.plot(time[valid], ecg_filtered[valid], 'c|', markersize=12, markeredgewidth=2, label='QRS-end')
        # ==========================================

        valid = get_valid(res['t_offset'])
        if len(valid): ax.plot(time[valid], ecg_filtered[valid], 'b|', markersize=12, markeredgewidth=2, label='T-end')

        info = f"Seg {res['segment_num']} | BPM: {res['bpm']:.0f} | "
        info += f"PR: {res['avg_pr']:.0f}ms | QRS: {res['avg_qrs']:.0f}ms | QT: {res['avg_qt']:.0f}ms"
        
        ax.set_title(info, fontsize=11, fontweight='bold')
        ax.set_xlim([seg_time[0], seg_time[-1]])
        ax.grid(True, alpha=0.3)
        if i == 0: ax.legend(loc='upper right', ncol=7, fontsize='small') # Increased ncol for new item

    plt.tight_layout()
    plt.show()
    plt.close()


def plot_p_wave_quality(signal, r_peaks, p_peaks, sampling_rate, segment_num=""):
    """Visualize P-wave detection quality"""
    fig, axes = plt.subplots(4, 1, figsize=(15, 12))
    
    time = np.arange(len(signal)) / sampling_rate
    axes[0].plot(time, signal, 'k-', alpha=0.6, linewidth=0.8)
    
    valid_r = r_peaks[~np.isnan(r_peaks)].astype(int)
    axes[0].plot(time[valid_r], signal[valid_r], 'ro', markersize=6, label='R-peaks')
    
    valid_p = p_peaks[~np.isnan(p_peaks)].astype(int)
    axes[0].plot(time[valid_p], signal[valid_p], 'g^', markersize=6, label='P-waves')
    
    missing_p = np.where(np.isnan(p_peaks[1:]))[0] + 1
    if len(missing_p) > 0:
        missing_r = r_peaks[missing_p].astype(int)
        axes[0].plot(time[missing_r], signal[missing_r], 'rx', markersize=10, 
                    markeredgewidth=2, label=f'Missing P ({len(missing_p)})')
    
    axes[0].set_title(f'P-wave Detection - Segment {segment_num}')
    axes[0].set_ylabel('Amplitude (mV)')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    pr_intervals = []
    beat_numbers = []
    beat_idx = 0
    for i, (r, p) in enumerate(zip(r_peaks, p_peaks)):
        if not np.isnan(p) and not np.isnan(r):
            pr = (r - p) / sampling_rate * 1000
            pr_intervals.append(pr)
            beat_numbers.append(beat_idx)
        beat_idx += 1
    
    if pr_intervals:
        axes[1].plot(beat_numbers, pr_intervals, 'bo-', markersize=4)
        axes[1].axhline(y=120, color='g', linestyle='--', alpha=0.5, label='Normal PR min (120ms)')
        axes[1].axhline(y=200, color='orange', linestyle='--', alpha=0.5, label='1st° AVB threshold (200ms)')
        axes[1].axhline(y=np.mean(pr_intervals), color='b', linestyle='-', alpha=0.7, 
                               label=f'Mean: {np.mean(pr_intervals):.1f}ms')
        axes[1].set_ylabel('PR Interval (ms)')
        axes[1].set_xlabel('Beat Number')
        axes[1].set_title(f'PR Intervals (mean: {np.mean(pr_intervals):.1f}ms ± {np.std(pr_intervals):.1f}ms)')
        axes[1].grid(True, alpha=0.3)
        axes[1].legend()
    
    p_amplitudes = []
    beat_numbers_amp = []
    beat_idx = 0
    for p in valid_p:
        p_amplitudes.append(abs(signal[p]))
        beat_numbers_amp.append(beat_idx)
        beat_idx += 1
    
    if p_amplitudes:
        axes[2].plot(beat_numbers_amp, p_amplitudes, 'go-', markersize=4)
        axes[2].axhline(y=0.1, color='g', linestyle='--', alpha=0.5, label='Typical P (0.1mV)')
        axes[2].axhline(y=0.05, color='orange', linestyle='--', alpha=0.5, label='Low amplitude threshold')
        axes[2].axhline(y=np.mean(p_amplitudes), color='g', linestyle='-', alpha=0.7,
                               label=f'Mean: {np.mean(p_amplitudes):.3f}mV')
        axes[2].set_ylabel('Amplitude (mV)')
        axes[2].set_xlabel('Beat Number')
        axes[2].set_title(f'P-wave Amplitudes (mean: {np.mean(p_amplitudes):.3f}mV ± {np.std(p_amplitudes):.3f}mV)')
        axes[2].grid(True, alpha=0.3)
        axes[2].legend()
    
    window_size = int(2 * sampling_rate)
    num_windows = len(signal) // window_size
    quality_over_time = []
    time_points = []
    
    for w in range(num_windows):
        start = w * window_size
        end = min((w + 1) * window_size, len(signal))
        window_signal = signal[start:end]
        quality = calculate_signal_quality(window_signal)
        quality_over_time.append(quality)
        time_points.append((start + end) / 2 / sampling_rate)
    
    if quality_over_time:
        axes[3].plot(time_points, quality_over_time, 'r-', linewidth=2, label='Signal Quality')
        axes[3].axhline(y=0.7, color='g', linestyle='--', alpha=0.5, label='Good (>0.7)')
        axes[3].axhline(y=0.5, color='orange', linestyle='--', alpha=0.5, label='Moderate (>0.5)')
        axes[3].axhline(y=0.3, color='r', linestyle='--', alpha=0.5, label='Poor (<0.3)')
        axes[3].fill_between(time_points, 0, quality_over_time, alpha=0.3, color='red')
        axes[3].set_ylabel('Quality Index')
        axes[3].set_xlabel('Time (s)')
        axes[3].set_title('Signal Quality Over Time')
        axes[3].set_ylim([0, 1])
        axes[3].grid(True, alpha=0.3)
        axes[3].legend()
    
    plt.tight_layout()
    # plt.savefig('/mnt/user-data/outputs/p_wave_quality.png', dpi=150, bbox_inches='tight')
    plt.close()
    
    total_beats = len(r_peaks) - 1
    detected_p = np.sum(~np.isnan(p_peaks[1:]))
    detection_rate = detected_p / total_beats * 100 if total_beats > 0 else 0
    
    print(f"\n--- P-Wave Detection Summary (Segment {segment_num}) ---")
    print(f"Total R-peaks: {len(r_peaks)}")
    print(f"P-waves detected: {detected_p}/{total_beats} ({detection_rate:.1f}%)")
    print(f"Mean signal quality: {np.mean(quality_over_time):.2f}")
    if pr_intervals:
        print(f"PR interval: {np.mean(pr_intervals):.1f} ± {np.std(pr_intervals):.1f} ms")
    if p_amplitudes:
        print(f"P amplitude: {np.mean(p_amplitudes):.3f} ± {np.std(p_amplitudes):.3f} mV")
    print("-" * 50)


def plot_full_ecg(ecg_raw, ecg_filtered, sampling_rate, waves, r_peaks, bpm, title="Full ECG Analysis"):
    # --- 1. Setup Figure (Dynamic Width) ---
    duration_sec = len(ecg_raw) / sampling_rate
    fig_width = max(15, min(100, int(duration_sec * 0.8))) 
    
    fig, ax = plt.subplots(figsize=(fig_width, 6))
    time = np.arange(len(ecg_raw)) / sampling_rate

    # --- 2. Plot Signals ---
    ax.plot(time, ecg_filtered, 'b-', alpha=0.7, linewidth=0.8, label='Filtered')
    ax.plot(time, ecg_raw, 'k-', alpha=0.3, linewidth=0.5, label='Raw')

    # Helper to clean indices
    def get_valid(indices):
        if len(indices) == 0: return np.array([], dtype=int)
        valid = indices[~np.isnan(indices)]
        valid = valid[(valid >= 0) & (valid < len(ecg_raw))]
        return valid.astype(int)

    # --- 3. Plot Peaks (R, P, T) ---
    valid_r = get_valid(r_peaks)
    if len(valid_r):
        ax.plot(time[valid_r], ecg_filtered[valid_r], 'ro', markersize=6, label='R_peak')

    if 'p_peak' in waves:
        valid_p = get_valid(waves['p_peak'])
        if len(valid_p):
            ax.plot(time[valid_p], ecg_filtered[valid_p], 'g^', markersize=6, label='P_peak')

    if 't_peak' in waves:
        valid_t = get_valid(waves['t_peak'])
        if len(valid_t):
            ax.plot(time[valid_t], ecg_filtered[valid_t], 'bD', markersize=6, label='T_peak')

    # --- 4. Plot Interval Bars (PR & QT) ---
    y_min = np.min(ecg_filtered)
    bar_y_pr = y_min - (np.ptp(ecg_filtered) * 0.05)
    bar_y_qt = y_min - (np.ptp(ecg_filtered) * 0.10)

    p_onsets = waves.get('p_onset', [])
    q_onsets = waves.get('q_onset', [])
    t_offsets = waves.get('t_offset', [])

    count_pr = 0
    count_qt = 0
    
    num_beats = len(r_peaks)
    
    for j in range(num_beats):
        if j < len(p_onsets) and j < len(q_onsets):
            pon = p_onsets[j]
            qon = q_onsets[j]
            
            if not np.isnan(pon) and not np.isnan(qon):
                pon, qon = int(pon), int(qon)
                if pon < qon and qon < len(time):
                    ax.hlines(y=bar_y_pr, xmin=time[pon], xmax=time[qon], 
                             colors='green', linewidth=4, alpha=0.7)
                    if count_pr == 0: 
                        ax.text(time[pon], bar_y_pr, 'PR', color='green', 
                               fontsize=8, ha='right', va='center', fontweight='bold')
                    count_pr += 1

        if j < len(q_onsets) and j < len(t_offsets):
            qon = q_onsets[j]
            toff = t_offsets[j]
            
            if not np.isnan(qon) and not np.isnan(toff):
                qon, toff = int(qon), int(toff)
                if qon < toff and toff < len(time):
                    ax.hlines(y=bar_y_qt, xmin=time[qon], xmax=time[toff], 
                             colors='blue', linewidth=4, alpha=0.7)
                    if count_qt == 0: 
                        ax.text(time[toff], bar_y_qt, 'QT', color='blue', 
                               fontsize=8, ha='left', va='center', fontweight='bold')
                    count_qt += 1

    # --- 5. Plot Specific Markers (Onsets/Offsets) ---
    if 'p_onset' in waves:
        valid_pon = get_valid(waves['p_onset'])
        if len(valid_pon):
            ax.plot(time[valid_pon], ecg_filtered[valid_pon], 'gx', markersize=8, label='P-start')

    if 'q_onset' in waves:
        valid_qon = get_valid(waves['q_onset'])
        if len(valid_qon):
            ax.plot(time[valid_qon], ecg_filtered[valid_qon], 'm|', markersize=12, markeredgewidth=2, label='QRS-start')

    # ==========================================
    # ADDED: QRS END PLOTTING (S-OFFSET)
    # ==========================================
    if 's_offset' in waves:
        valid_soff = get_valid(waves['s_offset'])
        if len(valid_soff):
            ax.plot(time[valid_soff], ecg_filtered[valid_soff], 'c|', markersize=12, markeredgewidth=2, label='QRS-end')
    # ==========================================

    if 't_offset' in waves:
        valid_toff = get_valid(waves['t_offset'])
        if len(valid_toff):
            ax.plot(time[valid_toff], ecg_filtered[valid_toff], 'b|', markersize=12, markeredgewidth=2, label='T-end')

    # --- 6. Add Statistics Box ---
    pr, qrs, qt = calculate_intervals(waves, sampling_rate)
    stats_text = (
        f"HEART RATE: {bpm:.0f} BPM\n"
        f"Avg PR:  {np.nanmean(pr):.0f} ms\n"
        f"Avg QRS: {np.nanmean(qrs):.0f} ms\n"
        f"Avg QT:  {np.nanmean(qt):.0f} ms"
    )
    props = dict(boxstyle='round', facecolor='wheat', alpha=0.9)
    ax.text(0.005, 0.95, stats_text, transform=ax.transAxes, fontsize=11,
            verticalalignment='top', bbox=props, fontfamily='monospace')

    # --- 7. Formatting ---
    ax.set_title(title, fontsize=14, fontweight='bold')
    ax.set_xlabel("Time (seconds)")
    ax.set_ylabel("Amplitude")
    ax.set_xlim([0, time[-1]])
    ax.grid(True, which='both', alpha=0.3)
    
    handles, labels = ax.get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    ax.legend(by_label.values(), by_label.keys(), loc='upper right', ncol=5, fontsize='small')

    plt.tight_layout()
    plt.show()
    plt.close()

# ==========================================
# 7. MAIN EXECUTION
# ==========================================

if __name__ == "__main__":
    # Filter setup
    freq = 500
    low_pass_cutoff = 40
    low_pass_order = 7
    b_lp, a_lp = sp_signal.butter(low_pass_order, low_pass_cutoff / (freq / 2), btype="low")
    b_notch, a_notch = sp_signal.iirnotch(50, 50 / 20, freq)  

    def low_pass_filter(data):
        try:
            return sp_signal.filtfilt(b_lp, a_lp, data)
        except:
            return data

    def notch_filter(data):
        try:
            return sp_signal.filtfilt(b_notch, a_notch, data)
        except:
            return data

    # # input_json = r"simulator\contec\30bpm_1756099931979.json"   
    # input_json = r"simulator\contec\40bpm_1756099996286.json"   
    # # input_json = r"simulator\contec\60bpm_1756100055600.json"   
    # # input_json = r"simulator\contec\80bpm_1756100133843.json"   
    # # input_json = r"simulator\contec\100bpm_1756100188299.json"   
    # input_json = r"simulator\contec\120bpm_1756100243822.json"     
    # # input_json = r"simulator\contec\140bpm_1756100303625.json"   
    # # input_json = r"simulator\contec\180bpm_1756100430492.json"   
    # # input_json = r"simulator\contec\200bpm_1756100489086.json"   
    # # input_json = r"simulator\contec\220bpm_1756100542987.json"   

    # input_json = r"intervals\Antor_1769669515402.json"   
    # input_json = r"intervals\Asif_1769673034525.json"   
    # input_json = r"intervals\Atiur_1769684477237.json"     #### 89
    input_json = r"intervals\Aupo_1769678570784.json"     #### 98
    # input_json = r"intervals\Faizur_1769672384762.json"    
    # input_json = r"intervals\Alauddin_1769763819330.json"   

    with open(input_json, 'r') as file:
        file_data = json.load(file)


    # # input_json = r"exception\L2_1759207950416.json"   
    # input_json = r"issues\L2_1757064122874.json"  
    # # input_json = r"issues\L2_1757579288752.json"
    # # input_json = r"issues\L2_1757737806463.json"  
    # # input_json = r"v01_prob\L2_1765984517025.json"  
    # # input_json = r"1st-last-peaks\L2_1759908627949.json"
    # # input_json = r"1st-last-peaks\L2_1759908888619.json"

    # # input_json = r"issues2\1757998097068\L2_1757998097068.json"
    # # input_json = r"issues2\1758943573744\L2_1758943573744.json"
    # # input_json = r"issues2\1759059066184\L2_1759059066184.json"     
    # # input_json = r"issues2\1759117739887\L2_1759117739887.json"
    # # input_json = r"issues2\1759118709079\L2_1759118709079.json"
    # # input_json = r"issues2\1759202739736\L2_1759202739736.json"  
    # # input_json = r"issues2\1759639059357\L2_1759639059357.json"  ####

    # doubles = []
    # with open(input_json, "rb") as f:
    #     while chunk := f.read(8):
    #         if len(chunk) < 8:
    #             break
    #         value = struct.unpack("<d", chunk)[0]
    #         doubles.append(value)

    # file_data = {'dataL2': doubles}   



    # # def decrypt(input_file):
    # #     """Decrypt encrypted JSON file (optional - commented out in your version)"""
    # #     private_key = "Msz377xMbcn++vrcDel9vxOuEss8fsWO"
    # #     cipher = AES.new(private_key.encode('utf-8'), AES.MODE_ECB)
    # #     with open(input_file, 'rb') as f:
    # #         encrypted_data = f.read()
    # #     enc = base64.b64decode(encrypted_data[24:])
    # #     data = unpad(cipher.decrypt(enc), 16)
    # #     decoded_string = data.decode('utf-8')
    # #     return json.loads(decoded_string)

    # input_json = r"NHF2\DATA_1750689015865.json"  ####
    # # input_json = r"NHF2\DATA_1750689460556.json"  ####
    # # input_json = r"NHF2\DATA_1750851207409.json"  ####
    # # input_json = r"NHF2\DATA_1750858856842.json"  ####
    # # input_json = r"NHF2\DATA_1750862721789.json"  ####
    # # input_json = r"NHF2\DATA_1750996455820.json"
    # file_data = decrypt(input_json)

    # # input_json = r"NHF\DATA_1752067426678.json"  
    # # input_json = r"NHF\DATA_1752121970835.json"  
    # # input_json = r"NHF\DATA_1754709586876.json"  
    # input_json = r"NHF\DATA_1754729551054.json"
    # file_data = decrypt(input_json)  
        
    try:
        # Process raw data for R-peak detection
        print("Processing raw data for R-peak detection...")
        raw_data = data_process(file_data)
        ecg_raw = raw_data[0, :15000, 0]
        
        # Process filtered data for P, Q, S, T wave detection
        print("Processing filtered data for wave delineation...")
        filtered_data = data_process(low_pass_filter(notch_filter(baseline_wander(np.array(file_data["dataL2"])))))
        ecg_filtered = filtered_data[0, :15000, 0]
        
        sampling_rate = 500

        print("\n" + "="*80)
        print("ULTRA-ACCURATE ECG ANALYSIS")
        print("="*80)
        print("✓ Using raw data for R-peak detection")
        print("✓ Using filtered data for P, Q, S, T wave delineation")
        print("✓ Enhanced T-wave offset detection (first steep descent)")
        print("="*80 + "\n")
        
        segment_results = process_ecg_segments(
            ecg_raw=ecg_raw,
            ecg_filtered=ecg_filtered,
            sampling_rate=sampling_rate,
            num_segments=4,
            min_segment_length=5000
        )

        # Plot results
        plot_ecg_segments(
            ecg_raw,
            ecg_filtered,
            sampling_rate,
            segment_results,
            "Ultra-Accurate Clinical Interval Analysis"
        )

        print("\n" + "="*80)
        print("ULTRA-ACCURATE FULL ECG ANALYSIS")
        print("="*80)
        
        # --- 2. PEAK DETECTION ---
        print("Detecting R-peaks on signal...")
        _, r_peaks, bpm, _ = qrs_detect(ecg_raw, sampling_rate)
        print(f"✓ Detected {len(r_peaks)} R-peaks. BPM: {bpm:.1f}")

        # --- 3. WAVE DELINEATION ---
        print("Delineating waves (P, Q, S, T)...")
        waves = improved_delineate_ecg_waves(ecg_filtered, r_peaks, sampling_rate)
        
        # --- 4. CALCULATE STATISTICS ---
        pr, qrs, qt = calculate_intervals(waves, sampling_rate)
        
        print("\n" + "="*40)
        print(f"FULL SIGNAL STATISTICS")
        print("="*40)
        print(f"Avg PR Interval:  {np.nanmean(pr):.1f} ms")
        print(f"Avg QRS Duration: {np.nanmean(qrs):.1f} ms")
        print(f"Avg QT Interval:  {np.nanmean(qt):.1f} ms")
        print("="*40 + "\n")
        
        # --- 5. PLOT ---
        print("Generating full plot with statistics...")
        plot_full_ecg(
            ecg_raw, 
            ecg_filtered, 
            sampling_rate, 
            waves, 
            r_peaks, 
            bpm
        )
        
    except FileNotFoundError:
        print(f"Error: File not found ({input_json}). Please check the path.")
    except Exception as e:
        print(f"An error occurred: {e}")
        import traceback
        traceback.print_exc()

### qrs corrected

In [None]:
import json
import struct
import numpy as np
import scipy.signal as sp_signal
import matplotlib.pyplot as plt
import base64
from Crypto.Cipher import AES
from Crypto.Util.Padding import unpad
from scipy.interpolate import CubicSpline

# ==========================================
# 1. UTILITY FUNCTIONS
# ==========================================

def baseline_wander(data, sampling_rate=500, knot_spacing=0.4):
    """
    Cubic spline interpolation - fits smooth curve through evenly spaced points.
    knot_spacing: distance between knots in seconds
    """
    n_samples = len(data)
    knot_interval = int(sampling_rate * knot_spacing)
    
    # Create knot points at regular intervals
    knot_indices = np.arange(0, n_samples, knot_interval)
    if knot_indices[-1] != n_samples - 1:
        knot_indices = np.append(knot_indices, n_samples - 1)
    
    # Use percentile at each knot region to estimate baseline (robust to QRS)
    knot_values = []
    half_window = knot_interval // 2
    for idx in knot_indices:
        start = max(0, idx - half_window)
        end = min(n_samples, idx + half_window)
        knot_values.append(np.percentile(data[start:end], 50))
    
    # Fit cubic spline and subtract
    spline = CubicSpline(knot_indices, knot_values)
    baseline = spline(np.arange(n_samples))
    
    return data - baseline

def normalize(signal, min_val, max_val):
    if max_val - min_val == 0:
        return np.zeros_like(signal)
    return (signal - min_val) / (max_val - min_val)

def process_signal(signal_data, min_val, max_val):
    return normalize(signal_data, min_val, max_val)

def data_process(input_data):
    """
    Robust data processing that handles both dictionary inputs and direct array inputs.
    Fixed the FutureWarning issue.
    """
    keys = ['dataL2']
    datas = []
   
    # Check if input is a dictionary and has the key
    if isinstance(input_data, dict) and 'dataL2' in input_data:
        raw_data = input_data['dataL2']
    else:
        # Assume it's already the data array
        raw_data = input_data
        
    sig = np.array(raw_data)
    datas.append(sig.astype('float32'))
   
    datas_array = np.array(datas)
    min_val = np.min(datas_array)
    max_val = np.max(datas_array)
   
    signal = []
    for i in range(datas_array.shape[0]):
        signal.append(process_signal(datas_array[i, :], min_val, max_val))

    final_data = np.stack(signal)
    final_data = np.expand_dims(final_data, axis=0)
    final_data = final_data.transpose(0, 2, 1)
    return final_data


def remove_close_peaks(r_peaks, validation_signal, min_dist_samples):
    """Remove peaks that are too close together, keeping the stronger one"""
    if len(r_peaks) == 0:
        return np.array([])
    
    sorted_idx = np.argsort(r_peaks)
    r_peaks = r_peaks[sorted_idx]
    validation_abs = np.abs(validation_signal[r_peaks.astype(int)])
    
    keep = []
    last_kept = -min_dist_samples
    
    for i, current in enumerate(r_peaks):
        if current - last_kept >= min_dist_samples:
            keep.append(current)
            last_kept = current
        else:
            if validation_abs[i] > validation_abs[len(keep)-1]:
                keep[-1] = current
                last_kept = current
    
    return np.array(keep)


def amplitude_based_filtering(ecg_signal, peaks, segment_num="Unknown"):
    """Filter out high amplitude outlier peaks using IQR method"""
    if len(peaks) == 0:
        return peaks, np.array([])
    
    peak_amplitudes = np.abs(ecg_signal[peaks.astype(int)])
    
    median_amp = np.median(peak_amplitudes)
    q75, q25 = np.percentile(peak_amplitudes, [75, 25])
    iqr = q75 - q25
    
    if iqr > 0:
        high_amp_threshold = q75 + 1.5 * iqr
        
        high_amp_indices = np.where(peak_amplitudes > high_amp_threshold)[0]
        high_amp_count = len(high_amp_indices)
        
        if len(peaks) - high_amp_count > 0:
            mask = np.ones(len(peaks), dtype=bool)
            mask[high_amp_indices] = True
            cleaned_peaks = peaks[mask]
            cleaned_amplitudes = peak_amplitudes[mask]
        else:
            cleaned_peaks = peaks
            cleaned_amplitudes = peak_amplitudes
    else:
        cleaned_peaks = peaks
        cleaned_amplitudes = peak_amplitudes
    
    return cleaned_peaks, cleaned_amplitudes


def remove_t_waves(ecg_signal, peaks, sampling_rate):
    """Remove T-wave false positives based on timing and morphology"""
    if len(peaks) < 3:
        return peaks
    
    sorted_peaks = np.sort(peaks)
    cleaned_peaks = []
    
    for i, peak in enumerate(sorted_peaks):
        is_r_peak = True
        
        if i > 0:
            prev_peak = sorted_peaks[i-1]
            interval_ms = (peak - prev_peak) / sampling_rate * 1000
            
            if 160 < interval_ms < 450:
                prev_amp = abs(ecg_signal[int(prev_peak)])
                curr_amp = abs(ecg_signal[int(peak)])
                
                if curr_amp < prev_amp * 0.5:
                    half_max = curr_amp * 0.5
                    
                    left = peak
                    while left > 0 and left > peak - 100:
                        if abs(ecg_signal[int(left)]) < half_max:
                            break
                        left -= 1
                    
                    right = peak
                    while right < len(ecg_signal) - 1 and right < peak + 100:
                        if abs(ecg_signal[int(right)]) < half_max:
                            break
                        right += 1
                    
                    width_ms = (right - left) / sampling_rate * 1000
                    
                    if width_ms > 40:
                        is_r_peak = False
        
        if is_r_peak:
            cleaned_peaks.append(peak)
    
    return np.array(cleaned_peaks)


def robust_qrs_detect_internal(data_clean, sampling_rate):
    original_data = data_clean.copy()
    nyquist = 0.5 * sampling_rate
    
    # Calculate sharpness threshold
    low_strict = 10 / nyquist
    high_strict = 40 / nyquist
    b2, a2 = sp_signal.butter(2, [low_strict, high_strict], btype='band')
    filtered_strict = sp_signal.filtfilt(b2, a2, data_clean)
    diff_strict = np.diff(np.abs(filtered_strict))
    diff_strict = np.append(diff_strict, 0)
    strict_score = diff_strict ** 2
    
    if len(strict_score) > 0:
        sharpness_threshold = np.percentile(strict_score, 94)
    else:
        sharpness_threshold = 0
    
    all_candidate_peaks = []
    
    # Strategy 1: Multi-band detection with multiple thresholds
    freq_bands = [(5, 15), (8, 24), (10, 30), (12, 40)]
    
    for low_freq, high_freq in freq_bands:
        low = low_freq / nyquist
        high = high_freq / nyquist
        b, a = sp_signal.butter(2, [low, high], btype='band')
        filtered = sp_signal.filtfilt(b, a, data_clean)
        
        squared = filtered ** 2
        window_size = int(0.15 * sampling_rate)
        integrated = np.convolve(squared, np.ones(window_size) / window_size, mode='same')
        
        mean_val = np.mean(integrated)
        std_val = np.std(integrated)
        
        thresholds = [mean_val + 0.1 * std_val, mean_val + 0.2 * std_val, mean_val + 0.3 * std_val]
        
        for threshold in thresholds:
            candidates, _ = sp_signal.find_peaks(
                integrated,
                height=threshold,
                distance=int(0.2 * sampling_rate)
            )
            
            search_window = int(0.1 * sampling_rate)
            sharp_window = int(0.18 * sampling_rate)
            
            for peak in candidates:
                start_sharp = max(0, peak - sharp_window)
                end_sharp = min(len(strict_score), peak + sharp_window)
                if start_sharp < end_sharp:
                    local_sharpness = np.max(strict_score[start_sharp:end_sharp])
                    
                    if local_sharpness > sharpness_threshold:
                        start = max(0, peak - search_window)
                        end = min(len(original_data), peak + search_window)
                        if start < end:
                            local_segment = original_data[start:end]
                            local_max_idx = np.argmax(np.abs(local_segment))
                            refined_peak = start + local_max_idx
                            all_candidate_peaks.append(refined_peak)
    
    # Strategy 2: Prominence-based detection
    peaks_prom, properties = sp_signal.find_peaks(
        original_data,
        distance=int(0.2 * sampling_rate),
        prominence=0.02
    )
    
    sharp_window = int(0.18 * sampling_rate)
    for peak in peaks_prom:
        start_sharp = max(0, peak - sharp_window)
        end_sharp = min(len(strict_score), peak + sharp_window)
        if start_sharp < end_sharp:
            local_sharpness = np.max(strict_score[start_sharp:end_sharp])
            if local_sharpness > sharpness_threshold * 0.8:
                all_candidate_peaks.append(peak)
    
    # Strategy 3: Derivative-based detection
    diff_signal = np.diff(original_data)
    diff_squared = diff_signal ** 2
    diff_squared = np.append(diff_squared, 0)
    
    mean_diff = np.mean(diff_squared)
    std_diff = np.std(diff_squared)
    
    diff_peaks, _ = sp_signal.find_peaks(
        diff_squared,
        height=mean_diff + 0.5 * std_diff,
        distance=int(0.15 * sampling_rate)
    )
    
    search_window = int(0.08 * sampling_rate)
    
    for peak in diff_peaks:
        start_sharp = max(0, peak - sharp_window)
        end_sharp = min(len(strict_score), peak + sharp_window)
        if start_sharp < end_sharp:
            local_sharpness = np.max(strict_score[start_sharp:end_sharp])
            
            if local_sharpness > sharpness_threshold * 0.7:
                start = max(0, peak - search_window)
                end = min(len(original_data), peak + search_window)
                if start < end:
                    local_segment = original_data[start:end]
                    local_max_idx = np.argmax(np.abs(local_segment))
                    refined_peak = start + local_max_idx
                    all_candidate_peaks.append(refined_peak)
    
    # Merge and deduplicate peaks
    if len(all_candidate_peaks) > 0:
        all_candidate_peaks = np.unique(all_candidate_peaks)
        
        min_distance = int(0.15 * sampling_rate)
        sorted_peaks = np.sort(all_candidate_peaks)
        
        if len(sorted_peaks) > 0:
            keep_mask = [True]
            for i in range(1, len(sorted_peaks)):
                if sorted_peaks[i] - sorted_peaks[i-1] >= min_distance:
                    keep_mask.append(True)
                else:
                    start1 = max(0, sorted_peaks[i-1] - sharp_window)
                    end1 = min(len(strict_score), sorted_peaks[i-1] + sharp_window)
                    start2 = max(0, sorted_peaks[i] - sharp_window)
                    end2 = min(len(strict_score), sorted_peaks[i] + sharp_window)
                    
                    sharp1 = np.max(strict_score[start1:end1]) if start1 < end1 else 0
                    sharp2 = np.max(strict_score[start2:end2]) if start2 < end2 else 0
                    
                    if sharp2 > sharp1:
                        keep_mask[-1] = False
                        keep_mask.append(True)
                    else:
                        keep_mask.append(False)
            
            sorted_peaks = sorted_peaks[keep_mask]
    
    return sorted_peaks if len(all_candidate_peaks) > 0 else np.array([])


def qrs_detect(data, sampling_rate, segment_duration=None):
    """Enhanced QRS detection with Amplitude Guardrails for AV Blocks"""
    data_clean = data
    original_data = data_clean.copy()
    nyquist = 0.5 * sampling_rate
    
    # --- STREAM 1: Standard Detection ---
    low = 8 / nyquist
    high = 24 / nyquist
    b, a = sp_signal.butter(2, [low, high], btype='band')
    filtered_standard = sp_signal.filtfilt(b, a, data_clean)
    
    filtered_abs = np.abs(filtered_standard)
    diff = np.diff(filtered_abs)
    diff = np.append(diff, 0)
    squared = diff ** 2
    
    window_size = int(0.15 * sampling_rate)
    integrated = np.convolve(squared, np.ones(window_size) / window_size, mode='same')
    
    mean_val = np.mean(integrated)
    std_val = np.std(integrated)
    threshold = mean_val + 0.20 * std_val
    
    candidates, _ = sp_signal.find_peaks(
        integrated,
        height=threshold,
        distance=int(0.12 * sampling_rate)
    )
    
    # --- STREAM 2: Sharpness Validator ---
    low_strict = 10 / nyquist
    high_strict = 40 / nyquist
    b2, a2 = sp_signal.butter(2, [low_strict, high_strict], btype='band')
    filtered_strict = sp_signal.filtfilt(b2, a2, data_clean)
    diff_strict = np.diff(np.abs(filtered_strict))
    diff_strict = np.append(diff_strict, 0)
    strict_score = diff_strict ** 2
    
    if len(strict_score) > 0:
        sharpness_threshold = np.percentile(strict_score, 94)
    else:
        sharpness_threshold = 0
    
    confirmed_peaks = []
    search_window = int(0.18 * sampling_rate)
    
    for peak in candidates:
        start_check = max(0, peak - search_window)
        end_check = min(len(strict_score), peak + search_window)
        if start_check >= end_check:
            continue
        
        local_sharpness = np.max(strict_score[start_check:end_check])
        
        if local_sharpness > sharpness_threshold:
            local_segment = original_data[start_check:end_check]
            if len(local_segment) > 0:
                abs_local_segment = np.abs(local_segment)
                local_max_idx = np.argmax(abs_local_segment)
                confirmed_peaks.append(start_check + local_max_idx)
    
    r_peaks = np.array(confirmed_peaks)
    
    # Remove close peaks
    min_dist = int(0.15 * sampling_rate)
    r_peaks = remove_close_peaks(r_peaks, original_data, min_dist)
    
    cleaned_r = np.sort(np.array([x for x in r_peaks if not (isinstance(x, float) and np.isnan(x))]))
    
    # GAP FILLING WITH AMPLITUDE GUARDRAILS
    if len(cleaned_r) >= 2:
        existing_heights = np.abs(original_data[cleaned_r.astype(int)])
        median_r_height = np.median(existing_heights) if len(existing_heights) > 0 else 0
        
        rr_intervals = np.diff(cleaned_r) / sampling_rate
        median_rr = np.median(rr_intervals) if len(rr_intervals) > 0 else 1.0
        new_peaks = list(cleaned_r)
        
        if median_rr < 1.5:
            for i in range(len(rr_intervals)):
                if rr_intervals[i] > 1.4 * median_rr:
                    gap_start = cleaned_r[i]
                    gap_end = cleaned_r[i+1]
                    if gap_start >= gap_end:
                        continue
                    
                    gap_integrated = integrated[gap_start:gap_end]
                    low_thresh = mean_val * 0.6
                    
                    gap_candidates, _ = sp_signal.find_peaks(
                        gap_integrated,
                        height=low_thresh,
                        distance=int(0.10 * sampling_rate)
                    )
                    
                    for gc in gap_candidates:
                        abs_idx = gap_start + gc
                        sw_start = max(0, abs_idx - search_window)
                        sw_end = min(len(strict_score), abs_idx + search_window)
                        if sw_start >= sw_end:
                            continue
                        
                        local_sharp_max = np.max(strict_score[sw_start:sw_end])
                        if local_sharp_max > sharpness_threshold * 0.4:
                            
                            local_segment = original_data[sw_start:sw_end]
                            abs_local_segment = np.abs(local_segment)
                            refine_idx = np.argmax(abs_local_segment)
                            candidate_peak = sw_start + refine_idx
                            
                            candidate_amp = np.abs(original_data[candidate_peak])
                            
                            if candidate_amp > 0.45 * median_r_height:
                                new_peaks.append(candidate_peak)
        
        new_peaks = np.sort(np.unique(new_peaks))
        cleaned_r = remove_close_peaks(new_peaks, original_data, min_dist)
    
    # Determine expected peak count range
    if segment_duration is None:
        segment_duration = len(data_clean) / sampling_rate
    
    min_expected_peaks = int(30/60 * segment_duration)
    max_expected_peaks = int(180/60 * segment_duration)
    
    if len(cleaned_r) < min_expected_peaks or len(cleaned_r) > max_expected_peaks:
        initial_peaks = robust_qrs_detect_internal(data_clean, sampling_rate)
        initial_peaks = remove_t_waves(data_clean, initial_peaks, sampling_rate)
        cleaned_r, peak_amplitudes = amplitude_based_filtering(data_clean, initial_peaks, "Segment")
    else:
        cleaned_r = remove_t_waves(data_clean, cleaned_r, sampling_rate)
    
    # Calculate BPM
    if len(cleaned_r) > 1:
        rr_intervals = np.diff(cleaned_r) / sampling_rate
        
        valid_rr = rr_intervals[(rr_intervals > 0.2) & (rr_intervals < 4.0)]
        
        if len(valid_rr) > 0:
            mean_rr = np.mean(valid_rr)
            bpm = 60 / mean_rr if mean_rr > 0 else 0
        else:
            bpm = 0
    else:
        bpm = 0
    
    return data, cleaned_r, bpm, cleaned_r

# ==========================================
# 2. P-WAVE DETECTION
# ==========================================

def adaptive_noise_filter(segment, sampling_rate):
    """Apply stronger filtering in noisy regions"""
    diff = np.diff(segment)
    noise_std = np.std(diff)
    
    if noise_std > 0.05:
        nyquist = 0.5 * sampling_rate
        low = 1 / nyquist
        high = 25 / nyquist
        b, a = sp_signal.butter(3, [low, high], btype='band')
        return sp_signal.filtfilt(b, a, segment)
    
    return segment

def calculate_signal_quality(segment):
    """Calculate signal quality index (0-1, higher is better)"""
    diff = np.diff(segment)
    noise_level = np.std(diff)
    noise_score = np.exp(-noise_level / 0.1)
    
    baseline_drift = np.std(segment)
    drift_score = np.exp(-baseline_drift / 0.3)
    
    signal_range = np.max(segment) - np.min(segment)
    if 0.2 <= signal_range <= 1.5:
        amplitude_score = 1.0
    else:
        amplitude_score = 0.5
    
    quality = (noise_score * 0.4 + drift_score * 0.4 + amplitude_score * 0.2)
    
    return quality

def enhanced_p_wave_detection(signal, r_peaks, sampling_rate, segment_num):
    """Enhanced P-wave detection with Bradycardia fix"""
    if len(r_peaks) < 3:
        return np.full(len(r_peaks), np.nan)
    
    signal_quality = calculate_signal_quality(signal)
    
    p_peaks = np.full(len(r_peaks), np.nan)
    p_qualities = np.zeros(len(r_peaks))
    
    if len(r_peaks) > 1:
        rr_intervals = np.diff(r_peaks) / sampling_rate
        avg_rr = np.mean(rr_intervals)
        avg_hr = 60 / avg_rr if avg_rr > 0 else 0
    else:
        avg_hr = 0
    
    print(f"Segment {segment_num}: Signal Quality = {signal_quality:.2f}, Avg HR = {avg_hr:.0f} BPM")
    
    if avg_hr > 180:
        # print(f"  High heart rate detected - using adaptive short-cycle parameters")
        use_adaptive_short_cycle = True
        min_quality_threshold = 20  
    elif avg_hr > 120:
        use_adaptive_short_cycle = True
        min_quality_threshold = 30
    else:
        use_adaptive_short_cycle = False
        if signal_quality > 0.7:
            min_quality_threshold = 50
        elif signal_quality > 0.5:
            min_quality_threshold = 35
        else:
            min_quality_threshold = 25
    
    preliminary_pr_intervals = []
    preliminary_p_amps = []
    
    t_wave_ends = []
    for i in range(len(r_peaks) - 1):
        r_curr = int(r_peaks[i])
        r_next = int(r_peaks[i + 1])
        rr_interval = r_next - r_curr
        
        if rr_interval < 0.4 * sampling_rate:  
            estimated_t_end = r_curr + int(0.45 * rr_interval)
        elif rr_interval < 0.6 * sampling_rate: 
            estimated_t_end = r_curr + int(0.55 * rr_interval)
        else:  
            estimated_t_end = r_curr + min(int(0.5 * sampling_rate), int(0.65 * rr_interval))
        
        t_wave_ends.append(estimated_t_end)
    
    if len(r_peaks) > 0:
        last_r = int(r_peaks[-1])
        t_wave_ends.append(last_r + int(0.5 * sampling_rate))
    
    for i, r in enumerate(r_peaks):
        if i == 0: continue
            
        r = int(r)
        rr_prev = r - int(r_peaks[i-1])
        
        if use_adaptive_short_cycle and rr_prev < 0.5 * sampling_rate:  
            if i - 1 < len(t_wave_ends):
                t_end_prev = t_wave_ends[i - 1]
                search_start = t_end_prev + int(0.02 * sampling_rate)
            else:
                search_start = int(r_peaks[i-1] + 0.25 * rr_prev)
            
            search_end = int(r - 0.02 * sampling_rate)
            min_pr_ms = 80
            max_pr_ms = 300
            
        else:
            max_lookback_samples = int(0.40 * sampling_rate)
            earliest_allowed_start = r - max_lookback_samples
            
            if i - 1 < len(t_wave_ends):
                t_end_prev = t_wave_ends[i - 1]
                search_start = max(t_end_prev + int(0.05 * sampling_rate), earliest_allowed_start)
            else:
                search_start = max(int(r_peaks[i-1] + 0.4 * sampling_rate), earliest_allowed_start)
            
            search_end = int(r - 0.03 * sampling_rate)
            min_pr_ms = 80
            max_pr_ms = 400
        
        search_start = max(0, search_start)
        search_end = min(len(signal)-1, search_end)
        
        min_window_size = int(0.05 * sampling_rate)
        if search_end - search_start < min_window_size:
            continue
        
        segment = signal[search_start:search_end]
        segment_filtered = adaptive_noise_filter(segment, sampling_rate)
        
        if use_adaptive_short_cycle:
            min_prominence = 0.002
            min_distance = int(0.05 * sampling_rate)
            max_width = int(0.12 * sampling_rate)
        else:
            min_prominence = 0.003
            min_distance = int(0.08 * sampling_rate)
            max_width = int(0.15 * sampling_rate)
        
        try:
            candidate_peaks, properties = sp_signal.find_peaks(
                segment_filtered,
                distance=min_distance,
                prominence=min_prominence,
                width=(int(0.02*sampling_rate), max_width)
            )
        except:
            candidate_peaks = []
        
        if len(candidate_peaks) == 0:
            continue
        
        candidate_peaks = search_start + candidate_peaks
        
        best_score = -np.inf
        best_peak = None
        
        for cp in candidate_peaks:
            cp = int(cp)
            
            pr_interval = (r - cp) / sampling_rate * 1000
            if pr_interval < min_pr_ms or pr_interval > max_pr_ms:
                continue
            
            score = 0
            
            if use_adaptive_short_cycle:
                ideal_pr = 120
                sigma = 40
            else:
                ideal_pr = 160
                sigma = 60
            
            score += np.exp(-((pr_interval - ideal_pr) ** 2) / (2 * sigma ** 2)) * 150
            
            p_amp = abs(signal[cp])
            if p_amp > 0.5:
                score += 50 
            elif 0.015 <= p_amp <= 0.5:
                ideal_amp = 0.08
                score += np.exp(-((p_amp - ideal_amp) ** 2) / (2 * 0.10 ** 2)) * 400
            else:
                continue
            
            try:
                left_slope = signal[cp] - signal[cp - 5]
                right_slope = signal[cp + 5] - signal[cp]
                symmetry = 1 - abs(left_slope - right_slope) / (abs(left_slope) + abs(right_slope) + 1e-6)
                score += symmetry * 60
            except:
                pass
            
            score += 50 
            
            if score > best_score:
                best_score = score
                best_peak = cp
        
        if best_peak is not None and best_score > min_quality_threshold:
             p_peaks[i] = best_peak
             p_qualities[i] = best_score
             preliminary_pr_intervals.append((r - best_peak) / sampling_rate * 1000)
             preliminary_p_amps.append(abs(signal[best_peak]))

    if len(preliminary_pr_intervals) >= 3:
        median_pr = np.median(preliminary_pr_intervals)
        
        for i in range(1, len(r_peaks)):
            if np.isnan(p_peaks[i]): continue
            
            pr = (r_peaks[i] - p_peaks[i]) / sampling_rate * 1000
            
            tolerance = 100 if use_adaptive_short_cycle else 120
            if abs(pr - median_pr) > tolerance:
                p_peaks[i] = np.nan
                p_qualities[i] = 0

    return p_peaks

# ==========================================
# 4. P-WAVE ONSET DETECTION
# ==========================================

def find_p_onset_constrained(signal, p_peak_idx, sampling_rate, prev_t_offset=None):
    """
    Finds P-onset with strict constraints
    """
    if np.isnan(p_peak_idx): return np.nan
    p_idx = int(p_peak_idx)
    
    min_dist_samples = int(0.012 * sampling_rate)
    start_search = p_idx - min_dist_samples
    
    max_lookback = int(0.12 * sampling_rate)
    default_limit = p_idx - max_lookback
    
    if prev_t_offset is not None and not np.isnan(prev_t_offset):
        t_end_buffer = int(prev_t_offset) + int(0.02 * sampling_rate)
        limit_idx = max(default_limit, t_end_buffer)
    else:
        limit_idx = default_limit

    if limit_idx >= start_search:
        return start_search
        
    limit_idx = max(0, limit_idx)
    
    segment = signal[limit_idx:start_search + 1]
    
    if len(segment) < 3:
        return limit_idx
        
    grads = np.gradient(segment)
    
    max_slope = np.max(np.abs(grads))
    threshold = 0.10 * max_slope
    
    for i in range(len(grads) - 1, -1, -1):
        if np.abs(grads[i]) < threshold:
            found_idx = limit_idx + i
            return found_idx
            
    return limit_idx


# ==========================================
# 5. WAVE DELINEATION
# ==========================================


def find_t_wave_offset_with_stability(signal, t_peak_idx, sampling_rate, 
                                      next_r_peak=None, limit_idx=None):
    if np.isnan(t_peak_idx):
        return np.nan
    
    t_peak_idx = int(t_peak_idx)
    t_peak_value = signal[t_peak_idx]
    
    # ===== STAGE 1: Determine Search Limits =====
    if next_r_peak is not None:
        max_search = int(0.70 * (next_r_peak - t_peak_idx))
        max_search = max(int(0.100 * sampling_rate), max_search)
    else:
        max_search = int(0.200 * sampling_rate)
    
    if limit_idx is not None and not np.isnan(limit_idx):
        limit_idx = int(limit_idx)
        dist_to_limit = limit_idx - t_peak_idx
        if dist_to_limit <= 5:
            return t_peak_idx + 5
        max_search = min(max_search, dist_to_limit - 3)
    
    max_idx = min(len(signal) - 1, t_peak_idx + max_search)
    
    if max_idx <= t_peak_idx + 8:
        return t_peak_idx + 5
    
    # ===== STAGE 2: Setup Parameters =====
    skip_samples = max(3, int(0.008 * sampling_rate))  # Skip 8ms from peak
    
    if max_idx <= t_peak_idx + skip_samples + 5:
        return t_peak_idx + skip_samples
    
    segment = signal[t_peak_idx:max_idx]
    
    # Window size for variance calculation (30 samples ≈ 60ms at 500Hz)
    window_size = min(15, int(0.060 * sampling_rate))                                                      
    
    # Detect T-wave polarity
    if t_peak_value > 0:
        target_descent_sign = -1
    else:
        target_descent_sign = 1
    
    # ===== STAGE 3: Calculate Derivatives (for inflection backup) =====
    derivative_1st = np.gradient(segment)
    
    if len(derivative_1st) > 5:
        kernel_size = 3
        derivative_1st_smooth = np.convolve(derivative_1st, 
                                            np.ones(kernel_size)/kernel_size, 
                                            mode='same')
    else:
        derivative_1st_smooth = derivative_1st
    
    derivative_2nd = np.gradient(derivative_1st_smooth)
    
    # Adaptive thresholds for slope-based methods
    slope_magnitudes = np.abs(derivative_1st_smooth[skip_samples:])
    
    if len(slope_magnitudes) > 0:
        slope_75th = np.percentile(slope_magnitudes, 75)
        slope_median = np.median(slope_magnitudes)
        steep_threshold = max(slope_75th * 0.4, 0.002)
        flat_threshold = max(slope_median * 0.15, 0.0008)
    else:
        steep_threshold = 0.003
        flat_threshold = 0.001
    
    # ===== STAGE 4: NEW METHOD - Local Variance Stability Detection =====
    stability_idx = None
    
    # Start search after skip_samples
    search_start = skip_samples + int(0.010 * sampling_rate)  # At least 10ms from peak
    search_end = len(segment) - window_size
    
    if search_start < search_end:
        variance_ratios = []
        candidate_points = []
        
        for i in range(search_start, search_end):
            # Get windows before and after current point
            before_window = segment[max(0, i-window_size):i]
            after_window = segment[i:min(len(segment), i+window_size)]
            
            if len(before_window) < 10 or len(after_window) < 10:
                continue
            
            # Calculate variance (spread of values)
            var_before = np.var(before_window)
            var_after = np.var(after_window)
            
            # Also calculate standard deviation for robustness
            std_before = np.std(before_window)
            std_after = np.std(after_window)
            
            # Calculate ratio (how much does variance drop?)
            if var_before > 0:
                var_ratio = var_after / var_before
                std_ratio = std_after / std_before
                
                # Store results
                variance_ratios.append(var_ratio)
                candidate_points.append(i)
        
        # Find where variance drops significantly (baseline is more stable)
        if len(variance_ratios) > 0:
            variance_ratios = np.array(variance_ratios)
            candidate_points = np.array(candidate_points)
            
            # Threshold: variance after should be < 40% of variance before
            # This means we've transitioned from T-wave to stable baseline
            stability_mask = variance_ratios < 0.15                                                      
            
            if np.any(stability_mask):
                # Take the FIRST point where stability is achieved
                first_stable = candidate_points[stability_mask][0]
                stability_idx = first_stable
    
    # ===== STAGE 5: METHOD 2 - Inflection Point Detection (Backup) =====
    inflection_idx = None
    in_steep_descent = False
    
    for i in range(skip_samples, len(derivative_1st_smooth) - 2):
        current_slope = derivative_1st_smooth[i]
        
        if target_descent_sign * current_slope < -steep_threshold:
            if not in_steep_descent:
                in_steep_descent = True
        
        elif in_steep_descent:
            next_slopes = derivative_1st_smooth[i:min(i+4, len(derivative_1st_smooth))]
            
            # Option A: Slope becomes flat
            if np.all(np.abs(next_slopes) < flat_threshold * 1.5):
                inflection_idx = i
                break
            
            # Option B: Slope reverses direction
            if target_descent_sign * current_slope > 0:
                if i + 2 < len(derivative_1st_smooth):
                    if target_descent_sign * derivative_1st_smooth[i+1] > 0:
                        inflection_idx = i
                        break
            
            # Option C: Slope magnitude drops below threshold
            if np.abs(current_slope) < flat_threshold:
                if i + 3 < len(derivative_1st_smooth):
                    if np.mean(np.abs(derivative_1st_smooth[i:i+3])) < flat_threshold * 1.3:
                        inflection_idx = i
                        break
                else:
                    inflection_idx = i
                    break
    
    # ===== STAGE 6: METHOD 3 - Second Derivative Zero-Crossing =====
    curvature_inflection_idx = None
    
    if len(derivative_2nd) > skip_samples + 10:
        for i in range(skip_samples + 5, len(derivative_2nd) - 1):
            if derivative_2nd[i] * derivative_2nd[i+1] <= 0:
                if i > skip_samples + int(0.015 * sampling_rate):
                    if i + 3 < len(derivative_1st_smooth):
                        avg_slope_after = np.mean(np.abs(derivative_1st_smooth[i:i+3]))
                        if avg_slope_after < steep_threshold * 0.6:
                            curvature_inflection_idx = i
                            break
    
    # ===== STAGE 7: METHOD 4 - Minimum Slope Magnitude =====
    min_slope_idx = None
    
    if stability_idx is None and inflection_idx is None and curvature_inflection_idx is None:
        search_start_min = skip_samples + int(0.020 * sampling_rate)
        search_end_min = min(len(derivative_1st_smooth), skip_samples + int(0.100 * sampling_rate))
        
        if search_start_min < search_end_min:
            slope_window = np.abs(derivative_1st_smooth[search_start_min:search_end_min])
            if len(slope_window) > 0:
                local_min = np.argmin(slope_window)
                min_slope_idx = search_start_min + local_min
    
    # ===== STAGE 8: Select Best Detection (Prioritized) =====
    candidates = []
    
    # HIGHEST PRIORITY: Variance stability (your method!)
    if stability_idx is not None:
        candidates.append(('variance_stability', t_peak_idx + stability_idx, 120))
    
    # HIGH PRIORITY: Inflection point
    if inflection_idx is not None:
        candidates.append(('inflection', t_peak_idx + inflection_idx, 100))
    
    # MEDIUM PRIORITY: Curvature change
    if curvature_inflection_idx is not None:
        candidates.append(('curvature', t_peak_idx + curvature_inflection_idx, 80))
    
    # LOW PRIORITY: Minimum slope
    if min_slope_idx is not None:
        candidates.append(('min_slope', t_peak_idx + min_slope_idx, 60))
    
    if len(candidates) == 0:
        # Ultimate fallback
        return min(t_peak_idx + int(0.040 * sampling_rate), max_idx)
    
    # Use the highest priority detection
    best_method, best_offset, best_score = max(candidates, key=lambda x: x[2])
    
    # Debug info (optional - can be removed in production)
    # print(f"  T-offset method used: {best_method} (score: {best_score})")
    
    # ===== STAGE 9: Validation =====
    best_offset = max(best_offset, t_peak_idx + skip_samples)
    best_offset = min(best_offset, max_idx)
    
    # Sanity check on duration
    duration_ms = (best_offset - t_peak_idx) / sampling_rate * 1000
    
    if duration_ms < 15:
        best_offset = t_peak_idx + int(0.030 * sampling_rate)
    elif duration_ms > 300:
        best_offset = t_peak_idx + int(0.100 * sampling_rate)
    
    return int(best_offset)

def improved_delineate_ecg_waves(signal, r_peaks, sampling_rate):
    """
    Complete ECG wave delineation with ultra-accurate T-offset detection
    """
    waves = {
        'p_peak': [], 'p_onset': [], 'p_offset': [],
        'q_peak': [], 'q_onset': [],
        's_peak': [], 's_offset': [],
        't_peak': [], 't_onset': [], 't_offset': []
    }
    
    signal_len = len(signal)
    
    # Detect P-peaks
    p_peaks = enhanced_p_wave_detection(signal, r_peaks, sampling_rate, "Segment")
    
    # Helper for generic boundaries
    def find_boundary_local(peak_idx, direction, max_search_samples, thresh_factor=0.05):
        if np.isnan(peak_idx): return np.nan
        peak_idx = int(peak_idx)
        limit = peak_idx + (direction * max_search_samples)
        limit = max(0, min(signal_len, limit))
        if abs(limit - peak_idx) < 3: return peak_idx
        start, end = sorted([peak_idx, limit])
        segment = signal[start:end]
        if direction == -1: segment = segment[::-1] 
        diff = np.diff(segment)
        if len(diff) == 0: return peak_idx
        max_slope = np.max(np.abs(diff))
        thresh = max_slope * thresh_factor
        for i in range(1, len(diff)):
            if np.abs(diff[i]) < thresh:
                return peak_idx + (direction * i)
        return limit

    last_t_offset = None

    for i, r in enumerate(r_peaks):
        r = int(r)
        r_height = abs(signal[r]) if abs(signal[r]) > 0.05 else 1.0
        
        # --- P-WAVE ---
        p_peak_val = p_peaks[i] if i < len(p_peaks) else np.nan
        waves['p_peak'].append(p_peak_val)
        
        if not np.isnan(p_peak_val):
            p_onset = find_p_onset_constrained(signal, p_peak_val, sampling_rate, last_t_offset)
            waves['p_onset'].append(p_onset)
            waves['p_offset'].append(find_boundary_local(int(p_peak_val), 1, int(0.08 * sampling_rate)))
        else:
            waves['p_onset'].append(np.nan)
            waves['p_offset'].append(np.nan)

        # --- Q-WAVE ---
        win_q = int(0.05 * sampling_rate)
        q_search_start = max(0, r - win_q)
        q_window = signal[q_search_start:r]
        q_idx = q_search_start + np.argmin(q_window) if len(q_window) > 0 else np.nan
        waves['q_peak'].append(q_idx)
        anchor_q = q_idx if not np.isnan(q_idx) else r
        waves['q_onset'].append(find_boundary_local(anchor_q, -1, int(0.04 * sampling_rate)))
        
        # --- S-WAVE ---
        win_s = int(0.06 * sampling_rate)
        s_search_end = min(signal_len, r + win_s)
        s_window = signal[r:s_search_end]
        s_idx = r + np.argmin(s_window) if len(s_window) > 0 else np.nan
        waves['s_peak'].append(s_idx)
        anchor_s = s_idx if not np.isnan(s_idx) else r
        waves['s_offset'].append(find_boundary_local(anchor_s, 1, int(0.04 * sampling_rate)))
        
        # --- T-WAVE ---
        rr_next = (int(r_peaks[i+1]) - r) if i < len(r_peaks) - 1 else 1.0 * sampling_rate
        dyn_t_start = int(0.10 * sampling_rate)
        dyn_t_end = int(min(0.600 * sampling_rate, 0.65 * rr_next))
        t_search_start = min(signal_len, r + dyn_t_start)
        t_search_end = min(signal_len, r + dyn_t_end)
        t_idx = np.nan
        
        if t_search_start < t_search_end:
            t_window = signal[t_search_start:t_search_end]
            if len(t_window) > 0:
                local_peaks, _ = sp_signal.find_peaks(t_window, prominence=(0.05 * r_height))
                if len(local_peaks) > 0:
                    best_peak = local_peaks[np.argmax(t_window[local_peaks])]
                    t_idx = t_search_start + best_peak
                else:
                    t_idx = t_search_start + np.argmax(t_window)
        
        waves['t_peak'].append(t_idx)
        waves['t_onset'].append(find_boundary_local(t_idx, -1, int(0.08 * sampling_rate)))
        
        # --- T-OFFSET (ULTRA ACCURATE) ---
        next_r = int(r_peaks[i+1]) if i < len(r_peaks) - 1 else None
        next_p_onset_limit = None
        if i + 1 < len(p_peaks):
            next_p_peak = p_peaks[i+1]
            if not np.isnan(next_p_peak):
                next_p_onset_limit = find_boundary_local(int(next_p_peak), -1, int(0.08 * sampling_rate))
        
        t_offset = find_t_wave_offset_with_stability(
            signal, 
            t_idx, 
            sampling_rate, 
            next_r_peak=next_r,
            limit_idx=next_p_onset_limit
        )

        waves['t_offset'].append(t_offset)
        
        last_t_offset = t_offset

    for k in waves:
        waves[k] = np.array(waves[k])
    
    return waves


def calculate_intervals(waves, sampling_rate):
    """Calculate PR, QRS, and QT intervals"""
    pr_intervals = (waves['q_onset'] - waves['p_onset']) / sampling_rate * 1000
    # qrs_durations = (waves['s_offset'] - waves['q_onset']) / sampling_rate * 1000
    # qrs_durations = (waves['s_peak'] - waves['q_peak']) / sampling_rate * 1000
    qrs_durations = (waves['s_peak'] - waves['q_onset']) / sampling_rate * 1000
    qt_intervals = (waves['t_offset'] - waves['q_onset']) / sampling_rate * 1000
   
    pr_intervals = np.where((pr_intervals > 40) & (pr_intervals < 600), pr_intervals, np.nan)
    qrs_durations = np.where((qrs_durations > 30) & (qrs_durations < 200), qrs_durations, np.nan)
    qt_intervals = np.where((qt_intervals > 100) & (qt_intervals < 600), qt_intervals, np.nan)
   
    return pr_intervals, qrs_durations, qt_intervals


def process_ecg_segments(ecg_raw, ecg_filtered, sampling_rate, num_segments=7, min_segment_length=3500):
    """Process ECG into segments for analysis"""
    max_len = len(ecg_raw)
    window_step = round((max_len - min_segment_length) / (num_segments - 1)) if num_segments > 1 else 0
    results = []
   
    for i in range(num_segments):
        start_idx = i * window_step
        end_idx = min(start_idx + min_segment_length, max_len)
        if start_idx < 0: start_idx = 0
        
        segment_raw = ecg_raw[start_idx:end_idx]
        segment_filtered = ecg_filtered[start_idx:end_idx]
        
        if len(segment_raw) < 100: continue
        
        _, r_peaks, bpm, _ = qrs_detect(segment_raw, sampling_rate, len(segment_raw)/sampling_rate)
        
        waves = improved_delineate_ecg_waves(segment_filtered, r_peaks, sampling_rate)
        pr, qrs, qt = calculate_intervals(waves, sampling_rate)
        
        def adj(arr):
            if len(arr) == 0: return np.array([])
            return arr + start_idx
   
        res = {
            'segment_num': i + 1,
            'start_idx': start_idx,
            'end_idx': end_idx,
            'ecg_raw': segment_raw,
            'ecg_filtered': segment_filtered,
            'bpm': bpm,
            'r_peaks': adj(r_peaks),
            'avg_pr': np.nanmean(pr),
            'avg_qrs': np.nanmean(qrs),
            'avg_qt': np.nanmean(qt)
        }
        for k, v in waves.items():
            res[k] = adj(v)
            
        results.append(res)
   
    return results


# ==========================================
# 6. PLOTTING FUNCTIONS
# ==========================================

def plot_ecg_segments(ecg_raw, ecg_filtered, sampling_rate, results, title="ECG Analysis"):
    """Plot ECG segments with detected waves and intervals"""
    num_segments = len(results)
    fig, axes = plt.subplots(num_segments, 1, figsize=(20, 4*num_segments))
    if num_segments == 1: axes = [axes]
   
    time = np.arange(len(ecg_raw)) / sampling_rate
   
    def get_valid(indices):
        if len(indices) == 0: return np.array([], dtype=int)
        valid = indices[~np.isnan(indices)]
        valid = valid[valid < len(ecg_raw)]
        return valid.astype(int)

    t_wave_issues = 0
    
    for i, (ax, res) in enumerate(zip(axes, results)):
        seg_time = time[res['start_idx']:res['end_idx']]
        
        ax.plot(seg_time, res['ecg_filtered'], 'b-', alpha=0.7, linewidth=0.8, label='Filtered')
        ax.plot(seg_time, res['ecg_raw'], 'k-', alpha=0.3, linewidth=0.5, label='Raw')
        
        peaks = [('r_peaks', 'ro', 'R'), ('p_peak', 'g^', 'P'), ('t_peak', 'bD', 'T')]
        for key, style, lbl in peaks:
            valid = get_valid(res[key])
            if len(valid): 
                ax.plot(time[valid], ecg_filtered[valid], style, markersize=6, label=lbl)

        p_onsets = res['p_onset']
        q_onsets = res['q_onset']
        t_offsets = res['t_offset']
        t_peaks_arr = res['t_peak']
        
        for j in range(len(t_peaks_arr)):
            if not np.isnan(t_peaks_arr[j]) and not np.isnan(t_offsets[j]):
                if abs(t_peaks_arr[j] - t_offsets[j]) < 2:
                    t_wave_issues += 1
        
        y_min = np.min(res['ecg_filtered'])
        bar_y_pr = y_min - 0.05
        bar_y_qt = y_min - 0.10
        
        count = 0
        for j in range(len(p_onsets)):
            if j < len(q_onsets) and not np.isnan(p_onsets[j]) and not np.isnan(q_onsets[j]):
                pon = int(p_onsets[j])
                qon = int(q_onsets[j])
                if pon < qon:
                    ax.hlines(y=bar_y_pr, xmin=time[pon], xmax=time[qon], colors='green', linewidth=4, alpha=0.7)
                    if count == 0: ax.text(time[pon], bar_y_pr, 'PR', color='green', fontsize=8, ha='right', va='center')

            if j < len(t_offsets) and not np.isnan(q_onsets[j]) and not np.isnan(t_offsets[j]):
                qon = int(q_onsets[j])
                toff = int(t_offsets[j])
                if qon < toff:
                    ax.hlines(y=bar_y_qt, xmin=time[qon], xmax=time[toff], colors='blue', linewidth=4, alpha=0.7)
                    if count == 0: ax.text(time[toff], bar_y_qt, 'QT', color='blue', fontsize=8, ha='left', va='center')
            count += 1

        valid = get_valid(res['p_onset'])
        if len(valid): ax.plot(time[valid], ecg_filtered[valid], 'gx', markersize=8, label='P-start')
        
        valid = get_valid(res['q_onset'])
        if len(valid): ax.plot(time[valid], ecg_filtered[valid], 'm|', markersize=12, markeredgewidth=2, label='QRS-start')
        
        # ==========================================
        # ADDED: QRS END PLOTTING (S-OFFSET)
        # ==========================================
        valid = get_valid(res['s_offset'])
        if len(valid): ax.plot(time[valid], ecg_filtered[valid], 'c|', markersize=12, markeredgewidth=2, label='QRS-end')
        # ==========================================

        valid = get_valid(res['t_offset'])
        if len(valid): ax.plot(time[valid], ecg_filtered[valid], 'b|', markersize=12, markeredgewidth=2, label='T-end')

        info = f"Seg {res['segment_num']} | BPM: {res['bpm']:.0f} | "
        info += f"PR: {res['avg_pr']:.0f}ms | QRS: {res['avg_qrs']:.0f}ms | QT: {res['avg_qt']:.0f}ms"
        
        ax.set_title(info, fontsize=11, fontweight='bold')
        ax.set_xlim([seg_time[0], seg_time[-1]])
        ax.grid(True, alpha=0.3)
        if i == 0: ax.legend(loc='upper right', ncol=7, fontsize='small') # Increased ncol for new item

    plt.tight_layout()
    plt.show()
    plt.close()


def plot_p_wave_quality(signal, r_peaks, p_peaks, sampling_rate, segment_num=""):
    """Visualize P-wave detection quality"""
    fig, axes = plt.subplots(4, 1, figsize=(15, 12))
    
    time = np.arange(len(signal)) / sampling_rate
    axes[0].plot(time, signal, 'k-', alpha=0.6, linewidth=0.8)
    
    valid_r = r_peaks[~np.isnan(r_peaks)].astype(int)
    axes[0].plot(time[valid_r], signal[valid_r], 'ro', markersize=6, label='R-peaks')
    
    valid_p = p_peaks[~np.isnan(p_peaks)].astype(int)
    axes[0].plot(time[valid_p], signal[valid_p], 'g^', markersize=6, label='P-waves')
    
    missing_p = np.where(np.isnan(p_peaks[1:]))[0] + 1
    if len(missing_p) > 0:
        missing_r = r_peaks[missing_p].astype(int)
        axes[0].plot(time[missing_r], signal[missing_r], 'rx', markersize=10, 
                    markeredgewidth=2, label=f'Missing P ({len(missing_p)})')
    
    axes[0].set_title(f'P-wave Detection - Segment {segment_num}')
    axes[0].set_ylabel('Amplitude (mV)')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    pr_intervals = []
    beat_numbers = []
    beat_idx = 0
    for i, (r, p) in enumerate(zip(r_peaks, p_peaks)):
        if not np.isnan(p) and not np.isnan(r):
            pr = (r - p) / sampling_rate * 1000
            pr_intervals.append(pr)
            beat_numbers.append(beat_idx)
        beat_idx += 1
    
    if pr_intervals:
        axes[1].plot(beat_numbers, pr_intervals, 'bo-', markersize=4)
        axes[1].axhline(y=120, color='g', linestyle='--', alpha=0.5, label='Normal PR min (120ms)')
        axes[1].axhline(y=200, color='orange', linestyle='--', alpha=0.5, label='1st° AVB threshold (200ms)')
        axes[1].axhline(y=np.mean(pr_intervals), color='b', linestyle='-', alpha=0.7, 
                               label=f'Mean: {np.mean(pr_intervals):.1f}ms')
        axes[1].set_ylabel('PR Interval (ms)')
        axes[1].set_xlabel('Beat Number')
        axes[1].set_title(f'PR Intervals (mean: {np.mean(pr_intervals):.1f}ms ± {np.std(pr_intervals):.1f}ms)')
        axes[1].grid(True, alpha=0.3)
        axes[1].legend()
    
    p_amplitudes = []
    beat_numbers_amp = []
    beat_idx = 0
    for p in valid_p:
        p_amplitudes.append(abs(signal[p]))
        beat_numbers_amp.append(beat_idx)
        beat_idx += 1
    
    if p_amplitudes:
        axes[2].plot(beat_numbers_amp, p_amplitudes, 'go-', markersize=4)
        axes[2].axhline(y=0.1, color='g', linestyle='--', alpha=0.5, label='Typical P (0.1mV)')
        axes[2].axhline(y=0.05, color='orange', linestyle='--', alpha=0.5, label='Low amplitude threshold')
        axes[2].axhline(y=np.mean(p_amplitudes), color='g', linestyle='-', alpha=0.7,
                               label=f'Mean: {np.mean(p_amplitudes):.3f}mV')
        axes[2].set_ylabel('Amplitude (mV)')
        axes[2].set_xlabel('Beat Number')
        axes[2].set_title(f'P-wave Amplitudes (mean: {np.mean(p_amplitudes):.3f}mV ± {np.std(p_amplitudes):.3f}mV)')
        axes[2].grid(True, alpha=0.3)
        axes[2].legend()
    
    window_size = int(2 * sampling_rate)
    num_windows = len(signal) // window_size
    quality_over_time = []
    time_points = []
    
    for w in range(num_windows):
        start = w * window_size
        end = min((w + 1) * window_size, len(signal))
        window_signal = signal[start:end]
        quality = calculate_signal_quality(window_signal)
        quality_over_time.append(quality)
        time_points.append((start + end) / 2 / sampling_rate)
    
    if quality_over_time:
        axes[3].plot(time_points, quality_over_time, 'r-', linewidth=2, label='Signal Quality')
        axes[3].axhline(y=0.7, color='g', linestyle='--', alpha=0.5, label='Good (>0.7)')
        axes[3].axhline(y=0.5, color='orange', linestyle='--', alpha=0.5, label='Moderate (>0.5)')
        axes[3].axhline(y=0.3, color='r', linestyle='--', alpha=0.5, label='Poor (<0.3)')
        axes[3].fill_between(time_points, 0, quality_over_time, alpha=0.3, color='red')
        axes[3].set_ylabel('Quality Index')
        axes[3].set_xlabel('Time (s)')
        axes[3].set_title('Signal Quality Over Time')
        axes[3].set_ylim([0, 1])
        axes[3].grid(True, alpha=0.3)
        axes[3].legend()
    
    plt.tight_layout()
    # plt.savefig('/mnt/user-data/outputs/p_wave_quality.png', dpi=150, bbox_inches='tight')
    plt.close()
    
    total_beats = len(r_peaks) - 1
    detected_p = np.sum(~np.isnan(p_peaks[1:]))
    detection_rate = detected_p / total_beats * 100 if total_beats > 0 else 0
    
    print(f"\n--- P-Wave Detection Summary (Segment {segment_num}) ---")
    print(f"Total R-peaks: {len(r_peaks)}")
    print(f"P-waves detected: {detected_p}/{total_beats} ({detection_rate:.1f}%)")
    print(f"Mean signal quality: {np.mean(quality_over_time):.2f}")
    if pr_intervals:
        print(f"PR interval: {np.mean(pr_intervals):.1f} ± {np.std(pr_intervals):.1f} ms")
    if p_amplitudes:
        print(f"P amplitude: {np.mean(p_amplitudes):.3f} ± {np.std(p_amplitudes):.3f} mV")
    print("-" * 50)


def plot_full_ecg(ecg_raw, ecg_filtered, sampling_rate, waves, r_peaks, bpm, title="Full ECG Analysis"):
    # --- 1. Setup Figure (Dynamic Width) ---
    duration_sec = len(ecg_raw) / sampling_rate
    fig_width = max(15, min(100, int(duration_sec * 0.8))) 
    
    fig, ax = plt.subplots(figsize=(fig_width, 6))
    time = np.arange(len(ecg_raw)) / sampling_rate

    # --- 2. Plot Signals ---
    ax.plot(time, ecg_filtered, 'b-', alpha=0.7, linewidth=0.8, label='Filtered')
    ax.plot(time, ecg_raw, 'k-', alpha=0.3, linewidth=0.5, label='Raw')

    # Helper to clean indices
    def get_valid(indices):
        if len(indices) == 0: return np.array([], dtype=int)
        valid = indices[~np.isnan(indices)]
        valid = valid[(valid >= 0) & (valid < len(ecg_raw))]
        return valid.astype(int)

    # --- 3. Plot Peaks (R, P, T) ---
    valid_r = get_valid(r_peaks)
    if len(valid_r):
        ax.plot(time[valid_r], ecg_filtered[valid_r], 'ro', markersize=6, label='R_peak')

    if 'p_peak' in waves:
        valid_p = get_valid(waves['p_peak'])
        if len(valid_p):
            ax.plot(time[valid_p], ecg_filtered[valid_p], 'g^', markersize=6, label='P_peak')

    if 't_peak' in waves:
        valid_t = get_valid(waves['t_peak'])
        if len(valid_t):
            ax.plot(time[valid_t], ecg_filtered[valid_t], 'bD', markersize=6, label='T_peak')

    # --- 4. Plot Interval Bars (PR & QT) ---
    y_min = np.min(ecg_filtered)
    bar_y_pr = y_min - (np.ptp(ecg_filtered) * 0.05)
    bar_y_qt = y_min - (np.ptp(ecg_filtered) * 0.10)

    p_onsets = waves.get('p_onset', [])
    q_onsets = waves.get('q_onset', [])
    t_offsets = waves.get('t_offset', [])

    count_pr = 0
    count_qt = 0
    
    num_beats = len(r_peaks)
    
    for j in range(num_beats):
        if j < len(p_onsets) and j < len(q_onsets):
            pon = p_onsets[j]
            qon = q_onsets[j]
            
            if not np.isnan(pon) and not np.isnan(qon):
                pon, qon = int(pon), int(qon)
                if pon < qon and qon < len(time):
                    ax.hlines(y=bar_y_pr, xmin=time[pon], xmax=time[qon], 
                             colors='green', linewidth=4, alpha=0.7)
                    if count_pr == 0: 
                        ax.text(time[pon], bar_y_pr, 'PR', color='green', 
                               fontsize=8, ha='right', va='center', fontweight='bold')
                    count_pr += 1

        if j < len(q_onsets) and j < len(t_offsets):
            qon = q_onsets[j]
            toff = t_offsets[j]
            
            if not np.isnan(qon) and not np.isnan(toff):
                qon, toff = int(qon), int(toff)
                if qon < toff and toff < len(time):
                    ax.hlines(y=bar_y_qt, xmin=time[qon], xmax=time[toff], 
                             colors='blue', linewidth=4, alpha=0.7)
                    if count_qt == 0: 
                        ax.text(time[toff], bar_y_qt, 'QT', color='blue', 
                               fontsize=8, ha='left', va='center', fontweight='bold')
                    count_qt += 1

    # --- 5. Plot Specific Markers (Onsets/Offsets) ---
    if 'p_onset' in waves:
        valid_pon = get_valid(waves['p_onset'])
        if len(valid_pon):
            ax.plot(time[valid_pon], ecg_filtered[valid_pon], 'gx', markersize=8, label='P-start')

    if 'q_onset' in waves:
        valid_qon = get_valid(waves['q_onset'])
        if len(valid_qon):
            ax.plot(time[valid_qon], ecg_filtered[valid_qon], 'm|', markersize=12, markeredgewidth=2, label='QRS-start')

    # ==========================================
    # ADDED: QRS END PLOTTING (S-OFFSET)
    # ==========================================
    if 's_offset' in waves:
        valid_soff = get_valid(waves['s_offset'])
        if len(valid_soff):
            ax.plot(time[valid_soff], ecg_filtered[valid_soff], 'c|', markersize=12, markeredgewidth=2, label='QRS-end')
    # ==========================================

    if 't_offset' in waves:
        valid_toff = get_valid(waves['t_offset'])
        if len(valid_toff):
            ax.plot(time[valid_toff], ecg_filtered[valid_toff], 'b|', markersize=12, markeredgewidth=2, label='T-end')

    # --- 6. Add Statistics Box ---
    pr, qrs, qt = calculate_intervals(waves, sampling_rate)
    stats_text = (
        f"HEART RATE: {bpm:.0f} BPM\n"
        f"Avg PR:  {np.nanmean(pr):.0f} ms\n"
        f"Avg QRS: {np.nanmean(qrs):.0f} ms\n"
        f"Avg QT:  {np.nanmean(qt):.0f} ms"
    )
    props = dict(boxstyle='round', facecolor='wheat', alpha=0.9)
    ax.text(0.005, 0.95, stats_text, transform=ax.transAxes, fontsize=11,
            verticalalignment='top', bbox=props, fontfamily='monospace')

    # --- 7. Formatting ---
    ax.set_title(title, fontsize=14, fontweight='bold')
    ax.set_xlabel("Time (seconds)")
    ax.set_ylabel("Amplitude")
    ax.set_xlim([0, time[-1]])
    ax.grid(True, which='both', alpha=0.3)
    
    handles, labels = ax.get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    ax.legend(by_label.values(), by_label.keys(), loc='upper right', ncol=5, fontsize='small')

    plt.tight_layout()
    plt.show()
    plt.close()

# ==========================================
# 7. MAIN EXECUTION
# ==========================================

if __name__ == "__main__":
    # Filter setup
    freq = 500
    low_pass_cutoff = 40
    low_pass_order = 7
    b_lp, a_lp = sp_signal.butter(low_pass_order, low_pass_cutoff / (freq / 2), btype="low")
    b_notch, a_notch = sp_signal.iirnotch(50, 50 / 20, freq)  

    def low_pass_filter(data):
        try:
            return sp_signal.filtfilt(b_lp, a_lp, data)
        except:
            return data

    def notch_filter(data):
        try:
            return sp_signal.filtfilt(b_notch, a_notch, data)
        except:
            return data

    # input_json = r"simulator\contec\30bpm_1756099931979.json"   
    # input_json = r"simulator\contec\40bpm_1756099996286.json"   
    # input_json = r"simulator\contec\60bpm_1756100055600.json"   
    # input_json = r"simulator\contec\80bpm_1756100133843.json"   
    # input_json = r"simulator\contec\100bpm_1756100188299.json"   
    # input_json = r"simulator\contec\120bpm_1756100243822.json"     
    # input_json = r"simulator\contec\140bpm_1756100303625.json"   
    # input_json = r"simulator\contec\180bpm_1756100430492.json"   
    # input_json = r"simulator\contec\200bpm_1756100489086.json"   
    # input_json = r"simulator\contec\220bpm_1756100542987.json"   

    # input_json = r"intervals\Antor_1769669515402.json"   
    # input_json = r"intervals\Asif_1769673034525.json"   
    input_json = r"intervals\Atiur_1769684477237.json"     #### 89
    # input_json = r"intervals\Aupo_1769678570784.json"     #### 98
    # input_json = r"intervals\Faizur_1769672384762.json"    
    # input_json = r"intervals\Alauddin_1769763819330.json"   

    with open(input_json, 'r') as file:
        file_data = json.load(file)


    # # input_json = r"exception\L2_1759207950416.json"   
    # input_json = r"issues\L2_1757064122874.json"  
    # # input_json = r"issues\L2_1757579288752.json"
    # # input_json = r"issues\L2_1757737806463.json"  
    # # input_json = r"v01_prob\L2_1765984517025.json"  
    # # input_json = r"1st-last-peaks\L2_1759908627949.json"
    # # input_json = r"1st-last-peaks\L2_1759908888619.json"

    # # input_json = r"issues2\1757998097068\L2_1757998097068.json"
    # # input_json = r"issues2\1758943573744\L2_1758943573744.json"
    # # input_json = r"issues2\1759059066184\L2_1759059066184.json"     
    # # input_json = r"issues2\1759117739887\L2_1759117739887.json"
    # # input_json = r"issues2\1759118709079\L2_1759118709079.json"
    # # input_json = r"issues2\1759202739736\L2_1759202739736.json"  
    # # input_json = r"issues2\1759639059357\L2_1759639059357.json"  ####

    # doubles = []
    # with open(input_json, "rb") as f:
    #     while chunk := f.read(8):
    #         if len(chunk) < 8:
    #             break
    #         value = struct.unpack("<d", chunk)[0]
    #         doubles.append(value)

    # file_data = {'dataL2': doubles}   



    # # def decrypt(input_file):
    # #     """Decrypt encrypted JSON file (optional - commented out in your version)"""
    # #     private_key = "Msz377xMbcn++vrcDel9vxOuEss8fsWO"
    # #     cipher = AES.new(private_key.encode('utf-8'), AES.MODE_ECB)
    # #     with open(input_file, 'rb') as f:
    # #         encrypted_data = f.read()
    # #     enc = base64.b64decode(encrypted_data[24:])
    # #     data = unpad(cipher.decrypt(enc), 16)
    # #     decoded_string = data.decode('utf-8')
    # #     return json.loads(decoded_string)

    # input_json = r"NHF2\DATA_1750689015865.json"  ####
    # # input_json = r"NHF2\DATA_1750689460556.json"  ####
    # # input_json = r"NHF2\DATA_1750851207409.json"  ####
    # # input_json = r"NHF2\DATA_1750858856842.json"  ####
    # # input_json = r"NHF2\DATA_1750862721789.json"  ####
    # # input_json = r"NHF2\DATA_1750996455820.json"
    # file_data = decrypt(input_json)

    # # input_json = r"NHF\DATA_1752067426678.json"  
    # # input_json = r"NHF\DATA_1752121970835.json"  
    # # input_json = r"NHF\DATA_1754709586876.json"  
    # input_json = r"NHF\DATA_1754729551054.json"
    # file_data = decrypt(input_json)  
        
    try:
        # Process raw data for R-peak detection
        print("Processing raw data for R-peak detection...")
        raw_data = data_process(file_data)
        ecg_raw = raw_data[0, :15000, 0]
        
        # Process filtered data for P, Q, S, T wave detection
        print("Processing filtered data for wave delineation...")
        filtered_data = data_process(low_pass_filter(notch_filter(baseline_wander(np.array(file_data["dataL2"])))))
        ecg_filtered = filtered_data[0, :15000, 0]
        
        sampling_rate = 500

        print("\n" + "="*80)
        print("ULTRA-ACCURATE ECG ANALYSIS")
        print("="*80)
        print("✓ Using raw data for R-peak detection")
        print("✓ Using filtered data for P, Q, S, T wave delineation")
        print("✓ Enhanced T-wave offset detection (first steep descent)")
        print("="*80 + "\n")
        
        segment_results = process_ecg_segments(
            ecg_raw=ecg_raw,
            ecg_filtered=ecg_filtered,
            sampling_rate=sampling_rate,
            num_segments=4,
            min_segment_length=5000
        )

        # Plot results
        plot_ecg_segments(
            ecg_raw,
            ecg_filtered,
            sampling_rate,
            segment_results,
            "Ultra-Accurate Clinical Interval Analysis"
        )

        print("\n" + "="*80)
        print("ULTRA-ACCURATE FULL ECG ANALYSIS")
        print("="*80)
        
        # --- 2. PEAK DETECTION ---
        print("Detecting R-peaks on signal...")
        _, r_peaks, bpm, _ = qrs_detect(ecg_raw, sampling_rate)
        print(f"✓ Detected {len(r_peaks)} R-peaks. BPM: {bpm:.1f}")

        # --- 3. WAVE DELINEATION ---
        print("Delineating waves (P, Q, S, T)...")
        waves = improved_delineate_ecg_waves(ecg_filtered, r_peaks, sampling_rate)
        
        # --- 4. CALCULATE STATISTICS ---
        pr, qrs, qt = calculate_intervals(waves, sampling_rate)
        
        print("\n" + "="*40)
        print(f"FULL SIGNAL STATISTICS")
        print("="*40)
        print(f"Avg PR Interval:  {np.nanmean(pr):.1f} ms")
        print(f"Avg QRS Duration: {np.nanmean(qrs):.1f} ms")
        print(f"Avg QT Interval:  {np.nanmean(qt):.1f} ms")
        print("="*40 + "\n")
        
        # --- 5. PLOT ---
        print("Generating full plot with statistics...")
        plot_full_ecg(
            ecg_raw, 
            ecg_filtered, 
            sampling_rate, 
            waves, 
            r_peaks, 
            bpm
        )
        
    except FileNotFoundError:
        print(f"Error: File not found ({input_json}). Please check the path.")
    except Exception as e:
        print(f"An error occurred: {e}")
        import traceback
        traceback.print_exc()

## flat line

In [None]:
import json
import struct
import numpy as np
import pywt
import scipy.signal as sp_signal
import matplotlib.pyplot as plt
import base64   
from Crypto.Cipher import AES
from Crypto.Util.Padding import unpad

# Define filter coefficients if not defined
fs = 500  # sampling rate
nyq = 0.5 * fs

# Example low pass filter (cutoff 40 Hz)
low_cutoff = 40 / nyq
b_lp, a_lp = sp_signal.butter(4, low_cutoff, btype='low')

# Example notch filter (50 Hz)
q = 30
w0 = 50 / nyq
b_notch, a_notch = sp_signal.iirnotch(w0, q)

def decrypt(input_file):
    """Decrypt encrypted JSON file"""
    private_key = "Msz377xMbcn++vrcDel9vxOuEss8fsWO"
    cipher = AES.new(private_key.encode('utf-8'), AES.MODE_ECB)
    with open(input_file, 'rb') as f:
        encrypted_data = f.read()
    enc = base64.b64decode(encrypted_data[24:])
    data = unpad(cipher.decrypt(enc), 16)
    decoded_string = data.decode('utf-8')
    return json.loads(decoded_string)

    
def baseline_wander(X):
    def get_median_filter_width(sampling_rate, duration):
        res = int(sampling_rate * duration)
        res += (res % 2) - 1
        return res

    ms_flt_array = [0.2, 0.6]
    mfa = np.zeros(len(ms_flt_array), dtype="int")
    for i in range(0, len(ms_flt_array)):
        mfa[i] = get_median_filter_width(500, ms_flt_array[i])
    X0 = X
    for mi in range(0, len(mfa)):
        X0 = sp_signal.medfilt(X0, mfa[mi])
    X0 = np.subtract(X, X0)
    return X0


def normalize(signal, min_val, max_val):
    """Normalize signal to range [0, 1]."""
    if max_val - min_val == 0:
        return np.zeros_like(signal)
    return (signal - min_val) / (max_val - min_val)


def process_signal(signal_data, min_val, max_val):
    data = normalize(signal_data, min_val, max_val)
    return data


def data_process(filename):
    keys = ['dataL2']
    datas = []
    
    for key in keys:
        sig = np.array(filename[key])
        datas.append(sig.astype('float32'))
    
    datas_array = np.array(datas)               # shape: (1, length) or (channels, length)
    
    # ── Compute real (raw) statistics here ───────────────────────────────
    raw_min   = np.min(datas_array)
    raw_max   = np.max(datas_array)
    raw_mean  = np.mean(datas_array)
    raw_std   = np.std(datas_array)
    raw_var   = np.var(datas_array)
    raw_median = np.median(datas_array)
    
    print("\nRaw (pre-normalized) signal statistics:")
    print(f"  Min    = {raw_min:12.4f}")
    print(f"  Max    = {raw_max:12.4f}")
    print(f"  Mean   = {raw_mean:12.4f}")
    print(f"  Std    = {raw_std:12.4f}")
    print(f"  Var    = {raw_var:14.6f}")
    print(f"  Median = {raw_median:12.4f}")
    print(f"  Range  = {raw_max - raw_min:.4f}\n")
    
    # Now do normalization (your existing code)
    min_val = raw_min
    max_val = raw_max
    signal = []
    for i in range(datas_array.shape[0]):
        signal.append(normalize(datas_array[i, :], min_val, max_val))

    final_data = np.stack(signal)
    final_data = np.expand_dims(final_data, axis=0)
    final_data = final_data.transpose(0, 2, 1)
    
    return final_data, {
        'raw_min': raw_min, 'raw_max': raw_max, 'raw_mean': raw_mean,
        'raw_std': raw_std, 'raw_var': raw_var, 'raw_median': raw_median,
        'raw_range': raw_max - raw_min
    }, datas_array[0]  # return flattened raw for simplicity


def remove_close_peaks(r_peaks, validation_signal, min_dist_samples):
    """Remove peaks that are too close together, keeping the stronger one"""
    if len(r_peaks) == 0: 
        return np.array([])
    
    sorted_idx = np.argsort(r_peaks)
    r_peaks = r_peaks[sorted_idx]
    validation_abs = np.abs(validation_signal[r_peaks.astype(int)])
    
    keep = []
    last_kept = -min_dist_samples
    
    for i, current in enumerate(r_peaks):
        if current - last_kept >= min_dist_samples:
            keep.append(current)
            last_kept = current
        else:
            if validation_abs[i] > validation_abs[len(keep)-1]:
                keep[-1] = current
                last_kept = current
    
    return np.array(keep)


def amplitude_based_filtering(ecg_signal, peaks, segment_num="Unknown"):
    """Filter out high amplitude outlier peaks using IQR method"""
    if len(peaks) == 0:
        return peaks, np.array([])
    
    peak_amplitudes = np.abs(ecg_signal[peaks.astype(int)])
    
    median_amp = np.median(peak_amplitudes)
    q75, q25 = np.percentile(peak_amplitudes, [75, 25])
    iqr = q75 - q25
    
    if iqr > 0:
        high_amp_threshold = q75 + 1.5 * iqr
        
        high_amp_indices = np.where(peak_amplitudes > high_amp_threshold)[0]
        high_amp_count = len(high_amp_indices)
        
        # If we have 2-3 outlier peaks, remove them
        # if 2 <= high_amp_count <= 3:
        # if 3 <= high_amp_count <= 4 and len(peaks) - high_amp_count > 0:
        # if 5 <= high_amp_count <= 6 and len(peaks) - high_amp_count > 0:
        if len(peaks) - high_amp_count > 0:
            mask = np.ones(len(peaks), dtype=bool)
            # mask[high_amp_indices] = False
            mask[high_amp_indices] = True
            cleaned_peaks = peaks[mask]
            cleaned_amplitudes = peak_amplitudes[mask]
        else:
            cleaned_peaks = peaks
            cleaned_amplitudes = peak_amplitudes
    else:
        cleaned_peaks = peaks
        cleaned_amplitudes = peak_amplitudes
    
    return cleaned_peaks, cleaned_amplitudes


def remove_t_waves(ecg_signal, peaks, sampling_rate):
    """Remove T-wave false positives based on timing and morphology"""
    if len(peaks) < 3:
        return peaks
    
    sorted_peaks = np.sort(peaks)
    cleaned_peaks = []
    
    for i, peak in enumerate(sorted_peaks):
        is_r_peak = True
        
        if i > 0:
            prev_peak = sorted_peaks[i-1]
            interval_ms = (peak - prev_peak) / sampling_rate * 1000
            
            # Check if this could be a T-wave (160-450ms after R-peak)
            if 160 < interval_ms < 450:
                prev_amp = abs(ecg_signal[int(prev_peak)])
                curr_amp = abs(ecg_signal[int(peak)])
                
                # T-waves are typically smaller and wider
                if curr_amp < prev_amp * 0.5:
                    half_max = curr_amp * 0.5
                    
                    # Measure width at half maximum
                    left = peak
                    while left > 0 and left > peak - 100:
                        if abs(ecg_signal[int(left)]) < half_max:
                            break
                        left -= 1
                    
                    right = peak
                    while right < len(ecg_signal) - 1 and right < peak + 100:
                        if abs(ecg_signal[int(right)]) < half_max:
                            break
                        right += 1
                    
                    width_ms = (right - left) / sampling_rate * 1000
                    
                    # T-waves are wider than QRS complexes
                    if width_ms > 40:
                        is_r_peak = False
        
        if is_r_peak:
            cleaned_peaks.append(peak)
    
    return np.array(cleaned_peaks)


def robust_qrs_detect_internal(data_clean, sampling_rate):
    """Multi-strategy robust QRS detection for difficult cases"""
    original_data = data_clean.copy()
    nyquist = 0.5 * sampling_rate
    
    # Calculate sharpness threshold
    low_strict = 10 / nyquist
    high_strict = 40 / nyquist
    b2, a2 = sp_signal.butter(2, [low_strict, high_strict], btype='band')
    filtered_strict = sp_signal.filtfilt(b2, a2, data_clean)
    diff_strict = np.diff(np.abs(filtered_strict))
    diff_strict = np.append(diff_strict, 0)
    strict_score = diff_strict ** 2
    
    if len(strict_score) > 0:
        sharpness_threshold = np.percentile(strict_score, 94)
    else:
        sharpness_threshold = 0
    
    all_candidate_peaks = []
    
    # Strategy 1: Multi-band detection with multiple thresholds
    freq_bands = [(5, 15), (8, 24), (10, 30), (12, 40)]
    
    for low_freq, high_freq in freq_bands:
        low = low_freq / nyquist
        high = high_freq / nyquist
        b, a = sp_signal.butter(2, [low, high], btype='band')
        filtered = sp_signal.filtfilt(b, a, data_clean)
        
        squared = filtered ** 2
        window_size = int(0.15 * sampling_rate)
        integrated = np.convolve(squared, np.ones(window_size) / window_size, mode='same')
        
        mean_val = np.mean(integrated)
        std_val = np.std(integrated)
        
        thresholds = [mean_val + 0.1 * std_val, mean_val + 0.2 * std_val, mean_val + 0.3 * std_val]
        
        for threshold in thresholds:
            candidates, _ = sp_signal.find_peaks(
                integrated, 
                height=threshold,
                distance=int(0.2 * sampling_rate)
            )
            
            search_window = int(0.1 * sampling_rate)
            sharp_window = int(0.18 * sampling_rate)
            
            for peak in candidates:
                start_sharp = max(0, peak - sharp_window)
                end_sharp = min(len(strict_score), peak + sharp_window)
                if start_sharp < end_sharp:
                    local_sharpness = np.max(strict_score[start_sharp:end_sharp])
                    
                    if local_sharpness > sharpness_threshold:
                        start = max(0, peak - search_window)
                        end = min(len(original_data), peak + search_window)
                        if start < end:
                            local_segment = original_data[start:end]
                            local_max_idx = np.argmax(np.abs(local_segment))
                            refined_peak = start + local_max_idx
                            all_candidate_peaks.append(refined_peak)
    
    # Strategy 2: Prominence-based detection
    peaks_prom, properties = sp_signal.find_peaks(
        original_data,
        distance=int(0.2 * sampling_rate),
        prominence=0.02
    )
    
    sharp_window = int(0.18 * sampling_rate)
    for peak in peaks_prom:
        start_sharp = max(0, peak - sharp_window)
        end_sharp = min(len(strict_score), peak + sharp_window)
        if start_sharp < end_sharp:
            local_sharpness = np.max(strict_score[start_sharp:end_sharp])
            if local_sharpness > sharpness_threshold * 0.8:
                all_candidate_peaks.append(peak)
    
    # Strategy 3: Derivative-based detection
    diff_signal = np.diff(original_data)
    diff_squared = diff_signal ** 2
    diff_squared = np.append(diff_squared, 0)
    
    mean_diff = np.mean(diff_squared)
    std_diff = np.std(diff_squared)
    
    diff_peaks, _ = sp_signal.find_peaks(
        diff_squared,
        height=mean_diff + 0.5 * std_diff,
        distance=int(0.15 * sampling_rate)
    )
    
    search_window = int(0.08 * sampling_rate)
    
    for peak in diff_peaks:
        start_sharp = max(0, peak - sharp_window)
        end_sharp = min(len(strict_score), peak + sharp_window)
        if start_sharp < end_sharp:
            local_sharpness = np.max(strict_score[start_sharp:end_sharp])
            
            if local_sharpness > sharpness_threshold * 0.7:
                start = max(0, peak - search_window)
                end = min(len(original_data), peak + search_window)
                if start < end:
                    local_segment = original_data[start:end]
                    local_max_idx = np.argmax(np.abs(local_segment))
                    refined_peak = start + local_max_idx
                    all_candidate_peaks.append(refined_peak)
    
    # Merge and deduplicate peaks
    if len(all_candidate_peaks) > 0:
        all_candidate_peaks = np.unique(all_candidate_peaks)
        
        min_distance = int(0.15 * sampling_rate)
        sorted_peaks = np.sort(all_candidate_peaks)
        
        if len(sorted_peaks) > 0:
            keep_mask = [True]
            for i in range(1, len(sorted_peaks)):
                if sorted_peaks[i] - sorted_peaks[i-1] >= min_distance:
                    keep_mask.append(True)
                else:
                    start1 = max(0, sorted_peaks[i-1] - sharp_window)
                    end1 = min(len(strict_score), sorted_peaks[i-1] + sharp_window)
                    start2 = max(0, sorted_peaks[i] - sharp_window)
                    end2 = min(len(strict_score), sorted_peaks[i] + sharp_window)
                    
                    sharp1 = np.max(strict_score[start1:end1]) if start1 < end1 else 0
                    sharp2 = np.max(strict_score[start2:end2]) if start2 < end2 else 0
                    
                    if sharp2 > sharp1:
                        keep_mask[-1] = False
                        keep_mask.append(True)
                    else:
                        keep_mask.append(False)
            
            sorted_peaks = sorted_peaks[keep_mask]
    
    return sorted_peaks if len(all_candidate_peaks) > 0 else np.array([])


def qrs_detect(data, sampling_rate, segment_duration=None, raw_segment=None):
    if raw_segment is not None:
        var_raw = np.var(raw_segment)
        # if var_raw < 0.0095:                  
        if var_raw < 0.005:                  
            print(f"Raw variance {var_raw:.6f} < 0.0095 → treating as asystole / flatline")
            return data, np.array([]), 0.0, np.array([])
    # else:
    #     var = np.var(data)
    #     if var < 0.00015:                     
    #         print(f"Normalized variance {var:.6f} too low → possible asystole")
    #         return data, np.array([]), 0.0, np.array([])

    # data_clean = baseline_wander(data) 

    data_clean = data 

    original_data = data_clean.copy()
    nyquist = 0.5 * sampling_rate
    
    # --- STREAM 1: Standard Detection ---
    low = 8 / nyquist
    high = 24 / nyquist
    b, a = sp_signal.butter(2, [low, high], btype='band')
    filtered_standard = sp_signal.filtfilt(b, a, data_clean)
    
    filtered_abs = np.abs(filtered_standard)
    diff = np.diff(filtered_abs)
    diff = np.append(diff, 0)
    squared = diff ** 2
    
    window_size = int(0.15 * sampling_rate)
    integrated = np.convolve(squared, np.ones(window_size) / window_size, mode='same')
    
    mean_val = np.mean(integrated)
    std_val = np.std(integrated)
    threshold = mean_val + 0.20 * std_val
    
    candidates, _ = sp_signal.find_peaks(
        integrated,
        height=threshold,
        distance=int(0.12 * sampling_rate)
    )
    
    # --- STREAM 2: Sharpness Validator ---
    low_strict = 10 / nyquist
    high_strict = 40 / nyquist
    b2, a2 = sp_signal.butter(2, [low_strict, high_strict], btype='band')
    filtered_strict = sp_signal.filtfilt(b2, a2, data_clean)
    diff_strict = np.diff(np.abs(filtered_strict))
    diff_strict = np.append(diff_strict, 0)
    strict_score = diff_strict ** 2
    
    if len(strict_score) > 0:
        sharpness_threshold = np.percentile(strict_score, 94)
    else:
        sharpness_threshold = 0
    
    confirmed_peaks = []
    search_window = int(0.18 * sampling_rate)
    
    for peak in candidates:
        start_check = max(0, peak - search_window)
        end_check = min(len(strict_score), peak + search_window)
        if start_check >= end_check:
            continue
            
        local_sharpness = np.max(strict_score[start_check:end_check])
        
        if local_sharpness > sharpness_threshold:
            local_segment = original_data[start_check:end_check]
            if len(local_segment) > 0:
                abs_local_segment = np.abs(local_segment)
                local_max_idx = np.argmax(abs_local_segment)
                confirmed_peaks.append(start_check + local_max_idx)
    
    r_peaks = np.array(confirmed_peaks)
    
    # Remove close peaks
    min_dist = int(0.15 * sampling_rate)
    r_peaks = remove_close_peaks(r_peaks, original_data, min_dist)
    
    cleaned_r = np.sort(np.array([x for x in r_peaks if not (isinstance(x, float) and np.isnan(x))]))
    
    # =================================================================
    # CRITICAL FIX: GAP FILLING WITH AMPLITUDE GUARDRAILS
    # =================================================================
    if len(cleaned_r) >= 2:
        # Calculate reference height (Median of existing peaks)
        existing_heights = np.abs(original_data[cleaned_r.astype(int)])
        median_r_height = np.median(existing_heights) if len(existing_heights) > 0 else 0
        
        rr_intervals = np.diff(cleaned_r) / sampling_rate
        median_rr = np.median(rr_intervals) if len(rr_intervals) > 0 else 1.0
        new_peaks = list(cleaned_r)
        
        # Only fill gaps if median_rr suggests a normal rhythm (< 1.5s).
        # If median_rr is already 2.0s (bradycardia), huge gaps are normal.
        if median_rr < 1.5: 
            for i in range(len(rr_intervals)):
                if rr_intervals[i] > 1.4 * median_rr:
                    gap_start = cleaned_r[i]
                    gap_end = cleaned_r[i+1]
                    if gap_start >= gap_end:
                        continue
                        
                    gap_integrated = integrated[gap_start:gap_end]
                    # Lower threshold slightly for gap search
                    low_thresh = mean_val * 0.6 
                    
                    gap_candidates, _ = sp_signal.find_peaks(
                        gap_integrated,
                        height=low_thresh,
                        distance=int(0.10 * sampling_rate)
                    )
                    
                    for gc in gap_candidates:
                        abs_idx = gap_start + gc
                        sw_start = max(0, abs_idx - search_window)
                        sw_end = min(len(strict_score), abs_idx + search_window)
                        if sw_start >= sw_end:
                            continue
                            
                        # 1. Check Sharpness
                        local_sharp_max = np.max(strict_score[sw_start:sw_end])
                        if local_sharp_max > sharpness_threshold * 0.4:
                            
                            # 2. Refine Position
                            local_segment = original_data[sw_start:sw_end]
                            abs_local_segment = np.abs(local_segment)
                            refine_idx = np.argmax(abs_local_segment)
                            candidate_peak = sw_start + refine_idx
                            
                            # 3. AMPLITUDE CHECK (The Fix)
                            # Even if it's sharp, is it tall enough?
                            # AV Block P-waves are sharp but short.
                            candidate_amp = np.abs(original_data[candidate_peak])
                            
                            # Must be at least 40-50% of the median R-peak height
                            if candidate_amp > 0.45 * median_r_height:
                                new_peaks.append(candidate_peak)

        new_peaks = np.sort(np.unique(new_peaks))
        cleaned_r = remove_close_peaks(new_peaks, original_data, min_dist)
    
    # =================================================================

    # Determine expected peak count range
    if segment_duration is None:
        segment_duration = len(data_clean) / sampling_rate
    
    # Relaxed expectations for Bradycardia/AV Block
    min_expected_peaks = int(30/60 * segment_duration) 
    max_expected_peaks = int(180/60 * segment_duration)
    
    # Fallback to robust only if counts are extremely off
    if len(cleaned_r) < min_expected_peaks or len(cleaned_r) > max_expected_peaks:
        initial_peaks = robust_qrs_detect_internal(data_clean, sampling_rate)
        initial_peaks = remove_t_waves(data_clean, initial_peaks, sampling_rate)
        cleaned_r, peak_amplitudes = amplitude_based_filtering(data_clean, initial_peaks, "Segment")
    else:
        cleaned_r = remove_t_waves(data_clean, cleaned_r, sampling_rate)
    
    # Calculate BPM
    if len(cleaned_r) > 1:
        rr_intervals = np.diff(cleaned_r) / sampling_rate
        
        # Valid intervals widened to account for Bradycardia/Pauses
        # valid_rr = rr_intervals[(rr_intervals > 0.2) & (rr_intervals < 3.5)] 
        valid_rr = rr_intervals[(rr_intervals > 0.2) & (rr_intervals < 4.0)] 
        
        if len(valid_rr) > 0:
            mean_rr = np.mean(valid_rr)
            bpm = 60 / mean_rr if mean_rr > 0 else 0
        else:
            bpm = 0
    else:
        bpm = 0
    
    return data, cleaned_r, bpm, cleaned_r



def process_ecg_segments(ecg_data, sampling_rate, num_segments=7, min_segment_length=3500):
    max_len = len(ecg_data)
    
    if num_segments > 1:
        window_step = (max_len - min_segment_length) / (num_segments - 1)
        window_step = round(window_step)
    else:
        window_step = 0
    
    results = []
    
    for i in range(num_segments):
        start_idx = i * window_step
        end_idx = start_idx + min_segment_length
        
        if end_idx > max_len:
            start_idx = max_len - min_segment_length
            end_idx = max_len
            
        if start_idx < 0:
            start_idx = 0
            end_idx = min(min_segment_length, max_len)
        
        segment = ecg_data[start_idx:end_idx]
        
        if len(segment) < 100:
            results.append({
                'segment_num': i + 1,
                'start_idx': start_idx,
                'end_idx': end_idx,
                'ecg_filtered': np.array([]),
                'r_peaks': np.array([]),
                'bpm': 0,
                'cleaned_r': np.array([]),
                'ecg_raw': segment
            })
            continue
        
        segment_duration = len(segment) / sampling_rate
        # ecg_filtered, r_peaks, bpm, cleaned_r = qrs_detect(segment, sampling_rate, segment_duration)

        raw_segment = raw_ecg[start_idx:end_idx]   # ← the real raw amplitudes

        quality, sqi, details = assess_ecg_quality(raw_segment, sampling_rate)

        print(f"Segment {i+1} SQI: {sqi:.2f} → {quality}")

        # if quality == "BAD" or quality == "MARGINAL":
        if quality == "BADDDDDDDDDDDDDDDDD":
            results.append({
                'segment_num': i + 1,
                'start_idx': start_idx,
                'end_idx': end_idx,
                'ecg_filtered': np.array([]),
                'r_peaks': np.array([]),
                'bpm': 0,
                'cleaned_r': np.array([]),
                'ecg_raw': segment
            })
            continue

        ecg_filtered, r_peaks, bpm, cleaned_r = qrs_detect(
            segment,
            sampling_rate,
            segment_duration,
            raw_segment=raw_segment                # ← pass raw here
        )

        print(f"Segment {i+1}: Detected {len(r_peaks)} R-peaks, BPM: {bpm:.1f}")
        
        adjusted_r_peaks = r_peaks + start_idx if len(r_peaks) > 0 else np.array([])
        adjusted_cleaned_r = np.array(cleaned_r) + start_idx if len(cleaned_r) > 0 else np.array([])
        
        results.append({
            'segment_num': i + 1,
            'start_idx': start_idx,
            'end_idx': end_idx,
            'ecg_filtered': ecg_filtered,
            'r_peaks': adjusted_r_peaks,
            'bpm': bpm,
            'cleaned_r': adjusted_cleaned_r,
            'ecg_raw': segment
        })
    
    return results

import numpy as np

def compute_ecg_stats(signal, fs=500):
    """Compute common statistics for an ECG segment"""
    if len(signal) == 0:
        return {
            'nsamples': 0,
            'mean': np.nan,
            'std': np.nan,
            'var': np.nan,
            'min': np.nan,
            'max': np.nan,
            'median': np.nan,
            'rms': np.nan,
            'duration_s': 0.0
        }
    
    return {
        'nsamples': len(signal),
        'mean': float(np.mean(signal)),
        'std': float(np.std(signal)),
        'var': float(np.var(signal)),
        'min': float(np.min(signal)),
        'max': float(np.max(signal)),
        'median': float(np.median(signal)),
        'rms': float(np.sqrt(np.mean(signal**2))),
        'duration_s': len(signal) / fs
    }


def format_stats_text(stats, prefix=""):
    """Create a compact multi-line stats string for plotting"""
    lines = [
        f"{prefix}Duration: {stats['duration_s']:.2f} s",
        f"Samples:   {stats['nsamples']}",
        f"Mean:      {stats['mean']:.4f}",
        f"Std:       {stats['std']:.4f}",
        f"Var:       {stats['var']:.6f}",
        f"Min / Max: {stats['min']:.4f} / {stats['max']:.4f}",
        f"Median:    {stats['median']:.4f}",
        f"RMS:       {stats['rms']:.4f}",
    ]
    return "\n".join(lines)


def plot_ecg_segments(ecg_data, sampling_rate, results, title="ECG Segments with R-peaks and BPM", raw_ecg=None):
    num_segments = len(results)
    fig, axes = plt.subplots(num_segments, 1, figsize=(15, 3.5 * num_segments), sharex=False)
    
    if num_segments == 1:
        axes = [axes]
    
    time = np.arange(len(ecg_data)) / sampling_rate
    
    global_stats = compute_ecg_stats(ecg_data, sampling_rate)
    fig.suptitle(f"{title}\nFull signal stats: {global_stats['duration_s']:.1f}s | "
                 f"mean={global_stats['mean']:.4f}  std={global_stats['std']:.4f}", 
                 fontsize=13, y=0.98)
    
    for i, (ax, result) in enumerate(zip(axes, results)):
        segment_num = result['segment_num']
        start_idx = result['start_idx']
        end_idx = result['end_idx']
        bpm = result['bpm']
        r_peaks = result['r_peaks']
        
        segment_time = time[start_idx:end_idx]
        segment_data = result['ecg_raw']
        
        ax.plot(segment_time, segment_data, 'b-', alpha=0.8, linewidth=1.1, label='ECG')
        
        if len(r_peaks) > 0:
            r_times = r_peaks / sampling_rate
            r_values = ecg_data[r_peaks.astype(int)]
            ax.plot(r_times, r_values, 'ro', markersize=7, label='R-peaks', alpha=0.85)
        
        # ── Statistics box per segment (use raw if available) ───────────────────────────────
        if raw_ecg is not None:
            raw_segment = raw_ecg[start_idx:end_idx]
            seg_stats = compute_ecg_stats(raw_segment, sampling_rate)
            prefix = "Raw "
        else:
            seg_stats = compute_ecg_stats(segment_data, sampling_rate)
            prefix = ""
        stats_text = format_stats_text(seg_stats, prefix + f"Seg {segment_num}  ")
        stats_text += f"\nBPM:       {bpm:.1f}"
        
        ax.text(0.02, 0.98, stats_text,
                transform=ax.transAxes,
                fontsize=9.5,
                verticalalignment='top',
                bbox=dict(facecolor='white', alpha=0.82, edgecolor='gray', boxstyle='round,pad=0.4'))
        
        segment_duration = (end_idx - start_idx) / sampling_rate
        ax.set_title(f'Segment {segment_num}: {start_idx:,} – {end_idx:,}  |  BPM: {bpm:.1f}')
        ax.set_ylabel('Amplitude (norm)')
        ax.grid(True, alpha=0.35, linestyle='--')
        ax.set_xlim([segment_time[0], segment_time[-1]])
        ax.legend(loc='upper right', fontsize=9)
    
    axes[-1].set_xlabel('Time (seconds)')
    plt.tight_layout(rect=[0, 0, 1, 0.96])   # make room for suptitle
    plt.show()
    
    # ── Console summary ───────────────────────────────────────────────
    print("═" * 70)
    print("ECG SEGMENT STATISTICS SUMMARY")
    print("═" * 70)
    for res in results:
        if raw_ecg is not None:
            s = compute_ecg_stats(raw_ecg[res['start_idx']:res['end_idx']], sampling_rate)
            prefix = "Raw "
        else:
            s = compute_ecg_stats(res['ecg_raw'], sampling_rate)
            prefix = "Norm "
        print(f"Segment {res['segment_num']:2d} | {s['duration_s']:5.2f}s | "
              f"mean={s['mean']:8.4f}  std={s['std']:7.4f}  BPM={res['bpm']:5.1f} ({prefix.strip()})"
            )
    print("═" * 70)
    
    
def plot_full_ecg(ecg_data, sampling_rate, title="Full ECG Signal Analysis", raw_ecg=None):
    # _, r_peaks, global_bpm, _ = qrs_detect(ecg_data, sampling_rate)
    _, r_peaks, global_bpm, _ = qrs_detect(
        ecg_data,
        sampling_rate,
        raw_segment=raw_ecg[:len(ecg_data)]    # pass corresponding raw part
    )
        
    if raw_ecg is not None:
        stats = compute_ecg_stats(raw_ecg[:len(ecg_data)], sampling_rate)
        prefix = "Raw "
    else:
        stats = compute_ecg_stats(ecg_data, sampling_rate)
        prefix = ""
    
    plt.figure(figsize=(20, 6)) # Width of 20 makes the 15k samples readable
    
    # Create time axis
    time_axis = np.arange(len(ecg_data)) / sampling_rate
    
    # Plot the signal
    plt.plot(time_axis, ecg_data, 'b-', linewidth=0.8, alpha=0.8, label='Filtered ECG')
    
    # Plot the peaks
    if len(r_peaks) > 0:
        # Filter out peaks that might be out of bounds (safety check)
        valid_peaks = r_peaks[r_peaks < len(ecg_data)].astype(int)
        
        peak_times = valid_peaks / sampling_rate
        peak_values = ecg_data[valid_peaks]
        
        plt.plot(peak_times, peak_values, 'ro', markersize=4, label='R-peaks')
        
        # Optional: Annotate every 5th peak to help navigation
        for i, (t, v) in enumerate(zip(peak_times, peak_values)):
            if i % 5 == 0:
                plt.annotate(f'{t:.1f}s', (t, v), xytext=(0, 10), 
                             textcoords='offset points', ha='center', fontsize=8, color='red')

    plt.title(f"{title} | Global BPM: {global_bpm:.1f} | Total Peaks: {len(r_peaks)}")
    plt.xlabel("Time (seconds)")
    plt.ylabel("Normalized Amplitude")
    plt.legend(loc='upper right')
    plt.grid(True, which='both', alpha=0.5)
    plt.tight_layout()
    plt.show()
    
    print(f"Global Analysis: {len(r_peaks)} peaks detected over {len(ecg_data)/sampling_rate:.2f} seconds.")
    print(f"{prefix}Full signal stats →  mean={stats['mean']:.4f}  std={stats['std']:.4f}  var={stats['var']:.6f}")


# def assess_ecg_quality(ecg_raw, fs=500):
#     ecg = np.asarray(ecg_raw)

#     if len(ecg) < fs:
#         return "BADDDDDDDDDDDDDDDDD", 0.0, {"reason": "Too short"}

#     # ── 1. Variance (flatline / saturation)
#     var = np.var(ecg)

#     # ── 2. Baseline wander (0–0.5 Hz)
#     f, pxx = sp_signal.welch(ecg, fs=fs, nperseg=2048)
#     baseline_power = np.sum(pxx[(f >= 0.05) & (f <= 0.5)])
#     total_power = np.sum(pxx)
#     baseline_ratio = baseline_power / total_power if total_power > 0 else 1.0

#     # ── 3. Powerline noise (50 Hz)
#     power_50hz = np.sum(pxx[(f >= 48) & (f <= 52)])
#     power_5_40hz = np.sum(pxx[(f >= 5) & (f <= 40)])
#     powerline_ratio = power_50hz / power_5_40hz if power_5_40hz > 0 else 1.0

#     # ── 4. QRS energy dominance
#     qrs_band_power = np.sum(pxx[(f >= 8) & (f <= 25)])
#     qrs_ratio = qrs_band_power / total_power if total_power > 0 else 0

#     # ── 5. Clipping detection
#     clip_ratio = np.mean((ecg == np.max(ecg)) | (ecg == np.min(ecg)))

#     # ── Scoring (0–1)
#     score = 1.0
#     score -= 0.4 if var < 0.005 else 0
#     score -= min(baseline_ratio * 2, 0.3)
#     score -= min(powerline_ratio * 2, 0.3)
#     score -= 0.2 if clip_ratio > 0.02 else 0
#     score += min(qrs_ratio * 1.5, 0.3)

#     score = np.clip(score, 0, 1)

#     if score >= 0.75:  # 75
#         quality = "GOOD"
#     # elif score >= 0.45:
#     #     quality = "MARGINAL"
#     else:
#         quality = "BADDDDDDDDDDDDDDDDD"

#     details = {
#         "variance": float(var),
#         "baseline_ratio": float(baseline_ratio),
#         "powerline_ratio": float(powerline_ratio),
#         "qrs_ratio": float(qrs_ratio),
#         "clip_ratio": float(clip_ratio)
#     }

#     return quality, score, details

# BADDDDDDDDDDDDDDDDD


def assess_ecg_quality(ecg_raw, fs=500):
    ecg = np.asarray(ecg_raw)

    if len(ecg) < fs:
        return "BAD", 0.0, {"reason": "Too short"}

    # ── 1. Variance (flatline / saturation)
    var = np.var(ecg)
    
    # ── 2. Check for monotonic rising/drowning patterns using peak differences ──
    # Find all local maxima and minima (peaks and valleys)
    maxima, _ = sp_signal.find_peaks(ecg, distance=50)  # local peaks
    minima, _ = sp_signal.find_peaks(-ecg, distance=50)  # local valleys
    
    # Combine and sort all extrema
    all_extrema = np.sort(np.concatenate([maxima, minima]))
    
    monotonic_score = 1.0
    monotonic_details = {"peak_difference_pattern": "variable"}
    
    if len(all_extrema) >= 3:  # Need at least 3 extrema to analyze pattern
        # Calculate amplitude differences between consecutive extrema
        extrema_values = ecg[all_extrema]
        amplitude_diffs = np.abs(np.diff(extrema_values))
        
        # Calculate time differences between consecutive extrema
        time_diffs = np.diff(all_extrema) / fs  # in seconds
        
        # Analyze patterns:
        # 1. Check if amplitude differences are consistently small (suggests drowning/smooth signal)
        if len(amplitude_diffs) > 2:
            amp_std = np.std(amplitude_diffs)
            amp_mean = np.mean(amplitude_diffs)
            
            # Low standard deviation of amplitude differences suggests consistent pattern
            if amp_std < 0.1 * amp_mean and amp_mean < 0.05 * (np.max(ecg) - np.min(ecg)):
                monotonic_score -= 0.3
                monotonic_details["peak_difference_pattern"] = "consistent_small"
        
        # 2. Check for consistent rising/falling trend by analyzing peak-valley sequences
        trend_strength = 0
        trend_direction = 0
        
        # Analyze sequence of peaks and valleys
        peak_valley_values = ecg[all_extrema]
        peak_valley_signs = np.sign(np.diff(peak_valley_values))
        
        # Count consecutive same-sign differences
        if len(peak_valley_signs) > 3:
            same_sign_count = 0
            max_same_sign = 0
            
            for i in range(1, len(peak_valley_signs)):
                if peak_valley_signs[i] == peak_valley_signs[i-1]:
                    same_sign_count += 1
                    max_same_sign = max(max_same_sign, same_sign_count)
                else:
                    same_sign_count = 0
            
            # If many consecutive same-sign differences, strong trend exists
            if max_same_sign >= 4:  # 4+ consecutive same-sign changes
                monotonic_score -= 0.4
                monotonic_details["peak_difference_pattern"] = "strong_trend"
        
        # 3. Check if time intervals between extrema are too regular (suggests artificial pattern)
        if len(time_diffs) > 3:
            time_std = np.std(time_diffs)
            time_mean = np.mean(time_diffs)
            
            if time_std < 0.1 * time_mean:  # Very regular timing
                monotonic_score -= 0.2
                monotonic_details["time_regularity"] = "high"
    
    # ── 3. Baseline wander (0–0.5 Hz)
    f, pxx = sp_signal.welch(ecg, fs=fs, nperseg=2048)
    baseline_power = np.sum(pxx[(f >= 0.05) & (f <= 0.5)])
    total_power = np.sum(pxx)
    baseline_ratio = baseline_power / total_power if total_power > 0 else 1.0

    # ── 4. Powerline noise (50 Hz)
    power_50hz = np.sum(pxx[(f >= 48) & (f <= 52)])
    power_5_40hz = np.sum(pxx[(f >= 5) & (f <= 40)])
    powerline_ratio = power_50hz / power_5_40hz if power_5_40hz > 0 else 1.0

    # ── 5. QRS energy dominance
    qrs_band_power = np.sum(pxx[(f >= 8) & (f <= 25)])
    qrs_ratio = qrs_band_power / total_power if total_power > 0 else 0

    # ── 6. Clipping detection
    clip_ratio = np.mean((ecg == np.max(ecg)) | (ecg == np.min(ecg)))
    
    # ── 7. Additional checks for extreme monotonic patterns ──
    # Check overall trend using linear regression
    x = np.arange(len(ecg))
    slope, intercept = np.polyfit(x, ecg, 1)
    trend_line = slope * x + intercept
    trend_residuals = ecg - trend_line
    trend_strength = np.abs(slope) * len(ecg) / (np.max(ecg) - np.min(ecg) + 1e-10)
    
    # If strong linear trend with low residuals
    if trend_strength > 0.5 and np.std(trend_residuals) < 0.2 * np.std(ecg):
        monotonic_score -= 0.5
        monotonic_details["linear_trend"] = "strong"

    # ── Combined Scoring (0–1) ──
    score = 1.0
    
    # Apply monotonic pattern penalty
    score = max(0, score - (1 - monotonic_score))
    
    # Apply other penalties
    score -= 0.4 if var < 0.005 else 0
    score -= min(baseline_ratio * 2, 0.3)
    score -= min(powerline_ratio * 2, 0.3)
    score -= 0.2 if clip_ratio > 0.02 else 0
    
    score = np.clip(score, 0, 1)

    # Quality classification
    if score >= 0.7:
        quality = "GOOD"
    elif score >= 0.4:
        quality = "MARGINAL"
    else:
        quality = "BAD"

    details = {
        "variance": float(var),
        "baseline_ratio": float(baseline_ratio),
        "powerline_ratio": float(powerline_ratio),
        "qrs_ratio": float(qrs_ratio),
        "clip_ratio": float(clip_ratio),
        "monotonic_pattern": monotonic_details["peak_difference_pattern"],
        "score": float(score),
        "extrema_count": len(all_extrema),
        "trend_strength": float(trend_strength) if 'trend_strength' in locals() else 0.0
    }

    return quality, score, details



# ============================================================================

 
# # input_json = r"simulator\contec\trigeminy_1756103085272.json" 
# # input_json = r"simulator\contec\asystl_1756103447146.json" 
# # input_json = r"simulator\contec\1d av_1756104504294.json"  
# # input_json = r"simulator\contec\3d av_1756104633918.json"  
# # input_json = r"simulator\contec\280bpm_1756100716422.json" 
# # input_json = r"simulator\contec\av sequence_1756106676125.json"  
# # input_json = r"simulator\contec\dmnd freq_1756106571373.json"  
# #    
# # input_json = r"simulator\fluke\trigeminy_1754543043205.json"   
# # input_json = r"simulator\fluke\3d av_1754545068278.json"   
# # input_json = r"simulator\fluke\asystole_1754544406847.json"   

# # input_json = r"simulator\fluke\80bpm_1754288606201.json"   
# input_json = r"0_bpm\asystole_jan_2_1770026113214.json"   

# with open(input_json, 'r') as file:
#     file_data = json.load(file)


# ============================================================================


# input_json = r"v01_prob\teton_ecg.ecgdatas.json"  
# with open(input_json, 'r') as file:
#     all_id_data = json.load(file)

# file_data = all_id_data[3]['ecgValue']   



# # input_json = r"bpms\afib_1766471694144.json"
# # input_json = r"bpms\bigeminy_1766467666407.json"
# # input_json = r"bpms\pvc 6_1766467718685.json"    
# # input_json = r"bpms\tri_1766467618314.json"
# # input_json = r"v01_prob/220_1767858669130.json"
# # input_json = r"v01_prob/240bpm_1767858615562.json"
# # input_json = r"v01_prob\25 contec_1768375918389.json"
# # input_json = r"v01_prob\30bpm contec_1768375716454.json"
# # input_json = r"v01_prob\2d av_1754545008828.json"  
# input_json = r"v01_prob\3rd_davb_1768554217066.json"  

# with open(input_json, 'r') as file:
#     file_data = json.load(file)




input_json = r"exception\L2_1759207950416.json"   
# input_json = r"issues\L2_1757064122874.json"  
# input_json = r"v01_prob\run 5 pvc.json"  
# input_json = r"issues\L2_1757579288752.json"
# input_json = r"issues\L2_1757737806463.json"  
# input_json = r"v01_prob\L2_1765984517025.json"  
# input_json = r"1st-last-peaks\L2_1759908627949.json"
# input_json = r"1st-last-peaks\L2_1759908888619.json"

# input_json = r"issues2\1757998097068\L2_1757998097068.json"
# input_json = r"issues2\1758943573744\L2_1758943573744.json"
# input_json = r"issues2\1759059066184\L2_1759059066184.json"     ###########
# input_json = r"issues2\1759117739887\L2_1759117739887.json"
# input_json = r"issues2\1759118709079\L2_1759118709079.json"
# input_json = r"issues2\1759202739736\L2_1759202739736.json"
# input_json = r"issues2\1759639059357\L2_1759639059357.json"
# input_json = r"issues2\L2_1759208381248.json"

doubles = []
with open(input_json, "rb") as f:
    while chunk := f.read(8):
        if len(chunk) < 8:
            break
        value = struct.unpack("<d", chunk)[0]
        doubles.append(value)

file_data = {'dataL2': doubles}   




# def decrypt(input_file):
#     """Decrypt encrypted JSON file (optional - commented out in your version)"""
#     private_key = "Msz377xMbcn++vrcDel9vxOuEss8fsWO"
#     cipher = AES.new(private_key.encode('utf-8'), AES.MODE_ECB)
#     with open(input_file, 'rb') as f:
#         encrypted_data = f.read()
#     enc = base64.b64decode(encrypted_data[24:])
#     data = unpad(cipher.decrypt(enc), 16)
#     decoded_string = data.decode('utf-8')
#     return json.loads(decoded_string)

# # # input_json = r"NHF2\DATA_1750689015865.json"
# # # input_json = r"NHF2\DATA_1750689460556.json"
# # # input_json = r"NHF2\DATA_1750851207409.json"
# # # input_json = r"NHF2\DATA_1750858856842.json"
# # # input_json = r"NHF2\DATA_1750862721789.json"
# # input_json = r"NHF2\DATA_1750996455820.json"
# # file_data = decrypt(input_json)

# # input_json = r"NHF\DATA_1752067426678.json"  #####
# # input_json = r"NHF\DATA_1752121970835.json"  ########
# input_json = r"NHF\DATA_1754709586876.json"  #####
# # input_json = r"NHF\DATA_1754729551054.json"
# file_data = decrypt(input_json)



# def decrypt(input_file):
#     """Decrypt encrypted CSV file using AES ECB mode"""
#     Private_key = "Msz377xMbcn++vrcDel9vxOuEss8fsWO"
    
#     cipher = AES.new(Private_key.encode(), AES.MODE_ECB)
#     with open(input_file, 'rb') as f:
#         encrypted_data = f.read()

#     enc = base64.b64decode(encrypted_data[24:])
#     cipher = AES.new(Private_key.encode('utf-8'), AES.MODE_ECB)
#     data = unpad(cipher.decrypt(enc), 16)

#     decoded_string = data.decode('utf-8')
#     data_list = decoded_string.split(",")
#     float_list = [float(x) for x in data_list]

#     return float_list

# selected_path = "v01_prob\ECG_1735798172211.csv"  ####
# # selected_path = "v01_prob\ECG_L2_1738637533455.csv"
# file_data = decrypt(selected_path)
# file_data = {'dataL2': file_data}



def low_pass_filter(data):
    try:
        return sp_signal.filtfilt(b_lp, a_lp, data)
    except:
        return data


def notch_filter(data):
    try:
        return sp_signal.filtfilt(b_notch, a_notch, data)
    except:
        return data

     

# =========================================================================

# processed_data, raw_global_stats, raw_ecg = data_process(
#     low_pass_filter(notch_filter(file_data))
# )

# ecg_full = processed_data[0, :15000, 0]


# processed_data, raw_global_stats, raw_ecg = data_process(
#     low_pass_filter(notch_filter(file_data))
# )
# ecg_full = processed_data[0, :15000, 0]

# data = data_process(file_data)
# ecg_full = data[0, :15000, 0]

processed, stats, raw_ecg = data_process(file_data)
ecg_full = processed[0, :15000, 0]

sampling_rate = 500

results = process_ecg_segments(
    ecg_data=ecg_full,
    sampling_rate=sampling_rate,
    num_segments=20,
    min_segment_length=1500
)

plot_ecg_segments(ecg_full, sampling_rate, results, "ECG Analysis: 4 Segments with R-peak Detection", raw_ecg=raw_ecg)

print("\n--- Plotting Full Data ---")
plot_full_ecg(ecg_full, sampling_rate, "Final Full Data View", raw_ecg=raw_ecg)