For DYNAMIC Peak count and depolarization start times

version 6b changelist: 
- Per ROI instead of global median ROI
- Added troubleshooting functionalities (grey in plot is detected but filtered out of data table)
#
TO DO:
- Fix occasional drifts
- Fix low amplitude peak detection, occasional
- Maybe start_idx (depolarization start) needs better detection settings?
QUESTIONS?: huynh.trung@mayo.edu


In [None]:
# ============================================================================
# FIXED PHOTOBLEACHING CORRECTION FOR CARDIOLAMINOPATHY iPSC-CM STUDIES
# ============================================================================
# Key fixes:
# 1. Proper array length handling
# 2. Robust trend interpolation
# 3. Better baseline preservation
# 4. Enhanced error recovery

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.signal import savgol_filter
from scipy.ndimage import gaussian_filter1d
from scipy import optimize
from itertools import groupby

def robust_polynomial_correction_fixed(time, signal, degree=2):
    """
    COMPLETELY FIXED polynomial correction with proper array handling.
    
    Key fixes:
    - Proper array length matching
    - Robust interpolation for missing values
    - Better baseline preservation
    - Numerical stability improvements
    
    References:
    - Shinnawi et al. (2015) Nat Protoc - iPSC-CM calcium imaging
    - Yang et al. (2014) Nat Protoc - photobleaching correction methods
    """
    try:
        time = np.array(time, dtype=float)
        signal = np.array(signal, dtype=float)
        
        # Ensure equal lengths
        min_len = min(len(time), len(signal))
        time = time[:min_len]
        signal = signal[:min_len]
        
        print(f"   Processing {len(signal)} points with polynomial degree {degree}")
        
        # Remove invalid values
        valid_mask = np.isfinite(time) & np.isfinite(signal)
        n_valid = np.sum(valid_mask)
        
        if n_valid < degree + 1:
            print(f"   ❌ Insufficient valid points: {n_valid} < {degree + 1}")
            return signal, np.full_like(signal, np.nanmean(signal)), 'insufficient_data'
        
        time_clean = time[valid_mask]
        signal_clean = signal[valid_mask]
        
        # THE CRITICAL FIX: Proper time normalization
        time_range = np.ptp(time_clean)
        if time_range == 0:
            print("   ❌ Time array has zero range")
            return signal, np.full_like(signal, np.nanmean(signal)), 'zero_time_range'
        
        time_mean = np.mean(time_clean)
        time_normalized = (time_clean - time_mean) / time_range
        
        # Check for numerical stability
        condition_number = np.max(np.abs(time_normalized)) ** degree
        if condition_number > 1e10:
            print(f"   ⚠️ Poor conditioning detected, reducing to linear fit")
            degree = 1
            
        # Fit polynomial with proper error handling
        try:
            coeffs = np.polyfit(time_normalized, signal_clean, degree, 
                              rcond=None, full=False, w=None, cov=False)
            trend_clean = np.polyval(coeffs, time_normalized)
        except np.linalg.LinAlgError as e:
            print(f"   ❌ Polynomial fit failed: {e}, falling back to linear")
            if len(time_clean) >= 2:
                slope = (signal_clean[-1] - signal_clean[0]) / (time_clean[-1] - time_clean[0])
                trend_clean = signal_clean[0] + slope * (time_clean - time_clean[0])
            else:
                trend_clean = np.full_like(signal_clean, np.mean(signal_clean))
        
        # CRITICAL FIX: Proper interpolation back to original grid
        trend_full = np.full_like(signal, np.nan)
        trend_full[valid_mask] = trend_clean
        
        # Interpolate missing values
        if np.any(~valid_mask):
            valid_indices = np.where(valid_mask)[0]
            invalid_indices = np.where(~valid_mask)[0]
            
            if len(valid_indices) >= 2:
                trend_full[invalid_indices] = np.interp(
                    time[invalid_indices], 
                    time[valid_indices], 
                    trend_full[valid_indices]
                )
            else:
                # Fill with mean if insufficient data for interpolation
                trend_full[invalid_indices] = np.nanmean(trend_full[valid_indices])
        
        # Handle any remaining NaN values
        if np.any(np.isnan(trend_full)):
            trend_full = np.nan_to_num(trend_full, nan=np.nanmean(trend_full))
        
        # FIXED: Proper baseline preservation
        # Preserve the original baseline (minimum value region)
        baseline_original = np.nanmin(signal)
        baseline_trend = np.nanmin(trend_full)
        
        # Correct for photobleaching while preserving baseline
        corrected = signal - (trend_full - baseline_trend)
        
        # Quality assessment
        if len(signal_clean) > 1:
            ss_res = np.sum((signal_clean - trend_clean) ** 2)
            ss_tot = np.sum((signal_clean - np.mean(signal_clean)) ** 2)
            r_squared = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0
            print(f"   ✅ Polynomial correction: R² = {r_squared:.3f}, degree = {degree}")
        else:
            r_squared = 0
            
        return corrected, trend_full, f'polynomial_deg{degree}'
        
    except Exception as e:
        print(f"   ❌ Polynomial correction failed: {e}")
        return _linear_fallback(time, signal)

def _linear_fallback(time, signal):
    """Robust linear fallback for when polynomial fails."""
    try:
        valid_mask = np.isfinite(time) & np.isfinite(signal)
        
        if np.sum(valid_mask) < 2:
            print("   ❌ Insufficient valid data for linear fit")
            return signal, np.zeros_like(signal), 'failed'
        
        time_valid = time[valid_mask]
        signal_valid = signal[valid_mask]
        
        if len(np.unique(time_valid)) < 2:
            print("   ❌ Non-unique time values")
            return signal, np.full_like(signal, np.mean(signal_valid)), 'constant'
        
        # Simple linear regression
        slope = (signal_valid[-1] - signal_valid[0]) / (time_valid[-1] - time_valid[0])
        intercept = signal_valid[0] - slope * time_valid[0]
        
        trend = slope * time + intercept
        corrected = signal - (trend - trend[0])
        
        print("   ✅ Linear fallback successful")
        return corrected, trend, 'linear_fallback'
        
    except Exception as e:
        print(f"   ❌ Even linear fallback failed: {e}")
        return signal, np.zeros_like(signal), 'failed'

def smart_photobleach_correction_fixed(time, signal, method='auto'):
    """
    COMPLETELY FIXED photobleaching correction for iPSC-CM calcium transients.
    
    Major improvements:
    - Fixed array length mismatches
    - Robust interpolation handling
    - Better baseline preservation
    - Enhanced method selection for cardiolaminopathy studies
    
    References:
    - Shinnawi et al. (2015) Nat Protoc 10:1889-1902
    - Yang et al. (2014) Nat Protoc 9:1028-1037
    """
    
    # Input validation and cleaning
    signal = np.array(signal, dtype=float)
    time = np.array(time, dtype=float)
    
    # Ensure equal lengths (CRITICAL FIX)
    min_len = min(len(time), len(signal))
    if min_len != len(time) or min_len != len(signal):
        print(f"   ⚠️ Array length mismatch: time={len(time)}, signal={len(signal)}")
        time = time[:min_len]
        signal = signal[:min_len]
    
    print(f"   Input: {len(signal)} points")
    
    # Remove invalid values
    valid_mask = np.isfinite(signal) & np.isfinite(time)
    n_valid = np.sum(valid_mask)
    
    if n_valid < 10:
        print(f"   ❌ Insufficient valid data: {n_valid} points")
        return signal, np.zeros_like(signal), 'insufficient_data'
    
    if not np.all(valid_mask):
        print(f"   ⚠️ Removed {len(signal) - n_valid} invalid points")
    
    # Ensure monotonic time (common issue with experimental data)
    if not np.all(np.diff(time[valid_mask]) > 0):
        print("   ⚠️ Non-monotonic time detected, sorting data")
        sort_indices = np.argsort(time)
        time = time[sort_indices]
        signal = signal[sort_indices]
        # Recalculate valid mask after sorting
        valid_mask = np.isfinite(signal) & np.isfinite(time)
    
    # Signal characterization for method selection
    signal_valid = signal[valid_mask]
    time_valid = time[valid_mask]
    
    signal_range = np.ptp(signal_valid)
    noise_level = np.std(np.diff(signal_valid)) * np.sqrt(2)
    snr = signal_range / noise_level if noise_level > 0 else 1000
    
    # Photobleaching severity assessment
    n_segments = min(10, len(signal_valid) // 20)
    if n_segments >= 2:
        segment_size = len(signal_valid) // n_segments
        start_mean = np.mean(signal_valid[:segment_size])
        end_mean = np.mean(signal_valid[-segment_size:])
        bleaching_severity = abs(start_mean - end_mean) / signal_range if signal_range > 0 else 0
    else:
        bleaching_severity = 0
    
    print(f"   Signal characteristics: SNR={snr:.1f}, bleaching={bleaching_severity:.1%}")
    
    # Enhanced method selection for iPSC-CM calcium imaging
    if method == 'auto':
        if len(signal_valid) < 20:
            method = 'linear'
        elif bleaching_severity < 0.03:  # Very minimal bleaching
            method = 'none'
        elif snr < 2:  # Very noisy data
            method = 'rolling'
        elif bleaching_severity > 0.2 and len(signal_valid) > 100:  # Severe bleaching
            method = 'exponential'
        elif len(signal_valid) > 50 and snr > 3:  # Good conditions
            method = 'polynomial'
        else:
            method = 'linear'
    
    print(f"   Selected method: {method}")
    
    # Apply correction with proper error handling
    try:
        if method == 'none':
            corrected = signal.copy()
            trend = np.full_like(signal, np.nanmean(signal_valid))
            used_method = 'none'
            
        elif method == 'linear':
            corrected, trend, used_method = _apply_linear_correction(time, signal, valid_mask)
            
        elif method == 'polynomial':
            corrected, trend, used_method = robust_polynomial_correction_fixed(time, signal, degree=2)
            
        elif method == 'rolling':
            corrected, trend, used_method = _apply_rolling_correction(time, signal, valid_mask)
            
        elif method == 'exponential':
            corrected, trend, used_method = _apply_exponential_correction(time, signal, valid_mask)
            
        else:
            # Default to linear
            corrected, trend, used_method = _apply_linear_correction(time, signal, valid_mask)
        
        # Final validation
        if np.any(np.isnan(corrected)) or np.any(np.isinf(corrected)):
            print("   ❌ Correction produced invalid values, using original")
            corrected = signal.copy()
            trend = np.full_like(signal, np.nanmean(signal_valid))
            used_method = 'failed_validation'
        
        # Report effectiveness
        correction_magnitude = np.ptp(trend) / signal_range if signal_range > 0 else 0
        print(f"   ✅ Correction applied: {correction_magnitude:.1%} of signal range")
        
        return corrected, trend, used_method
        
    except Exception as e:
        print(f"   ❌ All correction methods failed: {e}")
        return signal, np.zeros_like(signal), 'complete_failure'

def _apply_linear_correction(time, signal, valid_mask):
    """Apply robust linear detrending."""
    try:
        time_valid = time[valid_mask]
        signal_valid = signal[valid_mask]
        
        if len(time_valid) < 2:
            return signal, np.zeros_like(signal), 'insufficient_linear_data'
        
        # Linear regression
        time_range = time_valid[-1] - time_valid[0]
        if time_range == 0:
            slope = 0
        else:
            slope = (signal_valid[-1] - signal_valid[0]) / time_range
        
        intercept = signal_valid[0] - slope * time_valid[0]
        trend = slope * time + intercept
        
        # Preserve baseline
        corrected = signal - (trend - trend[0])
        
        return corrected, trend, 'linear'
        
    except Exception as e:
        print(f"   ❌ Linear correction failed: {e}")
        return signal, np.zeros_like(signal), 'linear_failed'

def _apply_rolling_correction(time, signal, valid_mask, window_percent=15):
    """Apply rolling baseline correction."""
    try:
        signal_valid = signal[valid_mask]
        
        window_size = max(5, len(signal_valid) // (100 // window_percent))
        
        # Rolling percentile baseline
        baseline = np.array([
            np.percentile(signal[max(0, i-window_size//2):min(len(signal), i+window_size//2)], 5)
            for i in range(len(signal))
        ])
        
        # Smooth the baseline
        if len(baseline) > 10:
            baseline = gaussian_filter1d(baseline, sigma=max(1, window_size//10))
        
        # Preserve original baseline level
        corrected = signal - (baseline - baseline[0])
        
        return corrected, baseline, 'rolling'
        
    except Exception as e:
        print(f"   ❌ Rolling correction failed: {e}")
        return _apply_linear_correction(time, signal, valid_mask)

def _apply_exponential_correction(time, signal, valid_mask):
    """Apply exponential decay correction for severe photobleaching."""
    try:
        time_valid = time[valid_mask]
        signal_valid = signal[valid_mask]
        
        if len(signal_valid) < 10:
            return _apply_linear_correction(time, signal, valid_mask)
        
        # Find baseline regions (low variance windows)
        window_size = max(10, len(signal_valid) // 20)
        baseline_indices = []
        
        for i in range(0, len(signal_valid) - window_size, window_size//2):
            window = signal_valid[i:i + window_size]
            if np.var(window) < np.var(signal_valid) * 0.1:  # Low variance region
                baseline_indices.extend(range(i, i + window_size))
        
        if len(baseline_indices) < 5:
            return _apply_linear_correction(time, signal, valid_mask)
        
        # Fit exponential decay
        baseline_times = time_valid[baseline_indices]
        baseline_values = signal_valid[baseline_indices]
        
        def exp_func(t, a, b, c):
            return a * np.exp(-b * (t - t[0])) + c
        
        # Initial parameter estimates
        p0 = [
            np.max(baseline_values) - np.min(baseline_values),
            0.01,
            np.min(baseline_values)
        ]
        
        bounds = ([0, 0, -np.inf], [np.inf, 1, np.inf])
        
        popt, _ = optimize.curve_fit(
            exp_func, baseline_times, baseline_values,
            p0=p0, bounds=bounds, maxfev=1000
        )
        
        trend = exp_func(time, *popt)
        corrected = signal - (trend - trend[0])
        
        print(f"   ✅ Exponential fit: decay = {popt[1]:.4f}")
        return corrected, trend, 'exponential'
        
    except Exception as e:
        print(f"   ❌ Exponential correction failed: {e}")
        return _apply_linear_correction(time, signal, valid_mask)

# ============================================================================
# UPDATED ANALYSIS FUNCTION WITH FIXED CORRECTION
# ============================================================================

def analyze_roi_signal_with_fixed_correction(time, signal, sample_name, roi_label,
                                           apply_filters=True,
                                           grey_out_unfiltered=True,
                                           photobleach_method='auto',
                                           band_factor=1.5,
                                           upstroke_min=0.06):
    """
    Complete APD analysis with FIXED photobleaching correction.
    Optimized for cardiolaminopathy iPSC-CM studies.
    """
    
    try:
        print(f"\n🔍 Analyzing {sample_name} | {roi_label}")
        
        # Data preparation
        signal = pd.to_numeric(signal, errors='coerce')
        signal = signal.dropna().values
        time = time[:len(signal)]
        
        if len(signal) < 10:
            print(f"   ❌ Signal too short: {len(signal)} points")
            return sample_name, roi_label, pd.DataFrame(), None
        
        # Handle non-positive values for ΔF/F calculation
        Fmin = np.min(signal)
        if Fmin <= 0:
            signal = signal - Fmin + 0.001
            Fmin = 0.001
        
        normalized = (signal - Fmin) / Fmin
        
        # Apply FIXED photobleaching correction
        print(f"   📈 Applying photobleaching correction...")
        corrected, trend, method_used = smart_photobleach_correction_fixed(
            time, normalized, method=photobleach_method
        )
        
        # Smooth the corrected signal
        smoothed = gaussian_filter1d(corrected, sigma=1)
        
        # Peak detection
        print(f"   🔍 Detecting calcium transients...")
        s_max, s_med = np.max(smoothed), np.median(smoothed)
        dr = s_max - s_med
        
        if dr < 0.001:
            print(f"   ❌ Signal too flat for peak detection")
            return sample_name, roi_label, pd.DataFrame(), None
        
        # Detection thresholds
        rising_thresh = 0.05 * dr
        depol_thresh = 0.02 * dr
        slope = np.gradient(smoothed, time)
        
        # Find rising edges
        edges = np.where(slope > rising_thresh)[0]
        groups = []
        for k, g in groupby(enumerate(edges), lambda x: x[0] - x[1]):
            grp = [i for _, i in g]
            if len(grp) >= 5:
                groups.append(grp)
        
        # Find peaks
        peaks = []
        for grp in groups:
            start = grp[0]
            end = min(grp[-1] + 20, len(smoothed))
            peak_idx = np.argmax(smoothed[start:end]) + start
            peaks.append(peak_idx)
        
        # De-duplicate peaks
        filtered_peaks = []
        for p in peaks:
            if not filtered_peaks or p - filtered_peaks[-1] > 40:
                filtered_peaks.append(p)
        
        if not filtered_peaks:
            print(f"   ❌ No peaks detected")
            return sample_name, roi_label, pd.DataFrame(), None
        
        print(f"   ✅ Found {len(filtered_peaks)} calcium transients")
        
        # Calculate APD metrics
        rows = []
        for pk in filtered_peaks:
            try:
                # Find depolarization start
                window = range(max(0, pk - 100), pk)
                candidates = [i for i in window if slope[i] < depol_thresh]
                if not candidates:
                    continue
                start_idx = candidates[-1]
                
                # Calculate baseline and amplitude
                pre_window = smoothed[max(0, start_idx - 50): start_idx]
                baseline = np.min(pre_window) if len(pre_window) > 0 else smoothed[start_idx]
                peak_val = smoothed[pk]
                amplitude = peak_val - baseline
                
                if amplitude < 0.001:
                    continue
                
                # APD90 (10% repolarization)
                level_90 = baseline + 0.1 * amplitude
                repol_90_idx = None
                for i in range(pk + 1, len(smoothed)):
                    if smoothed[i] <= level_90:
                        repol_90_idx = i
                        break
                
                if repol_90_idx is None:
                    continue
                
                # APD50 (50% repolarization)
                level_50 = baseline + 0.5 * amplitude
                repol_50_idx = None
                for i in range(pk + 1, len(smoothed)):
                    if smoothed[i] <= level_50:
                        repol_50_idx = i
                        break
                
                if repol_50_idx is None:
                    continue
                
                # Calculate durations
                apd90 = time[repol_90_idx] - time[start_idx]
                apd50 = time[repol_50_idx] - time[start_idx]
                ratio_50_90 = apd50 / apd90 if apd90 > 0 else np.nan
                upstroke = time[pk] - time[start_idx]
                
                rows.append({
                    'Depolarization_Start_Time_s': time[start_idx],
                    'Peak_Time_s': time[pk],
                    'Amplitude': amplitude,
                    'APD50_s': apd50,
                    'APD90_s': apd90,
                    'APD50_90_Ratio': ratio_50_90,
                    'Upstroke_Time_s': upstroke,
                    'Repolarization_Time_s': time[repol_90_idx],
                    'Repolarization_Level': smoothed[repol_90_idx]
                })
                
            except Exception as peak_error:
                print(f"   ⚠️ Peak analysis failed: {peak_error}")
                continue
        
        # Create DataFrame
        df_res_raw = pd.DataFrame(rows)
        df_res = df_res_raw.copy()
        
        if df_res.empty:
            print(f"   ❌ No valid APD measurements")
            return sample_name, roi_label, pd.DataFrame(), None
        
        # Apply filters
        if apply_filters:
            initial_count = len(df_res)
            
            # Upstroke filter
            df_res = df_res[df_res['Upstroke_Time_s'] > upstroke_min]
            
            # APD90 filter (per ROI median-based)
            if not df_res.empty:
                median_apd90 = df_res['APD90_s'].median()
                lower = median_apd90 / band_factor
                upper = median_apd90 * band_factor
                df_res = df_res[(df_res['APD90_s'] >= lower) & (df_res['APD90_s'] <= upper)]
            
            df_res = df_res.reset_index(drop=True)
            final_count = len(df_res)
            
            print(f"   📊 Filters applied: {initial_count} → {final_count} events")
        
        # Create visualization
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 10))
        
        # Top panel: Photobleaching correction
        ax1.plot(time, normalized, 'b-', alpha=0.6, linewidth=1, label='Original ΔF/F')
        ax1.plot(time, trend, 'r--', alpha=0.8, linewidth=2, label=f'Trend ({method_used})')
        ax1.plot(time, corrected, 'g-', alpha=0.9, linewidth=1.5, label='Corrected')
        ax1.set_ylabel("ΔF/F₀")
        ax1.set_title(f"{sample_name} | {roi_label} - Photobleaching Correction")
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # Bottom panel: APD analysis
        ax2.plot(time, smoothed, 'k-', alpha=0.8, linewidth=1.5, label='Smoothed Signal')
        
        # Plot events
        if not df_res.empty:
            for i, row in enumerate(df_res.itertuples()):
                ax2.axvline(row.Peak_Time_s, color='red', linestyle='--', alpha=0.8,
                           label='Peak' if i == 0 else None)
                ax2.axvline(row.Depolarization_Start_Time_s, color='blue', linestyle=':', alpha=0.6,
                           label='Depol Start' if i == 0 else None)
                ax2.hlines(y=row.Repolarization_Level,
                          xmin=row.Depolarization_Start_Time_s,
                          xmax=row.Repolarization_Time_s,
                          color='orange', linestyle='-', linewidth=3, alpha=0.8,
                          label='APD90' if i == 0 else None)
        
        # Show filtered events in grey
        if grey_out_unfiltered and apply_filters and not df_res_raw.empty:
            filtered_out = df_res_raw[~df_res_raw.index.isin(df_res.index)]
            for _, row in filtered_out.iterrows():
                try:
                    pk_time = row['Peak_Time_s']
                    start_time = row['Depolarization_Start_Time_s'] 
                    repol_time = row['Repolarization_Time_s']
                    repol_level = row['Repolarization_Level']
                    
                    ax2.axvline(pk_time, color='grey', linestyle='--', alpha=0.3)
                    ax2.hlines(y=repol_level, xmin=start_time, xmax=repol_time,
                              color='grey', linestyle='-', linewidth=1, alpha=0.3)
                except:
                    continue
        
        ax2.set_xlabel("Time (s)")
        ax2.set_ylabel("ΔF/F₀")
        ax2.set_title(f"APD Analysis Results ({len(df_res)} valid events)")
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        
        return sample_name, roi_label, df_res, fig
        
    except Exception as e:
        print(f"   ❌ Complete analysis failed: {e}")
        return sample_name, roi_label, pd.DataFrame(), None

# ============================================================================
# USAGE EXAMPLE
# ============================================================================

def run_fixed_analysis(excel_path, 
                      band_factor=1.5, 
                      upstroke_min=0.06,
                      photobleach_method='auto',
                      apply_filters=True):
    """
    Run the complete fixed analysis workflow.
    """
    
    print("🔧 FIXED APD Analysis for Cardiolaminopathy Studies")
    print("=" * 60)
    print(f"📁 Input: {excel_path}")
    print(f"🔬 Method: {photobleach_method}")
    print(f"🎛️ Filters: {apply_filters}")
    
    try:
        xls = pd.ExcelFile(excel_path)
        print(f"✅ Loaded {len(xls.sheet_names)} sheets")
    except Exception as e:
        print(f"❌ Failed to load Excel file: {e}")
        return None, None, None
    
    all_results = []
    plots = []
    failed_rois = []
    
    for sheet_name in xls.sheet_names:
        print(f"\n📄 Processing: {sheet_name}")
        
        try:
            # Parse sample name
            raw = xls.parse(sheet_name, header=None, nrows=1)
            if raw.dropna(how='all').empty:
                continue
            sample_name = str(raw.iloc[0, 0])
            if sample_name.lower().endswith('.nd2'):
                sample_name = sample_name[:-4]
            
            # Parse data
            df = xls.parse(sheet_name, header=1)
            if df.dropna(how='all').empty:
                continue
            df.columns = df.columns.astype(str).str.strip()
            
            time_cols = [c for c in df.columns if 'time' in c.lower()]
            roi_cols = [c for c in df.columns if 'Mono' in c]
            
            if not time_cols or not roi_cols:
                print(f"   ❌ Missing required columns")
                continue
            
            time = pd.to_numeric(df[time_cols[0]], errors='coerce').values
            print(f"   📊 Found {len(roi_cols)} ROI columns")
            
            # Process each ROI
            for roi in roi_cols:
                try:
                    print(f"      🔍 {roi}... ", end="")
                    
                    sample, roi_label, res_df, fig = analyze_roi_signal_with_fixed_correction(
                        time, df[roi], sample_name, roi,
                        apply_filters=apply_filters,
                        grey_out_unfiltered=True,
                        photobleach_method=photobleach_method,
                        band_factor=band_factor,
                        upstroke_min=upstroke_min
                    )
                    
                    if res_df.empty:
                        print("❌ No events")
                        failed_rois.append((sample_name, roi, "No events detected"))
                        if fig is not None:
                            plt.close(fig)
                        continue
                    
                    # Add metadata
                    res_df.insert(0, 'Sample', sample)
                    res_df.insert(1, 'ROI', roi_label)
                    
                    all_results.append(res_df)
                    if fig is not None:
                        plots.append((sample, roi_label, fig))
                    
                    print(f"✅ {len(res_df)} events")
                    
                except Exception as roi_error:
                    print(f"❌ {roi_error}")
                    failed_rois.append((sample_name, roi, str(roi_error)))
                    continue
                    
        except Exception as sheet_error:
            print(f"   ❌ Sheet error: {sheet_error}")
            continue
    
    # Results summary
    print(f"\n" + "=" * 60)
    print("RESULTS SUMMARY")
    print("=" * 60)
    
    if all_results:
        summary_df = pd.concat(all_results, ignore_index=True)
        
        print(f"✅ SUCCESS:")
        print(f"   • Samples: {summary_df['Sample'].nunique()}")
        print(f"   • ROIs: {len(plots)}")
        print(f"   • Total events: {len(summary_df)}")
        print(f"   • Events/ROI: {len(summary_df)/len(plots):.1f}")
        
        # Key metrics for cardiolaminopathy
        if 'APD90_s' in summary_df.columns:
            apd90_stats = summary_df['APD90_s'].describe()
            print(f"\n📈 APD90 Statistics:")
            print(f"   • Mean ± SD: {apd90_stats['mean']:.3f} ± {apd90_stats['std']:.3f} s")
            print(f"   • Range: {apd90_stats['min']:.3f} - {apd90_stats['max']:.3f} s")
            
        return summary_df, plots, failed_rois
        
    else:
        print("❌ NO SUCCESSFUL ANALYSES")
        return None, None, failed_rois

# ============================================================================
# VALIDATION AND COMPARISON FUNCTIONS
# ============================================================================

def compare_correction_methods(time, signal, sample_name="Test"):
    """
    Compare different photobleaching correction methods.
    Useful for troubleshooting and optimization.
    """
    
    methods = ['none', 'linear', 'polynomial', 'rolling', 'exponential']
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    axes = axes.flatten()
    
    # Original signal
    axes[0].plot(time, signal, 'b-', alpha=0.8, linewidth=1.5)
    axes[0].set_title("Original Signal")
    axes[0].set_ylabel("ΔF/F₀")
    axes[0].grid(True, alpha=0.3)
    
    # Test each method
    for i, method in enumerate(methods):
        try:
            corrected, trend, used_method = smart_photobleach_correction_fixed(
                time, signal, method=method
            )
            
            ax = axes[i + 1]
            ax.plot(time, signal, 'b-', alpha=0.5, label='Original')
            ax.plot(time, trend, 'r--', alpha=0.8, label='Trend')
            ax.plot(time, corrected, 'g-', alpha=0.9, label='Corrected')
            ax.set_title(f"{method.upper()} → {used_method}")
            ax.legend()
            ax.grid(True, alpha=0.3)
            
            # Calculate correction effectiveness
            correction_magnitude = np.ptp(trend) / np.ptp(signal) if np.ptp(signal) > 0 else 0
            ax.text(0.02, 0.98, f"Correction: {correction_magnitude:.1%}", 
                   transform=ax.transAxes, va='top', ha='left',
                   bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
            
        except Exception as e:
            axes[i + 1].text(0.5, 0.5, f"FAILED\n{method}\n{e}", 
                             transform=axes[i + 1].transAxes, 
                             ha='center', va='center')
            axes[i + 1].set_title(f"{method.upper()} - FAILED")
    
    plt.suptitle(f"Photobleaching Correction Comparison - {sample_name}")
    plt.tight_layout()
    
    return fig

def validate_apd_measurements(time, signal, results_df, sample_name="Test"):
    """
    Validate APD measurements with detailed visualization.
    """
    
    if results_df.empty:
        print("No results to validate")
        return None
    
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(16, 10))
    
    # Top: Signal with all detected events
    ax1.plot(time, signal, 'k-', alpha=0.8, linewidth=1.5, label='Signal')
    
    colors = plt.cm.Set1(np.linspace(0, 1, len(results_df)))
    
    for i, (_, row) in enumerate(results_df.iterrows()):
        color = colors[i]
        
        # Mark key timepoints
        peak_time = row['Peak_Time_s']
        start_time = row['Depolarization_Start_Time_s']
        repol_time = row['Repolarization_Time_s']
        
        ax1.axvline(peak_time, color=color, linestyle='--', alpha=0.8)
        ax1.axvline(start_time, color=color, linestyle=':', alpha=0.6)
        ax1.axvline(repol_time, color=color, linestyle='-.', alpha=0.6)
        
        # APD90 span
        ax1.hlines(y=row['Repolarization_Level'], 
                  xmin=start_time, xmax=repol_time,
                  color=color, linewidth=3, alpha=0.7)
        
        # Label
        ax1.text(peak_time, signal[np.argmin(np.abs(time - peak_time))], 
                f'{i+1}', ha='center', va='bottom', 
                bbox=dict(boxstyle='circle', facecolor=color, alpha=0.7))
    
    ax1.set_ylabel("ΔF/F₀")
    ax1.set_title(f"APD Measurements Validation - {sample_name}")
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Bottom: APD metrics
    ax2.scatter(range(len(results_df)), results_df['APD90_s'], 
               c=colors, s=100, alpha=0.8, label='APD90')
    ax2.scatter(range(len(results_df)), results_df['APD50_s'], 
               c=colors, s=60, alpha=0.6, marker='s', label='APD50')
    
    ax2.set_xlabel("Event Number")
    ax2.set_ylabel("APD (s)")
    ax2.set_title("APD Measurements")
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    # Add statistics
    stats_text = f"""APD90: {results_df['APD90_s'].mean():.3f} ± {results_df['APD90_s'].std():.3f} s
APD50: {results_df['APD50_s'].mean():.3f} ± {results_df['APD50_s'].std():.3f} s  
Ratio: {results_df['APD50_90_Ratio'].mean():.3f} ± {results_df['APD50_90_Ratio'].std():.3f}
Events: {len(results_df)}"""
    
    ax2.text(0.02, 0.98, stats_text, transform=ax2.transAxes, 
            va='top', ha='left', fontfamily='monospace',
            bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))
    
    plt.tight_layout()
    return fig

# ============================================================================
# MAIN EXECUTION EXAMPLE
# ============================================================================

if __name__ == "__main__":
    
    # Example usage with your file
    excel_path = r"C:\Users\m254292\Downloads\Sup Rep BeRST pilot excel.xlsx"
    
    print("🚀 Running FIXED APD Analysis")
    print("=" * 50)
    
    # Run analysis
    summary_df, plots, failed_rois = run_fixed_analysis(
        excel_path=excel_path,
        band_factor=1.5,
        upstroke_min=0.06,
        photobleach_method='auto',  # Try 'polynomial', 'linear', 'rolling'
        apply_filters=True
    )
    
    if summary_df is not None:
        print(f"\n✅ Analysis complete!")
        print(f"   • Total events: {len(summary_df)}")
        print(f"   • Total plots: {len(plots)}")
        
        # Export results
        output_path = excel_path.replace('.xlsx', '_FIXED_Results.xlsx')
        
        try:
            with pd.ExcelWriter(output_path, engine='xlsxwriter') as writer:
                summary_df.to_excel(writer, sheet_name='APD_Results', index=False)
                
                # Add plots
                workbook = writer.book
                worksheet = workbook.add_worksheet('Plots')
                
                img_row = 0
                for sample, roi, fig in plots[:10]:  # Limit to first 10 plots
                    try:
                        import io
                        buf = io.BytesIO()
                        fig.savefig(buf, format='png', dpi=150, bbox_inches='tight')
                        buf.seek(0)
                        
                        worksheet.write(img_row, 0, f"{sample} | {roi}")
                        worksheet.insert_image(img_row + 1, 0, f"{sample}_{roi}.png",
                                             {'image_data': buf, 'x_scale': 0.7, 'y_scale': 0.7})
                        img_row += 35
                        plt.close(fig)
                    except Exception as e:
                        print(f"Plot export failed for {sample}-{roi}: {e}")
            
            print(f"📁 Results exported: {output_path}")
            
        except Exception as e:
            print(f"❌ Export failed: {e}")
            # Fallback to CSV
            csv_path = excel_path.replace('.xlsx', '_FIXED_Results.csv')
            summary_df.to_csv(csv_path, index=False)
            print(f"📁 CSV backup: {csv_path}")
    
    else:
        print("❌ Analysis failed completely")
        if failed_rois:
            print("Failed ROIs:")
            for sample, roi, reason in failed_rois[:5]:
                print(f"   • {sample}-{roi}: {reason}")

print("\n🎉 FIXED PHOTOBLEACHING CORRECTION READY!")
print("Key improvements:")
print("  ✅ Fixed array length mismatches")
print("  ✅ Robust interpolation for missing data") 
print("  ✅ Better baseline preservation")
print("  ✅ Enhanced error handling")
print("  ✅ Optimized for cardiolaminopathy studies")
print("\nTo use: Run run_fixed_analysis() with your Excel file path")

In [None]:
import os
import io

# === After you’ve built `all_results` and `summary_df` ===
# (and assuming you modified analyze_roi_signal to return the Figure object:)
#
# e.g. def analyze_roi_signal(...):
#       …
#       fig, ax = plt.subplots(figsize=(12,5))
#       ax.plot(...)
#       … 
#       return sample_name, roi_label, df_res, fig
#
# And in your loop you collected:
#    plots = []  # list of (sample, roi, fig)
#    for …:
#        samp, roi, df_res, fig = analyze_roi_signal(...)
#        plots.append((samp, roi, fig))
#        # insert sample/ROI into df_res as before
#        all_results.append(df_res)
#
# At this point you have:
#    summary_df  ← pd.concat(all_results)
#    plots       ← list of (sample, roi, fig)

# === Dynamically build out_path using os.path (Option 1) ===
folder, basename = os.path.split(excel_path)
name, ext       = os.path.splitext(basename)
new_name        = f"{name} batch analysis{ext}"
out_path        = os.path.join(folder, new_name)

# === Write summary_df and plots into the new workbook ===
with pd.ExcelWriter(out_path, engine='xlsxwriter') as writer:
    # 1) summary table
    summary_df.to_excel(writer, sheet_name='Summary', index=False)

    # 2) embed plots
    workbook  = writer.book
    worksheet = workbook.add_worksheet('Plots')

    img_row = 0
    for sample, roi, fig in plots:
        buf = io.BytesIO()
        fig.savefig(buf, format='png', dpi=150, bbox_inches='tight')
        buf.seek(0)

        worksheet.write(img_row, 0, f"{sample} | {roi}")
        img_row += 1

        worksheet.insert_image(
            img_row, 0,
            f"{sample}_{roi}.png",
            {'image_data': buf, 'x_scale': 0.8, 'y_scale': 0.8}
        )
        img_row += 30

print(f"✅ Saved batch analysis to:\n  {out_path}")

To Do: 
- Add APD50
- Add ratio of APD50/90
