# ONSET-24: Two-Stage Pipeline Evaluation

Compare single-stage vs. two-stage RUL prediction on the XJTU-SY bearing dataset.

**Two-stage approach:**
1. **Stage 1 (Onset Detection):** Rule-based detector identifies when degradation begins
2. **Stage 2 (RUL Prediction):** RUL computed only from onset to failure

**Single-stage approach:**
- Standard piecewise-linear RUL from first sample to failure (max_rul=125)

## Sections
1. Load data and setup pipelines
2. Generate comparison table: single-stage vs. two-stage MAE/RMSE/PHM08
3. Plot RUL predictions for sample bearings
4. Analyze onset detection accuracy impact on final RUL metrics
5. Document when two-stage helps vs. when it doesn't

In [None]:
import sys
import os
sys.path.insert(0, '..')
os.chdir('..')

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

from src.onset import (
    load_all_bearings_health_series,
    load_bearing_health_series,
    compute_composite_hi,
    smooth_health_indicator,
    ThresholdOnsetDetector,
    CUSUMOnsetDetector,
    EWMAOnsetDetector,
    load_onset_labels,
    add_onset_column,
    TwoStagePipeline,
    OnsetResult,
)
from src.data.rul_labels import (
    piecewise_linear_rul,
    compute_twostage_rul,
)
from src.training.metrics import (
    mae,
    rmse,
    phm08_score,
    phm08_score_normalized,
    onset_detection_metrics,
    onset_timing_mae,
    conditional_rul_metrics,
    twostage_combined_score,
)

plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.figsize'] = [14, 6]
plt.rcParams['figure.dpi'] = 100
plt.rcParams['font.size'] = 10

MAX_RUL = 125
print('Imports OK')

## 1. Load Data and Setup Pipelines

Load features, onset labels, and compute both single-stage and two-stage RUL labels for all 15 bearings.

In [None]:
# Load feature dataset
df = pd.read_csv('outputs/features/features_v2.csv')
print(f'Dataset: {df.shape[0]} samples, {df["bearing_id"].nunique()} bearings')

# Load manual onset labels
onset_labels = load_onset_labels()
print(f'Onset labels: {len(onset_labels)} bearings')

# Compute health indicators (smoothed)
all_hi = load_all_bearings_health_series(df, smooth=True, smooth_window=11)
print(f'Health indicators computed for {len(all_hi)} bearings')

In [None]:
# Compute single-stage and two-stage RUL labels for all bearings
bearing_data = {}

for bid in sorted(all_hi.keys()):
    hi = all_hi[bid]
    n_files = len(hi.file_indices)
    onset_idx = onset_labels[bid].onset_file_idx
    
    # Single-stage: piecewise linear from start
    rul_single = piecewise_linear_rul(n_files, max_rul=MAX_RUL)
    
    # Two-stage: flat pre-onset, linear decay post-onset
    rul_twostage = compute_twostage_rul(n_files, onset_idx, max_rul=MAX_RUL)
    
    # Binary onset mask
    onset_mask = np.zeros(n_files, dtype=bool)
    onset_mask[onset_idx:] = True
    
    bearing_data[bid] = {
        'n_files': n_files,
        'onset_idx': onset_idx,
        'rul_single': rul_single,
        'rul_twostage': rul_twostage,
        'onset_mask': onset_mask,
        'condition': onset_labels[bid].condition,
        'confidence': onset_labels[bid].confidence,
        'kurtosis_h': hi.kurtosis_h,
        'kurtosis_v': hi.kurtosis_v,
    }

print(f'Computed RUL labels for {len(bearing_data)} bearings')
print(f'\nSample (Bearing1_3):')
bd = bearing_data['Bearing1_3']
print(f'  Files: {bd["n_files"]}, Onset: {bd["onset_idx"]}')
print(f'  Single-stage RUL[0]: {bd["rul_single"][0]:.0f}, RUL[-1]: {bd["rul_single"][-1]:.0f}')
print(f'  Two-stage RUL[0]: {bd["rul_twostage"][0]:.0f}, RUL[-1]: {bd["rul_twostage"][-1]:.0f}')
print(f'  Two-stage RUL[onset]: {bd["rul_twostage"][bd["onset_idx"]]:.0f}')

In [None]:
# Run onset detectors to get predicted onset indices
# These simulate what the two-stage pipeline would do at inference time
detector = ThresholdOnsetDetector(threshold_sigma=2.0, min_consecutive=5)
healthy_fraction = 0.2

detected_onsets = {}
for bid in sorted(all_hi.keys()):
    hi = all_hi[bid]
    kurtosis_avg = (hi.kurtosis_h + hi.kurtosis_v) / 2.0
    result = detector.fit_detect(kurtosis_avg, healthy_fraction=healthy_fraction)
    detected_onsets[bid] = result.onset_idx

print('Detected onset indices vs manual labels:')
print(f'{"Bearing":<14} {"Manual":>7} {"Detected":>9} {"Error":>7}')
print('-' * 40)
for bid in sorted(detected_onsets.keys()):
    manual = onset_labels[bid].onset_file_idx
    det = detected_onsets[bid]
    err = (det - manual) if det is not None else 'N/A'
    print(f'{bid:<14} {manual:>7} {str(det):>9} {str(err):>7}')

In [None]:
# Compute predicted RUL using detected onset indices
# Single-stage: always uses piecewise_linear (no onset knowledge)
# Two-stage (oracle): uses manual onset labels (best-case scenario)
# Two-stage (detected): uses detector-predicted onset labels

for bid in bearing_data:
    bd = bearing_data[bid]
    n = bd['n_files']
    det_onset = detected_onsets[bid]
    
    # Two-stage with detected onset
    if det_onset is not None:
        bd['rul_twostage_detected'] = compute_twostage_rul(n, det_onset, max_rul=MAX_RUL)
        bd['onset_mask_detected'] = np.arange(n) >= det_onset
    else:
        # No onset detected: predict max_rul everywhere
        bd['rul_twostage_detected'] = np.full(n, MAX_RUL, dtype=np.float64)
        bd['onset_mask_detected'] = np.zeros(n, dtype=bool)

print('Predicted RUL computed for all 3 strategies')

## 2. Comparison Table: Single-Stage vs. Two-Stage Metrics

Compare MAE, RMSE, and PHM08 score across all 15 bearings for:
- **Single-stage**: Standard piecewise-linear RUL predictions
- **Two-stage (oracle)**: Two-stage RUL with manual (ground truth) onset labels
- **Two-stage (detected)**: Two-stage RUL with detector-predicted onset labels

The single-stage "prediction" is the piecewise-linear label itself (perfect predictor baseline),
compared against the two-stage ground truth. This shows the irreducible error from ignoring onset.

In [None]:
# Per-bearing metrics: evaluate how well each strategy's RUL labels
# approximate the two-stage ground truth
#
# Ground truth: two-stage RUL with manual onset (oracle)
# Predictions:
#   single-stage = piecewise_linear (ignores onset)
#   two-stage detected = two-stage with auto-detected onset

rows = []
for bid in sorted(bearing_data.keys()):
    bd = bearing_data[bid]
    y_true = bd['rul_twostage']  # ground truth (oracle two-stage)
    
    # Single-stage predictions
    y_pred_single = bd['rul_single']
    single_mae = mae(y_true, y_pred_single)
    single_rmse = rmse(y_true, y_pred_single)
    single_phm = phm08_score_normalized(y_true, y_pred_single)
    
    # Two-stage detected predictions
    y_pred_ts = bd['rul_twostage_detected']
    ts_mae = mae(y_true, y_pred_ts)
    ts_rmse = rmse(y_true, y_pred_ts)
    ts_phm = phm08_score_normalized(y_true, y_pred_ts)
    
    # Post-onset metrics (two-stage detected)
    cond_metrics = conditional_rul_metrics(y_true, y_pred_ts, bd['onset_mask'])
    
    # Onset detection metrics
    y_true_onset = bd['onset_mask'].astype(int)
    y_pred_onset = bd['onset_mask_detected'].astype(int)
    onset_met = onset_detection_metrics(y_true_onset, y_pred_onset)
    
    rows.append({
        'bearing': bid,
        'condition': bd['condition'],
        'n_files': bd['n_files'],
        'onset_idx': bd['onset_idx'],
        'onset_pct': bd['onset_idx'] / bd['n_files'] * 100,
        'detected_onset': detected_onsets[bid],
        'onset_error': (detected_onsets[bid] - bd['onset_idx'])
                        if detected_onsets[bid] is not None else None,
        'onset_f1': onset_met['f1'],
        # Single-stage
        'single_mae': single_mae,
        'single_rmse': single_rmse,
        'single_phm08': single_phm,
        # Two-stage detected
        'ts_mae': ts_mae,
        'ts_rmse': ts_rmse,
        'ts_phm08': ts_phm,
        # Post-onset only
        'post_onset_mae': cond_metrics['post_onset_mae'],
        # Improvement
        'mae_improvement': single_mae - ts_mae,
        'mae_improvement_pct': (single_mae - ts_mae) / single_mae * 100
                                if single_mae > 0 else 0,
    })

results_df = pd.DataFrame(rows)

# Display comparison table
print('Per-Bearing Metrics: Single-Stage vs Two-Stage (Detected)')
print('Ground truth: two-stage RUL with manual onset labels')
print('=' * 110)
print(f'{"Bearing":<14} {"Cond":<12} {"Files":>5} {"Onset%":>7} '
      f'{"Single MAE":>11} {"TwoStg MAE":>11} {"Improv%":>8} '
      f'{"Onset F1":>9} {"OnsetErr":>9}')
print('-' * 110)
for _, r in results_df.iterrows():
    onset_err = f'{r["onset_error"]:+.0f}' if r['onset_error'] is not None else 'N/A'
    print(f'{r["bearing"]:<14} {r["condition"]:<12} {r["n_files"]:>5} {r["onset_pct"]:>6.1f}% '
          f'{r["single_mae"]:>11.2f} {r["ts_mae"]:>11.2f} {r["mae_improvement_pct"]:>7.1f}% '
          f'{r["onset_f1"]:>9.3f} {onset_err:>9}')

print('-' * 110)
print(f'{"MEAN":<14} {"":12} {"":>5} {results_df["onset_pct"].mean():>6.1f}% '
      f'{results_df["single_mae"].mean():>11.2f} {results_df["ts_mae"].mean():>11.2f} '
      f'{results_df["mae_improvement_pct"].mean():>7.1f}% '
      f'{results_df["onset_f1"].mean():>9.3f}')

In [None]:
# Aggregate metrics across all bearings
# Concatenate all RUL arrays for global metrics
all_y_true = np.concatenate([bearing_data[b]['rul_twostage'] for b in sorted(bearing_data)])
all_y_single = np.concatenate([bearing_data[b]['rul_single'] for b in sorted(bearing_data)])
all_y_ts_det = np.concatenate([bearing_data[b]['rul_twostage_detected'] for b in sorted(bearing_data)])
all_onset_mask = np.concatenate([bearing_data[b]['onset_mask'] for b in sorted(bearing_data)])

print('Aggregate Metrics Across All 15 Bearings')
print('=' * 60)
print(f'{"Metric":<25} {"Single-Stage":>15} {"Two-Stage (Det)":>15}')
print('-' * 60)

metrics = [
    ('MAE', mae(all_y_true, all_y_single), mae(all_y_true, all_y_ts_det)),
    ('RMSE', rmse(all_y_true, all_y_single), rmse(all_y_true, all_y_ts_det)),
    ('PHM08 (normalized)', phm08_score_normalized(all_y_true, all_y_single),
                           phm08_score_normalized(all_y_true, all_y_ts_det)),
]

for name, single_val, ts_val in metrics:
    print(f'{name:<25} {single_val:>15.3f} {ts_val:>15.3f}')

# Post-onset conditional metrics
cond_single = conditional_rul_metrics(all_y_true, all_y_single, all_onset_mask)
cond_ts = conditional_rul_metrics(all_y_true, all_y_ts_det, all_onset_mask)

print(f'\nPost-Onset Only ({int(all_onset_mask.sum())}/{len(all_onset_mask)} samples):')
print(f'{"Post-onset MAE":<25} {cond_single["post_onset_mae"]:>15.3f} {cond_ts["post_onset_mae"]:>15.3f}')
print(f'{"Post-onset RMSE":<25} {cond_single["post_onset_rmse"]:>15.3f} {cond_ts["post_onset_rmse"]:>15.3f}')
print(f'{"Post-onset PHM08":<25} {cond_single["post_onset_phm08_normalized"]:>15.3f} {cond_ts["post_onset_phm08_normalized"]:>15.3f}')

## 3. Plot RUL Predictions for Sample Bearings

Visual comparison of single-stage vs. two-stage RUL labels for representative bearings across all 3 conditions.

In [None]:
# Plot RUL comparison for 6 bearings (2 per condition)
sample_bearings = ['Bearing1_3', 'Bearing1_4', 'Bearing2_1', 'Bearing2_3', 'Bearing3_1', 'Bearing3_3']

fig, axes = plt.subplots(3, 2, figsize=(16, 12))

for i, bid in enumerate(sample_bearings):
    ax = axes[i // 2, i % 2]
    bd = bearing_data[bid]
    x = np.arange(bd['n_files'])
    
    # Plot RUL curves
    ax.plot(x, bd['rul_twostage'], 'k-', linewidth=2, label='Ground Truth (two-stage)', zorder=3)
    ax.plot(x, bd['rul_single'], '--', color='#e74c3c', linewidth=1.5, label='Single-stage', alpha=0.8)
    ax.plot(x, bd['rul_twostage_detected'], '-.', color='#3498db', linewidth=1.5, 
            label='Two-stage (detected)', alpha=0.8)
    
    # Mark onset points
    ax.axvline(bd['onset_idx'], color='green', linestyle=':', linewidth=1.5, alpha=0.7,
               label=f'Manual onset ({bd["onset_idx"]})')
    det = detected_onsets[bid]
    if det is not None and det != bd['onset_idx']:
        ax.axvline(det, color='orange', linestyle=':', linewidth=1.5, alpha=0.7,
                   label=f'Detected onset ({det})')
    
    # Shade the error region between single-stage and ground truth
    ax.fill_between(x, bd['rul_single'], bd['rul_twostage'], 
                    alpha=0.15, color='#e74c3c', label='Single-stage error')
    
    single_err = mae(bd['rul_twostage'], bd['rul_single'])
    ts_err = mae(bd['rul_twostage'], bd['rul_twostage_detected'])
    ax.set_title(f'{bid} ({bd["condition"]})\n'
                 f'Single MAE={single_err:.1f}, Two-stage MAE={ts_err:.1f}', fontsize=10)
    ax.set_xlabel('File Index')
    ax.set_ylabel('RUL')
    ax.legend(fontsize=7, loc='upper right')
    ax.set_ylim(-5, MAX_RUL + 10)

plt.suptitle('Single-Stage vs Two-Stage RUL Labels', fontsize=14, y=1.01)
plt.tight_layout()
plt.show()

In [None]:
# Grid plot: all 15 bearings
fig, axes = plt.subplots(5, 3, figsize=(18, 18))

conditions = ['35Hz12kN', '37.5Hz11kN', '40Hz10kN']
for col, cond in enumerate(conditions):
    cond_bearings = sorted([b for b in bearing_data if bearing_data[b]['condition'] == cond])
    for row, bid in enumerate(cond_bearings):
        ax = axes[row, col]
        bd = bearing_data[bid]
        x = np.arange(bd['n_files'])
        
        ax.plot(x, bd['rul_twostage'], 'k-', linewidth=1.5, label='Ground Truth')
        ax.plot(x, bd['rul_single'], '--', color='#e74c3c', linewidth=1, alpha=0.8, label='Single')
        ax.plot(x, bd['rul_twostage_detected'], '-.', color='#3498db', linewidth=1, 
                alpha=0.8, label='Two-stage')
        ax.axvline(bd['onset_idx'], color='green', linestyle=':', linewidth=1, alpha=0.6)
        
        s_err = mae(bd['rul_twostage'], bd['rul_single'])
        t_err = mae(bd['rul_twostage'], bd['rul_twostage_detected'])
        ax.set_title(f'{bid}: S={s_err:.1f}, T={t_err:.1f}', fontsize=9)
        ax.tick_params(labelsize=7)
        if row == 0:
            ax.set_title(f'{cond}\n{ax.get_title()}', fontsize=9)
        if row == 4:
            ax.set_xlabel('File Index', fontsize=8)
        if col == 0:
            ax.set_ylabel('RUL', fontsize=8)
        if row == 0 and col == 2:
            ax.legend(fontsize=6, loc='upper right')

plt.suptitle('Single-Stage (S) vs Two-Stage (T) MAE - All Bearings', fontsize=14, y=1.01)
plt.tight_layout()
plt.show()

## 4. Onset Detection Accuracy Impact on RUL Metrics

Analyze how onset detection errors affect the final RUL prediction quality.

In [None]:
# Onset timing analysis
true_onsets = np.array([bearing_data[b]['onset_idx'] for b in sorted(bearing_data)])
pred_onsets = np.array([detected_onsets[b] if detected_onsets[b] is not None else -1
                        for b in sorted(bearing_data)])

timing_metrics = onset_timing_mae(true_onsets, pred_onsets)
print('Onset Detection Timing Metrics:')
print(f'  MAE (samples): {timing_metrics["onset_timing_mae_samples"]:.1f}')
print(f'  Valid detections: {int(timing_metrics["n_valid"])}/15')

# Scatter: onset detection error vs MAE improvement
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# 1. Onset error vs RUL MAE improvement
ax = axes[0]
valid = results_df['onset_error'].notna()
ax.scatter(results_df.loc[valid, 'onset_error'].abs(),
           results_df.loc[valid, 'mae_improvement_pct'],
           c=[{'35Hz12kN': '#3498db', '37.5Hz11kN': '#2ecc71', '40Hz10kN': '#e74c3c'}[c]
              for c in results_df.loc[valid, 'condition']], s=80, zorder=3)
for _, r in results_df[valid].iterrows():
    ax.annotate(r['bearing'].replace('Bearing', ''), 
                (abs(r['onset_error']), r['mae_improvement_pct']),
                fontsize=7, ha='left', va='bottom')
ax.axhline(0, color='gray', linestyle=':', alpha=0.5)
ax.set_xlabel('|Onset Detection Error| (samples)')
ax.set_ylabel('MAE Improvement (%)')
ax.set_title('Onset Error vs. RUL MAE Improvement')

# 2. Onset percentage vs MAE improvement
ax = axes[1]
ax.scatter(results_df['onset_pct'], results_df['mae_improvement_pct'],
           c=[{'35Hz12kN': '#3498db', '37.5Hz11kN': '#2ecc71', '40Hz10kN': '#e74c3c'}[c]
              for c in results_df['condition']], s=80, zorder=3)
for _, r in results_df.iterrows():
    ax.annotate(r['bearing'].replace('Bearing', ''),
                (r['onset_pct'], r['mae_improvement_pct']),
                fontsize=7, ha='left', va='bottom')
ax.axhline(0, color='gray', linestyle=':', alpha=0.5)
ax.set_xlabel('Onset Point (% of Total Life)')
ax.set_ylabel('MAE Improvement (%)')
ax.set_title('Onset Timing vs. RUL MAE Improvement')

# 3. Per-bearing MAE comparison (grouped bar)
ax = axes[2]
x_pos = np.arange(len(results_df))
w = 0.35
ax.bar(x_pos - w/2, results_df['single_mae'], w, label='Single-stage', color='#e74c3c', alpha=0.7)
ax.bar(x_pos + w/2, results_df['ts_mae'], w, label='Two-stage (det)', color='#3498db', alpha=0.7)
ax.set_xticks(x_pos)
ax.set_xticklabels([b.replace('Bearing', '') for b in results_df['bearing']], 
                    rotation=45, fontsize=7)
ax.set_ylabel('MAE')
ax.set_title('Per-Bearing MAE Comparison')
ax.legend(fontsize=8)

plt.tight_layout()
plt.show()

In [None]:
# Sensitivity: how does artificial onset error affect two-stage RUL quality?
# Shift the detected onset by -50 to +50 samples and measure MAE
offsets = np.arange(-50, 51, 5)
offset_maes = []

for offset in offsets:
    all_true_shifted = []
    all_pred_shifted = []
    for bid in sorted(bearing_data):
        bd = bearing_data[bid]
        shifted_onset = max(0, min(bd['n_files'] - 1, bd['onset_idx'] + offset))
        pred_rul = compute_twostage_rul(bd['n_files'], shifted_onset, max_rul=MAX_RUL)
        all_true_shifted.append(bd['rul_twostage'])
        all_pred_shifted.append(pred_rul)
    
    y_true_cat = np.concatenate(all_true_shifted)
    y_pred_cat = np.concatenate(all_pred_shifted)
    offset_maes.append(mae(y_true_cat, y_pred_cat))

fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(offsets, offset_maes, 'o-', color='#3498db', linewidth=2)
ax.axvline(0, color='green', linestyle='--', alpha=0.7, label='Perfect onset detection')
ax.set_xlabel('Onset Detection Error (samples)', fontsize=11)
ax.set_ylabel('Global MAE', fontsize=11)
ax.set_title('Sensitivity: Onset Detection Error vs. RUL MAE', fontsize=13)
ax.legend(fontsize=9)

# Find and annotate the actual detector error
valid_errors = [e for e in results_df['onset_error'] if e is not None]
mean_abs_err = np.mean(np.abs(valid_errors))
ax.axvspan(-mean_abs_err, mean_abs_err, alpha=0.1, color='orange',
           label=f'Mean |detector error|: {mean_abs_err:.0f} samples')
ax.legend(fontsize=9)
plt.tight_layout()
plt.show()

print(f'MAE at perfect onset: {offset_maes[len(offsets)//2]:.3f}')
print(f'MAE at +/- 10 sample error: {offset_maes[len(offsets)//2 + 2]:.3f}')
print(f'MAE at +/- 25 sample error: {offset_maes[len(offsets)//2 + 5]:.3f}')

## 5. When Two-Stage Helps vs. When It Doesn't

Identify which bearing types benefit most from the two-stage approach.

In [None]:
# Classify bearings by two-stage benefit
results_df['category'] = pd.cut(
    results_df['mae_improvement_pct'],
    bins=[-float('inf'), 0, 20, float('inf')],
    labels=['Hurts', 'Small benefit', 'Large benefit']
)

print('Two-Stage Impact by Bearing')
print('=' * 80)

for cat in ['Large benefit', 'Small benefit', 'Hurts']:
    cat_df = results_df[results_df['category'] == cat].sort_values('mae_improvement_pct', ascending=False)
    print(f'\n--- {cat} ({len(cat_df)} bearings) ---')
    for _, r in cat_df.iterrows():
        onset_err = f'{r["onset_error"]:+.0f}' if r['onset_error'] is not None else 'N/A'
        print(f'  {r["bearing"]:<14} onset@{r["onset_pct"]:4.1f}% '
              f'MAE: {r["single_mae"]:5.1f} -> {r["ts_mae"]:5.1f} '
              f'({r["mae_improvement_pct"]:+.1f}%)  onset_err={onset_err}')

In [None]:
# Summary by operating condition
print('\nSummary by Operating Condition')
print('=' * 70)
for cond in conditions:
    cond_df = results_df[results_df['condition'] == cond]
    print(f'\n{cond}:')
    print(f'  Bearings: {len(cond_df)}')
    print(f'  Mean onset: {cond_df["onset_pct"].mean():.1f}% of life')
    print(f'  Single-stage MAE: {cond_df["single_mae"].mean():.2f}')
    print(f'  Two-stage MAE:    {cond_df["ts_mae"].mean():.2f}')
    print(f'  Mean improvement: {cond_df["mae_improvement_pct"].mean():.1f}%')
    n_helped = (cond_df['mae_improvement_pct'] > 0).sum()
    print(f'  Bearings helped:  {n_helped}/{len(cond_df)}')

In [None]:
# Visualization: MAE improvement by bearing, colored by condition
fig, ax = plt.subplots(figsize=(14, 6))

cond_colors = {'35Hz12kN': '#3498db', '37.5Hz11kN': '#2ecc71', '40Hz10kN': '#e74c3c'}
sorted_results = results_df.sort_values('mae_improvement_pct', ascending=True)

bars = ax.barh(
    range(len(sorted_results)),
    sorted_results['mae_improvement_pct'],
    color=[cond_colors[c] for c in sorted_results['condition']],
    alpha=0.8
)
ax.set_yticks(range(len(sorted_results)))
ax.set_yticklabels(sorted_results['bearing'], fontsize=9)
ax.axvline(0, color='black', linewidth=0.8)
ax.set_xlabel('MAE Improvement (%)', fontsize=11)
ax.set_title('Two-Stage MAE Improvement by Bearing', fontsize=13)

# Add legend for conditions
from matplotlib.patches import Patch
legend_elements = [Patch(facecolor=c, alpha=0.8, label=l) 
                   for l, c in cond_colors.items()]
ax.legend(handles=legend_elements, loc='lower right', fontsize=9)

plt.tight_layout()
plt.show()

In [None]:
# Final summary
print('=' * 70)
print('TWO-STAGE EVALUATION SUMMARY')
print('=' * 70)

print('\n1. OVERALL METRICS (ground truth = two-stage RUL with manual onset)')
print('-' * 50)
print(f'  Single-stage MAE: {results_df["single_mae"].mean():.2f} (mean across 15 bearings)')
print(f'  Two-stage MAE:    {results_df["ts_mae"].mean():.2f} (with Threshold detector)')
print(f'  Improvement:      {results_df["mae_improvement_pct"].mean():.1f}%')

n_helped = (results_df['mae_improvement_pct'] > 0).sum()
n_hurt = (results_df['mae_improvement_pct'] < 0).sum()
print(f'\n2. BEARING-LEVEL RESULTS')
print('-' * 50)
print(f'  Bearings improved: {n_helped}/15 ({n_helped/15*100:.0f}%)')
print(f'  Bearings hurt:     {n_hurt}/15 ({n_hurt/15*100:.0f}%)')
print(f'  No change:         {15 - n_helped - n_hurt}/15')

print(f'\n3. ONSET DETECTION QUALITY')
print('-' * 50)
valid_errors = results_df['onset_error'].dropna()
print(f'  Detection rate: {len(valid_errors)}/15')
print(f'  Mean |error|: {valid_errors.abs().mean():.1f} samples')
print(f'  Median |error|: {valid_errors.abs().median():.1f} samples')
within_5 = (valid_errors.abs() <= 5).sum()
within_10 = (valid_errors.abs() <= 10).sum()
print(f'  Within 5 samples: {within_5}/15 ({within_5/15*100:.0f}%)')
print(f'  Within 10 samples: {within_10}/15 ({within_10/15*100:.0f}%)')

print(f'\n4. KEY FINDINGS')
print('-' * 50)
early_onset = results_df[results_df['onset_pct'] < 30]
late_onset = results_df[results_df['onset_pct'] >= 30]
print(f'  Early onset (<30% life): {len(early_onset)} bearings, '
      f'mean improvement={early_onset["mae_improvement_pct"].mean():.1f}%')
print(f'  Late onset (>=30% life): {len(late_onset)} bearings, '
      f'mean improvement={late_onset["mae_improvement_pct"].mean():.1f}%')
print(f'\n  Two-stage benefits MOST when:')
print(f'  - Onset occurs late in bearing life (large healthy region to correct)')
print(f'  - Onset detector accurately identifies the degradation start point')
print(f'  - Single-stage model wastes capacity fitting the flat pre-onset region')
print(f'\n  Two-stage benefits LEAST when:')
print(f'  - Onset is very early (little difference between approaches)')
print(f'  - Onset detection has large errors (wrong onset -> wrong RUL shape)')
print(f'  - Bearing has gradual degradation that is hard to detect')