In [None]:
import pandas as pd
import obspy
from obspy import Stream, UTCDateTime
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
import matplotlib.dates as mdates
from matplotlib.collections import LineCollection, PolyCollection
import importlib
from io import StringIO
import io
import picking_funcs_new as pf
from picking_funcs_new import (
    arrakis_read, das_read,
    arrakis_preprocess, das_preprocess,
    comb_trigger_simple, das_trigger_simple, AdvOpt, VetoOpt,
)
# importlib.reload(pf)
from typing import Sequence, Tuple, Union, Optional, Literal,List,Dict, Iterable,Mapping,Any
from pathlib import Path

import textwrap
import pyproj
from pyproj import Transformer
from operator import itemgetter
from sklearn.linear_model import RANSACRegressor, LinearRegression
from scipy.signal import find_peaks, hilbert,correlate

import itertools
import os
import datetime
from datetime import timezone, UTC
from scipy.fft import fft, ifft, fftfreq
from scipy.interpolate import pchip_interpolate
from scipy.optimize import curve_fit


# The first section in this notebook can not be added until the dataset is made publically available in February 2026. At that time this notebook will be updated.

In [None]:
def extract_event_traces(stream, df_picks, event_id, 
                         pre_pick=0.02, post_pick=0.06):
    """
    Extract geophone traces for a single event, windowing each trace 
    individually around its own P arrival pick.
    
    Parameters:
    -----------
    stream : obspy.Stream
        Full stream containing all traces (geophones and DAS)
    df_picks : pandas.DataFrame
        DataFrame with columns: 'eventid', 'station', 'pick', 'origin'
        Times are stored as strings
    event_id : str or int
        Unique event identifier to extract
    pre_pick : float
        Time in seconds BEFORE each station's P pick to include
    post_pick : float
        Time in seconds AFTER each station's P pick to include
        
    Returns:
    --------
    event_stream : obspy.Stream
        Stream containing windowed geophone traces, each centered on its own pick
    pick_times : dict
        Dictionary mapping station codes to their P pick times (UTCDateTime)
    """
    
    # Filter picks for this specific event
    event_picks = df_picks[df_picks['eventid'] == event_id].copy()
    
    if len(event_picks) == 0:
        raise ValueError(f"No picks found for event_id: {event_id}")
    
    # Filter for geophone stations only (3-character station names)
    event_picks = event_picks[event_picks['station'].str.len() == 3]
    
    # Remove station 5-1 from picks (poor data quality)
    event_picks = event_picks[event_picks['station'] != '5-1']
    
    if len(event_picks) == 0:
        raise ValueError(f"No geophone picks found for event_id: {event_id}")
    
    # Get list of stations with picks for this event
    stations_with_picks = event_picks['station'].unique().tolist()
    
    print(f"Event {event_id}:")
    print(f"  Number of geophone stations with picks: {len(stations_with_picks)}")
    print(f"  Window: {pre_pick} s before to {post_pick} s after each pick")
    print(f"  Total window length: {pre_pick + post_pick} s")
    
    # Create dictionary of pick times
    pick_times = {}
    for _, row in event_picks.iterrows():
        station = row['station']
        pick_utc = UTCDateTime(row['pick'])
        pick_times[station] = pick_utc
    
    # Extract and window each trace individually around its own pick
    event_stream = obspy.Stream()
    
    for station in stations_with_picks:
        # Get the pick time for this station
        pick_time = pick_times[station]
        
        # Define window for THIS station
        window_start = pick_time - pre_pick
        window_end = pick_time + post_pick
        
        # Try to get the trace for this station
        tr_select = stream.select(network='SQ', station=station)
        
        if len(tr_select) == 0:
            print(f"  Warning: No trace found for station {station}")
            continue
        
        if len(tr_select) > 1:
            print(f"  Warning: Multiple traces found for station {station}, using first")
        
        # Get the trace and window it
        tr = tr_select[0].copy()
        tr_windowed = tr.slice(starttime=window_start, endtime=window_end)
        
        # Check if we got valid data
        if tr_windowed.stats.npts == 0:
            print(f"  Warning: No data in window for station {station}")
            continue
        
        event_stream.append(tr_windowed)
    
    print(f"  Extracted {len(event_stream)} windowed traces")
    
    return event_stream, pick_times

In [None]:

# Example: extract windows for one early event
first_event_id = df_picks['eventid'].iloc[9420]  # Or however you identify early vs late events
# late_origin = df_picks['origin'].iloc[10921]   # Or however you identify late event
# early_origin = df_picks['origin'].iloc[9753]  # Or however you identify early vs late events
print(f"Testing with event: {first_event_id}\n")

# Extract traces for this event
event_stream, pick_times = extract_event_traces(
    comb_stZ, 
    df_picks, 
    first_event_id,
    pre_pick=0.02,
    post_pick=0.06
)
# Quick diagnostic - add this after the function returns
stations_in_stream = set([tr.stats.station for tr in event_stream])
stations_in_picks = set(pick_times.keys())

missing_pick = stations_in_stream - stations_in_picks
extra_pick = stations_in_picks - stations_in_stream

if missing_pick:
    print(f"  Stations in stream WITHOUT picks: {missing_pick}")
if extra_pick:
    print(f"  Stations with picks NOT in stream: {extra_pick}")
# Print trace information
print(f"\nTraces in event_stream:")
for tr in event_stream:
    print(f"  {tr.stats.station}: {tr.stats.npts} samples, "
          f"dt={tr.stats.delta:.4f} s")

# Print pick times
print(f"\nPick times:")
for station, pick_time in pick_times.items():
    print(f"  {station}: {pick_time}")

In [None]:
def create_reference_waveform(event_stream, pick_times, 
                              snr_window_signal=0.02, snr_window_noise=0.02,
                              cc_window_length=0.03, max_lag_sec=0.01, 
                              normalize=True):
    """
    Create a reference waveform by aligning and stacking traces using cross-correlation.
    
    Parameters:
    -----------
    event_stream : obspy.Stream
        Windowed stream for a single event (each trace windowed around its own pick)
    pick_times : dict
        Dictionary mapping station codes to P pick times (UTCDateTime)
    snr_window_signal : float
        Time window (s) after pick to calculate signal power
    snr_window_noise : float
        Time window (s) before pick to calculate noise power
    cc_window_length : float
        Length of window (s) to use for cross-correlation alignment
        Should be long enough to capture first arrival pulse only
    max_lag_sec : float
        Maximum allowed lag (in seconds) for cross-correlation alignment
    normalize : bool
        If True, normalize each trace before stacking
        
    Returns:
    --------
    reference_trace : obspy.Trace
        Stacked and aligned reference waveform
    aligned_stream : obspy.Stream
        Stream of aligned individual traces (for QC)
    snr_dict : dict
        SNR values for each station
    """
    
    if len(event_stream) == 0:
        raise ValueError("Empty stream provided")
    
    # Get sampling rate (assume all traces have same rate)
    sampling_rate = event_stream[0].stats.sampling_rate
    max_lag_samples = int(max_lag_sec * sampling_rate)
    cc_window_samples = int(cc_window_length * sampling_rate)
    
    print(f"\n  Max allowed lag: {max_lag_sec} s ({max_lag_samples} samples)")
    print(f"  Cross-correlation window: {cc_window_length} s ({cc_window_samples} samples)")
    
    # Calculate SNR for each trace
    snr_dict = {}
    valid_traces = []
    
    # For each trace, the pick should be at the start of the trace + pre_pick time
    # We need to figure out where the pick is in each trace
    for tr in event_stream:
        station = tr.stats.station
        
        # Skip if we don't have a pick for this station
        if station not in pick_times:
            print(f"  Warning: No pick time for station {station}, skipping")
            continue
        
        # Get pick time relative to trace start
        pick_time = pick_times[station]
        pick_sample = int((pick_time - tr.stats.starttime) * tr.stats.sampling_rate)
        
        # Calculate SNR
        noise_samples = int(snr_window_noise * tr.stats.sampling_rate)
        signal_samples = int(snr_window_signal * tr.stats.sampling_rate)
        
        # Make sure we have enough samples
        if pick_sample < noise_samples or pick_sample + signal_samples >= tr.stats.npts:
            print(f"  Warning: Pick too close to edge for {station}, skipping")
            continue
        
        # Noise: window before pick
        noise = tr.data[pick_sample - noise_samples : pick_sample]
        # Signal: window after pick
        signal_win = tr.data[pick_sample : pick_sample + signal_samples]
        
        # Calculate RMS for signal and noise
        noise_rms = np.sqrt(np.mean(noise**2))
        signal_rms = np.sqrt(np.mean(signal_win**2))
        
        # Avoid division by zero
        if noise_rms > 0:
            snr = signal_rms / noise_rms
        else:
            snr = 0
            
        snr_dict[station] = snr
        valid_traces.append(tr.copy())
    
    if len(valid_traces) == 0:
        raise ValueError("No valid traces with picks")
    
    print(f"\n  Valid traces for stacking: {len(valid_traces)}")
    
    # Select pilot trace (highest SNR)
    pilot_station = max(snr_dict, key=snr_dict.get)
    pilot_trace = [tr for tr in valid_traces if tr.stats.station == pilot_station][0]
    pilot_pick_sample = int((pick_times[pilot_station] - pilot_trace.stats.starttime) 
                           * pilot_trace.stats.sampling_rate)
    
    print(f"  Pilot trace: {pilot_station} (SNR = {snr_dict[pilot_station]:.1f})")
    
    # Extract window around pick from pilot for cross-correlation template
    # Window centered on pick, extending cc_window_length/2 on each side
    cc_half_window = cc_window_samples // 2
    pilot_cc_start = max(0, pilot_pick_sample - cc_half_window)
    pilot_cc_end = min(len(pilot_trace.data), pilot_pick_sample + cc_half_window)
    pilot_cc_window = pilot_trace.data[pilot_cc_start:pilot_cc_end]
    
    # Align all traces to pilot using cross-correlation with constrained lag
    aligned_stream = obspy.Stream()
    aligned_data = []
    
    for tr in valid_traces:
        station = tr.stats.station
        
        # Get pick sample for this trace
        pick_sample = int((pick_times[station] - tr.stats.starttime) 
                         * tr.stats.sampling_rate)
        
        # Extract window around pick for cross-correlation
        tr_cc_start = max(0, pick_sample - cc_half_window)
        tr_cc_end = min(len(tr.data), pick_sample + cc_half_window)
        tr_cc_window = tr.data[tr_cc_start:tr_cc_end]
        
        # Cross-correlate ONLY the windowed segments around the picks
        correlation = correlate(tr_cc_window, pilot_cc_window, mode='same')
        
        # Restrict search to +/- max_lag_samples around zero lag
        center = len(correlation) // 2
        search_start = max(0, center - max_lag_samples)
        search_end = min(len(correlation), center + max_lag_samples + 1)
        
        # Find best lag within constrained window
        local_max_idx = np.argmax(correlation[search_start:search_end])
        lag_sample = (search_start + local_max_idx) - center
        
        # Get correlation coefficient at this lag
        cc_max = correlation[search_start + local_max_idx]
        cc_norm = cc_max / (np.linalg.norm(tr_cc_window) * np.linalg.norm(pilot_cc_window))
        
        print(f"  {station}: SNR={snr_dict[station]:.1f}, lag={lag_sample} samples ({lag_sample/sampling_rate*1000:.1f} ms), CC={cc_norm:.3f}")
        
        # Shift the ENTIRE trace by the lag determined from the windowed cross-correlation
        if lag_sample > 0:
            aligned = np.pad(tr.data[lag_sample:], (0, lag_sample), 
                           mode='constant', constant_values=0)
        elif lag_sample < 0:
            aligned = np.pad(tr.data[:lag_sample], (-lag_sample, 0), 
                           mode='constant', constant_values=0)
        else:
            aligned = tr.data.copy()
        
        # Normalize if requested
        if normalize:
            aligned = aligned / np.max(np.abs(aligned))
        
        # Store aligned trace
        tr_aligned = tr.copy()
        tr_aligned.data = aligned
        aligned_stream.append(tr_aligned)
        aligned_data.append(aligned)
    
    # Stack aligned traces
    stacked_data = np.mean(aligned_data, axis=0)
    
    # Create reference trace (copy stats from pilot)
    reference_trace = pilot_trace.copy()
    reference_trace.data = stacked_data
    reference_trace.stats.station = 'STACK'
    
    print(f"\n  Reference waveform created from {len(aligned_data)} traces")
    
    return reference_trace, aligned_stream, snr_dict

In [None]:
reference_trace, aligned_stream, snr_dict = create_reference_waveform(
    event_stream, 
    pick_times,
    cc_window_length=0.03,  # Only use 15 samples around pick for alignment
    max_lag_sec=0.01
)

In [None]:
# Plot traces BEFORE alignment (just the raw windowed traces)
fig, axes = plt.subplots(2, 1, figsize=(12, 8))

# Plot raw windowed traces (before alignment)
for tr in event_stream:
    axes[0].plot(tr.times(), tr.data / np.max(np.abs(tr.data)), 
                 alpha=0.5, linewidth=0.5)
    # Mark the pick position (should be at pre_pick seconds)
axes[0].axvline(x=0.02, color='r', linestyle='--', label='Expected pick position')
axes[0].set_title('Raw Windowed Traces (Before Alignment)')
axes[0].set_ylabel('Normalized Amplitude')
axes[0].legend()

# Plot aligned traces
for tr in aligned_stream:
    axes[1].plot(tr.times(), tr.data, alpha=0.5, linewidth=0.5)
axes[1].set_title('After Cross-Correlation Alignment')
axes[1].set_xlabel('Time (s)')
axes[1].set_ylabel('Normalized Amplitude')

plt.tight_layout()
plt.show()

In [None]:
def create_reference_waveform(event_stream, pick_times, 
                              snr_window_signal=0.02, snr_window_noise=0.02,
                              cc_window_length=0.03, max_lag_sec=0.01,
                              min_cc_threshold=0.7, normalize=True):
    """
    Create a reference waveform by aligning and stacking traces using cross-correlation.
    
    Parameters:
    -----------
    event_stream : obspy.Stream
        Windowed stream for a single event (each trace windowed around its own pick)
    pick_times : dict
        Dictionary mapping station codes to P pick times (UTCDateTime)
    snr_window_signal : float
        Time window (s) after pick to calculate signal power
    snr_window_noise : float
        Time window (s) before pick to calculate noise power
    cc_window_length : float
        Length of window (s) to use for cross-correlation alignment
        Should capture ~1-2 cycles of the dominant frequency
        Default 0.03 s works well for 10-100 Hz bandpass
    max_lag_sec : float
        Maximum allowed lag (in seconds) for cross-correlation alignment
        Should be fraction of dominant period to avoid cycle skips
        Default 0.01 s = half cycle at 50 Hz
    min_cc_threshold : float
        Minimum correlation coefficient to accept alignment (0-1)
    normalize : bool
        If True, normalize each trace before stacking
        
    Returns:
    --------
    reference_trace : obspy.Trace
        Stacked and aligned reference waveform
    aligned_stream : obspy.Stream
        Stream of aligned individual traces (for QC)
    snr_dict : dict
        SNR values for each station
    alignment_quality : dict
        Correlation coefficients for each station's alignment
    """
    
    if len(event_stream) == 0:
        raise ValueError("Empty stream provided")
    
    # Get sampling rate (assume all traces have same rate)
    sampling_rate = event_stream[0].stats.sampling_rate
    cc_window_samples = int(cc_window_length * sampling_rate)
    max_lag_samples = int(max_lag_sec * sampling_rate)
    
    print(f"\n  CC window: {cc_window_length*1000:.1f} ms ({cc_window_samples} samples)")
    print(f"  Max lag: {max_lag_sec*1000:.1f} ms ({max_lag_samples} samples)")
    
    # Calculate SNR for each trace
    snr_dict = {}
    valid_traces = []
    
    for tr in event_stream:
        station = tr.stats.station
        
        # Skip if we don't have a pick for this station
        if station not in pick_times:
            print(f"  Warning: No pick time for station {station}, skipping")
            continue
        
        # Get pick time relative to trace start
        pick_time = pick_times[station]
        pick_sample = int((pick_time - tr.stats.starttime) * tr.stats.sampling_rate)
        
        # Calculate SNR
        noise_samples = int(snr_window_noise * tr.stats.sampling_rate)
        signal_samples = int(snr_window_signal * tr.stats.sampling_rate)
        
        # Make sure we have enough samples
        if pick_sample < noise_samples or pick_sample + signal_samples >= tr.stats.npts:
            print(f"  Warning: Pick too close to edge for {station}, skipping")
            continue
        
        # Noise: window before pick
        noise = tr.data[pick_sample - noise_samples : pick_sample]
        # Signal: window after pick
        signal_win = tr.data[pick_sample : pick_sample + signal_samples]
        
        # Calculate RMS for signal and noise
        noise_rms = np.sqrt(np.mean(noise**2))
        signal_rms = np.sqrt(np.mean(signal_win**2))
        
        # Avoid division by zero
        if noise_rms > 0:
            snr = signal_rms / noise_rms
        else:
            snr = 0
            
        snr_dict[station] = snr
        valid_traces.append(tr.copy())
    
    if len(valid_traces) == 0:
        raise ValueError("No valid traces with picks")
    
    print(f"  Valid traces for stacking: {len(valid_traces)}")
    
    # Select pilot trace (highest SNR)
    pilot_station = max(snr_dict, key=snr_dict.get)
    pilot_trace = [tr for tr in valid_traces if tr.stats.station == pilot_station][0]
    pilot_pick_sample = int((pick_times[pilot_station] - pilot_trace.stats.starttime) 
                           * pilot_trace.stats.sampling_rate)
    
    print(f"  Pilot trace: {pilot_station} (SNR = {snr_dict[pilot_station]:.1f})")
    
    # Extract window around pick from pilot for cross-correlation template
    cc_half_window = cc_window_samples // 2
    pilot_cc_start = max(0, pilot_pick_sample - cc_half_window)
    pilot_cc_end = min(len(pilot_trace.data), pilot_pick_sample + cc_half_window)
    pilot_cc_window = pilot_trace.data[pilot_cc_start:pilot_cc_end]
    
    # Align all traces to pilot using cross-correlation
    aligned_stream = obspy.Stream()
    aligned_data = []
    alignment_quality = {}
    low_cc_stations = []
    
    for tr in valid_traces:
        station = tr.stats.station
        
        # Get pick sample for this trace
        pick_sample = int((pick_times[station] - tr.stats.starttime) 
                         * tr.stats.sampling_rate)
        
        # Extract window around pick for cross-correlation
        tr_cc_start = max(0, pick_sample - cc_half_window)
        tr_cc_end = min(len(tr.data), pick_sample + cc_half_window)
        tr_cc_window = tr.data[tr_cc_start:tr_cc_end]
        
        # Cross-correlate ONLY the windowed segments around the picks
        correlation = correlate(tr_cc_window, pilot_cc_window, mode='same')
        
        # Restrict search to +/- max_lag_samples around zero lag
        center = len(correlation) // 2
        search_start = max(0, center - max_lag_samples)
        search_end = min(len(correlation), center + max_lag_samples + 1)
        
        # Find best lag within constrained window
        local_max_idx = np.argmax(correlation[search_start:search_end])
        lag_sample = (search_start + local_max_idx) - center
        
        # Calculate normalized correlation coefficient
        cc_max = correlation[search_start + local_max_idx]
        cc_norm = cc_max / (np.linalg.norm(tr_cc_window) * np.linalg.norm(pilot_cc_window))
        alignment_quality[station] = cc_norm
        
        # Flag low correlation coefficients
        if cc_norm < min_cc_threshold:
            low_cc_stations.append(station)
            flag = " [LOW CC!]"
        else:
            flag = ""
        
        print(f"  {station}: SNR={snr_dict[station]:.1f}, lag={lag_sample} samples ({lag_sample/sampling_rate*1000:.1f} ms), CC={cc_norm:.3f}{flag}")
        
        # Shift the ENTIRE trace by the lag
        if lag_sample > 0:
            aligned = np.pad(tr.data[lag_sample:], (0, lag_sample), 
                           mode='constant', constant_values=0)
        elif lag_sample < 0:
            aligned = np.pad(tr.data[:lag_sample], (-lag_sample, 0), 
                           mode='constant', constant_values=0)
        else:
            aligned = tr.data.copy()
        
        # Normalize if requested
        if normalize:
            aligned = aligned / np.max(np.abs(aligned))
        
        # Store aligned trace
        tr_aligned = tr.copy()
        tr_aligned.data = aligned
        aligned_stream.append(tr_aligned)
        aligned_data.append(aligned)
    
    if low_cc_stations:
        print(f"\n  WARNING: {len(low_cc_stations)} stations with CC < {min_cc_threshold}: {low_cc_stations}")
    
    # Stack aligned traces
    stacked_data = np.mean(aligned_data, axis=0)
    
    # Create reference trace (copy stats from pilot)
    reference_trace = pilot_trace.copy()
    reference_trace.data = stacked_data
    reference_trace.stats.station = 'STACK'
    
    print(f"\n  Reference waveform created from {len(aligned_data)} traces")
    print(f"  Mean CC: {np.mean(list(alignment_quality.values())):.3f}")
    
    return reference_trace, aligned_stream, snr_dict, alignment_quality

In [None]:
reference_trace, aligned_stream, snr_dict, alignment_quality = create_reference_waveform(
    event_stream, 
    pick_times,
    snr_window_signal=0.02, 
    snr_window_noise=0.02,
    cc_window_length=0.03, 
    max_lag_sec=0.01,
    min_cc_threshold=0.7, 
    normalize=True
)

In [None]:
first_event_id = df_picks['eventid'].iloc[9420]  
second_event_id = df_picks['eventid'].iloc[10921]  # Note: should be 'eventid' not 'origin'

In [None]:
def get_reference_waveform_for_event(stream, df_picks, event_id, 
                                      pre_pick=0.02, post_pick=0.06,
                                      cc_window_length=0.03, max_lag_sec=0.01):
    """
    Extract traces and create reference waveform for a single event.
    Convenience wrapper around extract_event_traces and create_reference_waveform.
    
    Parameters:
    -----------
    stream : obspy.Stream
        Full stream containing all data
    df_picks : pandas.DataFrame
        DataFrame with picks
    event_id : str or int
        Event identifier
    pre_pick : float
        Time (s) before pick to window
    post_pick : float
        Time (s) after pick to window
    cc_window_length : float
        CC window length for alignment (s)
    max_lag_sec : float
        Maximum lag for alignment (s)
        
    Returns:
    --------
    reference_trace : obspy.Trace
        Reference waveform for this event
    event_info : dict
        Dictionary containing event_stream, aligned_stream, snr_dict, 
        alignment_quality for QC
    """
    
    print(f"\n{'='*60}")
    print(f"Processing Event: {event_id}")
    print(f"{'='*60}")
    
    # Extract windowed traces
    event_stream, pick_times = extract_event_traces(
        stream, df_picks, event_id, 
        pre_pick=pre_pick, 
        post_pick=post_pick
    )
    
    # Create reference waveform
    reference_trace, aligned_stream, snr_dict, alignment_quality = create_reference_waveform(
        event_stream, pick_times,
        cc_window_length=cc_window_length,
        max_lag_sec=max_lag_sec
    )
    
    # Package info for return
    event_info = {
        'event_stream': event_stream,
        'aligned_stream': aligned_stream,
        'snr_dict': snr_dict,
        'alignment_quality': alignment_quality,
        'pick_times': pick_times,
        'mean_snr': np.mean(list(snr_dict.values())),
        'mean_cc': np.mean(list(alignment_quality.values()))
    }
    
    return reference_trace, event_info

In [None]:
# Get your event IDs
first_event_id = df_picks['eventid'].iloc[9420]  
second_event_id = df_picks['eventid'].iloc[10921]

# Create reference waveforms for both events
ref1, info1 = get_reference_waveform_for_event(comb_stZ, df_picks, first_event_id)
ref2, info2 = get_reference_waveform_for_event(comb_stZ, df_picks, second_event_id)

# Quick comparison plot
fig, ax = plt.subplots(1, 1, figsize=(10, 4))
ax.plot(ref1.times(), ref1.data, 'b', label=f'Event {first_event_id}', linewidth=2)
ax.plot(ref2.times(), ref2.data, 'r', label=f'Event {second_event_id}', linewidth=2)
ax.set_xlabel('Time (s)')
ax.set_ylabel('Normalized Amplitude')
ax.legend()
ax.set_title('Reference Waveforms Comparison')
plt.tight_layout()
plt.show()

In [None]:
def compare_waveforms_stretching(ref_trace1, ref_trace2, event_id1, event_id2,
                                 stretch_max=0.10, stretch_step=0.01,
                                 pulse_window_start=0.0, pulse_window_end=0.03):
    """
    Compare two reference waveforms using trace stretching method.
    Stretches ref_trace2 (comparison event) to match ref_trace1 (reference event).
    Only stretches a specified window capturing the first P-wave pulse.
    
    Parameters:
    -----------
    ref_trace1 : obspy.Trace
        Reference waveform for REFERENCE event (typically early event)
    ref_trace2 : obspy.Trace
        Reference waveform for COMPARISON event (typically later event)
    event_id1 : str or int
        Identifier for reference event
    event_id2 : str or int
        Identifier for comparison event
    stretch_max : float
        Maximum stretch factor to test (e.g., 0.10 = ±10%)
    stretch_step : float
        Step size for stretch factor (e.g., 0.01 = 1% increments)
    pulse_window_start : float
        Start time (s) of pulse window to analyze (relative to trace start)
    pulse_window_end : float
        End time (s) of pulse window to analyze (relative to trace start)
        
    Returns:
    --------
    results : dict
        Dictionary containing:
        - 'epsilon': optimal stretch factor
            ε > 0: comparison event pulse is LONGER (lower fc)
            ε < 0: comparison event pulse is SHORTER (higher fc)
        - 'cc': correlation coefficient at optimal stretch
        - 'eps_array': array of tested stretch values
        - 'cc_array': array of correlation coefficients
        - 'pulse_window': tuple of (start, end) times used
    """
    
    print(f"\n{'='*60}")
    print(f"Comparing Events (Trace Stretching)")
    print(f"{'='*60}")
    print(f"  Reference event: {event_id1}")
    print(f"  Comparison event: {event_id2}")
    print(f"  Method: Stretch Event {event_id2} to match Event {event_id1}")
    
    dt = ref_trace1.stats.delta
    sampling_rate = ref_trace1.stats.sampling_rate
    
    # Convert time window to sample indices
    start_sample = int(pulse_window_start * sampling_rate)
    end_sample = int(pulse_window_end * sampling_rate)
    
    # Extract only the pulse window from both traces
    u1_full = ref_trace1.data
    u2_full = ref_trace2.data
    
    # Make sure we don't exceed trace length
    end_sample = min(end_sample, len(u1_full), len(u2_full))
    
    u1 = u1_full[start_sample:end_sample]  # Reference (not stretched)
    u2 = u2_full[start_sample:end_sample]  # This will be stretched
    
    pulse_duration = len(u1) * dt
    
    print(f"  Full trace length: {len(u1_full)} samples ({len(u1_full) * dt:.3f} s)")
    print(f"  Pulse window: {pulse_window_start:.3f} - {pulse_window_end:.3f} s")
    print(f"  Pulse samples: {len(u1)} samples ({pulse_duration:.3f} s)")
    print(f"  Sampling rate: {sampling_rate} Hz")
    print(f"  Stretch range: ±{stretch_max*100:.0f}%")
    print(f"  Stretch step: {stretch_step*100:.1f}%")
    
    # Create array of stretch factors to test
    eps_array = np.arange(-stretch_max, stretch_max + stretch_step, stretch_step)
    cc_array = np.zeros(len(eps_array))
    
    # Test each stretch factor - NOW STRETCHING u2 (comparison) to match u1 (reference)
    print(f"\n  Testing {len(eps_array)} stretch values...")
    for i, eps in enumerate(eps_array):
        # Stretch the time axis of COMPARISON trace (u2)
        t_stretched = np.arange(len(u2)) * (1 + eps)
        
        # Interpolate u2 onto stretched time axis
        u2_stretched = pchip_interpolate(np.arange(len(u2)), u2, t_stretched, axis=0)
        
        # Ensure same length after interpolation
        valid_len = min(len(u2_stretched), len(u1))
        
        # Calculate correlation coefficient between reference (u1) and stretched comparison (u2)
        cc_array[i] = np.corrcoef(u1[:valid_len], u2_stretched[:valid_len])[0, 1]
    
    # Find optimal stretch factor (maximum correlation)
    best_idx = np.argmax(cc_array)
    best_eps = eps_array[best_idx]
    best_cc = cc_array[best_idx]
    
    print(f"\n  Optimal stretch: {best_eps*100:.2f}% (ε = {best_eps:.4f})")
    print(f"  Max correlation: {best_cc:.4f}")
    
    # Interpretation (now correct!)
    if abs(best_eps) < 0.01:
        interp = "Events have very similar pulse durations"
    elif best_eps > 0:
        interp = f"Comparison event pulse is ~{abs(best_eps)*100:.1f}% LONGER (lower fc)"
    else:
        interp = f"Comparison event pulse is ~{abs(best_eps)*100:.1f}% SHORTER (higher fc)"
    print(f"  Interpretation: {interp}")
    
    results = {
        'epsilon': best_eps,
        'cc': best_cc,
        'eps_array': eps_array,
        'cc_array': cc_array,
        'event_id1': event_id1,
        'event_id2': event_id2,
        'pulse_window': (pulse_window_start, pulse_window_end),
        'u1_pulse': u1,  # Reference pulse
        'u2_pulse': u2   # Comparison pulse
    }
    
    return results


def plot_stretching_results(ref_trace1, ref_trace2, results):
    """
    Plot trace stretching comparison results.
    
    Parameters:
    -----------
    ref_trace1, ref_trace2 : obspy.Trace
        Reference waveforms
    results : dict
        Results from compare_waveforms_stretching()
    """
    
    fig, axes = plt.subplots(3, 1, figsize=(12, 10))
    
    # Top panel: Full waveforms with pulse window highlighted
    ax = axes[0]
    ax.plot(ref_trace1.times(), ref_trace1.data, 'b', 
            label=f"Event {results['event_id1']}", linewidth=2)
    ax.plot(ref_trace2.times(), ref_trace2.data, 'r', 
            label=f"Event {results['event_id2']}", linewidth=2, alpha=0.7)
    
    # Highlight the pulse window used for stretching
    ax.axvspan(results['pulse_window'][0], results['pulse_window'][1], 
               alpha=0.2, color='green', label='Pulse window')
    
    ax.set_xlabel('Time (s)')
    ax.set_ylabel('Normalized Amplitude')
    ax.legend()
    ax.set_title('Full Reference Waveforms (pulse window highlighted)')
    ax.grid(True, alpha=0.3)
    
    # Middle panel: Just the pulse windows
    ax = axes[1]
    dt = ref_trace1.stats.delta
    pulse_times = np.arange(len(results['u1_pulse'])) * dt + results['pulse_window'][0]
    
    ax.plot(pulse_times, results['u1_pulse'], 'b', 
            label=f"Event {results['event_id1']}", linewidth=2)
    ax.plot(pulse_times, results['u2_pulse'], 'r', 
            label=f"Event {results['event_id2']}", linewidth=2, alpha=0.7)
    ax.set_xlabel('Time (s)')
    ax.set_ylabel('Normalized Amplitude')
    ax.legend()
    ax.set_title('P-Wave Pulse Windows (used for stretching)')
    ax.grid(True, alpha=0.3)
    
    # Bottom panel: Correlation vs stretch
    ax = axes[2]
    ax.plot(results['eps_array'] * 100, results['cc_array'], 'k-', linewidth=2)
    ax.axvline(results['epsilon'] * 100, color='r', linestyle='--', 
               label=f"Best ε = {results['epsilon']*100:.2f}%", linewidth=2)
    ax.axhline(results['cc'], color='r', linestyle=':', alpha=0.5)
    ax.set_xlabel('Stretch Factor (%)')
    ax.set_ylabel('Correlation Coefficient')
    ax.set_title(f"Trace Stretching: ε = {results['epsilon']*100:.2f}%, CC = {results['cc']:.4f}")
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_xlim([results['eps_array'].min()*100, results['eps_array'].max()*100])
    
    plt.tight_layout()
    plt.show()

In [None]:
def calculate_corner_frequency(trace, method='brune', f_min=10, f_max=100, 
                                taper_percent=0.05, zero_pad_factor=4):
    """
    Calculate corner frequency from a trace's amplitude spectrum.
    
    Parameters:
    -----------
    trace : obspy.Trace
        Input trace (preferably full reference waveform, not just pulse)
    method : str
        Method for corner frequency estimation ('brune' or 'peak')
    f_min : float
        Minimum frequency (Hz) for fitting
    f_max : float
        Maximum frequency (Hz) for fitting
    taper_percent : float
        Taper percentage for edges (0-1)
    zero_pad_factor : int
        Zero-padding factor to increase frequency resolution
        
    Returns:
    --------
    fc : float
        Corner frequency (Hz)
    fit_info : dict
        Dictionary with spectrum, frequencies, and fit details
    """
    
    # Taper the trace to reduce edge effects
    tr = trace.copy()
    tr.taper(max_percentage=taper_percent, type='cosine')
    
    # Get trace parameters
    npts = tr.stats.npts
    dt = tr.stats.delta
    sampling_rate = tr.stats.sampling_rate
    
    print(f"    Trace info:")
    print(f"      Length: {npts} samples ({npts*dt:.4f} s)")
    print(f"      Sampling rate: {sampling_rate} Hz")
    
    # Zero-pad to improve frequency resolution
    npts_padded = npts * zero_pad_factor
    data_padded = np.pad(tr.data, (0, npts_padded - npts), mode='constant')
    
    # Calculate FFT
    fft = np.fft.rfft(data_padded)
    freqs = np.fft.rfftfreq(npts_padded, dt)
    amp_spectrum = np.abs(fft)
    
    freq_resolution = 1.0 / (npts_padded * dt)
    print(f"      Frequency resolution: {freq_resolution:.2f} Hz")
    print(f"      Number of frequency points: {len(freqs)}")
    
    # Frequency band for analysis (within your bandpass filter)
    freq_mask = (freqs >= f_min) & (freqs <= f_max)
    freqs_fit = freqs[freq_mask]
    amp_fit = amp_spectrum[freq_mask]
    
    print(f"      Points in fit band ({f_min}-{f_max} Hz): {len(freqs_fit)}")
    
    if len(freqs_fit) < 10:
        print(f"      WARNING: Very few frequency points! Consider longer trace.")
    
    if method == 'brune':
        # Fit Brune (1970) spectrum: A(f) = Ω₀ / (1 + (f/fc)²)
        
        # Smooth spectrum for more stable fit
        from scipy.ndimage import gaussian_filter1d
        if len(amp_fit) > 5:
            sigma = max(1, len(amp_fit) // 20)  # Adaptive smoothing
            amp_smooth = gaussian_filter1d(amp_fit, sigma=sigma)
        else:
            amp_smooth = amp_fit
        
        # Find low-frequency plateau (Ω₀) - use median of lowest 20% of frequencies
        n_low = max(3, len(amp_smooth) // 5)
        omega_0 = np.median(amp_smooth[:n_low])
        
        # Find frequency where amplitude drops to omega_0 / sqrt(2)
        target_amp = omega_0 / np.sqrt(2)
        
        # Find crossing point
        idx = np.argmin(np.abs(amp_smooth - target_amp))
        fc = freqs_fit[idx]
        
        # Make sure we didn't hit the boundary
        if fc >= f_max * 0.95:
            print(f"      WARNING: Corner frequency at upper bound! May be >f_max")
        if fc <= f_min * 1.05:
            print(f"      WARNING: Corner frequency at lower bound! May be <f_min")
        
        print(f"    Brune model fit:")
        print(f"      Ω₀ (low-freq plateau): {omega_0:.2e}")
        print(f"      Corner frequency: {fc:.1f} Hz")
        
        fit_info = {
            'freqs': freqs,
            'amp_spectrum': amp_spectrum,
            'freqs_fit': freqs_fit,
            'amp_fit': amp_fit,
            'amp_smooth': amp_smooth,
            'omega_0': omega_0,
            'method': 'brune',
            'freq_resolution': freq_resolution
        }
        
    elif method == 'peak':
        # Alternative: Use peak frequency as proxy
        idx_peak = np.argmax(amp_fit)
        fc = freqs_fit[idx_peak]
        
        print(f"    Peak frequency method:")
        print(f"      Peak frequency: {fc:.1f} Hz")
        
        fit_info = {
            'freqs': freqs,
            'amp_spectrum': amp_spectrum,
            'freqs_fit': freqs_fit,
            'amp_fit': amp_fit,
            'method': 'peak',
            'freq_resolution': freq_resolution
        }
    
    return fc, fit_info

def plot_spectrum_with_fc(trace, fc, fit_info, event_id):
    """
    Plot amplitude spectrum with corner frequency marked.
    
    Parameters:
    -----------
    trace : obspy.Trace
        Input trace
    fc : float
        Corner frequency (Hz)
    fit_info : dict
        Fit information from calculate_corner_frequency
    event_id : str/int
        Event identifier for labeling
    """
    
    fig, ax = plt.subplots(1, 1, figsize=(10, 6))
    
    # Plot full spectrum
    ax.loglog(fit_info['freqs'], fit_info['amp_spectrum'], 'gray', 
              alpha=0.3, linewidth=0.5, label='Full spectrum')
    
    # Plot fitting region
    ax.loglog(fit_info['freqs_fit'], fit_info['amp_fit'], 'k', 
              linewidth=1.5, label='Fit region')
    
    if fit_info['method'] == 'brune':
        # Plot smoothed spectrum
        ax.loglog(fit_info['freqs_fit'], fit_info['amp_smooth'], 'b', 
                  linewidth=2, label='Smoothed')
        
        # Mark Ω₀ level
        ax.axhline(fit_info['omega_0'], color='green', linestyle='--', 
                   alpha=0.5, label=f"Ω₀ = {fit_info['omega_0']:.2e}")
        
        # Mark Ω₀/√2 level
        ax.axhline(fit_info['omega_0']/np.sqrt(2), color='orange', 
                   linestyle=':', alpha=0.5, label=f"Ω₀/√2")
    
    # Mark corner frequency
    ax.axvline(fc, color='red', linestyle='--', linewidth=2,
               label=f"fc = {fc:.1f} Hz")
    
    ax.set_xlabel('Frequency (Hz)')
    ax.set_ylabel('Amplitude Spectrum')
    ax.set_title(f'Amplitude Spectrum - Event {event_id}')
    ax.legend()
    ax.grid(True, alpha=0.3, which='both')
    ax.set_xlim([5, 150])
    
    plt.tight_layout()
    plt.show()

In [None]:
results = compare_waveforms_stretching(
    ref1, ref2,  # ref1 is reference, ref2 is stretched
    first_event_id, second_event_id,
    stretch_max=0.10,
    stretch_step=0.01,
    pulse_window_start=0.01,
    pulse_window_end=0.04
)

# Now the corner frequency ratio should match!
fc_ratio_spectral = fc2 / fc1
fc_ratio_stretching = 1 / (1 + results['epsilon'])

print(f"\nCorner frequency ratio (from spectra): {fc_ratio_spectral:.3f}")
print(f"Corner frequency ratio (from stretching): {fc_ratio_stretching:.3f}")

In [None]:
# STEP 1: Get event IDs
first_event_id = df_picks['eventid'].iloc[9420]  
second_event_id = df_picks['eventid'].iloc[10921]

# STEP 2: Create reference waveforms
ref1, info1 = get_reference_waveform_for_event(comb_stZ, df_picks, first_event_id)
ref2, info2 = get_reference_waveform_for_event(comb_stZ, df_picks, second_event_id)

# STEP 3: Run stretching analysis (use UPDATED function I just gave you)
results = compare_waveforms_stretching(
    ref1, ref2,
    first_event_id, second_event_id,
    stretch_max=0.10,
    stretch_step=0.01,
    pulse_window_start=0.01,
    pulse_window_end=0.048
)

# STEP 4: Plot stretching results
plot_stretching_results(ref1, ref2, results)

# STEP 5: Calculate corner frequencies from full reference traces
fc1, fit1 = calculate_corner_frequency(ref1, method='brune', f_min=15, f_max=80)
fc2, fit2 = calculate_corner_frequency(ref2, method='brune', f_min=15, f_max=80)

# STEP 6: Compare the two methods
fc_ratio_spectral = fc2 / fc1
fc_ratio_stretching = 1 / (1 + results['epsilon'])

print(f"\n{'='*60}")
print(f"COMPARISON: SPECTRAL vs STRETCHING")
print(f"{'='*60}")
print(f"Event 1 (reference): fc = {fc1:.1f} Hz")
print(f"Event 2 (comparison): fc = {fc2:.1f} Hz")
print(f"\nCorner frequency ratio (from spectra): {fc_ratio_spectral:.3f}")
print(f"Corner frequency ratio (from stretching): {fc_ratio_stretching:.3f}")
print(f"Difference: {abs(fc_ratio_spectral - fc_ratio_stretching):.3f}")
print(f"Percent difference: {abs(fc_ratio_spectral - fc_ratio_stretching)/fc_ratio_spectral * 100:.1f}%")

if abs(fc_ratio_spectral - fc_ratio_stretching) / fc_ratio_spectral < 0.15:
    print("\n✓ Good agreement! Stretching analysis is valid.")
else:
    print("\n✗ Poor agreement. Investigate further.")

In [None]:
def calculate_stress_drop_ratio(epsilon, fc_ratio=None):
    """
    Calculate relative stress drop from stretch factor or corner frequency ratio.
    
    Based on Brune (1970) stress drop scaling:
        Δσ ∝ fc³
    
    Parameters:
    -----------
    epsilon : float
        Stretch factor from trace stretching
        ε > 0: comparison event has longer pulse (lower fc, lower stress drop)
        ε < 0: comparison event has shorter pulse (higher fc, higher stress drop)
    fc_ratio : float, optional
        Corner frequency ratio (fc2/fc1) from spectral analysis
        If provided, also calculates stress drop ratio from spectra
        
    Returns:
    --------
    results : dict
        Dictionary containing:
        - 'stress_drop_ratio_stretch': Δσ2/Δσ1 from stretching
        - 'stress_drop_ratio_spectral': Δσ2/Δσ1 from spectra (if fc_ratio provided)
        - 'fc_ratio_stretch': fc2/fc1 from stretching
        - 'fc_ratio_spectral': fc2/fc1 from spectra (if provided)
    """
    
    # Corner frequency ratio from stretching
    # fc2/fc1 = 1/(1+ε)
    fc_ratio_stretch = 1.0 / (1.0 + epsilon)
    
    # Stress drop ratio from stretching
    # Δσ2/Δσ1 = (fc2/fc1)³
    stress_ratio_stretch = fc_ratio_stretch ** 3
    
    print(f"\n{'='*60}")
    print(f"STRESS DROP ANALYSIS")
    print(f"{'='*60}")
    
    print(f"\nFrom Trace Stretching:")
    print(f"  Stretch factor (ε): {epsilon:.4f} ({epsilon*100:.2f}%)")
    print(f"  Corner frequency ratio (fc2/fc1): {fc_ratio_stretch:.3f}")
    print(f"  Stress drop ratio (Δσ2/Δσ1): {stress_ratio_stretch:.3f}")
    
    if epsilon > 0:
        print(f"  → Comparison event has {abs(epsilon)*100:.1f}% LONGER pulse")
        print(f"  → Comparison event has {(1-stress_ratio_stretch)*100:.1f}% LOWER stress drop")
    elif epsilon < 0:
        print(f"  → Comparison event has {abs(epsilon)*100:.1f}% SHORTER pulse")
        print(f"  → Comparison event has {(stress_ratio_stretch-1)*100:.1f}% HIGHER stress drop")
    else:
        print(f"  → Events have similar stress drops")
    
    results = {
        'stress_drop_ratio_stretch': stress_ratio_stretch,
        'fc_ratio_stretch': fc_ratio_stretch
    }
    
    # If spectral corner frequencies provided, calculate from those too
    if fc_ratio is not None:
        stress_ratio_spectral = fc_ratio ** 3
        
        print(f"\nFrom Spectral Analysis:")
        print(f"  Corner frequency ratio (fc2/fc1): {fc_ratio:.3f}")
        print(f"  Stress drop ratio (Δσ2/Δσ1): {stress_ratio_spectral:.3f}")
        
        print(f"\nComparison:")
        print(f"  Stress drop ratio difference: {abs(stress_ratio_stretch - stress_ratio_spectral):.3f}")
        print(f"  Percent difference: {abs(stress_ratio_stretch - stress_ratio_spectral)/stress_ratio_spectral*100:.1f}%")
        
        results['stress_drop_ratio_spectral'] = stress_ratio_spectral
        results['fc_ratio_spectral'] = fc_ratio
    
    return results

In [None]:
# After running your comparison from before, add:

# Calculate stress drop ratios
stress_results = calculate_stress_drop_ratio(
    results['epsilon'],
    fc_ratio=fc2/fc1  # Optional: compare with spectral estimate
)

# Store everything together
comparison_summary = {
    'event_id1': first_event_id,
    'event_id2': second_event_id,
    'epsilon': results['epsilon'],
    'cc': results['cc'],
    'fc1': fc1,
    'fc2': fc2,
    'fc_ratio_spectral': fc2/fc1,
    'fc_ratio_stretch': stress_results['fc_ratio_stretch'],
    'stress_ratio_spectral': stress_results['stress_drop_ratio_spectral'],
    'stress_ratio_stretch': stress_results['stress_drop_ratio_stretch']
}

print(f"\n{'='*60}")
print(f"SUMMARY")
print(f"{'='*60}")
for key, val in comparison_summary.items():
    if isinstance(val, float):
        print(f"  {key}: {val:.4f}")
    else:
        print(f"  {key}: {val}")

In [None]:
def get_event_peak_amplitude(event_stream, pick_times, window_length=0.03):
    """
    Calculate peak amplitude from raw (unnormalized) traces around P-wave arrival.
    
    Parameters:
    -----------
    event_stream : obspy.Stream
        Raw windowed traces for an event (before stacking/normalization)
    pick_times : dict
        Dictionary of pick times for each station
    window_length : float
        Time window (s) after pick to search for peak amplitude
        
    Returns:
    --------
    peak_amplitude : float
        Maximum absolute amplitude across all stations (m/s for velocity)
    median_amplitude : float
        Median peak amplitude across stations
    amplitude_dict : dict
        Peak amplitude for each station
    """
    
    amplitude_dict = {}
    
    for tr in event_stream:
        station = tr.stats.station
        
        if station not in pick_times:
            continue
        
        # Get sample index of pick
        pick_time = pick_times[station]
        pick_sample = int((pick_time - tr.stats.starttime) * tr.stats.sampling_rate)
        
        # Define window after pick
        window_samples = int(window_length * tr.stats.sampling_rate)
        end_sample = min(pick_sample + window_samples, tr.stats.npts)
        
        if pick_sample >= tr.stats.npts or pick_sample < 0:
            continue
        
        # Get peak amplitude in window
        signal_window = tr.data[pick_sample:end_sample]
        peak_amp = np.max(np.abs(signal_window))
        amplitude_dict[station] = peak_amp
    
    if len(amplitude_dict) == 0:
        return np.nan, np.nan, {}
    
    amplitudes = list(amplitude_dict.values())
    peak_amplitude = np.max(amplitudes)
    median_amplitude = np.median(amplitudes)
    
    return peak_amplitude, median_amplitude, amplitude_dict

In [None]:
def analyze_temporal_evolution_robust(stream, df_picks, reference_event_id, 
                                      comparison_event_ids,
                                      pulse_window_start=0.01, pulse_window_end=0.04,
                                      stretch_max=0.20, stretch_step=0.01,
                                      calculate_spectral_fc=True,
                                      min_stretch_cc=0.70,
                                      min_alignment_cc=0.70,
                                      min_snr=3.0):
    """
    Robust temporal evolution analysis with quality control filtering.
    
    Parameters:
    -----------
    stream : obspy.Stream
        Full data stream
    df_picks : pandas.DataFrame
        Picks dataframe
    reference_event_id : str/int
        Event ID for reference
    comparison_event_ids : list
        List of event IDs to analyze
    pulse_window_start : float
        Start of pulse window (s)
    pulse_window_end : float
        End of pulse window (s)
    stretch_max : float
        Maximum stretch factor
    stretch_step : float
        Stretch step size
    calculate_spectral_fc : bool
        If True, calculate corner frequency from spectrum for each event
    min_stretch_cc : float
        Minimum correlation coefficient for stretching analysis (0-1)
        Events below this threshold are flagged/rejected
    min_alignment_cc : float
        Minimum mean correlation coefficient for trace stacking (0-1)
    min_snr : float
        Minimum mean SNR for event stacking
        
    Returns:
    --------
    results_df : pandas.DataFrame
        DataFrame with all analyzed events
    qc_stats : dict
        Quality control statistics
    reference_info : dict
        Reference event information
    """
    
    print(f"{'='*60}")
    print(f"ROBUST TEMPORAL EVOLUTION ANALYSIS")
    print(f"{'='*60}")
    print(f"\nQuality Control Thresholds:")
    print(f"  Min stretch CC: {min_stretch_cc:.2f}")
    print(f"  Min alignment CC: {min_alignment_cc:.2f}")
    print(f"  Min SNR: {min_snr:.1f}")
    
    # Get reference event info
    ref_rows = df_picks[df_picks['eventid'] == reference_event_id]
    ref_origin = UTCDateTime(ref_rows['origin'].iloc[0])
    
    print(f"\nReference Event: {reference_event_id}")
    print(f"  Origin time: {ref_origin}")
    
    # Create reference waveform
    print(f"\nCreating reference waveform...")
    
    # Get RAW event stream (before stacking) for amplitude
    ref_event_stream, ref_pick_times = extract_event_traces(
        stream, df_picks, reference_event_id,
        pre_pick=0.02, post_pick=0.06
    )
    
    # Get amplitudes from raw traces
    ref_peak_amp, ref_median_amp, ref_amp_dict = get_event_peak_amplitude(
        ref_event_stream, ref_pick_times, window_length=0.03
    )
    
    # Create stacked reference for stretching
    ref_trace, ref_info = get_reference_waveform_for_event(
        stream, df_picks, reference_event_id,
        pre_pick=0.02, post_pick=0.06
    )
    
    # Calculate reference corner frequency from spectrum
    ref_fc_spectral, ref_fit = calculate_corner_frequency(
        ref_trace, method='brune', f_min=15, f_max=80
    )
    
    print(f"  Reference fc (spectral): {ref_fc_spectral:.1f} Hz")
    print(f"  Reference peak amplitude: {ref_peak_amp:.2e} m/s")
    print(f"  Reference median amplitude: {ref_median_amp:.2e} m/s")
    
    # Storage
    results_list = []
    qc_rejected = []
    
    # QC counters
    n_total = 0
    n_low_stretch_cc = 0
    n_low_alignment_cc = 0
    n_low_snr = 0
    n_accepted = 0
    
    # Add all events to list (including reference)
    all_event_ids = [reference_event_id] + comparison_event_ids
    
    for i, event_id in enumerate(all_event_ids):
        
        # Get origin time
        event_rows = df_picks[df_picks['eventid'] == event_id]
        if len(event_rows) == 0:
            print(f"\nEvent {event_id}: No picks found, skipping")
            continue
            
        event_origin = UTCDateTime(event_rows['origin'].iloc[0])
        time_diff = event_origin - ref_origin
        
        is_reference = (event_id == reference_event_id)
        n_total += 1
        
        print(f"\n{'-'*60}")
        if is_reference:
            print(f"Reference Event: {event_id}")
        else:
            print(f"Event {i}/{len(comparison_event_ids)}: {event_id}")
            print(f"  Time since reference: {time_diff/60:.1f} minutes")
        
        try:
            # Get RAW event stream for amplitude
            event_stream, event_pick_times = extract_event_traces(
                stream, df_picks, event_id,
                pre_pick=0.02, post_pick=0.06
            )
            
            # Get amplitudes from raw traces
            peak_amp, median_amp, amp_dict = get_event_peak_amplitude(
                event_stream, event_pick_times, window_length=0.03
            )
            
            # Create stacked waveform for spectral analysis and stretching
            event_trace, event_info = get_reference_waveform_for_event(
                stream, df_picks, event_id,
                pre_pick=0.02, post_pick=0.06
            )
            
            # Calculate spectral corner frequency
            if calculate_spectral_fc:
                fc_spectral, fit_info = calculate_corner_frequency(
                    event_trace, method='brune', f_min=15, f_max=80
                )
            else:
                fc_spectral = np.nan
            
            # Stretching analysis (skip for reference event)
            if not is_reference:
                stretch_results = compare_waveforms_stretching(
                    ref_trace, event_trace,
                    reference_event_id, event_id,
                    stretch_max=stretch_max,
                    stretch_step=stretch_step,
                    pulse_window_start=pulse_window_start,
                    pulse_window_end=pulse_window_end
                )
                epsilon = stretch_results['epsilon']
                stretch_cc = stretch_results['cc']
                fc_stretch = ref_fc_spectral / (1.0 + epsilon)
            else:
                epsilon = 0.0
                stretch_cc = 1.0
                fc_stretch = ref_fc_spectral
            
            # === QUALITY CONTROL CHECKS ===
            qc_pass = True
            qc_reasons = []
            
            # Check 1: Stretch correlation coefficient
            if stretch_cc < min_stretch_cc and not is_reference:
                qc_pass = False
                qc_reasons.append(f"Low stretch CC ({stretch_cc:.3f} < {min_stretch_cc})")
                n_low_stretch_cc += 1
            
            # Check 2: Stacking alignment quality
            if event_info['mean_cc'] < min_alignment_cc:
                qc_pass = False
                qc_reasons.append(f"Low alignment CC ({event_info['mean_cc']:.3f} < {min_alignment_cc})")
                n_low_alignment_cc += 1
            
            # Check 3: Mean SNR
            if event_info['mean_snr'] < min_snr:
                qc_pass = False
                qc_reasons.append(f"Low SNR ({event_info['mean_snr']:.1f} < {min_snr})")
                n_low_snr += 1
            
            # Print QC status
            if not qc_pass:
                print(f"  ⚠ QC FAILED: {', '.join(qc_reasons)}")
            else:
                print(f"  ✓ QC PASSED")
                n_accepted += 1
            
            # Store results (include all, mark QC status)
            results_list.append({
                'event_id': event_id,
                'origin_time': event_origin,
                'time_since_ref_sec': time_diff,
                'time_since_ref_min': time_diff / 60.0,
                'is_reference': is_reference,
                'qc_pass': qc_pass,
                'qc_reasons': '; '.join(qc_reasons) if not qc_pass else '',
                'fc_spectral': fc_spectral,
                'fc_stretch': fc_stretch,
                'epsilon': epsilon,
                'stretch_cc': stretch_cc,
                'peak_amplitude': peak_amp,
                'median_amplitude': median_amp,
                'mean_snr': event_info['mean_snr'],
                'mean_alignment_cc': event_info['mean_cc'],
                'n_stations': len(event_info['snr_dict'])
            })
            
            print(f"  fc (spectral): {fc_spectral:.1f} Hz")
            print(f"  Peak amplitude: {peak_amp:.2e} m/s")
            print(f"  ε: {epsilon:.4f}, stretch CC: {stretch_cc:.3f}")
            
        except Exception as e:
            print(f"  ERROR: {e}")
            import traceback
            traceback.print_exc()
            continue
    
    # Convert to DataFrame
    results_df = pd.DataFrame(results_list)
    
    # Calculate stress drop proxy
    if len(results_df) > 0:
        results_df['stress_drop_proxy'] = (results_df['peak_amplitude']**2) * results_df['fc_spectral']
    
    # QC Statistics
    qc_stats = {
        'n_total': n_total,
        'n_accepted': n_accepted,
        'n_rejected': n_total - n_accepted,
        'n_low_stretch_cc': n_low_stretch_cc,
        'n_low_alignment_cc': n_low_alignment_cc,
        'n_low_snr': n_low_snr,
        'acceptance_rate': n_accepted / n_total if n_total > 0 else 0
    }
    
    reference_info = {
        'event_id': reference_event_id,
        'origin_time': ref_origin,
        'fc_spectral': ref_fc_spectral,
        'peak_amplitude': ref_peak_amp,
        'median_amplitude': ref_median_amp,
        'info': ref_info
    }
    
    print(f"\n{'='*60}")
    print(f"QUALITY CONTROL SUMMARY")
    print(f"{'='*60}")
    print(f"Total events analyzed: {qc_stats['n_total']}")
    print(f"Passed QC: {qc_stats['n_accepted']} ({qc_stats['acceptance_rate']*100:.1f}%)")
    print(f"Failed QC: {qc_stats['n_rejected']}")
    print(f"  Low stretch CC: {qc_stats['n_low_stretch_cc']}")
    print(f"  Low alignment CC: {qc_stats['n_low_alignment_cc']}")
    print(f"  Low SNR: {qc_stats['n_low_snr']}")
    print(f"{'='*60}")
    
    return results_df, qc_stats, reference_info

In [None]:
def get_unique_events(df_picks, sort_by_time=True):
    """
    Get list of unique event IDs from picks dataframe.
    
    Parameters:
    -----------
    df_picks : pandas.DataFrame
        Picks dataframe with 'eventid' and 'origin' columns
    sort_by_time : bool
        If True, sort events chronologically
        
    Returns:
    --------
    event_ids : list
        List of unique event IDs in chronological order
    """
    
    if sort_by_time:
        # Get unique events and their origin times
        unique_events = df_picks.groupby('eventid')['origin'].first().sort_values()
        event_ids = unique_events.index.tolist()
    else:
        event_ids = df_picks['eventid'].unique().tolist()
    
    print(f"Total unique events: {len(event_ids)}")
    
    return event_ids

In [None]:
def plot_temporal_evolution_robust(results_df, reference_info, show_rejected=False):
    """
    Plot temporal evolution with absolute values, not ratios.
    
    Parameters:
    -----------
    results_df : pandas.DataFrame
        Results with QC information
    reference_info : dict
        Reference event info
    show_rejected : bool
        If True, show rejected events as faded points
    """
    
    fig, axes = plt.subplots(4, 1, figsize=(14, 12))
    
    # Separate QC pass/fail
    qc_pass_mask = (results_df['is_reference'] == False) & (results_df['qc_pass'] == True)
    qc_fail_mask = (results_df['is_reference'] == False) & (results_df['qc_pass'] == False)
    
    time_pass = results_df.loc[qc_pass_mask, 'time_since_ref_min'].values
    time_fail = results_df.loc[qc_fail_mask, 'time_since_ref_min'].values
    
    ref_fc = reference_info['fc_spectral']
    ref_amp = reference_info['peak_amplitude']
    
    # Panel 1: Corner Frequency (spectral)
    ax = axes[0]
    ax.scatter(time_pass, results_df.loc[qc_pass_mask, 'fc_spectral'].values, 
               c='blue', s=50, alpha=0.7, edgecolors='k', linewidth=0.5, label='QC Pass')
    if show_rejected and len(time_fail) > 0:
        ax.scatter(time_fail, results_df.loc[qc_fail_mask, 'fc_spectral'].values,
                   c='lightgray', s=30, alpha=0.3, edgecolors='k', linewidth=0.5, label='QC Fail')
    ax.axhline(ref_fc, color='red', linestyle='--', linewidth=2, 
               label=f'Reference fc = {ref_fc:.1f} Hz')
    ax.set_ylabel('Corner Frequency (Hz)')
    ax.set_title(f'Temporal Evolution (Reference: Event {reference_info["event_id"]})')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Panel 2: Peak Amplitude
    ax = axes[1]
    ax.scatter(time_pass, results_df.loc[qc_pass_mask, 'peak_amplitude'].values, 
               c='green', s=50, alpha=0.7, edgecolors='k', linewidth=0.5, label='QC Pass')
    if show_rejected and len(time_fail) > 0:
        ax.scatter(time_fail, results_df.loc[qc_fail_mask, 'peak_amplitude'].values,
                   c='lightgray', s=30, alpha=0.3, edgecolors='k', linewidth=0.5, label='QC Fail')
    ax.axhline(ref_amp, color='red', linestyle='--', linewidth=2,
               label=f'Reference amp = {ref_amp:.2e} m/s')
    ax.set_ylabel('Peak Amplitude (m/s)')
    ax.set_yscale('log')
    ax.legend()
    ax.grid(True, alpha=0.3, which='both')
    
    # Panel 3: Stress Drop Proxy
    ax = axes[2]
    ax.scatter(time_pass, results_df.loc[qc_pass_mask, 'stress_drop_proxy'].values,
               c='red', s=50, alpha=0.7, edgecolors='k', linewidth=0.5, label='QC Pass')
    if show_rejected and len(time_fail) > 0:
        ax.scatter(time_fail, results_df.loc[qc_fail_mask, 'stress_drop_proxy'].values,
                   c='lightgray', s=30, alpha=0.3, edgecolors='k', linewidth=0.5, label='QC Fail')
    ref_stress = (ref_amp**2) * ref_fc
    ax.axhline(ref_stress, color='red', linestyle='--', linewidth=2,
               label='Reference')
    ax.set_ylabel('Stress Drop Proxy\n(amp² × fc)')
    ax.set_yscale('log')
    ax.legend()
    ax.grid(True, alpha=0.3, which='both')
    
    # Panel 4: Stretch factor (for QC)
    ax = axes[3]
    ax.scatter(time_pass, results_df.loc[qc_pass_mask, 'epsilon'].values * 100,
               c='purple', s=50, alpha=0.7, edgecolors='k', linewidth=0.5, label='QC Pass')
    if show_rejected and len(time_fail) > 0:
        ax.scatter(time_fail, results_df.loc[qc_fail_mask, 'epsilon'].values * 100,
                   c='lightgray', s=30, alpha=0.3, edgecolors='k', linewidth=0.5, label='QC Fail')
    ax.axhline(0, color='k', linestyle='--', alpha=0.3)
    ax.set_ylabel('Stretch Factor ε (%)')
    ax.set_xlabel('Time Since Reference Event (minutes)')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

In [None]:
# Get all unique events in chronological order
all_event_ids = get_unique_events(df_picks, sort_by_time=True)



In [None]:
# Select reference event (first one or early one)
# reference_event_id = all_event_ids[0]  # Or pick a specific one
# ref_ind = np.where(first_event_id == all_event_ids)[0][0]
ref_ind = 5
reference_event_id = all_event_ids[ref_ind]

# Select comparison events (e.g., every 10th event for next 100 events)
comparison_event_ids = all_event_ids[ref_ind+1::3]  # Events 1, 11, 21, 31, ... 91

print(f"Reference event: {reference_event_id}")
print(f"Comparison events: {len(comparison_event_ids)}")

# Run with QC filters
results_df, qc_stats, ref_info = analyze_temporal_evolution_robust(
    comb_stZ, df_picks,
    reference_event_id=reference_event_id,
    comparison_event_ids=comparison_event_ids,
    pulse_window_start=0.01,
    pulse_window_end=0.04,
    stretch_max=0.50,
    calculate_spectral_fc=True,
    min_stretch_cc=0.5,      # Adjust these thresholds as needed
    min_alignment_cc=0.50,
    min_snr=3.0
)

# Plot (show rejected events as faded)
plot_temporal_evolution_robust(results_df, ref_info, show_rejected=True)

# Filter to only QC-passed events
results_qc_pass = results_df[results_df['qc_pass'] == True]

In [None]:
# Diagnostic analysis - run AFTER analyze_temporal_evolution_robust completes
# Uses: results_df, qc_stats from your analysis

print("QC Failure Breakdown:")
print(f"Total events: {len(results_df)}")
print(f"\nQC Status:")
print(results_df['qc_pass'].value_counts())

print("\n" + "="*60)
print("QC Metrics Summary:")
print("="*60)

# Get the thresholds you used (adjust these to match what you actually ran with)
min_stretch_cc = 0.75
min_alignment_cc = 0.70
min_snr = 3.0

# Check each metric
print("\nStretch CC:")
print(f"  Min: {results_df['stretch_cc'].min():.3f}")
print(f"  Max: {results_df['stretch_cc'].max():.3f}")
print(f"  Mean: {results_df['stretch_cc'].mean():.3f}")
print(f"  Median: {results_df['stretch_cc'].median():.3f}")
print(f"  Below threshold ({min_stretch_cc}): {(results_df['stretch_cc'] < min_stretch_cc).sum()}")

print("\nAlignment CC:")
print(f"  Min: {results_df['mean_alignment_cc'].min():.3f}")
print(f"  Max: {results_df['mean_alignment_cc'].max():.3f}")
print(f"  Mean: {results_df['mean_alignment_cc'].mean():.3f}")
print(f"  Median: {results_df['mean_alignment_cc'].median():.3f}")
print(f"  Below threshold ({min_alignment_cc}): {(results_df['mean_alignment_cc'] < min_alignment_cc).sum()}")

print("\nMean SNR:")
print(f"  Min: {results_df['mean_snr'].min():.1f}")
print(f"  Max: {results_df['mean_snr'].max():.1f}")
print(f"  Mean: {results_df['mean_snr'].mean():.1f}")
print(f"  Median: {results_df['mean_snr'].median():.1f}")
print(f"  Below threshold ({min_snr}): {(results_df['mean_snr'] < min_snr).sum()}")

# Show histogram of QC metrics
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

ax = axes[0]
ax.hist(results_df['stretch_cc'].dropna(), bins=30, edgecolor='k')
ax.axvline(min_stretch_cc, color='r', linestyle='--', linewidth=2, label=f'Threshold = {min_stretch_cc}')
ax.set_xlabel('Stretch CC')
ax.set_ylabel('Count')
ax.set_title('Stretch Correlation Coefficient')
ax.legend()
ax.grid(True, alpha=0.3)

ax = axes[1]
ax.hist(results_df['mean_alignment_cc'].dropna(), bins=30, edgecolor='k')
ax.axvline(min_alignment_cc, color='r', linestyle='--', linewidth=2, label=f'Threshold = {min_alignment_cc}')
ax.set_xlabel('Alignment CC')
ax.set_ylabel('Count')
ax.set_title('Mean Alignment CC')
ax.legend()
ax.grid(True, alpha=0.3)

ax = axes[2]
ax.hist(results_df['mean_snr'].dropna(), bins=30, edgecolor='k')
ax.axvline(min_snr, color='r', linestyle='--', linewidth=2, label=f'Threshold = {min_snr}')
ax.set_xlabel('Mean SNR')
ax.set_ylabel('Count')
ax.set_title('Mean SNR')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Look at the QC failure reasons
print("\n" + "="*60)
print("Most Common QC Failure Reasons:")
print("="*60)
failed = results_df[results_df['qc_pass'] == False]
if len(failed) > 0:
    print(failed['qc_reasons'].value_counts().head(10))
else:
    print("No failures!")