In [None]:
import numpy as np
import obspy
from obspy import UTCDateTime
from obspy.clients.fdsn import Client
from obspy import UTCDateTime, read
from scipy.signal import correlate, resample, butter, filtfilt
import csv
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import os
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
import h5py
from tqdm import tqdm 
import time
from obspy import Stream, UTCDateTime
import pandas as pd
import re
from datetime import datetime, timedelta
from scipy.signal import butter, filtfilt, tukey
from scipy.signal import fftconvolve
from matplotlib import cm
import re
from glob import glob
from tqdm import tqdm
import traceback

# Loading quakes


In [None]:
client = Client("IRIS")
network = 'CC'
station = 'PARA'
channel = 'BHZ' 
location = '*'
event_time = UTCDateTime(2023, 8, 23, 10, 0, 0)
event_endtime = UTCDateTime(2023, 11, 23, 10, 0, 0)
#event_endtime = UTCDateTime(2023, 9, 23, 10, 0, 0)
#duration = 60 * 7200 #120  # Duración de una dos en segundos

# References
latitude = 46.87  # Ejemplo de latitud
longitude = -121.760  # Ejemplo de longitud
max_radius_km = 50  # Radio máximo en km, before was 250km

# Convertir el radio máximo de kilómetros a grados
max_radius_deg = max_radius_km / 111.19

# Buscar todos los terremotos que ocurrieron en el rango de una hora y dentro del radio especificado
catalog = client.get_events(
    starttime=event_time,
    endtime=event_endtime,
    #minmagnitude=2.5,
    minmagnitude=0,#antes 1.8
    latitude=latitude,
    longitude=longitude,
    maxradius=max_radius_deg
)
#catalog to df
data = []
for event in catalog:
    origin = event.origins[0]
    magnitude = event.magnitudes[0]
    eventid_full= event.resource_id.id if event.resource_id else None
    event_id = eventid_full.split('eventid=')[-1] if eventid_full else None

    
    data.append({
        'getstring': origin.time,
        'Latitude': origin.latitude,
        'Longitude': origin.longitude,
        'Depth (km)': origin.depth / 1000,  # convertir de metros a kilómetros
        'Magnitude': magnitude.mag,
        'Magnitude Type': magnitude.magnitude_type,
        'evid':event_id
        
    })

# Crear el DataFrame
df = pd.DataFrame(data)
df

In [None]:
# Verificar el tipo actual de 'getstring'
print(df['getstring'].dtype) #object

# Convertir la columna 'getstring' directamente a UTCDateTime
for idx, row in df.iterrows():
    try:
        event_time = row['getstring']
        
        if pd.isna(event_time):
            print(f"Skipping event {row['evid']} due to NaT in getstring")
            continue

        # Convertir a UTCDateTime
        event_time = UTCDateTime(event_time)
        #print(f"Processing event {row['evid']} with time {event_time}")

    except Exception as e:
        print(f"Error processing event {row['evid']}: {e}")
        continue

# Loading DAS data...

fuctions for cross-correlation and plotting

In [None]:
def load_das_data(file_paths, start_time, duration_seconds, chan_min, chan_max):
    """
    Load and concatenate DAS segments from HDF5 files whose 60 s windows
    overlap [start_time, start_time+duration_seconds].
    Automatically discovers the first dataset name via visititems().
    """
    # 1) Discover the internal dataset name
    dataset_name = None
    class _Found(Exception): pass

    for path in sorted(file_paths):
        with h5py.File(path, 'r') as f:
            def _visitor(name, obj):
                nonlocal dataset_name
                if isinstance(obj, h5py.Dataset):
                    dataset_name = name
                    raise _Found()
            try:
                f.visititems(_visitor)
            except _Found:
                pass
        if dataset_name:
            break

    if not dataset_name:
        raise ValueError("No dataset found in any DAS HDF5 file.")

    # 2) Identify candidate files overlapping our window
    end_time = start_time + duration_seconds
    pat = re.compile(r'decimator_(\d{4}-\d{2}-\d{2})_(\d{2}\.\d{2}\.\d{2})_UTC\.h5$')
    candidates = []
    for path in file_paths:
        m = pat.search(path)
        if not m:
            continue
        date_str, time_str = m.groups()
        ts = UTCDateTime(f"{date_str}T{time_str.replace('.', ':')}")
        if ts < end_time and (ts + 60) > start_time:
            candidates.append((ts, path))
    candidates.sort(key=lambda x: x[0])
    if not candidates:
        raise ValueError("No DAS files overlap the requested time window.")

    # 3) Load, offset, and concatenate until we have enough samples
    chunks       = []
    used_files   = []
    fs           = None
    total_needed = None
    collected    = 0

    for file_start, path in candidates:
        with h5py.File(path, 'r') as f:
            arr = f[dataset_name][:, chan_min:chan_max+1]

        if fs is None:
            fs = arr.shape[0] / 60.0
            total_needed = int(round(duration_seconds * fs))

        offset = max(0, int(round((start_time - file_start) * fs)))
        chunk  = arr[offset:]
        take   = min(chunk.shape[0], total_needed - collected)
        if take > 0:
            chunks.append(chunk[:take])
            used_files.append(path)
            collected += take
        if collected >= total_needed:
            break

    if collected < total_needed:
        print(f"Warning: only collected {collected}/{total_needed} samples.")

    data_concat = np.vstack(chunks)[:total_needed]
    return data_concat, fs, used_files


#  Events cross correlation

In [None]:
def get_seismic_trace(network: str,
                      station: str,
                      location: str,
                      channel: str,
                      start_time: UTCDateTime,
                      duration_seconds: float,
                      freqmin: float = 2.0,
                      freqmax: float = 10.0):
    """
    Download and pre‐process a single seismometer trace.
    - Removes instrument response (if RESP files are available in your ObsPy config).
    - Detrends, demeans, band‐pass filters, and applies a taper.
    """
    client = Client("IRIS")
    end_time = start_time + duration_seconds

    # 1. Fetch raw data
    st = client.get_waveforms(network, station, location, channel,
                              start_time, end_time, attach_response=True)

    # 2. Remove instrument response to get true ground velocity
    st.remove_response(output="VEL", zero_mean=True, taper="hann")

    # 3. Detrend & demean
    st.detrend('linear')
    st.detrend('demean')

    # 4. Bandpass between freqmin–freqmax
    st.filter('bandpass', freqmin=freqmin, freqmax=freqmax, corners=4, zerophase=True)

    # 5. Taper edges with a Tukey window (alpha=0.1)
    for tr in st:
        tr.data *= tukey(tr.stats.npts, alpha=0.1)

    # We assume a single trace in the stream
    return st[0]

In [None]:
def calculate_lag_times(
    seismic_trace,
    das_data: np.ndarray,
    fs: float,
    corr_threshold: float = 0.8,
    lowcut: float = 2.0,
    highcut: float = 10.0
):
    """
    cross-correlate each das channel with the seismic trace.

    returns
    -------
    median_lag : float | None   # median of lag times for channels passing the threshold
    lag_std    : float | None   # standard deviation of those lags
    n_channels : int            # number of channels that exceeded the threshold
    peak_corr  : float          # global correlation peak (max abs across all channels)
    """
    # 1) band-pass filter das data
    b, a = butter(4, [lowcut, highcut], btype='bandpass', fs=fs)
    das_filtered = filtfilt(b, a, das_data, axis=0)

    # 2) taper window
    n_samples, n_ch = das_filtered.shape
    window = tukey(n_samples, alpha=0.1)

    # 3) demean, detrend & taper
    for ch_idx in range(n_ch):
        ch = das_filtered[:, ch_idx]
        ch = ch - np.mean(ch)
        trend = np.poly1d(np.polyfit(np.arange(n_samples), ch, 1))
        das_filtered[:, ch_idx] = (ch - trend(np.arange(n_samples))) * window

    # 4) normalize seismic trace
    seismic = seismic_trace.data.astype(float)
    seismic -= np.mean(seismic)
    seismic /= max(np.linalg.norm(seismic), 1e-12)

    # 5) cross-correlate
    zero_idx = len(seismic) - 1
    lags, accepted_corrs = [], []
    peak_corr = 0.0  # value to report

    for ch_idx in range(n_ch):
        chan = das_filtered[:, ch_idx]
        norm = np.linalg.norm(chan)
        if norm < 1e-12:
            continue
        chan /= norm

        corr = fftconvolve(seismic, chan[::-1], mode='full')
        pk   = float(np.max(np.abs(corr)))  # cast from float32 to float
        peak_corr = max(peak_corr, pk)      # update global peak

        if pk < corr_threshold:
            continue

        lag_sec = (np.argmax(np.abs(corr)) - zero_idx) / fs
        lags.append(lag_sec)
        accepted_corrs.append(pk)

    # 6) statistics
    if not lags:
        return None, None, 0, peak_corr

    median_lag = float(np.median(lags))
    lag_std    = float(np.std(lags))
    return median_lag, lag_std, len(lags), peak_corr


In [None]:
def plot_seismic_and_das_with_offset(
    time_axis: np.ndarray,
    seismic_trace,
    filtered_das_data: np.ndarray,
    event_id: str,
    origin_time,
    magnitude: float,
    best_corr_lag: float,
    seismic_scale: float = 3.0,
    das_scaling_factor: float = 5.0,
    channel_offset_step: float = 1.5,
    colormap: str = 'viridis'
):

    # 1. Normalize seismic trace by its maximum absolute value
    data = seismic_trace.data.astype(float)
    max_seis = np.max(np.abs(data))
    if max_seis > 0:
        seis_norm = (data / max_seis) * seismic_scale
    else:
        seis_norm = data

    # 2. Normalize DAS data by global max and apply scaling
    global_max = np.max(np.abs(filtered_das_data))
    if global_max > 0:
        das_norm = (filtered_das_data / global_max) * das_scaling_factor
    else:
        das_norm = filtered_das_data.copy()

    n_samples, n_ch = das_norm.shape
    offsets = np.arange(n_ch) * channel_offset_step

    # 3. Prepare color map
    cmap = cm.get_cmap(colormap, n_ch)
    colors = cmap(np.linspace(0, 1, n_ch))

    # 4. Create figure
    fig, ax = plt.subplots(figsize=(12, 6))

    # Plot seismic trace
    ax.plot(time_axis, seis_norm, 'k', linewidth=1.5, label='Seismic Trace')

    # Plot each DAS channel with its offset and color
    for i in range(n_ch):
        ax.plot(time_axis, das_norm[:, i] + offsets[i],
                color=colors[i], linewidth=0.8)
    # Only include one legend entry for DAS
    ax.plot([], [], color='gray', linewidth=0.8, label='DAS Channels')

    # 5. Annotate best‐lag with vertical line
    ax.axvline(best_corr_lag, color='red', linestyle='--', linewidth=1)
    ax.text(best_corr_lag + 0.02,  # small horizontal shift
            offsets[-1] + das_scaling_factor * 0.2,
            f'Lag = {best_corr_lag:.3f} s',
            color='red', fontsize=9, va='bottom')

    # 6. Optionally label a few channel offsets on y‑axis
    #    (here we show every 10th channel)
    tick_idxs = np.arange(0, n_ch, max(1, n_ch // 10))
    tick_locs = offsets[tick_idxs]
    tick_labels = [f'Ch {i}' for i in tick_idxs]
    ax.set_yticks(list(tick_locs) + [0])  # include zero for seismic
    ax.set_yticklabels(tick_labels + ['Seis'])

    # 7. Labels, title, legend and grid
    ax.set_xlabel("Time (s)")
    ax.set_ylabel("Normalized Amplitude + Offset")
    ax.set_title(
        f"Event {event_id} | {origin_time.strftime('%Y-%m-%d %H:%M:%S UTC')}\n"
        f"Magnitude {magnitude:.2f} | Best Corr Lag: {best_corr_lag:.3f} s"
    )
    ax.legend(loc='upper right', fontsize=9)
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

In [None]:
def time_shift_array(data: np.ndarray, shift_samples: int) -> np.ndarray:
    shifted = np.zeros_like(data)
    if shift_samples > 0:
        shifted[shift_samples:] = data[:-shift_samples]
    elif shift_samples < 0:
        shifted[:shift_samples] = data[-shift_samples:]
    else:
        shifted[:] = data
    return shifted


In [None]:
# User parameters
# ----------------------------------------
threshold = 0.34
das_root = '/1-fnp/petasaur/p-jbod1/rainier/2023'
file_paths = glob(os.path.join(das_root, '**', '*.h5'), recursive=True)

network, station, location, channel = "CC", "PARA", "*", "BHZ"
#network, station, location, channel = "UW", "LON", "*", "HHZ"

chan_min, chan_max = 60, 90
#chan_min, chan_max = 2000, 2030

duration_seconds = 20

output_csv = f'./results_lag_times_{station}_{threshold}.csv'


In [None]:
# Define folder to save PNG plots
output_plot_dir = "/data/data4/veronica-scratch-rainier/shift-time/"
os.makedirs(output_plot_dir, exist_ok=True)

df = df.dropna(subset=['getstring'])
results = []

for _, row in tqdm(df.iterrows(), total=len(df)):
    ev_id = row['evid']
    t_str = row['getstring']
    mag = row['Magnitude']
    
    if pd.isna(t_str):
        print(f"Skipping event {ev_id} due to NaT in getstring")
        results.append({
            'Magnitude': mag,
            'event_id': ev_id,
            'event_time': None,
            'median_lag': None,
            'lag_std': None,
            'n_channels': None,
            'das_files': None
        })
        continue

    try:
        event_time = UTCDateTime(t_str)
        start = event_time - 5
        seismic = get_seismic_trace(
            network, station, location, channel,
            start, duration_seconds
        )

        das_data, fs, used_files = load_das_data(
            file_paths, start, duration_seconds,
            chan_min, chan_max
        )
        if not used_files:
            print(f"No DAS data for event {ev_id}; skipping.")
            results.append({
                'Magnitude': mag,
                'event_id': ev_id,
                'event_time': event_time.isoformat(),
                'median_lag': None,
                'lag_std': None,
                'n_channels': None,
                'das_files': None
            })
            continue

        das_resampled = resample(das_data, len(seismic.data), axis=0)

        median_lag, lag_std, nch, peak_corr = calculate_lag_times(
            seismic, das_resampled,
            seismic.stats.sampling_rate,
            corr_threshold=threshold
        )
        if median_lag is None:
            print(f"No channels passed threshold for event {ev_id}; skipping.")
            results.append({
                'Magnitude': mag,
                'event_id': ev_id,
                'event_time': event_time.isoformat(),
                'median_lag': None,
                'lag_std': None,
                'n_channels': 0,
                'peak': peak_corr,
                'das_files': ','.join(os.path.basename(p) for p in used_files)
            })
            continue

        # Registro exitoso
        results.append({
            'Magnitude': mag,
            'event_id': ev_id,
            'event_time': event_time.isoformat(),
            'median_lag': median_lag,
            'lag_std': lag_std,
            'n_channels': nch,
            'peak': peak_corr,
            'das_files': ','.join(os.path.basename(p) for p in used_files)
        })

        # Plotting y guardado
        time_axis = np.linspace(0, duration_seconds, len(seismic.data))
        seis_norm = seismic.data / np.max(np.abs(seismic.data))
        shift_samps = int(round(median_lag * seismic.stats.sampling_rate))
        corrected = time_shift_array(das_resampled, shift_samps)

        def normalize(arr):
            epsilon = 1e-10
            mx = np.max(np.abs(arr))
            return (arr / mx) * 2 if mx > epsilon else arr

        das_before = normalize(das_resampled)
        das_after = normalize(corrected)

        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), sharex=True)
        for ax, data, label in zip([ax1, ax2], [das_before, das_after], ['Before', 'After']):
            ax.plot(time_axis, 3 * seis_norm, 'k', label='Seismic Trace')
            offset = 4
            for ch in data.T:
                ax.plot(time_axis, 10 * ch + offset, color='teal')
                offset += 1.5
            ax.set_ylabel("Normalized Amplitude")
            ax.set_title(
                f"{label} Correction | Event {ev_id} | "
                f"Lag = {median_lag:.3f}s | Magnitude = {mag}"
            )
            ax.grid(True)
        ax2.set_xlabel("Time (s)")
        plt.tight_layout()
        plt.savefig(os.path.join(output_plot_dir, f'correction_{ev_id}.png'), dpi=300)
        plt.close(fig)

    except Exception as e:
        print(f"Event {ev_id} error: {e}")
        traceback.print_exc()
        # También lo registramos como fallo en el CSV
        results.append({
            'Magnitude': mag,
            'event_id': ev_id,
            'event_time': t_str,
            'median_lag': None,
            'lag_std': None,
            'n_channels': None,
            'das_files': None,
            'peak': peak_corr
        })

# 6. Save all results to CSV
pd.DataFrame(results).to_csv(output_csv, index=False)


In [None]:
# 1. Build DataFrame and drop rows with no event_time at all
results_df = pd.DataFrame(results)
results_df = results_df.dropna(subset=['event_time'])

# 2. Convert ObsPy UTCDateTime → string → pandas datetime
results_df['event_time'] = pd.to_datetime(
    results_df['event_time'].astype(str),
    errors='coerce'
)

# (Optional) Drop any rows that still failed parsing
results_df = results_df.dropna(subset=['event_time'])

# 3. Filter only valid lag values in (0,4]
# 3. Save the complete results (before filtering)
complete_csv = os.path.join(output_plot_dir, "all-quakes-3months-nofilter-034.csv")
results_df.to_csv(complete_csv, index=False)
print(f"\nComplete results saved to {complete_csv}")

# 4. Apply the lag filter and save filtered CSV
filtered_results = results_df[
    results_df['median_lag'].notna() &
    (results_df['median_lag'] > 0.0) &
    (results_df['median_lag'] <= 12)
]

# 4. Compute average and save
average_lag_time = filtered_results['median_lag'].mean()
print(f"Average Lag Time (filtered): {average_lag_time:.4f} s")

filtered_results.to_csv(output_csv, index=False)
print(f"\nFiltered results saved to {output_csv}")

# 5. Scatter plot
fig, ax = plt.subplots(figsize=(10, 5))
ax.scatter(
    filtered_results['event_time'],
    filtered_results['median_lag'],
    s=150, marker='o', edgecolor='k' 
)
#ax.axhline(
#    average_lag_time, linestyle='--', linewidth=1.5,
#    label=f'Avg: {average_lag_time:.2f} s'
#)
ax.set_ylim(0, filtered_results['median_lag'].max() * 1.1)
plt.setp(ax.get_xticklabels(), rotation=30, ha='right', fontsize=18)
ax.set_xlabel('Event Time', fontsize=16)
ax.set_ylabel('Lag Time (s)', fontsize=16)
#ax.set_title(f'Lag Times for Station {station} (thr={threshold}), window=20s', fontsize=14)
ax.grid(True, linestyle=':', alpha=0.7)
ax.legend(frameon=False)
plt.tight_layout()
plt.savefig('Lagtime-allmagnitude', dpi = 300)
plt.show()
