# Dependencies

In [1]:
import numpy as np
import cupy as cp
import pymap3d as pm
import matplotlib.pyplot as plt
from scipy.signal import (find_peaks, windows, resample_poly, iirnotch,
                            butter, sosfilt, sosfiltfilt, get_window, firwin, 
                            filtfilt, lfilter, peak_widths)
from spkit import frft
import pandas as pd
from scipy.ndimage import (
        percentile_filter, minimum_filter1d, maximum_filter1d, median_filter,
        uniform_filter1d, binary_opening, binary_closing, label
    )
from numpy.fft import fft, fftfreq
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import DBSCAN
from tqdm import tqdm
from collections import defaultdict, Counter
import time
from geopy.distance import distance

In [2]:
from joblib import Parallel, delayed
import multiprocessing

In [None]:
from numpy.fft import fft, ifft, fftfreq

# Pre-Deinterleaver

## Generate Emitters

In [None]:
def generate_emitters_in_if(n, seed, fmin, fmax):
    rng = np.random.default_rng(seed)
    kinds = ["single frequency pulse", "cw", "lfm", "lfm", 
             "lfmcw", "single frequency pulse", "lfmcw", "cw"][:n]
    rows = []
    for i, kind in enumerate(kinds):
        f_rf = rng.uniform(fmin, fmax)
        eirp_dbm = rng.normal(90.0, 5.0)

        if kind == "single frequency pulse":
            PW  = rng.uniform(5e-6, 15e-6)
            PRI = rng.uniform(0.8e-3, 5e-3)
            f0 = f1 = f_rf
            k  = 0.0
            BW = 0.0
            f_center = f_rf

        elif kind == "cw":
            PW = np.nan; PRI = np.nan
            f0 = f1 = f_rf
            k = 0.0
            BW = 0.0
            f_center = f_rf

        elif kind == "lfm":
            PW = rng.uniform(20e-6, 80e-6)
            BW = rng.uniform(8e6, 20e6)        
            sgn = rng.choice([-1.0, +1.0])     
            k   = sgn * BW / PW                

            margin = BW/2
            f_center = rng.uniform(fmin + margin, fmax - margin)

            if sgn > 0:
                f0 = f_center - BW/2
                f1 = f_center + BW/2
            else:
                f0 = f_center + BW/2
                f1 = f_center - BW/2

            PRI = rng.uniform(0.8e-3, 3e-3)

        elif kind == "lfmcw":
            PW = PRI = rng.uniform(20e-6, 80e-6)
            BW = rng.uniform(10e6, 30e6)
            margin = BW/2
            f_center = rng.uniform(fmin + margin, fmax - margin)
        
            # Count how many LFMCWs are already in rows
            lfmcw_count = sum(r["type"] == "lfmcw" for r in rows)
        
            if lfmcw_count == 0:
                sgn = +1.0  # first LFMCW → up-saw
            else:
                sgn = -1.0  # second LFMCW → down-saw
        
            k = sgn * BW / PW
            if sgn > 0:
                f0 = f_center - BW/2
                f1 = f_center + BW/2
            else:
                f0 = f_center + BW/2
                f1 = f_center - BW/2

        rows.append(dict(
            emitter_id=i, type=kind, EIRP_dBm=eirp_dbm,
            f_tx_center_hz=f_center,
            f0_tx_hz=f0, f1_tx_hz=f1,
            k_tx_hz_per_s=k,
            BW_sig_hz=abs(f1 - f0),
            PW_s=PW, PRI_s=PRI
        ))
    return pd.DataFrame(rows)


## Geometry and Earth Models

In [None]:
# Physical constants
c = 299792458.0  # [m/s]

# WGS-84
WGS84_A  = 6378137.0
WGS84_E2 = 6.69437999014e-3


# =============================================================================
# 1. Coordinate Transforms (WGS-84)
# =============================================================================

def lla_to_ecef(lat_deg, lon_deg, alt_m):
    """Convert (lat, lon, alt) [deg,deg,m] → ECEF [m]."""
    lat = np.deg2rad(lat_deg)
    lon = np.deg2rad(lon_deg)
    a   = WGS84_A
    e2  = WGS84_E2

    sin_lat = np.sin(lat)
    cos_lat = np.cos(lat)
    cos_lon = np.cos(lon)
    sin_lon = np.sin(lon)

    N = a / np.sqrt(1.0 - e2 * sin_lat**2)

    x = (N + alt_m) * cos_lat * cos_lon
    y = (N + alt_m) * cos_lat * sin_lon
    z = (N * (1.0 - e2) + alt_m) * sin_lat
    return np.array([x, y, z], dtype=float)


def ecef_to_enu(xyz, ref_lat_deg, ref_lon_deg, ref_alt_m=0.0):
    """
    Convert ECEF point(s) to ENU w.r.t. a reference LLA.
    - xyz: (3,) or (N,3)
    """
    lat = np.deg2rad(ref_lat_deg)
    lon = np.deg2rad(ref_lon_deg)

    R = np.array([
        [-np.sin(lon),              np.cos(lon),             0.0],
        [-np.sin(lat)*np.cos(lon), -np.sin(lat)*np.sin(lon), np.cos(lat)],
        [ np.cos(lat)*np.cos(lon),  np.cos(lat)*np.sin(lon), np.sin(lat)]
    ])

    ref_ecef = lla_to_ecef(ref_lat_deg, ref_lon_deg, ref_alt_m)
    xyz = np.asarray(xyz)
    d = xyz - ref_ecef

    if d.ndim == 1:
        return R @ d
    else:
        return (R @ d.T).T


def ecef_to_lla(x, y, z):
    """Convert ECEF [m] → (lat, lon, h) [deg,deg,m]."""
    a  = WGS84_A
    e2 = WGS84_E2

    b   = a * np.sqrt(1.0 - e2)
    ep2 = (a*a - b*b) / (b*b)

    lon = np.arctan2(y, x)
    p   = np.sqrt(x*x + y*y)
    th  = np.arctan2(a * z, b * p)

    sin_th = np.sin(th)
    cos_th = np.cos(th)

    lat = np.arctan2(z + ep2*b*sin_th**3,
                     p - e2*a*cos_th**3)

    sin_lat = np.sin(lat)
    N = a / np.sqrt(1.0 - e2*sin_lat**2)
    h = p/np.cos(lat) - N

    return np.rad2deg(lat), np.rad2deg(lon), h

In [None]:
def compute_tau_fd_one_emitter_lla(e_lla, r_sats, v_sats, f_c):
    """
    Compute TDOA/FDOA for one emitter in LLA.
    Returns (tau21, tau31, fD21, fD31) where FF1 is the reference.
    """
    e_ecef = lla_to_ecef(e_lla["lat"], e_lla["lon"], e_lla["alt"])

    r1 = r_sats['FF1']; v1 = v_sats['FF1']
    r2 = r_sats['FF2']; v2 = v_sats['FF2']
    r3 = r_sats['FF3']; v3 = v_sats['FF3']

    R1 = np.linalg.norm(e_ecef - r1)
    R2 = np.linalg.norm(e_ecef - r2)
    R3 = np.linalg.norm(e_ecef - r3)

    tau21 = (R2 - R1) / c    # time at FF2 minus FF1
    tau31 = (R3 - R1) / c

    u1 = (e_ecef - r1) / R1
    u2 = (e_ecef - r2) / R2
    u3 = (e_ecef - r3) / R3

    vr1 = np.dot(v1, u1)
    vr2 = np.dot(v2, u2)
    vr3 = np.dot(v3, u3)

    fD1 = (f_c / c) * vr1
    fD2 = (f_c / c) * vr2
    fD3 = (f_c / c) * vr3

    fD21 = fD2 - fD1
    fD31 = fD3 - fD1

    return tau21, tau31, fD21, fD31


def build_tau_fd_maps(emitters_df, r_sats, v_sats, f_c):
    """
    For each emitter_id, build per-satellite (tau_rel, fD_rel) w.r.t. FF1.

    tau_fd_by_sat['FF1'][eid] = (0, 0)
    tau_fd_by_sat['FF2'][eid] = (tau21, fD21)
    tau_fd_by_sat['FF3'][eid] = (tau31, fD31)
    """
    tau_fd_by_sat = {sat: {} for sat in ['FF1', 'FF2', 'FF3']}

    for _, row in emitters_df.iterrows():
        eid = int(row["emitter_id"])
        e_lla = {
            "lat": float(row["tar_lat_deg"]),
            "lon": float(row["tar_lon_deg"]),
            "alt": 0.0,
        }

        tau21, tau31, fD21, fD31 = compute_tau_fd_one_emitter_lla(
            e_lla, r_sats, v_sats, f_c
        )

        tau_fd_by_sat['FF1'][eid] = (0.0,   0.0)
        tau_fd_by_sat['FF2'][eid] = (tau21, fD21)
        tau_fd_by_sat['FF3'][eid] = (tau31, fD31)

    return tau_fd_by_sat

## Link Budget

In [8]:
def up_vector_ecef(lat_deg, lon_deg):
    lat = np.deg2rad(lat_deg)
    lon = np.deg2rad(lon_deg)
    return np.array([
        np.cos(lat) * np.cos(lon),
        np.cos(lat) * np.sin(lon),
        np.sin(lat)
    ])

def elevation_and_range(r_sat_ecef, r_emit_ecef, up_ecef):
    los = r_sat_ecef - r_emit_ecef
    R = np.linalg.norm(los)
    los_hat = los / R
    elev_deg = np.degrees(np.arcsin(los_hat @ up_ecef))
    return elev_deg, R

def fspl_dB(R_m, f_hz):
    """Free-space path loss in dB."""
    R_km  = R_m / 1e3
    f_GHz = f_hz / 1e9
    return 92.45 + 20*np.log10(R_km) + 20*np.log10(f_GHz)

## IQ Generation

In [13]:
def synthesize_iq(
    emitters_df,
    T=1.0,             # capture length (s)
    fs=125e6,          # sample rate (Hz)
    F_LO=9.50e9,       # baseband LO (Hz)
    noise_V=0.1,       # noise voltage std (per real or imag component)
    sat_name="FF1",    # which satellite to generate IQ for
    tau_fd_map=None,   # dict[eid] -> (tau_rel_s, fD_rel_hz) w.r.t FF1
    seed=0,
    use_equal_power=False,  # if False: scale by SNR_dB_{sat_name}
):
    """
    Build one composite complex baseband IQ stream from emitters_df
    for a *specific* satellite (sat_name = 'FF1'/'FF2'/'FF3').

    If tau_fd_map is provided, each emitter_id i uses:
        tau_rel, fD_rel = tau_fd_map[i]

      - tau_rel shifts the *arrival time* at this satellite
        (pulses start at t_start + tau_rel).

      - fD_rel shifts the *baseband frequency* by the relative Doppler
        (f0_bb_eff = (f0 - F_LO) + fD_rel).

    Amplitude logic:
      - If use_equal_power=True: all emitters have amplitude 1.0
      - Else: amplitudes derived from link-budget SNR_dB_{sat_name} using:
            amp = noise_V * sqrt(2 * 10^(SNR_dB/10))
    """
    import numpy as np
    import pandas as pd

    rng = np.random.default_rng(seed)
    N  = int(round(T * fs))
    t  = np.arange(N, dtype=np.float64) / fs
    y  = np.zeros(N, dtype=np.complex64)

    snr_col = f"SNR_dB_{sat_name}"

    # ------------------------------------------------------------
    # Amplitude scaling
    # ------------------------------------------------------------
    if use_equal_power:
        amps = {int(r.emitter_id): 1.0 for _, r in emitters_df.iterrows()}
    else:
        amps = {}
        for _, r in emitters_df.iterrows():
            eid = int(r.emitter_id)
            if snr_col not in r or np.isnan(r[snr_col]):
                raise ValueError(
                    f"emitters_df must include {snr_col} when use_equal_power=False"
                )
            snr_db  = float(r[snr_col])
            snr_lin = 10.0 ** (snr_db / 10.0)
            amp     = noise_V * np.sqrt(2.0 * snr_lin)
            amps[eid] = amp

        print(f"\n[{sat_name}] Amplitude per emitter (based on {snr_col} and noise_V):")
        for eid, a in amps.items():
            print(f"  emitter {eid}: amp = {a:.3f}")

    # ------------------------------------------------------------
    # Signal synthesis loop
    # ------------------------------------------------------------
    truth = []

    for _, e in emitters_df.iterrows():
        i   = int(e["emitter_id"])
        typ = str(e["type"]).lower()

        f0  = float(e["f0_tx_hz"])
        f1  = float(e["f1_tx_hz"])
        k   = float(e["k_tx_hz_per_s"])
        PW  = float(e["PW_s"]) if not pd.isna(e["PW_s"]) else None
        PRI = float(e["PRI_s"]) if not pd.isna(e["PRI_s"]) else None

        # base PDW start time (global)
        has_tstart = ("t_start_s" in e) and (e["t_start_s"] is not None) and (not np.isnan(e["t_start_s"]))
        t_start = float(e["t_start_s"]) if has_tstart else 0.0

        # relative delay & Doppler for this sat (w.r.t FF1)
        tau_rel = 0.0
        fD_rel  = 0.0
        if tau_fd_map is not None and i in tau_fd_map:
            tau_rel, fD_rel = tau_fd_map[i]

        # total start time for this satellite
        t_start_total = t_start + tau_rel

        # Baseband frequency INCLUDING Doppler for this satellite
        # Treat FF1 as having fD_rel ≈ 0, others get fD21/fD31
        f0_bb = (f0 - F_LO) + fD_rel

        # For CW / LFMCW we don't care about t_start gating:
        if typ == "cw":
            phase = 2 * np.pi * (f0_bb * t)
            sig = np.exp(1j * phase)

        elif typ == "lfmcw":
            if PRI is None:
                raise ValueError("lfmcw requires PRI for LFMCW sweeps")
        
            # Apply absolute timing (emitter start + propagation delay)
            t_rel = t - t_start_total
        
            # Wrap within each sweep period
            tau = np.mod(t_rel, PRI)
        
            # Continuous-time LFMCW ramp with Doppler shift included
            phase = 2 * np.pi * (f0_bb * tau + 0.5 * k * tau**2)
        
            sig = np.exp(1j * phase)


        elif typ == "single frequency pulse":
            if PRI is None or PW is None:
                raise ValueError("single frequency pulse requires non-NaN PW and PRI")

            tau  = t - t_start_total
            gate = (tau >= 0.0) & (np.mod(tau, PRI) < PW)
            gate = gate.astype(np.float64)

            phase = 2 * np.pi * (f0_bb * t)
            sig = np.exp(1j * phase) * gate

        elif typ == "lfm":
            if PRI is None or PW is None:
                raise ValueError("lfm requires non-NaN PW and PRI")

            tau     = t - t_start_total
            tau_mod = np.mod(tau, PRI)
            gate    = (tau >= 0.0) & (tau_mod < PW)
            gate    = gate.astype(np.float64)

            tau_p = np.where(gate > 0.0, tau_mod, 0.0)
            phase_rel = 2 * np.pi * (f0_bb * tau_p + 0.5 * k * tau_p**2)
            sig = np.exp(1j * phase_rel) * gate

        else:
            raise ValueError(f"Unknown emitter type: {typ}")

        # Apply amplitude
        y += (amps[i] * sig).astype(np.complex64)

        truth.append({
            "emitter_id": i,
            "type": e["type"],
            "t_start_s": t_start_total,
            "f0_bb_hz": f0_bb,
            "f1_bb_hz": (f1 - F_LO) + fD_rel,
            "k_hz_per_s": k,
            "PW_s": PW,
            "PRI_s": PRI,
            "amp": amps[i],
            snr_col: e.get(snr_col, np.nan),
            "tau_rel_s": tau_rel,
            "fD_rel_hz": fD_rel,
        })

    # ------------------------------------------------------------
    # Add noise (complex Gaussian)
    # ------------------------------------------------------------
    noise = noise_V * (
        rng.standard_normal(N, dtype=np.float32)
        + 1j * rng.standard_normal(N, dtype=np.float32)
    )
    y += noise

    return y, pd.DataFrame(truth)

# Deinterleaver Helper Functions

## QOL

In [5]:
_stage_times = {}

def mark_stage(name, start=False):
    """Call with start=True to start timing, else logs elapsed."""
    global _stage_times
    if start:
        _stage_times[name] = time.time()
    else:
        elapsed = time.time() - _stage_times[name]
        print(f"{name}: {elapsed:.2f} s")

In [6]:
def process_window(w, signal_noisy, fs, stride, chunk_len,
                   k_vals, nfft, nfft_mode,
                   nfft_safety_margin, max_df_hz,
                   precomp_refs):
    start = w * stride
    end = start + chunk_len
    if end > len(signal_noisy):
        return None, None

    chunk = signal_noisy[start:end].astype(np.complex64)
    row_db = k_scan_energy_window(chunk, fs, k_vals,
                                  nfft=nfft,
                                  nfft_mode=nfft_mode,
                                  nfft_safety_margin=nfft_safety_margin,
                                  max_df_hz=max_df_hz,
                                  precomp_refs=precomp_refs)
    t0_global_s = start / fs
    return row_db, t0_global_s


## PDW Extraction (Use this for Pulse as well)

In [9]:
def baseband_PDW_extractor(
    signal, fs, *,
    k_hz_per_s,
    t0_global_s,           # window start time (global)
    window_us=5,           # smoothing / local-threshold window
    min_sustain_us=5,      # morphology size (opening/closing)
    threshold_percentile=50,
    amp_floor_pct=90,      # percentile of smoothed amp to floor low-level clutter
    downsample_factor=1,   # NEW
    plot=False, plot_raw=False, plot_smoothed=False, plot_mask=False, plot_thres=False
):
    """
    Cleans the envelope of a BASEBANDED signal, extracts pulse bounds, and computes PDWs.
    Assumes 'signal' is the re-chirped, isolated IQ from dechirp+heterodyne filter.
    """

    N = len(signal)
    t_local = np.arange(N) / fs
    t_global = t0_global_s + t_local

    # Envelope
    amp_env = np.abs(signal)

    # ---------- Step 1: optional downsampling ----------
    ds = max(1, int(downsample_factor))
    if ds > 1:
        amp_env_proc = amp_env[::ds]
        fs_proc = fs / ds
    else:
        amp_env_proc = amp_env
        fs_proc = fs

    # ---------- Step 2: robust smoothing ----------
    win_samples = max(1, int(round(window_us * 1e-6 * fs_proc)))
    smoothed_amp_proc = uniform_filter1d(
        amp_env_proc, size=win_samples, mode='reflect'
    )

    # ---------- Step 3: adaptive local threshold ----------
    local_max_proc = maximum_filter1d(smoothed_amp_proc, size=win_samples, mode='reflect')
    threshold_proc = (threshold_percentile / 100.0) * local_max_proc
    sustained_proc = smoothed_amp_proc > threshold_proc

    # ---------- Step 4: morphology ----------
    min_sustain_samples = max(1, int(round(min_sustain_us * 1e-6 * fs_proc)))
    mask_proc = binary_opening(sustained_proc, structure=np.ones(min_sustain_samples))
    mask_proc = binary_closing(mask_proc, structure=np.ones(min_sustain_samples))

    # ---------- Step 5: recover edge-trimmed valid pulses ----------
    labeled_sus, num_sus = label(sustained_proc)
    for i in range(1, num_sus + 1):
        idxs = np.where(labeled_sus == i)[0]
        if len(idxs) >= min_sustain_samples:
            mask_proc[idxs[0]:idxs[-1] + 1] = True

    # ---------- Step 6: amplitude floor rejection ----------
    amp_cut_proc = np.percentile(smoothed_amp_proc, amp_floor_pct)
    mask_proc[smoothed_amp_proc < 0.5 * amp_cut_proc] = 0

    # ---------- Step 7: upsample mask + threshold back to full rate ----------
    if ds > 1:
        mask = np.repeat(mask_proc, ds)[:len(amp_env)]
        smoothed_amp = np.interp(
            np.arange(len(amp_env)),
            np.arange(len(smoothed_amp_proc)) * ds,
            smoothed_amp_proc
        )
        threshold = np.interp(
            np.arange(len(amp_env)),
            np.arange(len(threshold_proc)) * ds,
            threshold_proc
        )
    else:
        mask = mask_proc
        smoothed_amp = smoothed_amp_proc
        threshold = threshold_proc

    # ---------- Step 8: extract pulses + PDWs ----------
    labeled_mask, num_regions = label(mask)
    pulses, pulse_bounds, pdws = [], [], []

    for rid in range(1, num_regions + 1):
        idxs = np.where(labeled_mask == rid)[0]
        if len(idxs) < 2:
            continue
        s, e = idxs[0], idxs[-1] + 1

        # Global timing
        t_start = t_global[s]
        t_end   = t_global[e-1] + 1/fs
        pw_s    = t_end - t_start

        # Amplitude stats
        env_seg = amp_env[s:e]
        amp_med = float(np.median(env_seg))

        # Frequency estimate via phase slope
        ph = np.unwrap(np.angle(signal[s:e]))
        fi = (fs/(2*np.pi)) * np.diff(ph)
        if len(fi) > 20:
            i_lo, i_hi = int(0.4*len(fi)), int(0.6*len(fi))
            f_center_hz = float(np.median(fi[i_lo:i_hi]))
        else:
            f_center_hz = float(np.median(fi)) if len(fi) else np.nan

        f_start_hz = f_center_hz - 0.5 * k_hz_per_s * pw_s
        f_end_hz   = f_center_hz + 0.5 * k_hz_per_s * pw_s

        pulses.append(signal[s:e])
        pulse_bounds.append((s, e))
        pdws.append({
            "TOA_global_us": t_start*1e6,
            "PW_us": pw_s*1e6,
            "Amp_med": amp_med,
            "f_start_Hz": f_start_hz,
            "f_end_Hz": f_end_hz,
            "f_center_Hz": f_center_hz,
        })

    # ---------- Optional diagnostics ----------
    if plot:
        plt.figure(figsize=(14, 6))
        if plot_raw: plt.plot(t_local*1e6, amp_env, label='Raw Amp', alpha=0.35)
        if plot_smoothed: plt.plot(t_local*1e6, smoothed_amp, label='Smoothed', linewidth=1.2)
        if plot_mask: plt.plot(t_local*1e6, smoothed_amp*mask, label='Masked', alpha=0.9)
        if plot_thres: plt.plot(t_local*1e6, threshold, 'r--', label='Local Threshold')
        for (s, e) in pulse_bounds:
            plt.axvspan(s/fs*1e6, e/fs*1e6, color='lime', alpha=0.08)
        plt.xlabel("Time [µs]"); plt.legend(); plt.grid(True); plt.tight_layout(); plt.show()

    return pulses, pulse_bounds, mask, np.zeros_like(signal), pdws


## 1. Signal Chirp Scanning

### GPU

In [None]:
def build_energy_map_dechirp_gpu_batched(
    signal_noisy, fs, total_duration_s,
    T_chunk=1e-3, stride_s=None,
    k_range=(-1.0e12, 1.0e12), k_step=0.1e12,
    batch_size=32,
    nfft=None, nfft_mode="downpow2",
    nfft_safety_margin: float = 1.05,
    max_df_hz: float | None = None,
    show_progress=True,
    return_candidates=True,
    row_prom_db=4.5,
    row_skip_margin_db=6.0,
    k0_eps_MHz_per_us=0.05,
    max_peaks_per_row=None,
    k_vals_override=None,
    # NEW
    use_parallel: bool = True,
    n_jobs: int = -1,
):
    """
    GPU build of dechirp-FFT energy map, batched across windows.
    Candidate picking can be done serially or in parallel (joblib).
    """

    t_total = time.perf_counter()

    if stride_s is None:
        stride_s = T_chunk

    chunk_len = int(round(T_chunk * fs))
    stride = int(round(stride_s * fs))
    n_total = int(round(total_duration_s * fs))
    n_win = 1 + max(0, int(np.floor((n_total - chunk_len) / stride)))

    # --- k grid ---
    if k_vals_override is not None:
        k_vals = np.array(k_vals_override, dtype=float)
    else:
        k_vals = np.arange(k_range[0], k_range[1] + 0.5*k_step, k_step, dtype=float)

    # GPU refs
    k_vals_gpu = cp.asarray(k_vals, dtype=cp.float32)
    t_rel = cp.arange(chunk_len, dtype=cp.float32) / fs
    t2_rel = t_rel * t_rel
    refs = cp.exp(-1j * cp.pi * k_vals_gpu[:, None] * t2_rel[None, :])  # (K, N)

    # Upload full signal once
    y_gpu = cp.asarray(signal_noisy)

    energy_rows = []
    time_axis = []

    # --- GPU FFT stage ---
    t_fft_start = time.perf_counter()
    for w0 in range(0, n_win, batch_size):
        w_batch = list(range(w0, min(w0+batch_size, n_win)))
        if not w_batch:
            break

        # Stack batch windows: (B, N)
        chunks = []
        for w in w_batch:
            start = w * stride
            end = start + chunk_len
            if end > y_gpu.size:
                continue
            chunks.append(y_gpu[start:end])
            time_axis.append(w * stride / fs)

        if not chunks:
            break

        batch_gpu = cp.stack(chunks, axis=0)             # (B, N)
        scratch = refs[None, :, :] * batch_gpu[:, None, :]   # (B, K, N)
        Y = cp.fft.fft(scratch, n=nfft or chunk_len, axis=-1)
        row_db_batch = 10 * cp.log10(cp.max(cp.abs(Y)**2, axis=-1) + 1e-30)

        # Back to CPU
        energy_rows.append(cp.asnumpy(row_db_batch))

        if show_progress and (w0 % (100*batch_size) == 0):
            print(f"Processed {w0}/{n_win} windows")

    energy_map_db = np.vstack(energy_rows)
    time_axis_s = np.array(time_axis)
    t_fft_end = time.perf_counter()

    if not return_candidates:
        print(f"[Timing] GPU FFT stage: {t_fft_end - t_fft_start:.2f} s "
              f"(total {time.perf_counter() - t_total:.2f} s)")
        return energy_map_db, time_axis_s, k_vals

    # --- Candidate picking ---
    t_pick_start = time.perf_counter()

    def process_row(row_db):
        if row_skip_margin_db is not None:
            if np.max(row_db) < (np.median(row_db) + row_skip_margin_db):
                return []
        pk, _ = find_peaks(row_db, prominence=row_prom_db)
        if (max_peaks_per_row is not None) and (len(pk) > max_peaks_per_row):
            top_idx = np.argsort(row_db[pk])[::-1][:max_peaks_per_row]
            pk = pk[top_idx]
        return pk.tolist()

    if use_parallel:
        row_k_peaks = Parallel(n_jobs=n_jobs, prefer="processes")(
            delayed(process_row)(row_db) for row_db in energy_map_db
        )
    else:
        row_k_peaks = [process_row(row_db) for row_db in energy_map_db]

    t_pick_end = time.perf_counter()

    print(f"[Timing] GPU FFT stage: {t_fft_end - t_fft_start:.2f} s")
    print(f"[Timing] CPU candidate picking ({'parallel' if use_parallel else 'serial'} find_peaks): "
          f"{t_pick_end - t_pick_start:.2f} s")
    print(f"[Timing] Total Time: {time.perf_counter() - t_total:.2f} s")

    near_zero_mask = (np.abs(k_vals) <= (k0_eps_MHz_per_us * 1e12))

    return energy_map_db, time_axis_s, k_vals, row_k_peaks, near_zero_mask


In [None]:
def k_scan_energy_window_gpu(chunk_gpu, fs, k_vals,
                             nfft=None,
                             nfft_mode: str = "downpow2",
                             nfft_safety_margin: float = 1.05,
                             max_df_hz: float | None = None,
                             precomp_refs=None,
                             scratch=None):
    """
    Dechirp + FFT for all k, fully on GPU.
    chunk_gpu: cp.ndarray (signal already on GPU)
    """
    N = chunk_gpu.shape[0]
    if nfft is None:
        L = choose_nfft(N, fs,
                        mode=nfft_mode,
                        safety_margin=nfft_safety_margin,
                        max_df_hz=max_df_hz)
    else:
        L = int(nfft)

    if precomp_refs is None:
        t_rel = cp.arange(N, dtype=cp.float32) / fs
        t2_rel = t_rel * t_rel
        refs = cp.exp(-1j * np.pi * k_vals[:, None] * t2_rel[None, :])
    else:
        refs = precomp_refs

    if scratch is None:
        scratch = cp.empty((len(k_vals), N), dtype=cp.complex64)

    cp.multiply(refs, chunk_gpu[None, :], out=scratch)
    Y = cp.fft.fft(scratch, n=L, axis=-1)

    row_db = 10 * cp.log10(cp.max(cp.abs(Y) ** 2, axis=1) + 1e-30)
    return row_db



### CPU

In [7]:
def next_pow2(n: int) -> int:
    return 1 << (int(n - 1).bit_length())

def prev_pow2(n: int) -> int:
    """Largest power of two <= n (n>=1)."""
    if n <= 1:
        return 1
    return 1 << ((int(n)).bit_length() - 1)

def choose_nfft(N: int, fs: float,
                mode: str = "downpow2",
                safety_margin: float = 1.05,
                max_df_hz: float | None = None) -> int:
    """
    Decide an FFT length smaller than next_pow2(N) but not too small.

    Args
    ----
    N : slice length (samples)
    fs : sample rate (Hz)
    mode : one of {"downpow2", "nopad", "next_pow2"}
      - "downpow2": nfft = largest power of two <= safety_margin * N
      - "nopad":    nfft = N
      - "next_pow2":nfft = next_pow2(N)  (original behavior)
    safety_margin : multiplier >= 1.0 used by "downpow2"
    max_df_hz : if set, enforce fs/nfft <= max_df_hz by increasing nfft
                (only increases within strategy’s family)

    Returns
    -------
    nfft : int
    """
    assert N >= 8, "N too small"
    assert safety_margin >= 1.0, "safety_margin should be >= 1.0"

    if mode == "nopad":
        nfft = int(N)

    elif mode == "next_pow2":
        nfft = next_pow2(N)

    elif mode == "downpow2":
        # target is slightly above N, then round DOWN to a friendly pow2
        target = max(int(np.ceil(N * safety_margin)), N)
        nfft = prev_pow2(target)
        # ensure we didn't round below something unreasonably small
        # (you can relax/tighten this line if needed)
        nfft = max(nfft, prev_pow2(N))  # don't go below the previous pow2 of N
    else:
        raise ValueError(f"Unknown nfft mode: {mode}")

    # Enforce bin-spacing guard if requested
    if max_df_hz is not None:
        # required nfft to meet df <= max_df_hz
        needed = int(np.ceil(fs / max_df_hz))
        if nfft < needed:
            if mode == "nopad":
                # bump just enough to satisfy df; mixed radix is fine
                nfft = needed
            elif mode in ("downpow2", "next_pow2"):
                # bump to next power-of-two that meets df
                nfft = next_pow2(needed)
    return int(nfft)


def k_scan_energy_window(chunk, fs, k_vals,
                         nfft=None,
                         nfft_mode: str = "downpow2",
                         nfft_safety_margin: float = 1.05,
                         max_df_hz: float | None = None,
                         precomp_refs=None,
                         scratch=None):
    """
    Vectorized version:
      - de-chirp by all k in k_vals using precomputed refs
      - FFT along axis for all k
      - return max-bin power (dB) for each k
    """
    N = len(chunk)
    if nfft is None:
        L = choose_nfft(N, fs,
                        mode=nfft_mode,
                        safety_margin=nfft_safety_margin,
                        max_df_hz=max_df_hz)
    else:
        L = int(nfft)

    if precomp_refs is None:
        # fallback: compute refs on the fly
        t_rel = np.arange(N) / fs
        t2_rel = t_rel * t_rel
        refs = np.exp(-1j * np.pi * k_vals[:, None] * t2_rel[None, :]).astype(np.complex64)
    else:
        refs = precomp_refs

    if scratch is None:
        scratch = np.empty((len(k_vals), N), dtype=np.complex64)

    # Broadcast multiply: dechirp all k in one go
    np.multiply(refs, chunk[None, :], out=scratch)

    # Batched FFT along axis=-1
    Y = np.fft.fft(scratch, n=L, axis=-1)

    # Max bin per row
    row_db = 10 * np.log10((np.abs(Y)**2).max(axis=1) + 1e-30)
    return row_db


def build_energy_map_dechirp(signal_noisy, fs, total_duration_s,
                             T_chunk=1e-3, stride_s=None,
                             k_range=(-1.0e12, 1.0e12), k_step=0.1e12,
                             nfft=None, nfft_mode="downpow2",
                             nfft_safety_margin=1.05,
                             max_df_hz: float | None = None,
                             show_progress=True,
                             return_candidates=True,
                             row_prom_db=4.5,
                             row_skip_margin_db=6.0,
                             k0_eps_MHz_per_us=0.05,
                             max_peaks_per_row=None,
                             k_vals_override=None,
                             n_jobs=None):
    """
    Build the dechirp-FFT energy map (CPU or GPU).
    If use_gpu=True, runs on GPU without joblib.
    """
    if stride_s is None:
        stride_s = T_chunk

    chunk_len = int(round(T_chunk * fs))
    stride = int(round(stride_s * fs))
    n_total = int(round(total_duration_s * fs))
    n_win = 1 + max(0, int(np.floor((n_total - chunk_len) / stride)))

    # --- build k grid ---
    if k_vals_override is not None:
        k_vals = np.array(k_vals_override, dtype=float)
    else:
        k_vals = np.arange(k_range[0], k_range[1] + 0.5*k_step, k_step, dtype=float)

    # Precompute refs
    N = chunk_len
    t_rel = np.arange(N) / fs
    t2_rel = t_rel * t_rel
    precomp_refs = np.exp(-1j * np.pi * k_vals[:, None] * t2_rel[None, :]).astype(np.complex64)

    # --- process windows ---
    energy_rows = []
    time_axis = []

    # parallel CPU path
    if n_jobs is None:
        n_jobs = multiprocessing.cpu_count()

    results = Parallel(n_jobs=n_jobs)(
        delayed(process_window)(
            w, signal_noisy, fs, stride, chunk_len,
            k_vals, nfft, nfft_mode,
            nfft_safety_margin, max_df_hz,
            precomp_refs
        )
        for w in range(n_win)
    )

    energy_rows = [row for row, _ in results if row is not None]
    time_axis = [t0 for _, t0 in results if t0 is not None]

    # --- stack results ---
    energy_map_db = np.vstack(energy_rows) if energy_rows else np.empty((0, len(k_vals)))
    time_axis_s = np.array(time_axis, dtype=float)

    if not return_candidates:
        return energy_map_db, time_axis_s, k_vals

    # --- candidate gathering ---
    row_k_peaks = []
    if energy_map_db.size == 0:
        near_zero_mask = (np.abs(k_vals) <= (k0_eps_MHz_per_us * 1e12))
        return energy_map_db, time_axis_s, k_vals, row_k_peaks, near_zero_mask

    near_zero_mask = (np.abs(k_vals) <= (k0_eps_MHz_per_us * 1e12))

    for w in range(energy_map_db.shape[0]):
        row_db = energy_map_db[w]

        if row_skip_margin_db is not None:
            if np.max(row_db) < (np.median(row_db) + row_skip_margin_db):
                row_k_peaks.append([])
                continue

        pk, _ = find_peaks(row_db, prominence=row_prom_db)

        if (max_peaks_per_row is not None) and (len(pk) > max_peaks_per_row):
            top_idx = np.argsort(row_db[pk])[::-1][:max_peaks_per_row]
            pk = pk[top_idx]

        row_k_peaks.append(pk.tolist())

    return energy_map_db, time_axis_s, k_vals, row_k_peaks, near_zero_mask



## 2. Identifying Signal Types (CW, LFMCW, or LFM pulses)
Single Frequency Pulses will be handled later using the same functions after CW removal

In [None]:
def detect_pulse_chirp_rates_nms(
    counts, k_vals, *,
    sigma: float = 3.0,        # MAD threshold multiplier
    min_count: float = 10.0,   # absolute floor on peak height
    nms_bins: int = 3,         # suppress neighbors within ±nms_bins
    smooth_bins: int = 0       # 0 = off, else odd window (e.g., 3 or 5)
):
    """
    Blind chirp-rate peak detection with robust threshold + 1-D NMS.
    - No max-peaks cap.
    - NMS is applied uniformly (including k≈0).
    - Optional light smoothing for noisier histograms.

    Returns
    -------
    peak_indices : (M,) int
    peak_k_vals  : (M,) float  (parabolic sub-bin refinement)
    peak_counts  : (M,) float  (from ORIGINAL counts, not smoothed)
    """
    counts = np.asarray(counts, dtype=float)
    k_vals = np.asarray(k_vals, dtype=float)
    N = counts.size
    if N == 0:
        return np.array([], int), np.array([]), np.array([])

    # ---- optional smoothing for the *detector path* only ----
    if smooth_bins and smooth_bins > 1:
        if smooth_bins % 2 == 0:
            smooth_bins += 1  # ensure odd
        kernel = np.ones(smooth_bins, dtype=float) / smooth_bins
        counts_det = np.convolve(counts, kernel, mode="same")
    else:
        counts_det = counts

    # ---- robust threshold (MAD) on the detector signal ----
    nz = counts_det[counts_det > 0]
    if nz.size == 0:
        return np.array([], int), np.array([]), np.array([])
    med = np.median(nz)
    mad = 1.4826 * np.median(np.abs(nz - med))
    thr = max(float(min_count), med + sigma * mad)

    # ---- initial candidates (no distance; NMS handles clustering) ----
    cand, _ = find_peaks(counts_det, height=thr)
    if cand.size == 0:
        return np.array([], int), np.array([]), np.array([])

    # ---- NMS in 1-D over original counts (strong→weak) ----
    order = np.argsort(counts[cand])[::-1]
    suppressed = np.zeros(N, dtype=bool)
    picked = []
    for idx in cand[order]:
        if suppressed[idx]:
            continue
        picked.append(int(idx))
        lo = max(0, idx - nms_bins)
        hi = min(N, idx + nms_bins + 1)
        suppressed[lo:hi] = True

    picked = np.array(sorted(picked), dtype=int)

    # ---- parabolic refinement (use ORIGINAL counts) ----
    refined = []
    for p in picked:
        if p == 0 or p == N - 1:
            refined.append(k_vals[p])
            continue
        c0, c1, c2 = counts[p - 1], counts[p], counts[p + 1]
        denom = (c0 - 2.0 * c1 + c2)
        delta = 0.0 if denom == 0.0 else 0.5 * (c0 - c2) / denom  # sub-bin offset
        # handle irregular grids: average left/right steps
        step = 0.5 * ((k_vals[p] - k_vals[p - 1]) + (k_vals[p + 1] - k_vals[p]))
        refined.append(k_vals[p] + delta * step)

    return picked, np.array(refined, dtype=float), counts[picked]


In [None]:
def classify_candidates(
    k_vals,
    row_k_peaks,
    *,
    min_support_frac_lfmcw=0.30,
    min_support_frac_cw=0.30,
    min_support_windows=None,
    zero_band_bins=1,         # how many bins around k=0 to treat as “CW region”
    nms_sigma=3.0,
    nms_min_count=10,
    nms_bins=3,
    smooth_bins=0,
):
    """
    Detect and classify candidate k bins into LFMCW, CW, or pulse types.

    Steps
    -----
    1. Compute support_counts (# of windows where each k bin was active).
    2. Run blind chirp-rate detection with NMS.
    3. Classify refined peaks into:
         - CW: near-zero k bins with enough support
         - LFMCW: nonzero bins with enough support (fraction or window count)
         - Pulse: all other NMS peaks not claimed as CW/LFMCW

    Returns
    -------
    results : list[dict]
        Each dict has keys: {'mode','k_hat','peak_idx','support_sum'}
    support_counts : (K,) int
    support_frac   : (K,) float
    """
    K = len(k_vals)
    n_win = len(row_k_peaks)

    # --- 1. Build support histogram ---
    support_counts = np.zeros(K, dtype=int)
    for idxs in row_k_peaks:
        if idxs is None or len(idxs) == 0:
            continue
        support_counts[idxs] += 1

    support_frac = support_counts / max(1, n_win)

    # --- 2. Detect peaks with NMS ---
    peak_idx, k_refined, peak_counts = detect_pulse_chirp_rates_nms(
        support_counts, k_vals,
        sigma=nms_sigma,
        min_count=nms_min_count,
        nms_bins=nms_bins,
        smooth_bins=smooth_bins,
    )

    # --- 3. Classification ---
    results = []
    k0 = int(np.argmin(np.abs(k_vals)))  # closest-to-zero index
    lo = max(0, k0 - zero_band_bins)
    hi = min(K, k0 + zero_band_bins + 1)

    for idx, k_val, cnt in zip(peak_idx, k_refined, peak_counts):
        mode = "pulse"
        frac = support_frac[idx]

        if lo <= idx < hi and frac >= min_support_frac_cw:
            mode = "cw"
        elif frac >= min_support_frac_lfmcw:
            mode = "lfmcw"

        results.append({
            "mode": mode,
            "k_hat": float(k_val),
            "peak_idx": int(idx),
            "support_sum": int(cnt),
        })

    return results, support_counts, support_frac


## 3. LFMCW/CW Handling

### Filtering and Coupon Slicing

In [None]:
def windows_supporting_khat(k_hat, k_vals, row_k_peaks, *, max_bin_delta=1): 
    k_idx_near = int(np.argmin(np.abs(k_vals - k_hat)))
    hits = []
    for w, idxs in enumerate(row_k_peaks):
        if idxs is None or len(idxs) == 0:
            continue
        idxs = np.asarray(idxs)
        if np.any(np.abs(idxs - k_idx_near) <= max_bin_delta):
            hits.append(w)
    return hits

def pick_blocks_from_hits(hits, X): 
    if not hits:
        return []
    hits = sorted(hits)
    if len(hits) <= X:
        return [hits]
    mid = len(hits) // 2
    blocks = [
        hits[:X],
        hits[max(0, mid - X//2): max(0, mid - X//2) + X],
        hits[-X:]
    ]
    # dedupe overlapping blocks
    uniq, seen = [], set()
    for b in blocks:
        t = tuple(b)
        if t not in seen:
            uniq.append(b); seen.add(t)
    return uniq

def make_nodes_from_windows(block_windows, T_chunk_s): 
    # minimal nodes for your backend
    return [{"w": int(w), "t0": float(w) * T_chunk_s} for w in sorted(block_windows)]


In [None]:
def _build_segment(sig, fs, T_chunk_s, nodes_block): 
    t0 = nodes_block[0]["t0"]
    t1 = nodes_block[-1]["t0"] + T_chunk_s
    i0, i1 = max(0, int(round(t0*fs))), min(len(sig), int(round(t1*fs)))
    if i1 <= i0:
        return np.zeros(1, dtype=np.complex64), float(t0), 0.0
    x = sig[i0:i1].astype(np.complex64, copy=False)
    return x, t0, (i1 - i0) / fs

def _alpha_from_k(k_hz_per_s, fs, T_seg): 
    a = 1.0 + (2.0/np.pi)*np.arctan((k_hz_per_s*T_seg)/fs)
    return float(np.clip(a, 1e-3, 2.0-1e-3))
def _block_medians(pdws): 
    if not pdws:
        return dict(N=0, PRI_us=np.nan, PW_us=np.nan, Amp_med=np.nan, f_center_Hz=np.nan)
    toas = np.sort([p["TOA_global_us"] for p in pdws])
    pris = np.diff(toas) if len(toas) >= 2 else np.array([np.nan])
    return dict(
        N=len(pdws),
        PRI_us=float(np.nanmedian(pris)) if pris.size else np.nan,
        PW_us =float(np.nanmedian([p["PW_us"] for p in pdws])),
        Amp_med=float(np.nanmedian([p["Amp_med"] for p in pdws])),
        f_center_Hz=float(np.nanmedian([p.get("f_center_Hz", np.nan) for p in pdws])),
    )

In [None]:
def summarize_all_nodes_df(
    clusters,
    *,
    k_vals, row_k_peaks,
    signal, fs, T_chunk_s, X,
    peak_frac=0.5,          # CW FFT peak picking threshold
    sanity_threshold=0.8,   # fraction of block span required
    lfmcw_sanity_rel_margin=0.05,   # 5% slack
    lfmcw_sanity_abs_us=1.0,        # +1 µs slack
    downsample_factor=1,
    AMP_MED_MIN=0.1,
    **extract_kwargs
):
    dfs = []
    for cl in clusters:
        mode  = cl['mode']
        k_hat = float(cl['k_hat'])

        # Find supporting windows
        hits = windows_supporting_khat(k_hat, k_vals, row_k_peaks, max_bin_delta=1)
        if not hits:
            dfs.append(pd.DataFrame([{"error": f"{mode} k={k_hat:.3e}: no supporting windows"}]))
            continue

        # Pick coupon blocks
        X_eff = min(X, len(hits)) if len(hits) < 3*X else X
        blocks = pick_blocks_from_hits(hits, X_eff)

        summaries, all_pdws, all_f0s = [], [], []   # collect PDWs + tone freqs
        for bw in blocks:
            nodes_block = make_nodes_from_windows(bw, T_chunk_s)

            if mode == "lfmcw":
                pdws = _extract_block_lfmcw_pdws(
                    nodes_block,
                    signal=signal, fs=fs, T_chunk_s=T_chunk_s,
                    k_hz_per_s=k_hat,
                    downsample_factor=downsample_factor,
                    **extract_kwargs
                )
                
            elif mode == "cw":
                # FFT block to detect tones
                x_seg, t0_start, T_seg = _build_segment(signal, fs, T_chunk_s, nodes_block)
                if len(x_seg) == 0:
                    pdws, f0s = [], []
                else:
                    Nfft = 1 << (len(x_seg)-1).bit_length()
                    Xfft = np.fft.fft(x_seg * np.hanning(len(x_seg)), n=Nfft)
                    mag  = np.abs(Xfft)
                    f_axis = np.fft.fftfreq(Nfft, 1/fs)

                    pk, props = find_peaks(mag, height=np.max(mag)*peak_frac)
                    f0s = [float(f_axis[p]) for p in pk] if pk.size else [float(f_axis[np.argmax(mag)])]
                    
                    pdws = []
                    for f0_hz in f0s:
                        pdws.extend(
                            _extract_block_cw_pdws(
                                nodes_block,
                                signal=signal, fs=fs, T_chunk_s=T_chunk_s,
                                f0_hz=f0_hz,
                                downsample_factor=downsample_factor,
                                **extract_kwargs
                            )
                        )
                all_f0s.extend(f0s)

            else:
                raise ValueError(f"Unknown cluster mode: {mode}")

            if pdws:
                all_pdws.extend(pdws)
                summaries.append(_block_medians(pdws))
        
        if not summaries:
            dfs.append(pd.DataFrame([{"error": f"{mode} k={k_hat:.3e}: no PDWs extracted"}]))
            continue
        # --- Aggregation ---
        if mode == "lfmcw":
            start_time_us = float(min(p["TOA_global_us"] for p in all_pdws))
            PW_us  = float(np.nanmedian([s["PW_us"] for s in summaries]))
            PRI_us = float(np.nanmedian([s["PRI_us"] for s in summaries]))
            Amp_med = float(np.nanmedian([s["Amp_med"] for s in summaries]))
            if not np.isfinite(Amp_med) or Amp_med < AMP_MED_MIN:
                continue

            vals = [s.get("f_center_Hz", np.nan) for s in summaries]
            finite_vals = [v for v in vals if not np.isnan(v)]
            f_center_Hz = float(np.median(finite_vals)) if finite_vals else np.nan

            k_mhz_per_us = k_hat / 1e12
            f_center_MHz = f_center_Hz / 1e6 if not np.isnan(f_center_Hz) else np.nan
            BW_MHz = k_mhz_per_us * PW_us
            f_start_MHz = f_center_MHz - 0.5 * BW_MHz if not np.isnan(f_center_MHz) else np.nan
            f_end_MHz   = f_center_MHz + 0.5 * BW_MHz if not np.isnan(f_center_MHz) else np.nan
            ptype = "LFMCW"

            # ---- LFMCW sanity: PW <= PRI (with slack) ----
            if np.isfinite(PW_us) and np.isfinite(PRI_us) and PRI_us > 0:
                pri_with_margin = PRI_us * (1.0 + lfmcw_sanity_rel_margin) + lfmcw_sanity_abs_us
                lfmcw_sanity_ok = bool(PW_us <= pri_with_margin)
            else:
                lfmcw_sanity_ok = pd.NA  # unknown PRI/PW → indeterminate

            dfs.append(pd.DataFrame([{
                "TOA (us)": start_time_us,
                "PW (us)": PW_us,
                "PRI (us)": PRI_us,
                "Envelope Amplitude": Amp_med,
                "Center Freq (MHz)": f_center_MHz,
                "Chirp Rate (MHz/us)": k_mhz_per_us,
                "Bandwidth (MHz)": BW_MHz,
                "Start Freq (MHz)": f_start_MHz,
                "End Freq (MHz)": f_end_MHz,
                "Pulse Type": ptype,
                "pulse_sanity_ok": lfmcw_sanity_ok,   # now boolean/<NA>
            }]))

        elif mode == "cw":

            # Use FFT-detected frequencies as the true tone list
            finite_f0s = [f for f in all_f0s if not np.isnan(f)]
            if not finite_f0s:
                continue

            # unique baseband tones in MHz (rounded like before)
            f_center_list = sorted(list(set([round(f/1e6, 2) for f in finite_f0s])))

            # block-level sanity check
            block_span_us = X_eff * T_chunk_s * 1e6
            pw_blocks = [s["PW_us"] for s in summaries if not np.isnan(s["PW_us"])]
            pulse_sanity_ok = all(pw >= sanity_threshold * block_span_us for pw in pw_blocks)

            rows = []

            # per-tone summarization
            for f0_MHz in f_center_list:

                # find all PDWs generated for this tone
                tone_pdws = [p for p in all_pdws
                             if "tone_freq_hz" in p
                             and round(p["tone_freq_hz"]/1e6, 2) == f0_MHz]

                if not tone_pdws:
                    continue

                # per-tone stats
                TOA_us = float(min(p["TOA_global_us"] for p in tone_pdws))
                PW_us  = (hits[-1] - hits[0] + 1) * T_chunk_s * 1e6
                Amp    = float(np.nanmedian([p["Amp_med"]  for p in tone_pdws]))
                if not np.isfinite(Amp) or Amp < AMP_MED_MIN:
                    continue

                rows.append({
                    "TOA (us)": TOA_us,
                    "PW (us)": PW_us,
                    "PRI (us)": np.nan,
                    "Envelope Amplitude": Amp,
                    "Center Freq (MHz)": f0_MHz,
                    "Chirp Rate (MHz/us)": 0.0,
                    "Bandwidth (MHz)": 0.0,
                    "Start Freq (MHz)": f0_MHz,
                    "End Freq (MHz)": f0_MHz,
                    "Pulse Type": "CW",
                    "pulse_sanity_ok": pulse_sanity_ok,
                })

            dfs.append(pd.DataFrame(rows))
            continue


    return pd.concat(dfs, ignore_index=True) if dfs else pd.DataFrame()


#### CW Path

In [15]:
def _extract_block_cw_pdws(
    nodes_block, *,
    signal, fs, T_chunk_s,
    f0_hz,                      # baseband CW frequency estimate
    lpf_bw_hz=200e3,            # low-pass cutoff around DC
    numtaps=513,
    window_us=5.0,
    min_sustain_us=5.0,
    thr_pct=50,
    amp_floor_pct=90,
    downsample_factor=1
):
    # 1) time span
    x_seg, t0_start, T_seg = _build_segment(signal, fs, T_chunk_s, nodes_block)
    if len(x_seg) == 0:
        return []

    # 2) heterodyne to DC
    t_seg = t0_start + np.arange(len(x_seg))/fs
    x_bb  = x_seg * np.exp(-1j*2*np.pi*f0_hz*t_seg)

    # 3) LPF
    lpf_taps = firwin(numtaps, lpf_bw_hz, fs=fs)
    x_filt = lfilter(lpf_taps, 1.0, x_bb)

    # 4) Extract PDWs (same baseband extractor as LFMCW)
    pulses, bounds, mask, _clean, pdws = baseband_PDW_extractor(
        x_filt, fs,
        k_hz_per_s=0.0,
        t0_global_s=t0_start,
        window_us=window_us,
        min_sustain_us=min_sustain_us,
        threshold_percentile=thr_pct,
        amp_floor_pct=amp_floor_pct,
        downsample_factor=downsample_factor,
        plot=False
    )

    for p in pdws:
        p["tone_freq_hz"] = f0_hz

    return pdws

#### LFMCW Path

In [22]:
def _extract_block_lfmcw_pdws(nodes_block, *, signal, fs, T_chunk_s, k_hz_per_s, half_bin=10,
                              prom_db=12.0, window_us=5.0, min_sustain_us=5.0, thr_pct=50, amp_floor_pct=90,
                              downsample_factor=1): 
    # 1) time span for this block
    x_seg, t0_start, T_seg = _build_segment(signal, fs, T_chunk_s, nodes_block)

    # 2) FrFT at α(k,T_seg); find peaks
    alpha = _alpha_from_k(k_hz_per_s, fs, T_seg)
    Xa = frft(x_seg, alpha)                   
    mag_db = 20*np.log10(np.abs(Xa)+1e-12)
    base = np.median(mag_db)
    peak_dist = int(max(1, round(fs * T_chunk_s)))
    peaks, _ = find_peaks(mag_db, height=base+prom_db, distance=peak_dist)
    max_val = np.max(mag_db)
    valid = mag_db[peaks] > (max_val - 6)  # within 6 dB of strongest
    peaks = peaks[valid]
    # 3) For each peak: gate → inverse → extract PDWs (may return multiple pulses)
    pdws = []
    for p in peaks:
        lo, hi = max(0, p-half_bin), min(len(Xa), p+half_bin+1)

        G = np.zeros_like(Xa); G[lo:hi] = Xa[lo:hi]
        y = frft(G, -alpha)                    
        pulses, bounds, mask, _clean, pdw_list = baseband_PDW_extractor(
            y, fs,
            k_hz_per_s=k_hz_per_s,
            t0_global_s=t0_start,
            window_us=window_us,
            min_sustain_us=min_sustain_us,
            threshold_percentile=thr_pct,
            amp_floor_pct=amp_floor_pct,
            downsample_factor=downsample_factor
        )
        pdws.extend(pdw_list)
    return pdws

### CW Notching

In [None]:
def notch_cws_fft_bins_gpu(signal, fs, cw_freqs_hz, *,
                           T_chunk_s, nfft_mode="nopad",
                           nfft_safety_margin=1.05,
                           max_df_hz=None,
                           half_bins=3):
    """
    GPU batched version of notch_cws_fft_bins.
    Removes CW tones by zeroing ±half_bins around each CW freq in FFT domain.

    Parameters
    ----------
    signal : ndarray (complex)
        Input IQ samples (NumPy or CuPy).
    fs : float
        Sampling rate [Hz].
    cw_freqs_hz : list or array of float
        CW freqs to notch (Hz).
    T_chunk_s : float
        Chunk duration [s].
    nfft_mode, nfft_safety_margin, max_df_hz :
        Passed to choose_nfft.
    half_bins : int
        Half-width of notch (bins).

    Returns
    -------
    residual : CuPy ndarray (complex64)
        Signal after CW notching, living on GPU.
    (half_bins, L_win) : tuple
        Diagnostics (same as CPU version).
    """
    # Ensure CuPy array
    signal_gpu = cp.asarray(signal, dtype=cp.complex64)

    N_total = signal_gpu.size
    N_chunk = int(round(T_chunk_s * fs))
    n_chunks = N_total // N_chunk
    if n_chunks == 0:
        return signal_gpu, (half_bins, N_chunk)

    # FFT length
    L_win = choose_nfft(N_chunk, fs, mode=nfft_mode,
                        safety_margin=nfft_safety_margin,
                        max_df_hz=max_df_hz)

    # Reshape into [n_chunks, N_chunk], zero-pad to [n_chunks, L_win]
    chunks = signal_gpu[:n_chunks * N_chunk].reshape(n_chunks, N_chunk)
    X_time = cp.zeros((n_chunks, L_win), dtype=cp.complex64)
    X_time[:, :N_chunk] = chunks

    # FFT along axis=1
    X_freq = cp.fft.fft(X_time, n=L_win, axis=1)

    # Build a single boolean notch mask over bins
    f_axis = cp.fft.fftfreq(L_win, d=1/fs)
    notch_mask = cp.zeros(L_win, dtype=cp.bool_)
    for f0 in cw_freqs_hz or []:
        i0 = int(cp.argmin(cp.abs(f_axis - float(f0))))
        lo = max(0, i0 - half_bins)
        hi = min(L_win, i0 + half_bins + 1)
        notch_mask[lo:hi] = True

    # Apply notch mask to all rows (broadcast)
    X_freq[:, notch_mask] = 0.0

    # IFFT back
    X_ifft = cp.fft.ifft(X_freq, n=L_win, axis=1)[:, :N_chunk]

    # Flatten back to 1-D
    out = X_ifft.reshape(-1).astype(cp.complex64, copy=False)

    return out, (half_bins, L_win)


In [None]:
def notch_cws_fft_bins_fast(signal, fs, cw_freqs_hz, *,
                            T_chunk_s, nfft_mode="nopad",
                            nfft_safety_margin=1.05,
                            max_df_hz=None,
                            half_bins=3,
                            taper=True,
                            batch_chunks=64):
    """
    Drop-in faster version of notch_cws_fft_bins with identical behavior.

    Speed tricks (no logic changes):
      - Precompute notch bin indices once; use a single boolean mask.
      - Batch chunks and use 2D FFT along axis=1.
      - Avoids work when no CW freqs.
      - Reuses work arrays to reduce allocations.

    Parameters (same as original) + batch_chunks:
      batch_chunks : int
          Number of time-chunks to process per batch (tune to your RAM/CPU).

    Returns
    -------
    residual : complex ndarray (complex64)
    (half_bins, L_win) : tuple for diagnostics (unchanged)
    """
    # --- Setup identical to original ---
    signal = np.asarray(signal)
    N_total = signal.size
    N_chunk = int(round(T_chunk_s * fs))

    # Truncate tail to full chunks (identical behavior to your loop)
    n_chunks = N_total // N_chunk
    if n_chunks == 0:
        return signal.astype(np.complex64, copy=True), (half_bins, N_chunk)

    L_win = choose_nfft(N_chunk, fs, mode=nfft_mode,
                        safety_margin=nfft_safety_margin,
                        max_df_hz=max_df_hz)

    # Early exit: if no tones, skip all FFTs and just return a typed copy
    if cw_freqs_hz is None or len(cw_freqs_hz) == 0:
        out = signal[:n_chunks * N_chunk].astype(np.complex64, copy=True)
        return out, (half_bins, L_win)

    # Frequency axis of the L_win FFT (same as original)
    f_axis = np.fft.fftfreq(L_win, d=1/fs)

    # Precompute Hann (if used), padded to L_win only when applying FFT
    win = np.hanning(N_chunk).astype(np.float32) if taper else None

    # --- Build one boolean notch mask over FFT bins (vectorized zeroing) ---
    notch_mask = np.zeros(L_win, dtype=bool)
    # exact same "nearest-bin" selection as original
    for f0 in cw_freqs_hz:
        i0 = int(np.argmin(np.abs(f_axis - float(f0))))
        lo = max(0, i0 - half_bins)
        hi = min(L_win, i0 + half_bins + 1)
        notch_mask[lo:hi] = True

    # --- Allocate output and working buffers ---
    out = np.empty(n_chunks * N_chunk, dtype=np.complex64)

    # A small helper to process a batch of contiguous chunks
    def process_batch(bstart, bend):
        bsize = bend - bstart
        # Build a [bsize, L_win] complex64 buffer
        X_time = np.zeros((bsize, L_win), dtype=np.complex64, order='C')

        # Copy chunks and apply window (if any)
        for bi, w in enumerate(range(bstart, bend)):
            s0 = w * N_chunk
            s1 = s0 + N_chunk
            seg = signal[s0:s1].astype(np.complex64, copy=False)
            if taper:
                seg = seg * win  # broadcasts to complex64
            X_time[bi, :N_chunk] = seg  # zero-padded tail already zeros

        # FFT along axis=1 (each row is a chunk)
        X_freq = np.fft.fft(X_time, n=L_win, axis=1)

        # Notch (vectorized)
        X_freq[:, notch_mask] = 0.0

        # IFFT back and write first N_chunk samples per row
        X_ifft = np.fft.ifft(X_freq, n=L_win, axis=1)[:, :N_chunk].astype(np.complex64, copy=False)
        # Store into output
        for bi, w in enumerate(range(bstart, bend)):
            s0 = w * N_chunk
            s1 = s0 + N_chunk
            out[s0:s1] = X_ifft[bi]

    # --- Batch over all chunks to limit memory while keeping vectorization ---
    if batch_chunks is None or batch_chunks < 1:
        batch_chunks = 64  # safe default

    w = 0
    while w < n_chunks:
        w_end = min(n_chunks, w + batch_chunks)
        process_batch(w, w_end)
        w = w_end

    # Append untouched tail exactly like original (original dropped it; keep identical)
    # Your original function *drops* the tail past last full chunk.
    # To be bit-identical, we will drop the tail too.
    # If you'd rather keep the tail unchanged, uncomment the lines below:
    # tail = signal[n_chunks * N_chunk:]
    # if tail.size:
    #     out = np.concatenate([out, tail.astype(np.complex64, copy=False)])

    return out, (half_bins, L_win)


In [4]:
def notch_cws_fft_bins(signal, fs, cw_freqs_hz, *,
                       T_chunk_s, nfft_mode="nopad",
                       nfft_safety_margin=1.05,
                       max_df_hz=None,
                       half_bins=3,
                       taper=True):
    """
    Remove CW tones in per-window FFTs by zeroing ±half_bins around each tone.

    Parameters
    ----------
    signal : complex ndarray
        Input IQ samples.
    fs : float
        Sampling rate [Hz].
    cw_freqs_hz : list of float
        CW center freqs to notch (Hz).
    T_chunk_s : float
        Chunk (window) duration [s] (same as Stage-4).
    nfft_mode, nfft_safety_margin, max_df_hz :
        Passed to choose_nfft (Stage-4 FFT length choice).
    half_bins : int
        Half-width of notch in FFT bins (>=1).
    taper : bool
        Apply Hann taper before FFT to reduce leakage.

    Returns
    -------
    residual : complex ndarray
        Signal after per-window CW notching.
    (half_bins, L_win) : tuple returned as well for diagnostics.
    """
    N_chunk = int(round(T_chunk_s * fs))
    L_win = choose_nfft(N_chunk, fs, mode=nfft_mode,
                        safety_margin=nfft_safety_margin,
                        max_df_hz=max_df_hz)
    f_axis = np.fft.fftfreq(L_win, d=1/fs)

    win = np.hanning(N_chunk).astype(np.float32) if taper else None
    n_chunks = len(signal) // N_chunk
    out = np.empty_like(signal, dtype=np.complex64)

    for w in range(n_chunks):
        s0, s1 = w * N_chunk, (w + 1) * N_chunk
        seg = signal[s0:s1].astype(np.complex64, copy=False)
        if taper:
            seg = seg * win

        X = np.fft.fft(seg, n=L_win)

        for f0 in cw_freqs_hz or []:
            i0 = int(np.argmin(np.abs(f_axis - float(f0))))
            lo = max(0, i0 - half_bins)
            hi = min(L_win, i0 + half_bins + 1)
            X[lo:hi] = 0.0

        seg_filt = np.fft.ifft(X, n=L_win)[:N_chunk].astype(np.complex64)
        out[s0:s1] = seg_filt

    return out, (half_bins, L_win)


## 4. Node Building for Pulses

In [None]:
def build_k_buckets(peaks, k_vals, k_refined, *, width_bins=1):
    """
    Create buckets around each detected slope. 
    Each detected k_refined has a small set of raw bins associated with it.
    
    Parameters
    ----------
    peaks : array of ints
        Indices of detected peak bins.
    k_vals : array of floats
        Chirp rate grid (Hz/s).
    k_refined : array of floats
        Refined chirp rates from parabolic interpolation.
    width_bins : int
        Half-width (in bins) around each peak to include.
    
    Returns
    -------
    buckets : list of dict
        Each dict has keys: kid, k_refined, bin_indices (np.ndarray).
    bin2kid : np.ndarray
        For each k-bin index, which kid it belongs to (or -1).
    """
    N = len(k_vals)
    K = len(peaks)
    buckets, bin2kid = [], -np.ones(N, dtype=int)

    for kid, p in enumerate(peaks):
        L = max(0, p - width_bins)
        R = min(N - 1, p + width_bins)
        idxs = np.arange(L, R + 1)

        # assign these bins to this kid (nearest wins if overlaps)
        for i in idxs:
            if bin2kid[i] == -1 or abs(i - p) < abs(i - peaks[bin2kid[i]]):
                bin2kid[i] = kid

        buckets.append({
            "kid": kid,
            "k_refined": k_refined[kid],
            "bin_indices": idxs
        })

    return buckets, bin2kid


In [None]:
def assign_buckets_per_window(row_k_peaks_reduced, bin2kid, n_kids):
    """
    For each window, assign presence of each kid (slope bucket).
    From row_k_peaks_reduced, mark which k-modes (kids) appear.
    
    Parameters
    ----------
    row_k_peaks_reduced : list of arrays
        Each entry = active k-bin indices for that window.
    bin2kid : np.ndarray
        Map from bin index to kid (or -1).
    n_kids : int
        Number of buckets.
    
    Returns
    -------
    presence : ndarray of shape (n_windows, n_kids), bool
    """
    n_windows = len(row_k_peaks_reduced)
    presence = np.zeros((n_windows, n_kids), dtype=bool)

    for w, active_bins in enumerate(row_k_peaks_reduced):
        for b in active_bins:
            kid = bin2kid[b]
            if kid >= 0:
                presence[w, kid] = True
    return presence


In [None]:
def group_windows_with_gaps(win_list, max_gap=1):
    """
    Turn a list of window indices into contiguous runs, allowing small gaps.
    
    Returns a list of (w_start, w_end).
    """
    if not win_list:
        return []
    runs, cur_start, cur_end = [], win_list[0], win_list[0]
    for w in win_list[1:]:
        if w <= cur_end + max_gap + 1:
            cur_end = w
        else:
            runs.append((cur_start, cur_end))
            cur_start, cur_end = w, w
    runs.append((cur_start, cur_end))
    return runs


In [None]:
from dataclasses import dataclass
from typing import List, Optional

@dataclass
class Node:
    kid: int
    k_refined: float
    bin_indices: Optional[List[int]]
    w_start: int
    w_end: int
    t_start_s: float
    t_end_s: float
    n_windows: int
    duration_s: float
    coverage_frac: float

    # new padded gate fields
    gate_w_start: int
    gate_w_end: int
    t_gate_start_s: float
    t_gate_end_s: float

    def __repr__(self):
        return (f"Node(kid={self.kid}, k_refined={self.k_refined:.2e}, "
                f"w_range=({self.gate_w_start}-{self.gate_w_end}), "
                f"t_range=({self.t_gate_start_s:.6f}-{self.t_gate_end_s:.6f})s, "
                f"n_windows={self.n_windows}, coverage={self.coverage_frac:.2f})")
        
def make_nodes_from_groups(
    buckets,
    presence,
    time_axis_s,
    T_chunk_s,
    stride_s,              # kept for signature compatibility (not needed if time_axis_s is start times)
    pad=1,                 # padding in window hops (±pad windows)
    min_windows=1,
    max_gap=1              # allow this many missing hops inside a run
):
    """
    Build Node objects from bucket presence over time, with optional padded time gates.

    Parameters
    ----------
    buckets : list of dict
        Output of build_k_buckets. Each item may have keys like {"k_refined", "bin_indices"}.
    presence : ndarray (n_windows, n_kids)
        Boolean/int mask: presence[w, kid] == 1 if kid present in window w.
    time_axis_s : array-like (n_windows,)
        Start time (global) of each analysis window.
    T_chunk_s : float
        Window length in seconds.
    stride_s : float
        Stride between consecutive windows (kept for compatibility; not used if time_axis_s holds starts).
    pad : int
        Number of windows to extend the gate on both sides (in hops).
    min_windows : int
        Minimum number of present windows to accept a node.
    max_gap : int
        Max gap (in hops) allowed when grouping runs.

    Returns
    -------
    nodes : list[Node]
    """
    nodes = []
    n_windows, n_kids = presence.shape

    # Precompute window indices where each kid is present
    present_lists = [np.flatnonzero(presence[:, kid]) for kid in range(n_kids)]

    for kid, bucket in enumerate(buckets):
        win_idxs = np.asarray(present_lists[kid], dtype=int)
        if win_idxs.size == 0:
            continue

        # Group into runs allowing small gaps
        runs = group_windows_with_gaps(win_idxs.tolist(), max_gap=max_gap)

        for (w_start, w_end) in runs:
            # windows actually present inside the span
            in_run = win_idxs[(win_idxs >= w_start) & (win_idxs <= w_end)]
            n_present = int(in_run.size)
            span_len  = (w_end - w_start + 1)

            if n_present < min_windows:
                continue

            # ----- padded gate in HOPS (indices) -----
            gate_w_start = max(0, w_start - pad)
            gate_w_end   = min(n_windows - 1, w_end + pad)

            # ----- convert to times -----
            t_start_s       = float(time_axis_s[w_start])
            t_end_s         = float(time_axis_s[w_end]   + T_chunk_s)
            t_gate_start_s  = float(time_axis_s[gate_w_start])
            t_gate_end_s    = float(time_axis_s[gate_w_end] + T_chunk_s)

            duration  = t_end_s - t_start_s
            coverage  = n_present / span_len

            # Build Node kwargs (include gate ranges so you can use them for PDW extraction)
            node_kwargs = dict(
                kid=kid,
                k_refined=bucket.get("k_refined", None),
                w_start=int(w_start),
                w_end=int(w_end),
                t_start_s=t_start_s,
                t_end_s=t_end_s,
                n_windows=n_present,      # actual present windows
                duration_s=duration,
                coverage_frac=coverage,
                gate_w_start=int(gate_w_start),
                gate_w_end=int(gate_w_end),
                t_gate_start_s=t_gate_start_s,
                t_gate_end_s=t_gate_end_s,
            )
            # Optional: pass bin indices if your Node accepts it
            if "bin_indices" in bucket:
                bi = bucket["bin_indices"]
                node_kwargs["bin_indices"] = bi.tolist() if hasattr(bi, "tolist") else list(bi)

            nodes.append(Node(**node_kwargs))

    return nodes



In [None]:
def build_nodes_for_candidates(
    candidates, k_vals, row_k_peaks, time_axis_s,
    T_chunk, stride, *,
    width_bins=1, min_windows=2, max_gap=2, margin=1
):
    """
    Build nodes (tracks) from candidate peaks.

    Parameters
    ----------
    candidates : list of dict
        Each dict should have 'peak_idx' and 'k_hat'.
    k_vals : np.ndarray
        Chirp-rate grid.
    row_k_peaks : list of lists
        Peak indices per window.
    time_axis_s : np.ndarray
        Window times (global).
    T_chunk, stride : float
        Chunk length and stride (s).
    width_bins : int
        Half-width around each peak for bucket building.
    min_windows, max_gap, margin : int
        Node-building parameters.

    Returns
    -------
    nodes : list
        Node objects built from these candidates.
    buckets : list of dict
        Bucket definitions for each candidate.
    bin2kid : np.ndarray
        Mapping from k-bin index → bucket id.
    presence : np.ndarray
        Presence matrix [n_windows × n_buckets].
    """
    if not candidates:
        return [], [], None, None

    peaks_idx = [c["peak_idx"] for c in candidates]
    k_refined = [c["k_hat"] for c in candidates]

    buckets, bin2kid = build_k_buckets(
        peaks=np.array(peaks_idx, dtype=int),
        k_vals=k_vals,
        k_refined=np.array(k_refined, dtype=float),
        width_bins=width_bins,
    )

    presence = assign_buckets_per_window(row_k_peaks, bin2kid, len(buckets))

    nodes = make_nodes_from_groups(
        buckets, presence,
        time_axis_s, T_chunk, stride,
        min_windows=min_windows, max_gap=max_gap, pad=margin
    )

    return nodes, buckets, bin2kid, presence


## 5. Pulse PDW Extraction

### Frequency Extraction

In [None]:
def dechirped_fft_for_chunk(chunk, fs, k_Hz_per_s, t0_global_s, pad_factor=4):
    """
    Apply de-chirp with k (using global time origin t0), then FFT and return spectrum for residual tone detection.
    De-chirp at k, FFT (zero-padded), return centered (F, mag_db).
    """
    N = len(chunk)
    t = t0_global_s + np.arange(N)/fs
    y = chunk * np.exp(-1j*np.pi*k_Hz_per_s*(t**2))

    L = 1 << int((N*pad_factor - 1).bit_length())
    Y = np.fft.fftshift(np.fft.fft(y, n=L))
    F = np.fft.fftshift(np.fft.fftfreq(L, d=1.0/fs))
    mag_db = 20*np.log10(np.abs(Y) + 1e-30)
    return F, mag_db


def bandpass_filter_with_fstar(chunk, fs, f_star_hz, lpf_bw_hz, mode="iir"):
    """
    Narrowband isolate one tone by heterodyning at f*, LPF, then mixing back.
    Operates in raw FFT coordinates (no dechirp).
    """
    x = np.asarray(chunk, dtype=np.complex128)
    if x.size == 0:
        return x

    N = x.size
    t = np.arange(N) / fs

    # 1) heterodyne to DC
    y_bb = x * np.exp(-1j * 2 * np.pi * f_star_hz * t)

    # 2) low-pass
    if mode == "iir":
        wn = min(0.45 * lpf_bw_hz / (0.5 * fs), 0.99)
        sos = butter(N=4, Wn=wn, btype="low", output="sos")
        guard = int(np.ceil(8 * fs / lpf_bw_hz))          # ~8 cycles
        guard = min(guard, len(y_bb)-1)
        y_ext = np.r_[y_bb[guard:0:-1], y_bb, y_bb[-2:-guard-2:-1]]
        y_lp_ext = sosfiltfilt(sos, y_ext, padlen=0)
        y_lp = y_lp_ext[guard:guard + len(y_bb)]
    else:
        y_lp = y_bb

    # 3) mix back up
    y_up = y_lp * np.exp(+1j * 2 * np.pi * f_star_hz * t)
    return y_up.astype(np.complex64, copy=False)



In [None]:
def detect_residual_tones_moving(signal, fs, *, pad_factor=4, floor_span_hz=1e7, offset_db=15.0,
                                 min_df_hz=5e6, top=None):
    """
    Use this FFT peak detection function only for single frequency pulses
    Moving-threshold residual tone detection.
    Returns array of detected tone frequencies [Hz].
    """
    N = signal.size
    L = 1 << int((N * pad_factor - 1).bit_length())
    Y = np.fft.fftshift(np.fft.fft(signal, n=L))
    F = np.fft.fftshift(np.fft.fftfreq(L, d=1.0/fs))
    mag_db = 20*np.log10(np.abs(Y) + 1e-30)

    # moving average floor
    df = abs(F[1] - F[0])
    W = max(3, int(np.ceil(floor_span_hz / df)) | 1)  # force odd
    floor_db = uniform_filter1d(mag_db, size=W, mode="reflect")
    thr_db_vec = floor_db + float(offset_db)

    # candidate peaks, thresholded
    min_dist = max(1, int(np.ceil(min_df_hz / df)))
    idx_all, _ = find_peaks(mag_db, distance=min_dist)
    idx = idx_all[mag_db[idx_all] >= thr_db_vec[idx_all]]

    # order by peak height
    order = np.argsort(mag_db[idx])[::-1]
    if top is not None:
        order = order[:top]
    idx = idx[order]

    return F[idx]


### PDW Extraction (Both Tones and LFMs)

In [3]:
def adaptive_envelope_params_from_node(node, *, Tmin_us: float, k_tone_thresh=1e9):
    """
    Return (window_us, min_sustain_us) based on node t-range as PW proxy.
    For tones (|k| < k_tone_thresh), fall back to fixed Tmin_us.
    """
    # Check if tone-like
    if abs(node.k_refined) < k_tone_thresh:   # ~0 slope → treat as tone
        return Tmin_us, Tmin_us

    # Otherwise estimate PW from t-range
    PW_est_us = max(1e-3, (node.t_gate_end_s - node.t_gate_start_s)/2 * 1e6)

    window_us = 0.5 * PW_est_us
    min_sustain_us = 0.25 * PW_est_us

    window_us = max(Tmin_us, window_us)
    min_sustain_us = max(Tmin_us, min_sustain_us)
    
    return window_us, min_sustain_us


In [None]:
def _clean_and_merge_pdws(df_pdws, Tmin_us=5.0, max_gap_us=5.0):
    """
    Post-process PDWs:
      1) Drop pulses shorter than Tmin_us
      2) Merge adjacent pulses within each node_id if gap < max_gap_us
    Works for both tone and LFM PDWs.
    """
    if df_pdws.empty:
        return df_pdws

    if "TOA (us)" not in df_pdws.columns:
        # Nothing to clean/merge if schema is unexpected
        return df_pdws

    out_rows = []

    for nid, g in df_pdws.groupby("node_id", dropna=False):
        g = g.sort_values("TOA (us)").reset_index(drop=True)

        # Drop too-short
        g = g[g["PW (us)"] >= Tmin_us].copy()
        if g.empty:
            continue

        # Merge close pulses
        s_idx, e_idx = 0, 0
        merged = []
        for i in range(1, len(g)):
            prev_end = g.iloc[e_idx]["TOA (us)"] + g.iloc[e_idx]["PW (us)"]
            curr_start = g.iloc[i]["TOA (us)"]
            if curr_start - prev_end <= max_gap_us:
                e_idx = i
            else:
                merged.append((s_idx, e_idx))
                s_idx, e_idx = i, i
        merged.append((s_idx, e_idx))

        for s, e in merged:
            rows = g.iloc[s:e+1]

            toa_start = rows["TOA (us)"].iloc[0]
            toa_end   = rows["TOA (us)"].iloc[-1] + rows["PW (us)"].iloc[-1]
            pw_us     = toa_end - toa_start

            row_out = {
                "node_id": nid,
                "TOA (us)": toa_start,
                "PW (us)": pw_us,
                "Envelope Amplitude": rows["Envelope Amplitude"].median(),
            }

            # Frequency-related columns
            if "Center Freq (MHz)" in rows:
                row_out["Center Freq (MHz)"] = rows["Center Freq (MHz)"].median()
            if "Chirp Rate (MHz/us)" in rows:
                row_out["Chirp Rate (MHz/us)"] = rows["Chirp Rate (MHz/us)"].median()
            if "Start Freq (MHz)" in rows:
                row_out["Start Freq (MHz)"] = rows["Start Freq (MHz)"].min()
            if "End Freq (MHz)" in rows:
                row_out["End Freq (MHz)"] = rows["End Freq (MHz)"].max()
            if "Bandwidth (MHz)" in rows:
                row_out["Bandwidth (MHz)"] = rows["Bandwidth (MHz)"].max()

            # Pulse type
            if "Pulse Type" in rows:
                row_out["Pulse Type"] = rows["Pulse Type"].mode().iloc[0]

            # Sanity flag
            if "pulse_sanity_ok" in rows:
                row_out["pulse_sanity_ok"] = rows["pulse_sanity_ok"].all()

            out_rows.append(row_out)

    if not out_rows:
        return pd.DataFrame(columns=df_pdws.columns)

    return pd.DataFrame(out_rows).sort_values("TOA (us)").reset_index(drop=True)


### PDW Extraction (Tones)

In [None]:
def wrap_to_nyquist(f_hz, fs_hz):
    """Map frequency (Hz) to the signed baseband interval (-fs/2, +fs/2]."""
    return ((f_hz + 0.5*fs_hz) % fs_hz) - 0.5*fs_hz

In [None]:
def extract_pdws_for_tone_pulses(
    node,
    signal, fs,
    *,
    Tmin_us=5.0,
    window_us=5.0,
    min_sustain_us=5.0,
    pad_factor=2,
    floor_span_hz=4e6,
    min_df_hz=None,
    top=1,
    k_min_abs_MHz_per_us=0.10,
    downsample_factor=1,
    debug=False,
    t0_override=None,
    node_id=None   # 🔹 NEW: index of node from wrapper
):
    """
    Extract PDWs over a single node for tone pulses.
    Detection + filtering both happen in *raw FFT coordinates*.
    """
    timings = {}
    t0_start = time.perf_counter()

    # ---- 1) ROI ----
    t0 = float(node.t_gate_start_s) if t0_override is None else float(t0_override)
    t1 = float(node.t_gate_end_s)
    w_lo, w_hi = node.gate_w_start, node.gate_w_end
    roi = signal[int(np.floor(t0*fs)) : int(np.ceil(t1*fs))]
    k_refined = float(node.k_refined)   # may be ≈0

    timings["roi_extract"] = time.perf_counter() - t0_start

    # ---- 2) Residual tone detection (raw FFT) ----
    t_fft_start = time.perf_counter()
    lpf_bw_hz = 3.0 / (Tmin_us * 1e-6)
    if min_df_hz is None:
        min_df_hz = 1.5 * lpf_bw_hz

    f_peaks = detect_residual_tones_moving(roi, fs,)
    timings["fft"] = time.perf_counter() - t_fft_start

    if len(f_peaks) == 0:
        return pd.DataFrame(), pd.DataFrame(), timings

    # ---- 3) Extract pulses ----
    pdw_rows, meta_rows = [], []
    t_pdws_start = time.perf_counter()

    for f_star in f_peaks:
        f_star = float(wrap_to_nyquist(float(f_star), fs))
        if debug:
            print(f"  -> processing f_star={f_star/1e6:.3f} MHz, ROI_len={len(roi)} samples")

        # isolate tone with raw FFT bandpass filter
        sig_iso = bandpass_filter_with_fstar(
            roi, fs, f_star,
            lpf_bw_hz=lpf_bw_hz, mode="iir"
        )

        pulses, bounds, mask, cleaned, pdws = baseband_PDW_extractor(
            sig_iso, fs,
            k_hz_per_s=0.0,              # treat as pure tone
            t0_global_s=t0,
            window_us=Tmin_us,
            min_sustain_us=min_sustain_us,
            threshold_percentile=50,
            amp_floor_pct=90,
        )

        for d in pdws:
            f_start_MHz  = d["f_start_Hz"]  / 1e6
            f_end_MHz    = d["f_end_Hz"]    / 1e6
            f_center_MHz = d["f_center_Hz"] / 1e6
            k_MHz_per_us = k_refined / 1e12   # still recorded for metadata
            pulse_type   = "Single Frequency Pulse" if abs(k_MHz_per_us) < k_min_abs_MHz_per_us else "LFM Pulse"

            pdw_rows.append({
                "node_id": node_id,                     # 🔹 wrapper index
                "node_kid": getattr(node, "kid", None), # 🔹 original node.kid if available
                "TOA (us)": d["TOA_global_us"],
                "PW (us)":  d["PW_us"],
                "Envelope Amplitude": d["Amp_med"],
                "Center Freq (MHz)": f_center_MHz,
                "Chirp Rate (MHz/us)": k_MHz_per_us,
                "Bandwidth (MHz)": abs(f_end_MHz - f_start_MHz),
                "Start Freq (MHz)": f_start_MHz,
                "End Freq (MHz)": f_end_MHz,
                "Pulse Type": pulse_type,
                "pulse_sanity_ok": (node.coverage_frac == 1.0)
            })

            meta_rows.append({
                "node_id": node_id,
                "node_kid": getattr(node, "kid", None),
                "k_refined_Hz_per_s": k_refined,
                "gate_w_start": w_lo,
                "gate_w_end": w_hi,
                "t_gate_start_s": t0,
                "t_gate_end_s": t1,
                "f_star_Hz": f_star,
                "lpf_bw_hz": lpf_bw_hz,
            })

    timings["pdw_extract"] = time.perf_counter() - t_pdws_start
    timings["total"] = time.perf_counter() - t0_start

    if not pdw_rows:
        return None, None, timings

    df_pdws = pd.DataFrame(pdw_rows).sort_values("TOA (us)", ignore_index=True)
    df_meta = pd.DataFrame(meta_rows).reset_index(drop=True)
    return df_pdws, df_meta, timings


In [None]:
def extract_pdws_all_tones_parallel(
    nodes, signal, fs, *,
    Tmin_us=5.0,
    pad_factor=2,
    floor_span_hz=4e6,
    min_df_hz=None,
    top=1,
    k_min_abs_MHz_per_us=0.10,
    downsample_factor=1,
    debug=False,
    n_jobs=None,
    chunk_size=10,   # 🔹 new
):
    """
    Parallel wrapper over all Single Frequency Pulse nodes.
    Use node chunking, governed by chunk_size to save time
    """
    if n_jobs is None:
        n_jobs = multiprocessing.cpu_count()

    def process_chunk(chunk, offset):
        results = []
        for j, node in enumerate(chunk):
            i = offset + j
            if debug:
                print(f"[{i+1}/{len(nodes)}] kid={node.kid}, "
                      f"k={node.k_refined:.2e}, win={node.gate_w_start}-{node.gate_w_end}")

            window_us, min_sustain_us = adaptive_envelope_params_from_node(
                node, Tmin_us=Tmin_us
            )

            df_pdws, df_meta, timings = extract_pdws_for_tone_pulses(
                node, signal, fs,
                Tmin_us=Tmin_us,
                window_us=window_us,
                min_sustain_us=min_sustain_us,
                pad_factor=pad_factor,
                floor_span_hz=floor_span_hz,
                min_df_hz=min_df_hz,
                top=top,
                k_min_abs_MHz_per_us=k_min_abs_MHz_per_us,
                downsample_factor=downsample_factor,
                debug=debug,
                node_id=i,
            )
            if df_pdws is not None and not df_pdws.empty:
                results.append((df_pdws, df_meta, timings))
        return results

    # --- split nodes into chunks ---
    chunks = [nodes[i:i+chunk_size] for i in range(0, len(nodes), chunk_size)]

    # --- run in parallel ---
    all_results = Parallel(n_jobs=n_jobs, prefer="processes")(
        delayed(process_chunk)(chunk, offset)
        for offset, chunk in zip(range(0, len(nodes), chunk_size), chunks)
    )

    # --- collect ---
    all_pdws, all_meta, all_timings = [], [], []
    for results in all_results:
        for df_pdws, df_meta, timings in results:
            all_pdws.append(df_pdws)
            all_meta.append(df_meta)
            all_timings.append(timings)

    df_pdws_all = pd.concat(all_pdws, ignore_index=True) if all_pdws else pd.DataFrame()
    df_meta_all = pd.concat(all_meta, ignore_index=True) if all_meta else pd.DataFrame()
    return df_pdws_all, df_meta_all, all_timings


### PDW Extraction (LFM Pulses)

In [None]:
def extract_pdws_for_lfm_pulses(
    node,
    signal, fs,
    *,
    Tmin_us=5.0,
    window_us=5.0,
    min_sustain_us=5.0,
    pad_factor=2,
    floor_span_hz=4e6,
    min_df_hz=None,
    top=1,
    downsample_factor=1,
    debug=False,
    t0_override=None,
    node_id=None,
    plot=False, plot_raw=False, plot_smoothed=False, plot_mask=False, plot_thres=False
):
    """
    Extract PDWs over a single node for LFM pulses.
    Uses dechirp → residual tone isolation → re-chirp → baseband envelope extraction.
    Operates on the *raw, unnotched* signal (not signal_after_cw).
    """

    timings = {}
    t0_start = time.perf_counter()

    # ---- 1) ROI ----
    t0 = float(node.t_gate_start_s) if t0_override is None else float(t0_override)
    t1 = float(node.t_gate_end_s)
    w_lo, w_hi = node.gate_w_start, node.gate_w_end
    roi = signal[int(np.floor(t0*fs)) : int(np.ceil(t1*fs))]
    k_refined = float(node.k_refined)   # Hz/s (nonzero for LFM)

    timings["roi_extract"] = time.perf_counter() - t0_start
    if debug:
        print(f"[LFM] node {node.kid}, ROI_len={len(roi)}, k_refined={k_refined:.3e}")

    # ---- 2) Dechirp FFT to find residual tone ----
    t_fft_start = time.perf_counter()
    F, mag_db = dechirped_fft_for_chunk(roi, fs, k_refined, t0_global_s=t0, pad_factor=pad_factor)

    # crude residual peak finder
    peak_idx = np.argmax(mag_db)
    f_res = float(F[peak_idx])  # Hz residual tone freq
    timings["fft"] = time.perf_counter() - t_fft_start

    # ---- 3) Dechirp in time domain, isolate residual, re-chirp ----
    N = len(roi)
    t = t0 + np.arange(N)/fs
    roi_dec = roi * np.exp(-1j*np.pi*k_refined*(t**2))

    lpf_bw_hz = 3.0 / (Tmin_us * 1e-6)
    if min_df_hz is None:
        min_df_hz = 1.5 * lpf_bw_hz

    y_dc = bandpass_filter_with_fstar(
        roi_dec, fs, f_star_hz=f_res, lpf_bw_hz=lpf_bw_hz, mode="iir"
    )
    sig_iso_lfm = y_dc * np.exp(+1j*np.pi*k_refined*(t**2))

    # ---- 4) Envelope + PDWs ----
    t_pdws_start = time.perf_counter()
    pulses, bounds, mask, cleaned, pdws = baseband_PDW_extractor(
        sig_iso_lfm, fs,
        k_hz_per_s=k_refined,
        t0_global_s=t0,
        window_us=window_us,
        min_sustain_us=min_sustain_us,
        threshold_percentile=50,
        amp_floor_pct=90,
        downsample_factor=downsample_factor,
        plot=plot, plot_raw=plot_raw, plot_smoothed=plot_smoothed,
        plot_mask=plot_mask, plot_thres=plot_thres
    )
    timings["pdw_extract"] = time.perf_counter() - t_pdws_start
    timings["total"] = time.perf_counter() - t0_start

    # ---- 5) Assemble PDW rows ----
    pdw_rows, meta_rows = [], []
    for d in pdws:
        f_start_MHz  = d["f_start_Hz"]  / 1e6
        f_end_MHz    = d["f_end_Hz"]    / 1e6
        f_center_MHz = d["f_center_Hz"] / 1e6
        k_MHz_per_us = k_refined / 1e12

        pdw_rows.append({
            "TOA (us)": d["TOA_global_us"],
            "PW (us)":  d["PW_us"],
            "Envelope Amplitude": d["Amp_med"],
            "Center Freq (MHz)": f_center_MHz,
            "Chirp Rate (MHz/us)": k_MHz_per_us,
            "Bandwidth (MHz)": abs(f_end_MHz - f_start_MHz),
            "Start Freq (MHz)": f_start_MHz,
            "End Freq (MHz)": f_end_MHz,
            "Pulse Type": "LFM Pulse",
            "pulse_sanity_ok": (node.coverage_frac == 1.0)
        })

        meta_rows.append({
            "node_id": node.kid,
            "k_refined_Hz_per_s": k_refined,
            "gate_w_start": w_lo,
            "gate_w_end": w_hi,
            "t_gate_start_s": t0,
            "t_gate_end_s": t1,
            "f_res_Hz": f_res,
            "lpf_bw_hz": lpf_bw_hz,
        })

    if not pdw_rows:
        return None, None, timings
    
    df_pdws = pd.DataFrame(pdw_rows).sort_values("TOA (us)", ignore_index=True)
    df_meta = pd.DataFrame(meta_rows).reset_index(drop=True)
    
    # 🔹 Use passed-in node_id
    df_pdws["node_id"] = node_id
    df_meta["node_id"] = node_id
    return df_pdws, df_meta, timings


In [None]:
def extract_pdws_all_lfms_parallel(
    nodes_lfm,
    signal, fs,
    *,
    Tmin_us=5.0,
    window_us=5.0,
    min_sustain_us=5.0,
    pad_factor=2,
    floor_span_hz=4e6,
    min_df_hz=None,
    top=None,
    downsample_factor=1,
    debug=False,
    n_jobs=-1,
    chunk_size=10,   # 🔹 NEW
):
    """
    Parallel wrapper over all LFM nodes, with chunking for efficiency.
    Calls extract_pdws_for_lfm_pulses(node, signal, fs, ..., node_id=idx).
    """

    t0 = time.perf_counter()

    def process_chunk(chunk, offset):
        results = []
        for j, node in enumerate(chunk):
            idx = offset + j
            df_pdws, df_meta, timings = extract_pdws_for_lfm_pulses(
                node,
                signal, fs,
                Tmin_us=Tmin_us,
                window_us=window_us,
                min_sustain_us=min_sustain_us,
                pad_factor=pad_factor,
                floor_span_hz=floor_span_hz,
                min_df_hz=min_df_hz,
                downsample_factor=downsample_factor,
                debug=debug,
                node_id=idx,
            )
            if df_pdws is not None and not df_pdws.empty:
                results.append((df_pdws, df_meta, timings))
        return results

    # --- split into chunks ---
    chunks = [nodes_lfm[i:i+chunk_size] for i in range(0, len(nodes_lfm), chunk_size)]

    # --- run chunks in parallel ---
    all_results = Parallel(n_jobs=n_jobs, prefer="processes")(
        delayed(process_chunk)(chunk, offset)
        for offset, chunk in zip(range(0, len(nodes_lfm), chunk_size), chunks)
    )

    # --- collect results ---
    pdws_all, meta_all, timings_all = [], [], {}
    for results in all_results:
        for df_pdws, df_meta, timings in results:
            pdws_all.append(df_pdws)
            meta_all.append(df_meta)
            for k, v in timings.items():
                timings_all[k] = timings_all.get(k, 0.0) + v

    if not pdws_all:
        return pd.DataFrame(), pd.DataFrame(), timings_all

    df_pdws = (
        pd.concat(pdws_all, ignore_index=True)
          .sort_values("TOA (us)")
          .reset_index(drop=True)
    )
    df_meta = pd.concat(meta_all, ignore_index=True).reset_index(drop=True)

    df_pdws = _clean_and_merge_pdws(df_pdws, Tmin_us=Tmin_us, max_gap_us=Tmin_us)

    timings_all["total"] = time.perf_counter() - t0
    return df_pdws, df_meta, timings_all


## 6. Cluster into Emitters

In [1]:
def _pri_consistency_from_toas_us(
    toa_us: np.ndarray, tol_frac: float = 0.05, min_pulses: int = 20
) -> tuple[bool, Optional[float], Optional[float]]:
    """
    Evaluate PRI consistency from TOAs (µs).

    Returns
    -------
    (is_consistent, pri_est, frac_ok)
      is_consistent : bool, True if cluster shows stable PRI
      pri_est       : median PRI (µs) if available
      frac_ok       : fraction of diffs within tol_frac of pri_est
    """
    toa_us = np.asarray(toa_us, dtype=float)
    if toa_us.size < min_pulses:
        return False, None, None

    toa_sorted = np.sort(toa_us)
    diffs = np.diff(toa_sorted)
    diffs = diffs[diffs > 0]
    if diffs.size == 0:
        return False, None, None

    pri_est = np.median(diffs)
    frac_ok = np.mean(np.abs(diffs - pri_est) < tol_frac * pri_est)
    return frac_ok > 0.8, float(pri_est), float(frac_ok)


def cluster_emitters_dbscan(
    df_pdws_all: pd.DataFrame,
    *,
    min_samples: int = 5,
    eps: float = 0.8,
    use_types: bool = True,
    require_sanity: bool = True,
    require_pri_consistency: bool = True,
    pri_tol_frac: float = 0.05,
    pri_min_pulses: int = 20,
) -> pd.DataFrame:
    """
    Cluster pulses (rows) into emitters via DBSCAN and return a per-emitter summary.

    Adds PRI-consistency check to suppress spurious clusters with unstable TOAs,
    AND computes per-emitter pulse-train start/end times in seconds for IQ
    reconstruction / CAF.

    Expected columns in df_pdws_all:
      "TOA (us)", "PW (us)", "Envelope Amplitude",
      "Center Freq (MHz)", "Chirp Rate (MHz/us)",
      "Bandwidth (MHz)", "Start Freq (MHz)", "End Freq (MHz)",
      "Pulse Type"
    """
    df = df_pdws_all.copy()

    if require_sanity and "pulse_sanity_ok" in df.columns:
        df = df[df["pulse_sanity_ok"] == True].copy()
        if df.empty:
            raise ValueError("No rows with pulse_sanity_ok == True after filtering.")

    needed = [
        "TOA (us)","PW (us)","Envelope Amplitude","Center Freq (MHz)",
        "Chirp Rate (MHz/us)","Bandwidth (MHz)","Start Freq (MHz)",
        "End Freq (MHz)","Pulse Type"
    ]
    miss = [c for c in needed if c not in df.columns]
    if miss:
        raise KeyError(f"Missing columns: {miss}")

    X_cols = ["Chirp Rate (MHz/us)", "Center Freq (MHz)", "PW (us)"]
    X = df[X_cols].to_numpy(dtype=float)

    if use_types and "Pulse Type" in df.columns:
        pt_flag = (df["Pulse Type"].str.contains("LFM", case=False)).astype(float).to_numpy()[:, None]
        X = np.hstack([X, pt_flag])

    scaler = StandardScaler()
    Z = scaler.fit_transform(X)
    labels = DBSCAN(eps=eps, min_samples=min_samples, n_jobs=-1).fit_predict(Z)

    df = df.assign(cluster_id=labels)
    df = df[df["cluster_id"] >= 0].copy()
    if df.empty:
        raise ValueError("DBSCAN produced no clusters (all noise). Consider relaxing eps or min_samples.")

    rows = []
    numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()

    for cid, g in df.groupby("cluster_id"):
        # PRI + consistency
        is_ok, pri_est, frac_ok = _pri_consistency_from_toas_us(
            g["TOA (us)"].to_numpy(),
            tol_frac=pri_tol_frac,
            min_pulses=pri_min_pulses,
        )
        if require_pri_consistency and not is_ok:
            continue  # drop spurious cluster

        medians = g[numeric_cols].median(numeric_only=True).to_dict()
        medians.pop("TOA (us)", None)

        cat_vals = {}
        for c in df.columns.difference(numeric_cols + ["cluster_id"]):
            mode_val = g[c].mode(dropna=True)
            cat_vals[c] = (mode_val.iloc[0] if not mode_val.empty else g[c].iloc[0])

        # ------------------------------
        # Pulse train time bounds
        # ------------------------------
        toa_us = g["TOA (us)"].to_numpy()
        pw_us  = g["PW (us)"].to_numpy()
        t_train_start_s = float(np.min(toa_us)) * 1e-6
        t_train_end_s   = float(np.max(toa_us + pw_us)) * 1e-6
        n_pulses        = int(len(g))

        out = {
            "PRI (us)": pri_est,
            "PW (us)": medians.get("PW (us)"),
            "Envelope Amplitude": medians.get("Envelope Amplitude"),
            "Center Freq (MHz)": medians.get("Center Freq (MHz)"),
            "Chirp Rate (MHz/us)": medians.get("Chirp Rate (MHz/us)"),
            "Bandwidth (MHz)": medians.get("Bandwidth (MHz)"),
            "Start Freq (MHz)": medians.get("Start Freq (MHz)"),
            "End Freq (MHz)": medians.get("End Freq (MHz)"),
            "Pulse Type": cat_vals.get("Pulse Type"),
            "cluster_id": cid,
            "count": len(g),
            "PRI_frac_ok": frac_ok,   # useful diagnostic
            # NEW: train timing for CAF / IQ reconstruction
            "t_train_start_s": t_train_start_s,
            "t_train_end_s": t_train_end_s,
        }
        rows.append(out)

    cols_out = [
        "PRI (us)", "PW (us)", "Envelope Amplitude", "Center Freq (MHz)",
        "Chirp Rate (MHz/us)", "Bandwidth (MHz)", "Start Freq (MHz)",
        "End Freq (MHz)", "Pulse Type", "cluster_id", "count",
        "PRI_frac_ok", "t_train_start_s", "t_train_end_s",
    ]
    df_emitters = pd.DataFrame(rows)[cols_out].sort_values(
        ["Pulse Type","Center Freq (MHz)","Chirp Rate (MHz/us)","PRI (us)"]
    )
    return df_emitters, df


NameError: name 'np' is not defined

## 7. Per Emitter IQ Extraction

### For Pulse Trains (LFMs and Single Frequency Types)

In [5]:
def make_emitter_mask_from_pdws(df_pdws_lfm_tagged, cluster_id, fs, N, guard_us=2.0):
    """
    Boolean mask over samples where the chosen LFM emitter is 'on'.
    Optional guard expands each pulse a bit.
    """
    mask = np.zeros(N, dtype=bool)
    g = df_pdws_lfm_tagged[df_pdws_lfm_tagged["cluster_id"] == cluster_id]

    for _, row in g.iterrows():
        t0_us = float(row["TOA (us)"])
        pw_us = float(row["PW (us)"])

        t_start_us = t0_us - guard_us
        t_end_us   = t0_us + pw_us + guard_us

        s = int(np.floor(t_start_us * 1e-6 * fs))
        e = int(np.ceil (t_end_us   * 1e-6 * fs))

        s = max(0, s)
        e = min(N, e)
        if e > s:
            mask[s:e] = True

    return mask

In [7]:
def isolate_lfm_in_roi(roi, fs, k_refined, t0_global_s, Tmin_us=5.0, pad_factor=2):
    """
    Given a ROI that contains your LFM of interest plus other stuff,
    dechirp w/ k_refined, LPF the residual tone, rechirp back.
    Returns the isolated complex LFM in the ROI.
    """
    # 1) residual tone via dechirped FFT
    F, mag_db = dechirped_fft_for_chunk(
        roi, fs, k_refined, t0_global_s=t0_global_s, pad_factor=pad_factor
    )
    f_res = float(F[np.argmax(mag_db)])

    # 2) dechirp in time
    N = len(roi)
    t = t0_global_s + np.arange(N)/fs
    roi_dec = roi * np.exp(-1j*np.pi*k_refined*(t**2))

    # 3) LPF around the residual tone
    lpf_bw_hz = 3.0 / (Tmin_us * 1e-6)
    y_dc = bandpass_filter_with_fstar(
        roi_dec, fs, f_star_hz=f_res, lpf_bw_hz=lpf_bw_hz, mode="iir"
    )

    # 4) rechirp back
    sig_iso_lfm = y_dc * np.exp(+1j*np.pi*k_refined*(t**2))
    return sig_iso_lfm

def extract_emitter_iq_lfm(signal, fs, df_pdws_lfm_tagged, cluster_id, k_refined, Tmin_us=5.0):
    """
    High-level: 
      - time-gate using PDWs,
      - for each contiguous gated block, dechirp -> LPF -> rechirp,
      - stitch blocks back into a full-length isolated IQ.
    """
    N = len(signal)
    mask = make_emitter_mask_from_pdws(df_pdws_lfm_tagged, cluster_id, fs, N)
    y_iso = np.zeros(N, dtype=np.complex64)

    # Find contiguous segments where mask is True
    on_idx = np.nonzero(mask)[0]
    if on_idx.size == 0:
        return y_iso

    # simple segmentation by breaks > 1 sample
    splits = np.where(np.diff(on_idx) > 1)[0] + 1
    segments = np.split(on_idx, splits)

    for seg in segments:
        s = seg[0]
        e = seg[-1] + 1
        t0 = s / fs
        roi = signal[s:e]

        sig_roi = isolate_lfm_in_roi(roi, fs, k_refined, t0_global_s=t0, Tmin_us=Tmin_us)
        y_iso[s:e] += sig_roi

    return y_iso

In [None]:
def extract_and_decimate_bw(
    iq_raw,
    df_pdws_lfm,
    cluster_id,
    fs,
    f_center,
    k_refined,
    t_start,
    PRI,
    PW,
    bw_coverage_hz,     # *** NEW: desired retained bandwidth ***
    guard_us=200.0,
    fir_len=101
):
    """
    Extract LFM emitter from FF1, then baseband, lowpass filter, and decimate
    using an explicit retained bandwidth (bw_coverage_hz).

    This function replaces lp_cutoff_frac entirely.

    Parameters
    ----------
    iq_raw : ndarray
        Raw IQ at original sampling fs.
    df_pdws_lfm : DataFrame
        PDWs for this LFM emitter.
    cluster_id : int
        Emitter ID.
    fs : float
        Original sampling rate [Hz].
    f_center : float
        LO frequency [Hz].
    k_refined : float
        Chirp slope [Hz/s].
    t_start : float
        First pulse start time.
    PRI : float
        PRI for this emitter.
    PW : float
        Pulse width.
    bw_coverage_hz : float
        Actual bandwidth (Hz) to retain in baseband after filtering.
        Must >= B_eff + 2*fD_max.
    guard_us : float
        Time guard window around pulses [us].
    fir_len : int
        FIR filter length.

    Returns
    -------
    y_ds : ndarray
        Decimated LFM baseband (FF1 only).
    fs_ds : float
        Decimated sample rate.
    idx0, idx1 : int
        Slice indices for reuse by FF2/FF3.
    fir_coeffs : ndarray
        FIR kernel.
    decim_factor : int
        Final decimation factor used.
    """

    import numpy as np
    from scipy.signal import firwin

    N_raw = len(iq_raw)

    # -----------------------------------------------------------
    # 1) Build emitter mask (FF1 reference)
    # -----------------------------------------------------------
    mask = make_emitter_mask_from_pdws(
        df_pdws_lfm,
        cluster_id=cluster_id,
        fs=fs,
        N=N_raw,
        guard_us=guard_us
    )

    y_gate = np.where(mask, iq_raw, 0.0)

    # Optional reconstruction
    y_iso = extract_emitter_iq_lfm(y_gate, fs, df_pdws_lfm, cluster_id, k_refined)
    y_ref = y_iso.astype(np.complex128)

    # -----------------------------------------------------------
    # 2) Determine clean slice indices
    # -----------------------------------------------------------
    on_idxs = np.where(mask)[0]
    if on_idxs.size == 0:
        raise RuntimeError("Empty mask for emitter {}".format(cluster_id))

    extra_guard = int(np.round(0.2 * PW * fs))
    idx0 = max(0, on_idxs[0] - extra_guard)
    idx1 = min(N_raw, on_idxs[-1] + extra_guard)

    # Time base (absolute time!)
    n_slice = np.arange(idx0, idx1)
    t_slice = n_slice / fs
    y_slice = y_ref[idx0:idx1]

    # -----------------------------------------------------------
    # 3) Downconvert with common LO (abs time ensures phase alignment)
    # -----------------------------------------------------------
    mix = np.exp(-1j * 2*np.pi * f_center * t_slice)
    y_bb = y_slice * mix

    # -----------------------------------------------------------
    # 4) Compute maximum allowed decimation factor
    #    Post-decim Nyquist must exceed bw_coverage_hz
    # -----------------------------------------------------------
    decim_factor = int(fs // (2 * bw_coverage_hz))
    decim_factor = max(decim_factor, 1)

    fs_ds = fs / decim_factor
    nyquist_ds = fs_ds / 2

    if bw_coverage_hz >= nyquist_ds:
        raise ValueError(
            f"bw_coverage_hz = {bw_coverage_hz} exceeds post-decim Nyquist = {nyquist_ds}. "
            "Reduce bw_coverage_hz or decim_factor."
        )

    # -----------------------------------------------------------
    # 5) LPF: cutoff = bw_coverage_hz (direct Hz!)
    # -----------------------------------------------------------
    fir_coeffs = firwin(
        fir_len,
        bw_coverage_hz,     # direct cutoff
        fs=fs,
    )

    y_filt = np.convolve(y_bb, fir_coeffs, mode="same")

    # -----------------------------------------------------------
    # 6) Decimate
    # -----------------------------------------------------------
    y_ds = y_filt[::decim_factor]

    return y_ds, fs_ds, idx0, idx1, fir_coeffs, decim_factor


def filter_and_decimate_same_slice(
    iq_raw,
    fs,
    f_center,
    idx0,
    idx1,
    fir_coeffs,
    decim_factor,
):
    """
    Apply the *same* slice, LO mixing, LPF and decimation to another channel
    (e.g. FF2 or FF3) as was used for FF1.

    Parameters
    ----------
    iq_raw : ndarray (complex)
        Raw IQ for this satellite at full-rate fs.
    fs : float
        Original sampling rate [Hz].
    f_center : float
        Same LO as used for FF1 [Hz].
    idx0, idx1 : int
        Slice indices (at full-rate fs) returned by extract_and_decimate_correct.
    fir_coeffs : ndarray
        FIR coefficients returned by extract_and_decimate_correct.
    decim_factor : int
        Same decimation factor.

    Returns
    -------
    y_ds : ndarray (complex)
        Decimated, basebanded signal for this satellite.
    """

    N_raw = len(iq_raw)
    idx0 = max(0, idx0)
    idx1 = min(N_raw, idx1)

    y_slice = iq_raw[idx0:idx1]

    # Absolute time axis for this slice (same as FF1)
    n_slice = np.arange(idx0, idx1)
    t_slice = n_slice / fs

    # Downconvert with the same LO and ABS time
    mix = np.exp(-1j * 2.0 * np.pi * f_center * t_slice)
    y_bb = y_slice * mix

    # Filter and decimate exactly like FF1
    y_filt = np.convolve(y_bb, fir_coeffs, mode="same")
    y_ds = y_filt[::decim_factor]

    return y_ds


### For CWs

In [14]:
def extract_cw_band(iq, fs, f_center_bb, bw_coverage_hz=100e3, fir_len=101):
    """
    Clean CW extraction WITHOUT pulse-train masking.
    Mimics extract_and_decimate_bw logic for continuous-wave emitters.

    Parameters
    ----------
    iq : ndarray
        Composite IQ (complex64) at sampling rate fs.
    fs : float
        Sampling rate of input IQ.
    f_center_tx : float
        Transmitter RF center frequency (Hz).
    bw_coverage_hz : float
        Processing bandwidth to retain around the CW.
    fir_len : int
        Length of FIR LPF.

    Returns
    -------
    ds : ndarray
        Decimated, filtered CW IQ.
    fs_ds : float
        New sampling rate after decimation.
    h : ndarray
        FIR coefficients.
    decim_factor : int
        Decimation factor applied (must also be applied to FF2/FF3).
    """
    N = len(iq)
    t = np.arange(N) / fs

    # 1. Mix to baseband (remove the CW tone)
    mixed = iq * np.exp(-1j * 2*np.pi * f_center_bb * t)

    # 2. Low-pass filter
    from scipy.signal import firwin, lfilter
    h = firwin(fir_len, bw_coverage_hz/(fs/2))
    filt = lfilter(h, 1.0, mixed)

    # 3. Auto decimation factor
    decim_factor = int(fs // (2 * bw_coverage_hz))
    decim_factor = max(decim_factor, 1)
    fs_ds = fs / decim_factor

    ds = filt[::decim_factor]

    return ds, fs_ds, h, decim_factor


### For LFMCWs

In [6]:
'''
For simplicity, only extracting first LFMCW sweep for CAF and geolocation later. 
Can theoretically be extending to multiple sweeps for better performance

Idea: Align and extract first sweeps across all receiver channels.

Alignment: 
- FF1 (mainlobe) relies on deinterleaver output.
- FF2 and 3 (sidelobes) relies on estimated emitter location, which will affect TDOA and FDOA offsets
- Take multiple starting positions evenly divided in AOI. Run CAFs for all starting positions and compare
  CAF's Peak to Sidelobe ratio. Proper alignment should yield higher PSR.

Extraction: Same concept as LFM case, using time information to time gate, followed by dechirp, rechirp filtering.

Estimated emitter location: Will affect alignment significantly. 
- Idea is to use AOI seed locations as emitter location proxies. Use CAF's PSR as scoring criteria to get appropriate 
  starting location proxy
- Using starting proxy, geolocate LFMCW emitter location, then feed in this location iteratively for a better 
  CAF/geolocation output
'''
def _aoi_bounds_from_emitters(EMITTER_LLA, margin_frac=0.20):
    lats = np.array([e["lat"] for e in EMITTER_LLA], float)
    lons = np.array([e["lon"] for e in EMITTER_LLA], float)
    lat_min0, lat_max0 = float(lats.min()), float(lats.max())
    lon_min0, lon_max0 = float(lons.min()), float(lons.max())
    dlat, dlon = lat_max0 - lat_min0, lon_max0 - lon_min0
    if dlat == 0: dlat = 0.1
    if dlon == 0: dlon = 0.1
    lat_min = lat_min0 - margin_frac * dlat
    lat_max = lat_max0 + margin_frac * dlat
    lon_min = lon_min0 - margin_frac * dlon
    lon_max = lon_max0 + margin_frac * dlon
    return lat_min, lat_max, lon_min, lon_max

def _caf_psr(CAF, guard=3):
    CAF = np.asarray(CAF)
    idx = np.argmax(CAF)
    i_fd, i_tau = np.unravel_index(idx, CAF.shape)
    peak = float(CAF[i_fd, i_tau])

    mask = np.ones_like(CAF, dtype=bool)
    fd0 = max(0, i_fd-guard); fd1 = min(CAF.shape[0], i_fd+guard+1)
    t0  = max(0, i_tau-guard); t1  = min(CAF.shape[1], i_tau+guard+1)
    mask[fd0:fd1, t0:t1] = False

    sidelobe = float(np.max(CAF[mask])) if np.any(mask) else np.nan
    psr = peak / sidelobe if (np.isfinite(sidelobe) and sidelobe > 0) else np.nan
    return peak, sidelobe, psr


def build_common_lfmcw_roi_geom(
    t0_ff1_s, PW_s,
    tau21, tau31,
    fs, N_raw,
    tau_guard_s
):
    t1_start = t0_ff1_s
    t1_end   = t0_ff1_s + PW_s

    t2_start = t0_ff1_s + tau21
    t2_end   = t2_start + PW_s

    t3_start = t0_ff1_s + tau31
    t3_end   = t3_start + PW_s

    t_roi_start = min(t1_start, t2_start, t3_start) - tau_guard_s
    t_roi_end   = max(t1_end, t2_end, t3_end) + tau_guard_s

    idx_start = int(np.floor(t_roi_start * fs))
    idx_end   = int(np.ceil(t_roi_end * fs))

    idx_start = max(0, idx_start)
    idx_end   = min(N_raw, idx_end)

    return idx_start, idx_end, idx_start / fs


def process_lfmcw_channel(
    iq, fs,
    idx_start, idx_end,
    t_roi_start_s,
    t0_ff1_s,
    PW_s,
    tau_est,
    k_hz_s,
    f_center,
    B_proc,
    fir_coeffs,
    decim_factor
):
    roi = iq[idx_start:idx_end]
    N_roi = len(roi)
    t_roi = t_roi_start_s + np.arange(N_roi) / fs

    # Time-gate around sweep for this sat
    t_sweep_start = t0_ff1_s + tau_est
    t_sweep_end   = t_sweep_start + PW_s
    gate = (t_roi >= t_sweep_start) & (t_roi < t_sweep_end)
    roi = roi * gate.astype(roi.dtype)

    # Isolation operator
    sweep = isolate_lfm_in_roi(
        roi=roi,
        fs=fs,
        k_refined=k_hz_s,
        t0_global_s=t_roi_start_s,   # ✅ correct: absolute time of roi[0]
    )


    # Downconvert
    t_local = np.arange(len(sweep)) / fs
    mix = np.exp(-1j * 2*np.pi * f_center * t_local)
    sweep_bb = sweep * mix

    # LPF
    sweep_lp = lfilter(fir_coeffs, [1.0], sweep_bb)

    # Decimate
    sweep_ds = sweep_lp[::decim_factor]
    return sweep_ds


def extract_lfmcw_common_window_geom(
    iq_FF1, iq_FF2, iq_FF3,
    fs,
    lfmcw_row,
    tau21_est, tau31_est,
    tau_guard_s
):
    t0_ff1_s = lfmcw_row["TOA (us)"] * 1e-6
    PW_s     = lfmcw_row["PW (us)"] * 1e-6
    f_center = lfmcw_row["Center Freq (MHz)"] * 1e6
    k_hz_s   = lfmcw_row["Chirp Rate (MHz/us)"] * 1e12
    B_proc   = abs(lfmcw_row["Bandwidth (MHz)"]) * 1e6

    N_raw = len(iq_FF1)

    # Common ROI
    idx0, idx1, t_roi_start_s = build_common_lfmcw_roi_geom(
        t0_ff1_s, PW_s,
        tau21_est, tau31_est,
        fs, N_raw,
        tau_guard_s,
    )

    # FIR + decimation same for all channels
    # LPF bandwidth slightly above B_proc
    fir_coeffs = firwin(101, 0.55 * B_proc / (fs/2))
    target_fs = 2.5 * B_proc
    decim_factor = max(1, int(fs // target_fs))

    # FF1 ref sweep
    x1_ds = process_lfmcw_channel(
        iq_FF1, fs,
        idx0, idx1,
        t_roi_start_s,
        t0_ff1_s, PW_s,
        tau_est=0.0,
        k_hz_s=k_hz_s,
        f_center=f_center,
        B_proc=B_proc,
        fir_coeffs=fir_coeffs,
        decim_factor=decim_factor
    )

    # FF2
    x2_ds = process_lfmcw_channel(
        iq_FF2, fs,
        idx0, idx1,
        t_roi_start_s,
        t0_ff1_s, PW_s,
        tau_est=tau21_est,
        k_hz_s=k_hz_s,
        f_center=f_center,
        B_proc=B_proc,
        fir_coeffs=fir_coeffs,
        decim_factor=decim_factor
    )

    # FF3
    x3_ds = process_lfmcw_channel(
        iq_FF3, fs,
        idx0, idx1,
        t_roi_start_s,
        t0_ff1_s, PW_s,
        tau_est=tau31_est,
        k_hz_s=k_hz_s,
        f_center=f_center,
        B_proc=B_proc,
        fir_coeffs=fir_coeffs,
        decim_factor=decim_factor
    )

    fs_ds = fs / decim_factor

    return x1_ds, x2_ds, x3_ds, fs_ds, {
        "idx_start": idx0,
        "idx_end": idx1,
        "fs_ds": fs_ds,
        "PW_s": PW_s,
        "f_center": f_center,
        "k_hz_s": k_hz_s,
    }

# CAF Helper Functions

## 8. CAF Implementation

### Applying Error Sources

In [None]:
# ------------------------------------------------------------
# Add receiver sync errors (per satellite)
# ---------------------------------------------------------
def apply_sync_errors(dTau_21_true, dTau_31_true, error_cfg, seed=1234):
    rng = np.random.default_rng(seed)
    sync_err = {
        "FF1": rng.normal(0, error_cfg["sync_err_rms"]),
        "FF2": rng.normal(0, error_cfg["sync_err_rms"]),
        "FF3": rng.normal(0, error_cfg["sync_err_rms"]),
    }

    dTau_21_err = dTau_21_true + (sync_err["FF2"] - sync_err["FF1"])
    dTau_31_err = dTau_31_true + (sync_err["FF3"] - sync_err["FF1"])

    return dTau_21_err, dTau_31_err, sync_err

In [None]:
def apply_position_errors(r_sats, error_cfg, seed=1234):
    rng = np.random.default_rng(seed)
    return {
        sat: r_sats[sat] + rng.normal(0, error_cfg["pos_err_rms"]/np.sqrt(3), 3)
        for sat in r_sats
    }

### Compute CAF Grids

In [None]:
# ============================================================
# 1) Compute TDOA & FDOA for ONE emitter
# ============================================================
def compute_tau_fd_one_emitter(em, r_sats, v_sats, f_c):
    """
    Compute (tau21, tau31, fD21, fD31) for a single emitter.

    Parameters
    ----------
    em : dict
        {"lat", "lon", "alt"}
    r_sats : dict
        {'FF1','FF2','FF3'} → sat ECEF positions
    v_sats : dict
        {'FF1','FF2','FF3'} → sat ECEF velocities
    f_c : float
        Carrier frequency Hz
    """

    # Convert emitter to ECEF
    e_ecef = lla_to_ecef(em["lat"], em["lon"], em["alt"])

    # Reference FF1 always
    r1, v1 = r_sats["FF1"], v_sats["FF1"]
    r2, v2 = r_sats["FF2"], v_sats["FF2"]
    r3, v3 = r_sats["FF3"], v_sats["FF3"]

    # Range distances
    R1 = np.linalg.norm(e_ecef - r1)
    R2 = np.linalg.norm(e_ecef - r2)
    R3 = np.linalg.norm(e_ecef - r3)

    # TDOA
    tau21 = (R2 - R1) / c
    tau31 = (R3 - R1) / c

    # LOS unit vectors
    u1 = (e_ecef - r1) / R1
    u2 = (e_ecef - r2) / R2
    u3 = (e_ecef - r3) / R3

    # Radial velocities
    vr1 = np.dot(v1, u1)
    vr2 = np.dot(v2, u2)
    vr3 = np.dot(v3, u3)

    # Dopplers per satellite
    fD1 = (f_c / c) * vr1
    fD2 = (f_c / c) * vr2
    fD3 = (f_c / c) * vr3

    # FDOA (relative to FF1)
    fD21 = fD2 - fD1
    fD31 = fD3 - fD1

    return tau21, tau31, fD21, fD31


# ============================================================
# 2) Build CAF grids (blind) using the full emitter region
# ============================================================
def build_caf_grids(r_sats, v_sats, emitter_lla_list, f_c,
                    fs_caf, df_step=100.0, margin_frac=0.20):

    tau21_list, tau31_list = [], []
    fD21_list,  fD31_list  = [], []

    # Compute τ and fD for each emitter in coverage region
    for em in emitter_lla_list:
        tau21, tau31, fD21, fD31 = compute_tau_fd_one_emitter(
            em, r_sats, v_sats, f_c
        )
        tau21_list.append(tau21)
        tau31_list.append(tau31)
        fD21_list.append(fD21)
        fD31_list.append(fD31)

    tau21_arr = np.array(tau21_list)
    tau31_arr = np.array(tau31_list)
    fD21_arr  = np.array(fD21_list)
    fD31_arr  = np.array(fD31_list)

    # Add margin around the physical range
    def bounds(arr):
        lo, hi = arr.min(), arr.max()
        span = max(abs(lo), abs(hi))
        margin = margin_frac * span
        return lo - margin, hi + margin

    tau21_min, tau21_max = bounds(tau21_arr)
    tau31_min, tau31_max = bounds(tau31_arr)
    fD21_min,  fD21_max  = bounds(fD21_arr)
    fD31_min,  fD31_max  = bounds(fD31_arr)

    # Build discrete CAF grids
    tau_vals_21 = np.arange(tau21_min, tau21_max + 1/fs_caf, 1/fs_caf)
    tau_vals_31 = np.arange(tau31_min, tau31_max + 1/fs_caf, 1/fs_caf)

    fd_vals_21  = np.arange(fD21_min, fD21_max + df_step, df_step)
    fd_vals_31  = np.arange(fD31_min, fD31_max + df_step, df_step)

    return (tau_vals_21, fd_vals_21,
            tau_vals_31, fd_vals_31,
            (tau21_min, tau21_max,
             tau31_min, tau31_max,
             fD21_min, fD21_max,
             fD31_min, fD31_max))

### CAF (Pulse Trains)

In [1]:
import cupy as cp
print("CuPy version:", cp.__version__)
print("CUDA runtime:", cp.cuda.runtime.runtimeGetVersion())
print("Device count:", cp.cuda.runtime.getDeviceCount())
dev = cp.cuda.Device()
print("Current device:", dev.id)
print("Device name:", cp.cuda.runtime.getDeviceProperties(dev.id)["name"].decode())


CuPy version: 13.6.0
CUDA runtime: 12090
Device count: 1
Current device: 0
Device name: NVIDIA GeForce RTX 3080 Ti Laptop GPU


In [None]:
import time, cupy as cp

x = cp.random.randn(4096, 4096, dtype=cp.float32)
cp.cuda.Stream.null.synchronize()
t0 = time.perf_counter()
y = x @ x
cp.cuda.Stream.null.synchronize()
print("matmul seconds:", time.perf_counter()-t0)


In [None]:
# ================================================================
# 1) Build Pulse Windows
# ================================================================

def make_pulse_windows(x_ref, y_meas, t, centers, half_width_s, fs):
    dt = 1.0 / fs
    N = len(t)
    N_win = int(np.round(2 * half_width_s * fs))

    windows = []

    for tc in centers:
        idx_c = int(np.round(tc * fs))
        start = max(0, idx_c - N_win//2)
        end   = min(N, start + N_win)

        x_seg = x_ref[start:end]
        y_seg = y_meas[start:end]
        t_seg = t[start:end]

        # Pad if needed
        if len(x_seg) < N_win:
            pad = N_win - len(x_seg)
            x_seg = np.pad(x_seg, (0, pad))
            y_seg = np.pad(y_seg, (0, pad))
            t_last = t_seg[-1]
            t_seg  = np.concatenate([t_seg,
                                     t_last + dt*np.arange(1, pad+1)])

        windows.append((x_seg, y_seg, t_seg))

    return windows, N_win


# ================================================================
# 2) Segmented FFT CAF (multi-window)
# ================================================================
def caf_fft_segmented_multi_gpu(
    x_ref, y_meas, t, centers,
    tau_vals, fd_vals, fs,
    half_width_s,
    fd_batch_size=8,     # safe on 8–24GB GPUs
    win_batch_size=32    # safe batch for windows (M)
):
    """
    GPU version of segmented FFT-domain CAF, with double batching:
    - batch over Doppler (fd)
    - batch over windows (pulse windows)

    This prevents OOM while still gaining GPU parallelism.

    Parameters
    ----------
    x_ref, y_meas : 1D np.ndarray (complex)
        Reference (FF1) and other channel (FF2/FF3) in the time slice.
    t : 1D np.ndarray (float)
        Absolute time axis for x_ref/y_meas.
    centers : 1D np.ndarray (float)
        Pulse centres (seconds) from PDWs.
    tau_vals : 1D np.ndarray (float)
        TDOA search grid (seconds).
    fd_vals : 1D np.ndarray (float)
        FDOA search grid (Hz).
    fs : float
        Sample rate of this CAF slice (Hz).
    half_width_s : float
        Half-window length for each pulse (seconds).

    Returns
    -------
    CAF_tot : 2D np.ndarray (float), shape = (N_fd, N_tau)
        Normalised CAF surface accumulated over all pulses.
    """

    # ================================================================
    # 2.1) Build ALL windows on CPU
    # ================================================================
    windows, N_win = make_pulse_windows(
        x_ref, y_meas, t, centers, half_width_s, fs
    )
    M = len(windows)

    # Stack windows (CPU)
    x_stack = np.stack([w[0] for w in windows], axis=0)  # (M, N_win)
    y_stack = np.stack([w[1] for w in windows], axis=0)
    t_stack = np.stack([w[2] for w in windows], axis=0)

    # ================================================================
    # 2.2) Precompute tau indices (CPU)
    # ================================================================
    tau_axis_win = (np.arange(N_win) - N_win//2) / fs
    tau_idx = np.array([
        np.argmin(np.abs(tau_axis_win - tau))
        for tau in tau_vals
    ], dtype=np.int32)
    tau_idx_gpu = cp.asarray(tau_idx)

    # ================================================================
    # 2.3) Prepare output CAF array (GPU)
    # ================================================================
    N_fd  = len(fd_vals)
    N_tau = len(tau_vals)
    CAF_tot_gpu = cp.zeros((N_fd, N_tau), dtype=cp.float32)

    # Doppler values as GPU array
    fd_gpu_all = cp.asarray(fd_vals, dtype=cp.float64)

    # ================================================================
    # 2.4) Two-level batching
    # ================================================================
    # batch 1: windows
    for w_start in range(0, M, win_batch_size):
        w_end = min(M, w_start + win_batch_size)
        x_win_gpu = cp.asarray(x_stack[w_start:w_end])   # (Wb, N_win)
        y_win_gpu = cp.asarray(y_stack[w_start:w_end])   # (Wb, N_win)
        t_win_gpu = cp.asarray(t_stack[w_start:w_end])   # (Wb, N_win)

        # FFT reference windows once per window batch
        X_gpu = cp.fft.fft(x_win_gpu, axis=1)            # (Wb, N_win)
        X_gpu = X_gpu[None, :, :]                        # (1, Wb, N_win)

        # batch 2: Doppler bins
        for fd_start in range(0, N_fd, fd_batch_size):
            fd_end = min(N_fd, fd_start + fd_batch_size)

            fd_sub = fd_gpu_all[fd_start:fd_end]         # (Fd_b,)
            Fd_gpu = fd_sub[:, None, None]               # (Fd_b, 1, 1)

            # Build Doppler mixing matrix
            E_gpu = cp.exp(-1j * 2 * cp.pi * Fd_gpu * t_win_gpu[None, :, :])
            # (Fd_b, Wb, N_win)

            # Mix y
            Ymix_gpu = E_gpu * y_win_gpu[None, :, :]     # (Fd_b, Wb, N_win)

            # FFT for all fd × windows
            Y_gpu = cp.fft.fft(Ymix_gpu, axis=2)         # (Fd_b, Wb, N_win)

            # Cross correlation
            R_gpu = cp.fft.ifft(Y_gpu * cp.conj(X_gpu), axis=2)
            R_gpu = cp.fft.fftshift(R_gpu, axes=2)

            R_abs_gpu = cp.abs(R_gpu)                    # (Fd_b, Wb, N_win)

            # Gather tau bins and sum windows
            CAF_sub = R_abs_gpu[:, :, tau_idx_gpu]       # (Fd_b, Wb, N_tau)
            CAF_sub = cp.sum(CAF_sub, axis=1)            # (Fd_b, N_tau)

            # Accumulate
            CAF_tot_gpu[fd_start:fd_end, :] += CAF_sub.astype(cp.float32)

    # Normalize and return to CPU
    CAF_tot_gpu /= cp.max(CAF_tot_gpu) + 1e-12
    return cp.asnumpy(CAF_tot_gpu)

# ================================================================
# 3) Utility to measure peak
# ================================================================

def peak_from_caf(CAF, fd_vals, tau_vals):
    i_fd, i_tau = np.unravel_index(np.argmax(CAF), CAF.shape)
    return fd_vals[i_fd], tau_vals[i_tau]

# ================================================================
# 4) Run segmented CAF for a pair
# ================================================================
def run_segmented_pair_blind_gpu(
    label,
    y1, yx,
    fs, Tp,
    centers,
    tau_vals,
    fd_vals,
    fd_batch_size=8,
    win_batch_size=32,
    seed=1234
):
   
    """
    Blind segmented CAF for a satellite pair (X→1) using GPU.

    Parameters
    ----------
    y1 : np.ndarray (complex)
        Reference channel (FF1) in time slice.
    yx : np.ndarray (complex)
        Other channel (FF2 or FF3) in time slice.
    fs : float
        Sample rate used for this CAF (Hz).
    Tp : float
        Pulse width (s).
    centers : np.ndarray
        Pulse centres (s).
    tau_vals : np.ndarray
        TDOA grid (s).
    fd_vals : np.ndarray
        FDOA grid (Hz).
    Returns
    -------
    fd_hat, tau_hat : float
        Estimated FDOA and TDOA from CAF peak.
    """
    N = len(y1)
    t = np.arange(N) / fs

    tau_span = max(abs(tau_vals.min()), abs(tau_vals.max()))
    half_width_s = Tp/2 + tau_span + 3e-6

    CAF = caf_fft_segmented_multi_gpu(
        y1, yx, t,
        centers,
        tau_vals, fd_vals,
        fs,
        half_width_s,
        fd_batch_size,
        win_batch_size,
    )
    fd_hat, tau_hat = peak_from_caf(CAF, fd_vals, tau_vals)

    # print(f"\n========== Segmented CAF (GPU, blind): {label} ==========")
    # print(f"Est. tau = {tau_hat*1e9:.3f} ns")
    # print(f"Est. fD  = {fd_hat:.3f} Hz")
    # print(f"Runtime  = {t1 - t0:.3f} s")
    
    return fd_hat, tau_hat


### CAF (CWs) - 1D Frequency Correlation

In [19]:
def estimate_tone_frequency_gpu(x, fs, pad_factor=4):
    """
    Estimate the frequency (Hz) of a single complex tone in x[n],
    using GPU (CuPy) if available, with zero-padding and parabolic interpolation.
    """

    # --- Move to GPU if available ---
    x_gpu = cp.asarray(x)
    N = len(x)
    N_fft = int(2**np.ceil(np.log2(N * pad_factor)))

    # Apply window to reduce leakage
    win_gpu = cp.hanning(N)
    xw = x_gpu * win_gpu

    # FFT
    X = cp.fft.fft(xw, n=N_fft)
    mag = cp.abs(X)

    # Frequency axis
    freqs = cp.fft.fftfreq(N_fft, d=1/fs)

    # Move back to CPU for interpolation
    mag_cpu = cp.asnumpy(mag)
    freqs_cpu = cp.asnumpy(freqs)

    # Find index of max bin
    idx = int(np.argmax(mag_cpu))

    # --- Parabolic Interpolation ---
    if 0 < idx < len(mag_cpu)-1:
        y1, y2, y3 = mag_cpu[idx-1], mag_cpu[idx], mag_cpu[idx+1]
        denom = (y1 - 2*y2 + y3)
        if denom != 0:
            delta = 0.5 * (y1 - y3) / denom
        else:
            delta = 0.0
    else:
        delta = 0.0

    df = freqs_cpu[1] - freqs_cpu[0]
    f_hat = freqs_cpu[idx] + delta * df
    return f_hat, N_fft


In [21]:
def estimate_fdoa_from_tones(x1, x2, x3, fs, pad_factor=4, error_cfg=None, seed=None):
    """
    FDOA estimation for CW case:
      fD21 = f2 - f1
      fD31 = f3 - f1

    Adds CW-only jitter (frequency-domain):
        - estimator jitter (fd_meas_err_rms)
        - LO/CFO jitter (lo_err_rms_hz)
    Uses a reproducible RNG seeded per call.

    Parameters
    ----------
    seed : int or None
        If provided, ensures reproducibility.  Noise samples remain independent.
    """

    # --- Tone extraction (exact) ---
    f1, N_fft = estimate_tone_frequency_gpu(x1, fs, pad_factor)
    f2, _     = estimate_tone_frequency_gpu(x2, fs, pad_factor)
    f3, _     = estimate_tone_frequency_gpu(x3, fs, pad_factor)

    fD21 = f2 - f1
    fD31 = f3 - f1

    # --- FFT spacing ---
    df_step = fs / N_fft

    # --- Jitter injection (CW) ---
    if error_cfg is not None:
        fd_meas_err = error_cfg.get("fd_meas_err_rms", 0.0)
        lo_err      = error_cfg.get("lo_err_rms_hz", 0.0)

        # Seeded RNG for reproducibility
        # Using key idea: RNG seeded per call, but every draw is independent.
        rng = np.random.default_rng(seed)

        # Estimator jitter (independent for each pair)
        if fd_meas_err > 0:
            fD21 += rng.normal(0, fd_meas_err)
            fD31 += rng.normal(0, fd_meas_err)

        # LO jitter (independent for each pair)
        # difference between LO2–LO1 gives sqrt(2) * σ_lo
        if lo_err > 0:
            lo_jitter_21 = rng.normal(0, np.sqrt(2.0) * lo_err)
            lo_jitter_31 = rng.normal(0, np.sqrt(2.0) * lo_err)

            fD21 += lo_jitter_21
            fD31 += lo_jitter_31

    return fD21, fD31, df_step


### Single Sweep CAF (LFMCWs)

In [None]:
'''
For simplicity, only doing CAF for first LFMCW sweep. 
Can theoretically be extending to multiple sweeps for better performance

Idea: Align and extract first sweeps across all receiver channels then CAF.

Alignment: 
- FF1 (mainlobe) relies on deinterleaver output.
- FF2 and 3 (sidelobes) relies on estimated emitter location, which will affect TDOA and FDOA offsets

Extraction: Same concept as LFM case, using time information to time gate, followed by dechirp, rechirp filtering.

Estimated emitter location: Will affect alignment significantly. 
- Idea is to use AOI center location as emitter location proxy
- Geolocate LFMCW emitter location, then feed in this location iteratively for a better CAF/geolocation output
'''

def caf_fft_gpu(x, y, tau_vals, fd_vals, fs, batch_size=128):
    N = len(x)
    x_gpu = cp.asarray(x)
    y_gpu = cp.asarray(y)
    t_gpu = cp.arange(N, dtype=cp.float64) / fs

    X_gpu = cp.fft.fft(x_gpu)
    tau_idx_gpu = cp.asarray((np.round(tau_vals * fs).astype(np.int64)) % N)

    CAF_gpu = cp.zeros((len(fd_vals), len(tau_vals)), dtype=cp.float32)
    fd_vals_gpu = cp.asarray(fd_vals, dtype=cp.float64)

    for i0 in range(0, len(fd_vals), batch_size):
        i1 = min(i0 + batch_size, len(fd_vals))
        fb = fd_vals_gpu[i0:i1]

        phase = cp.exp(-1j * 2*np.pi * fb[:,None] * t_gpu[None,:])
        y_mix = y_gpu[None,:] * phase
        Y_mix = cp.fft.fft(y_mix, axis=1)

        R = cp.fft.ifft(Y_mix * cp.conj(X_gpu)[None,:], axis=1)
        CAF_gpu[i0:i1,:] = cp.abs(R)[:, tau_idx_gpu]

    CAF_gpu /= (cp.max(CAF_gpu) + 1e-12)
    return cp.asnumpy(CAF_gpu)

def geom_tdoa_fdoa_from_lla(
    emitter_lla,
    r_sats,
    v_sats,
    f_center_hz
):
    """
    r_sats = {'FF1': np.array([x,y,z]), 'FF2': ..., 'FF3': ...}
    v_sats = {'FF1': np.array([vx,vy,vz]), ...}
    """

    r_em = lla_to_ecef(
        emitter_lla["lat"],
        emitter_lla["lon"],
        emitter_lla.get("alt", 0.0),
    )

    r1 = r_sats['FF1'];  r2 = r_sats['FF2'];  r3 = r_sats['FF3']
    v1 = v_sats['FF1'];  v2 = v_sats['FF2'];  v3 = v_sats['FF3']

    # Ranges
    d1 = np.linalg.norm(r_em - r1)
    d2 = np.linalg.norm(r_em - r2)
    d3 = np.linalg.norm(r_em - r3)

    # TDOA
    tau21 = (d2 - d1) / c
    tau31 = (d3 - d1) / c

    # LOS unit vectors
    u1 = (r_em - r1) / d1
    u2 = (r_em - r2) / d2
    u3 = (r_em - r3) / d3

    # Radial velocities
    v_r1 = np.dot(v1, u1)
    v_r2 = np.dot(v2, u2)
    v_r3 = np.dot(v3, u3)

    # Doppler per sat: fD = -(v_r/c)*f
    fD1 = -(v_r1 / c) * f_center_hz
    fD2 = -(v_r2 / c) * f_center_hz
    fD3 = -(v_r3 / c) * f_center_hz

    # FDOA
    fD21 = fD2 - fD1
    fD31 = fD3 - fD1

    return tau21, tau31, fD21, fD31

def test_lfmcw_caf_blind_geom(
    # Function used to align extracted LFMCW first sweeps across channels and CAF them
    iq_FF1, iq_FF2, iq_FF3,
    fs,
    lfmcw_row,
    emitter_lla,
    r_sats,
    v_sats,
    F_LO,
    tau21_true=None, tau31_true=None,
    fD21_true=None, fD31_true=None,
    tau_guard_extra_us=20.0,
    tau_search_us=100.0,
    fd_search_hz=10000.0,
    caf_batch_size=256,
):
    # === Geometry-based coarse τ/FDOA ===
    f_center = lfmcw_row["Center Freq (MHz)"] * 1e6
    f_rf = F_LO + f_center

    tau21_est, tau31_est, fD21_est, fD31_est = geom_tdoa_fdoa_from_lla(
        emitter_lla, r_sats, v_sats, f_rf
    )

    tau_guard_s = abs(tau21_est) + abs(tau31_est) + tau_guard_extra_us*1e-6

    # === Extract sweeps using common ROI ===
    x1_ds, x2_ds, x3_ds, fs_ds, meta = extract_lfmcw_common_window_geom(
        iq_FF1, iq_FF2, iq_FF3,
        fs, lfmcw_row,
        tau21_est, tau31_est,
        tau_guard_s,
    )

    # === CAF grids centered on geometry ===
    tau_search = tau_search_us * 1e-6

    tau_vals_21 = np.linspace(tau21_est - tau_search,
                              tau21_est + tau_search,
                              401)
    tau_vals_31 = np.linspace(tau31_est - tau_search,
                              tau31_est + tau_search,
                              401)

    fd_vals_21 = np.linspace(fD21_est - fd_search_hz,
                             fD21_est + fd_search_hz,
                             401)
    fd_vals_31 = np.linspace(fD31_est - fd_search_hz,
                             fD31_est + fd_search_hz,
                             401)

    # === CAF FF2→FF1 ===
    #print("\nRunning CAF FF2→FF1 ...")
    CAF21 = caf_fft_gpu(x1_ds, x2_ds, tau_vals_21, fd_vals_21, fs_ds,
                        batch_size=caf_batch_size)
    i_fd21, i_tau21 = np.unravel_index(np.argmax(CAF21), CAF21.shape)
    tau21_hat = tau_vals_21[i_tau21]
    fD21_hat  = fd_vals_21[i_fd21]

    # === CAF FF3→FF1 ===
    #print("Running CAF FF3→FF1 ...")
    CAF31 = caf_fft_gpu(x1_ds, x3_ds, tau_vals_31, fd_vals_31, fs_ds,
                        batch_size=caf_batch_size)
    i_fd31, i_tau31 = np.unravel_index(np.argmax(CAF31), CAF31.shape)
    tau31_hat = tau_vals_31[i_tau31]
    fD31_hat  = fd_vals_31[i_fd31]

    return dict(
        tau21_hat=tau21_hat,
        tau31_hat=tau31_hat,
        fD21_hat=fD21_hat,
        fD31_hat=fD31_hat,
        CAF21=CAF21,
        CAF31=CAF31,
        tau_vals_21=tau_vals_21,
        fd_vals_21=fd_vals_21,
        tau_vals_31=tau_vals_31,
        fd_vals_31=fd_vals_31,
        meta=meta,
    )

## 9. Geolocation

### 9.1 TDOA (for LFM and Single Frequency Pulses)

#### TDOA geolocation

In [None]:
def tdoa_model_latlon(
    lat_rad: float,
    lon_rad: float,
    r_sats: dict,
    alt_m: float = 0.0
) -> tuple[float, float]:
    """
    Compute TDOAs τ21 and τ31 [s] for a given emitter lat/lon, alt=0 by default.

    Parameters
    ----------
    lat_rad : float
        Latitude in radians.
    lon_rad : float
        Longitude in radians.
    r_sats : dict
        Dictionary of satellite ECEF positions, keys: 'FF1','FF2','FF3'.
        Each value should be a length-3 ndarray [m].
    alt_m : float, optional
        Altitude in meters (default 0.0).

    Returns
    -------
    tau21 : float
        TDOA (Sat2 - Sat1) [s].
    tau31 : float
        TDOA (Sat3 - Sat1) [s].
    """
    lat_deg = np.rad2deg(lat_rad)
    lon_deg = np.rad2deg(lon_rad)

    e_ecef = lla_to_ecef(lat_deg, lon_deg, alt_m)

    r1 = r_sats['FF1']
    r2 = r_sats['FF2']
    r3 = r_sats['FF3']

    R1 = np.linalg.norm(e_ecef - r1)
    R2 = np.linalg.norm(e_ecef - r2)
    R3 = np.linalg.norm(e_ecef - r3)

    tau21 = (R2 - R1) / c
    tau31 = (R3 - R1) / c

    return tau21, tau31


def tdoa_jacobian_latlon(
    lat_rad: float,
    lon_rad: float,
    r_sats: dict,
    alt_m: float = 0.0,
    eps_rad: float = 1e-6
) -> np.ndarray:
    """
    Numerically compute Jacobian J of TDOA wrt (lat, lon):

        J = [ dτ21/dlat   dτ21/dlon ]
            [ dτ31/dlat   dτ31/dlon ]

    where lat, lon are in radians.

    Parameters
    ----------
    lat_rad : float
        Latitude [rad] at which to evaluate the Jacobian.
    lon_rad : float
        Longitude [rad] at which to evaluate the Jacobian.
    r_sats : dict
        Satellite positions in ECEF (keys 'FF1','FF2','FF3').
    alt_m : float
        Altitude in meters (fixed, usually 0).
    eps_rad : float
        Small perturbation step in radians for numerical differentiation.

    Returns
    -------
    J : (2,2) ndarray
        Jacobian matrix of TDOA wrt (lat, lon) at the given point.
    """
    # Base TDOAs at (lat, lon)
    tau21_0, tau31_0 = tdoa_model_latlon(lat_rad, lon_rad, r_sats, alt_m)

    J = np.zeros((2, 2), dtype=float)

    # Perturb latitude
    lat_p = lat_rad + eps_rad
    tau21_p, tau31_p = tdoa_model_latlon(lat_p, lon_rad, r_sats, alt_m)
    J[:, 0] = [(tau21_p - tau21_0) / eps_rad,
               (tau31_p - tau31_0) / eps_rad]

    # Perturb longitude
    lon_p = lon_rad + eps_rad
    tau21_p, tau31_p = tdoa_model_latlon(lat_rad, lon_p, r_sats, alt_m)
    J[:, 1] = [(tau21_p - tau21_0) / eps_rad,
               (tau31_p - tau31_0) / eps_rad]

    return J

def geolocate_from_tdoa(
    tau_hat_21: float,
    tau_hat_31: float,
    r_sats: dict,
    lat0_deg: float,
    lon0_deg: float,
    alt_m: float = 0.0,
    max_iter: int = 100,
    tol: float = 1e-11
) -> tuple[float, float, np.ndarray]:
    """
    Solve for emitter (lat, lon) given measured TDOAs (τ̂21, τ̂31),
    using Gauss–Newton with altitude fixed.

    Parameters
    ----------
    tau_hat_21 : float
        Measured TDOA for Sat2-Sat1 [s] (from CAF).
    tau_hat_31 : float
        Measured TDOA for Sat3-Sat1 [s] (from CAF).
    r_sats : dict
        Satellite ECEF positions {'FF1','FF2','FF3'}.
    lat0_deg : float
        Initial guess latitude [deg].
    lon0_deg : float
        Initial guess longitude [deg].
    alt_m : float
        Fixed altitude [m], usually 0.
    max_iter : int
        Maximum Gauss–Newton iterations.
    tol : float
        Convergence threshold on parameter update norm.

    Returns
    -------
    lat_est_deg : float
        Estimated emitter latitude [deg].
    lon_est_deg : float
        Estimated emitter longitude [deg].
    J_est : (2,2) ndarray
        Jacobian of TDOA wrt (lat,lon) [rad] at the solution point.
        This is used later for covariance / CEP calculation.
    """
    # Work in radians internally
    lat = np.deg2rad(lat0_deg)
    lon = np.deg2rad(lon0_deg)

    z_hat = np.array([tau_hat_21, tau_hat_31], dtype=float)

    for _ in range(max_iter):
        # Predicted TDOAs at current estimate
        tau21_pred, tau31_pred = tdoa_model_latlon(lat, lon, r_sats, alt_m)
        z_pred = np.array([tau21_pred, tau31_pred], dtype=float)

        # Residual in TDOA space
        r_vec = z_hat - z_pred   # shape (2,)

        # Jacobian at current point
        J = tdoa_jacobian_latlon(lat, lon, r_sats, alt_m)

        # Gauss–Newton step: dx = (J^T J)^(-1) J^T r
        # (we assume measurement noise is isotropic when solving)
        JTJ = J.T @ J
        try:
            dx = np.linalg.solve(JTJ, J.T @ r_vec)
        except np.linalg.LinAlgError:
            # In degenerate geometry, fallback to pseudo-inverse
            dx = np.linalg.pinv(JTJ) @ (J.T @ r_vec)

        # Update lat, lon (in radians)
        lat += dx[0]
        lon += dx[1]

        # Check convergence in parameter space
        if np.linalg.norm(dx) < tol:
            break

    lat_est_deg = np.rad2deg(lat)
    lon_est_deg = np.rad2deg(lon)

    # Final Jacobian at solution (for covariance)
    J_est = tdoa_jacobian_latlon(lat, lon, r_sats, alt_m)

    return lat_est_deg, lon_est_deg, J_est


#### TDOA Ellipse Derivation

In [None]:
def TDOA_covariance_matrix(
    fs_caf, F_LO,
    B_sig,
    PW,
    snr_db_rx1,
    snr_db_rx2,
    snr_db_rx3,
    M,
    error_cfg,
    mode="lfm"
):
    """
    Build correlated TDOA covariance matrix Στ using:
      - Ulman TDOA CRLB for σ21 and σ31,
      - Synthetic FF1–FF1 Ulman CRLB for σ11,
      - sigma1_sq  = 0.5 * sigma_tau_11^2,
      - Στ = [[var21, sigma1_sq],
             [sigma1_sq, var31]].

    External errors are added only to the *pairwise* variances.
    """

    # -------------------------
    # 1) CORE CRLB (Ulman)
    # -------------------------

    if mode == "lfm":
        # Synthetic FF1–FF1
        sigma_tau_11 = theoretical_delay_sigma(
            fs_caf, B_sig, PW,
            snr_db_rx1, snr_db_rx1,
            M=M
        )

        # Real pairs
        sigma_tau_21 = theoretical_delay_sigma(
            fs_caf, B_sig, PW,
            snr_db_rx1, snr_db_rx2,
            M=M
        )
        sigma_tau_31 = theoretical_delay_sigma(
            fs_caf, B_sig, PW,
            snr_db_rx1, snr_db_rx3,
            M=M
        )

    elif mode == "tone":
        # Synthetic FF1–FF1
        sigma_tau_11 = theoretical_delay_sigma(
            fs_caf, 2/PW, PW,
            snr_db_rx1, snr_db_rx1,
            M=M
        )

        # Real pairs
        sigma_tau_21 = theoretical_delay_sigma(
            fs_caf, 2/PW, PW,
            snr_db_rx1, snr_db_rx2,
            M=M
        )
        sigma_tau_31 = theoretical_delay_sigma(
            fs_caf, 2/PW, PW,
            snr_db_rx1, snr_db_rx3,
            M=M
        )

    # -------------------------
    # 2) Per Receiver TOA variances
    # -------------------------
    sigma1_sq = 0.5 * sigma_tau_11**2
    sigma2_sq = sigma_tau_21**2 - sigma1_sq
    sigma3_sq = sigma_tau_31**2 - sigma1_sq

    # -------------------------
    # 3) External timing errors
    # -------------------------
    sync_var  = (error_cfg["sync_err_rms"])**2
    meas_var  = (error_cfg["meas_err_rms"])**2
    pos_var   = (np.sqrt(2.0/3)*error_cfg["pos_err_rms"]/c)**2
    quant_var = (1.0 / fs_caf / np.sqrt(12.0))**2

    # -------------------------
    # 4) Per-receiver measurement error
    # -------------------------
    sigma1_sq += meas_var
    sigma2_sq += meas_var
    sigma3_sq += meas_var

    # -------------------------
    # 4) Pairwise-only errors
    # -------------------------
    pair_var_extra = sync_var + pos_var + quant_var
    var21   = (sigma1_sq + sigma2_sq) + pair_var_extra
    var31   = (sigma1_sq + sigma3_sq) + pair_var_extra
    
    # -------------------------
    # 5) Final correlated Στ
    # -------------------------
    Sigma_tau = np.array([
        [var21,   sigma1_sq],
        [sigma1_sq, var31  ]
    ])

    return Sigma_tau


In [None]:
def theoretical_delay_sigma(
    fs,
    B_sig,
    Tp,
    snr_db_rx1,
    snr_db_rx2,
    M,
    Bn=None,
    use_numeric_Brms=False,
    compute_Brms_numerical=None,
):
    """
    TDOA CRLB-like σ_τ from Ulman & Geraniotis Eq. (5),(7),(8).

    σ_τ ≈ 1 / (β_s * sqrt(Bn * T * γ_eff))

    where:
      - β_s = 2π * B_rms  (RMS frequency in rad/s)
      - T   = M * Tp      (coherent integration time)
      - Bn  = noise bandwidth [Hz]
      - γ_eff from Eq. (8) with per-receiver SNRs γ1, γ2 in Bn

    Parameters
    ----------
    fs : float
        Sample rate [Hz].
    B_sig : float
        Signal (LFM) bandwidth B_s [Hz].
    Tp : float
        Pulse width [s].
    snr_db_rx1, snr_db_rx2 : float
        Per-receiver SNRs γ1, γ2 in dB, defined in the *receiver noise bandwidth* B_n.
    M : int
        Number of coherent pulses used in the CAF.
    Bn : float or None
        Common noise bandwidth [Hz]. If None, defaults to fs/2.
    use_numeric_Brms : bool
        If True, use user-provided compute_Brms_numerical(fs, Tp, B_sig).
        If False, approximate Brms ≈ B_sig / (2*sqrt(3)).
    compute_Brms_numerical : callable or None
        Brms = compute_Brms_numerical(fs, Tp, B_sig) if use_numeric_Brms is True.

    Returns
    -------
    sigma_tau : float
        CRLB-like standard deviation of TDOA [s].
    """
    if Bn is None:
        Bn = fs / 2.0

    # Per-receiver SNRs in linear scale
    gamma1 = 10.0 ** (snr_db_rx1 / 10.0)
    gamma2 = 10.0 ** (snr_db_rx2 / 10.0)

    eps = 1e-15
    gamma1 = max(gamma1, eps)
    gamma2 = max(gamma2, eps)

    # Eq. (8): effective SNR in CAF
    inv_gamma_eff = 0.5 * (1.0/gamma1 + 1.0/gamma2 + 2.0/(gamma1*gamma2))
    gamma_eff = 1.0 / inv_gamma_eff

    # Azaria-Hertz correction when noise bandwidth > signal bandwidth
    if Bn > B_sig:
        gamma_eff *= (B_sig / Bn)

    # Integration time
    T = M * Tp

    # β_s = 2π * B_rms
    if use_numeric_Brms:
        if compute_Brms_numerical is None:
            raise ValueError("compute_Brms_numerical must be provided.")
        Brms = compute_Brms_numerical(fs, Tp, B_sig)
    else:
        Brms = B_sig / (2.0 * np.sqrt(3.0))

    beta_s = 2.0 * np.pi * Brms

    # Eq. (5)
    sigma_tau = 1.0 / (beta_s * np.sqrt(Bn * T * gamma_eff))
    return sigma_tau


# Optional numerical Brms (if you want more accuracy)
def compute_Brms_numerical(fs, Tp, B):
    """Numerically compute B_rms for an LFM pulse."""
    t = np.arange(0, Tp, 1/fs)
    k = B / Tp
    s = np.exp(1j * np.pi * k * t**2)
    s /= np.sqrt(np.sum(np.abs(s)**2))

    dt = 1.0/fs
    ds_dt = np.diff(s) / dt

    E_sig  = np.sum(np.abs(s)**2) * dt
    E_dsig = np.sum(np.abs(ds_dt)**2) * dt

    return np.sqrt(E_dsig / (4 * np.pi**2 * E_sig))


#### TDOA Wrapper

In [None]:
def geolocate_and_cep_from_tdoa(
    tau21_hat: float,
    tau31_hat: float,
    Sigma_tau: np.ndarray,      # 2×2 TDOA covariance
    r_sats: dict,
    lat0_deg: float,
    lon0_deg: float,
    alt_m: float = 0.0,
    p: float = 0.5
) -> dict:
    """
    Full pipeline:
      - Geolocate emitter from TDOA-only (Gauss–Newton, alt fixed).
      - Compute position covariance via Σ_τ.
      - Convert to EN covariance at solution point.
      - Extract p-level ellipse (e.g. CEP50/95).
    """

    # --- 1) Geolocate from TDOA-only (blind, alt fixed) ---
    lat_est_deg, lon_est_deg, J_est = geolocate_from_tdoa(
        tau21_hat, tau31_hat,
        r_sats,
        lat0_deg, lon0_deg,
        alt_m=alt_m
    )

    # --- 2) Position covariance in (lat,lon) [rad^2] ---
    Sigma_latlon = position_covariance_latlon(J_est, Sigma_tau)

    # --- 3) Convert to local EN covariance [m^2] ---
    lat_est_rad = np.deg2rad(lat_est_deg)
    Sigma_EN = latlon_cov_to_EN_cov_wgs84(Sigma_latlon, lat_est_rad)

    # --- 4) Extract ellipse at probability p (e.g. CEP50) ---
    a, b, angle_deg = ellipse_from_cov_EN(Sigma_EN, p=p)

    return {
        "lat_est_deg": lat_est_deg,
        "lon_est_deg": lon_est_deg,
        "Sigma_latlon": Sigma_latlon,
        "Sigma_EN": Sigma_EN,
        "a": a,
        "b": b,
        "angle_deg": angle_deg,
        "p": p,
    }


### 9.2 FDOA (for CWs)

#### FDOA Geolocation

In [None]:
# ============================================================
#  FDOA GEOLOCATION PIPELINE (FDOA-only, alt fixed)
# ============================================================

def fdoa_model_latlon(
    lat_rad: float,
    lon_rad: float,
    r_sats: dict,
    v_sats: dict,
    f_c_hz: float,
    alt_m: float = 0.0,
) -> tuple[float, float]:
    """
    Compute FDOAs fD21 and fD31 [Hz] for a given emitter lat/lon, alt=0 by default.

    We assume a *static* ground emitter and moving satellites, so the Doppler at
    each satellite i is

        fD_i = -(f_c / c) * (v_i · u_i)

    where u_i is the unit LOS from emitter to satellite i.

    FDOA is then defined as:

        fD21 = fD_2 - fD_1
        fD31 = fD_3 - fD_1

    Parameters
    ----------
    lat_rad, lon_rad : float
        Latitude / longitude in radians.
    r_sats : dict
        Satellite ECEF positions, keys: 'FF1','FF2','FF3'. Each value shape (3,).
    v_sats : dict
        Satellite ECEF velocities, same keys as r_sats, each shape (3,).
    f_c_hz : float
        Carrier frequency [Hz].
    alt_m : float, optional
        Emitter altitude [m], default 0.

    Returns
    -------
    fD21 : float
        FDOA (Sat2 - Sat1) [Hz].
    fD31 : float
        FDOA (Sat3 - Sat1) [Hz].
    """
    lat_deg = np.rad2deg(lat_rad)
    lon_deg = np.rad2deg(lon_rad)

    # Emitter ECEF
    e_ecef = lla_to_ecef(lat_deg, lon_deg, alt_m)

    # Extract sat states
    r1, v1 = r_sats['FF1'], v_sats['FF1']
    r2, v2 = r_sats['FF2'], v_sats['FF2']
    r3, v3 = r_sats['FF3'], v_sats['FF3']

    # LOS unit vectors (emitter -> sat)
    rho1 = r1 - e_ecef
    rho2 = r2 - e_ecef
    rho3 = r3 - e_ecef

    u1 = rho1 / np.linalg.norm(rho1)
    u2 = rho2 / np.linalg.norm(rho2)
    u3 = rho3 / np.linalg.norm(rho3)

    # Radial velocities (project satellite velocity on LOS)
    vr1 = np.dot(v1, u1)
    vr2 = np.dot(v2, u2)
    vr3 = np.dot(v3, u3)

    # Doppler at each satellite (narrowband approximation)
    fD1 = -(f_c_hz / c) * vr1
    fD2 = -(f_c_hz / c) * vr2
    fD3 = -(f_c_hz / c) * vr3

    # Differential Dopplers
    fD21 = fD2 - fD1
    fD31 = fD3 - fD1

    return fD21, fD31


def fdoa_jacobian_latlon(
    lat_rad: float,
    lon_rad: float,
    r_sats: dict,
    v_sats: dict,
    f_c_hz: float,
    alt_m: float = 0.0,
    eps_rad: float = 1e-6,
) -> np.ndarray:
    """
    Numerically compute Jacobian J of FDOA wrt (lat, lon):

        J = [ d(fD21)/dlat   d(fD21)/dlon ]
            [ d(fD31)/dlat   d(fD31)/dlon ]

    where lat, lon are in radians.

    Parameters
    ----------
    lat_rad, lon_rad : float
        Point (rad) at which to evaluate the Jacobian.
    r_sats, v_sats : dict
        Satellite ECEF positions / velocities.
    f_c_hz : float
        Carrier frequency [Hz].
    alt_m : float
        Emitter altitude [m], fixed.
    eps_rad : float
        Small perturbation for finite-difference [rad].

    Returns
    -------
    J : (2,2) ndarray
        Jacobian matrix of FDOA wrt (lat, lon).
    """
    # Base FDOAs at (lat, lon)
    fD21_0, fD31_0 = fdoa_model_latlon(
        lat_rad, lon_rad, r_sats, v_sats, f_c_hz, alt_m
    )

    J = np.zeros((2, 2), dtype=float)

    # Perturb latitude
    lat_p = lat_rad + eps_rad
    fD21_p, fD31_p = fdoa_model_latlon(
        lat_p, lon_rad, r_sats, v_sats, f_c_hz, alt_m
    )
    J[:, 0] = [(fD21_p - fD21_0) / eps_rad,
               (fD31_p - fD31_0) / eps_rad]

    # Perturb longitude
    lon_p = lon_rad + eps_rad
    fD21_p, fD31_p = fdoa_model_latlon(
        lat_rad, lon_p, r_sats, v_sats, f_c_hz, alt_m
    )
    J[:, 1] = [(fD21_p - fD21_0) / eps_rad,
               (fD31_p - fD31_0) / eps_rad]

    return J

In [None]:
def geolocate_from_fdoa(
    fD21_hat: float,
    fD31_hat: float,
    r_sats: dict,
    v_sats: dict,
    f_c_hz: float,
    lat0_deg: float,
    lon0_deg: float,
    alt_m: float = 0.0,
    max_iter: int = 100,
    tol: float = 1e-11
) -> tuple[float, float, np.ndarray]:
    """
    Solve for emitter (lat, lon) from FDOA-only measurements
    (fD21_hat, fD31_hat) using Gauss–Newton with altitude fixed.

    Parameters
    ----------
    fD21_hat : float
        Measured FDOA for Sat2-Sat1 [Hz].
    fD31_hat : float
        Measured FDOA for Sat3-Sat1 [Hz].
    r_sats : dict
        Satellite ECEF positions {'FF1','FF2','FF3'}.
    v_sats : dict
        Satellite ECEF velocities {'FF1','FF2','FF3'}.
    f_c_hz : float
        Carrier frequency in Hz.
    lat0_deg, lon0_deg : float
        Initial guess.
    alt_m : float
        Fixed altitude.
    max_iter : int
        Maximum Gauss–Newton iterations.
    tol : float
        Convergence threshold.

    Returns
    -------
    lat_est_deg, lon_est_deg : float
        Estimated emitter LLA.
    J_est : (2,2) ndarray
        Jacobian of FDOA wrt (lat,lon) [rad] at the solution.
    """

    lat = np.deg2rad(lat0_deg)
    lon = np.deg2rad(lon0_deg)

    z_hat = np.array([fD21_hat, fD31_hat], dtype=float)

    for _ in range(max_iter):

        # Predicted FDOAs
        f21_pred, f31_pred = fdoa_model_latlon(
            lat, lon, r_sats, v_sats, f_c_hz, alt_m
        )
        z_pred = np.array([f21_pred, f31_pred], dtype=float)

        # Residual in FDOA space
        r_vec = z_hat - z_pred

        # Jacobian
        J = fdoa_jacobian_latlon(
            lat, lon, r_sats, v_sats, f_c_hz, alt_m
        )

        # Gauss–Newton step
        JTJ = J.T @ J
        try:
            dx = np.linalg.solve(JTJ, J.T @ r_vec)
        except np.linalg.LinAlgError:
            dx = np.linalg.pinv(JTJ) @ (J.T @ r_vec)

        lat += dx[0]
        lon += dx[1]

        if np.linalg.norm(dx) < tol:
            break

    lat_est_deg = np.rad2deg(lat)
    lon_est_deg = np.rad2deg(lon)

    # Final Jacobian for covariance/ellipse
    J_est = fdoa_jacobian_latlon(lat, lon, r_sats, v_sats, f_c_hz, alt_m)

    return lat_est_deg, lon_est_deg, J_est


#### FDOA Ellipse Derivation

In [None]:
def theoretical_doppler_sigma(
    fs: float,
    B_sig: float,
    Tp: float,
    snr_db_rx1: float,
    snr_db_rx2: float,
    M: int,
    Bn: float | None = None,
) -> float:
    """
    FDOA CRLB-like σ_v from Ulman & Geraniotis Eq. (6),(8).

    σ_v ≈ 0.55 / (T * sqrt(Bn * T * γ_eff))

    where:
      - T   = M * Tp      (coherent integration time)
      - Bn  = noise bandwidth [Hz]
      - γ_eff from Eq. (8) with per-receiver SNRs γ1, γ2 in Bn
    """
    if Bn is None:
        Bn = fs / 2.0

    # Per-receiver SNRs in linear scale
    gamma1 = 10.0 ** (snr_db_rx1 / 10.0)
    gamma2 = 10.0 ** (snr_db_rx2 / 10.0)

    eps = 1e-15
    gamma1 = max(gamma1, eps)
    gamma2 = max(gamma2, eps)

    # Eq. (8): effective SNR in CAF
    inv_gamma_eff = 0.5 * (1.0 / gamma1 + 1.0 / gamma2 + 2.0 / (gamma1 * gamma2))
    gamma_eff = 1.0 / inv_gamma_eff

    # Azaria–Hertz correction when noise bandwidth > signal bandwidth
    if Bn > B_sig:
        gamma_eff *= (B_sig / Bn)

    # Integration time
    T = M * Tp

    # Eq. (5) doppler std dev [Hz]
    sigma_v = 0.55 / (T * np.sqrt(Bn * T * gamma_eff))
    return sigma_v


In [None]:
def FDOA_covariance_matrix(
    fs: float,
    F_LO: float,
    f_center: float,
    B_sig: float,
    PW: float,
    snr_db_rx1: float,
    snr_db_rx2: float,
    snr_db_rx3: float,
    M: int,
    error_cfg: dict,
    df_step: float,
    mode: str = "cw",
) -> np.ndarray:
    """
    Build correlated FDOA covariance matrix Σ_f.

    For CW:
      - Use Ulman FDOA CRLB for σ_v,21 and σ_v,31,
      - Synthetic FF1–FF1 CRLB for σ_v,11,
      - infer per-receiver core variances:
            sigma1_sq_core = 0.5 * sigma_v_11^2
            sigma2_sq_core = sigma_v_21^2 - sigma1_sq_core
            sigma3_sq_core = sigma_v_31^2 - sigma1_sq_core
      - Add per-receiver fd_meas_err + LO error into σᵢ²,
      - Add vel_err + Doppler grid quantization as pairwise-only errors,
      - Form:

            var21   = sigma1_sq + sigma2_sq + pair_var_extra
            var31   = sigma1_sq + sigma3_sq + pair_var_extra
            cov2131 = sigma1_sq

        Σ_f = [[var21,   cov2131],
               [cov2131, var31  ]]
    """

    # -------------------------
    # 1) CORE CRLB (Ulman) in Doppler
    # -------------------------
    if mode == "cw":
        # Synthetic FF1–FF1
        sigma_v_11 = theoretical_doppler_sigma(
            fs, B_sig, PW,
            snr_db_rx1, snr_db_rx1,
            M=M
        )

        # Real pairs
        sigma_v_21 = theoretical_doppler_sigma(
            fs, B_sig, PW,
            snr_db_rx1, snr_db_rx2,
            M=M
        )
        sigma_v_31 = theoretical_doppler_sigma(
            fs, B_sig, PW,
            snr_db_rx1, snr_db_rx3,
            M=M
        )
    else:
        # Placeholder for LFMCW or other modes later
        raise NotImplementedError("FDOA_covariance_matrix: mode != 'cw' not implemented yet")

    # -------------------------
    # 2) Per-receiver core Doppler variances
    # -------------------------
    sigma1_sq = 0.5 * sigma_v_11**2
    sigma2_sq = sigma_v_21**2 - sigma1_sq
    sigma3_sq = sigma_v_31**2 - sigma1_sq

    # -------------------------
    # 3) Per-receiver measurement + LO errors
    # -------------------------
    meas_var = error_cfg["fd_meas_err_rms"]**2
    lo_var   = error_cfg["lo_err_rms_hz"]**2

    sigma1_sq +=  (meas_var + lo_var)
    sigma2_sq +=  (meas_var + lo_var)
    sigma3_sq +=  (meas_var + lo_var)

    # -------------------------
    # 4) Pairwise-only FDOA errors: velocity + Doppler grid quantization
    # -------------------------
    # Map velocity error [m/s] to Doppler [Hz]:
    #   Δf ≈ -(f_c / c) Δv_r
    f_c = F_LO + f_center
    vel_err_rms = error_cfg["vel_err_rms"]     # [m/s]

    vel_pair_std = np.sqrt(2.0/3) * vel_err_rms
    vel_var_fd   = (f_c / c)**2 * vel_pair_std**2   # [Hz^2]

    # Doppler grid quantization variance (1/√12 rule)
    quant_var = (df_step / np.sqrt(12.0))**2

    pair_var_extra = vel_var_fd + quant_var

    # -------------------------
    # 5) Final FDOA variances & covariance
    # -------------------------
    var21   = sigma1_sq + sigma2_sq + pair_var_extra
    var31   = sigma1_sq + sigma3_sq + pair_var_extra

    Sigma_v = np.array([
        [var21,   sigma1_sq],
        [sigma1_sq, var31  ]
    ])

    return Sigma_v


#### FDOA Wrapper

In [None]:
def geolocate_and_cep_from_fdoa(
    fD21_hat: float,
    fD31_hat: float,
    Sigma_f: np.ndarray,       # 2×2 FDOA covariance
    r_sats: dict,
    v_sats: dict,
    f_c_hz: float,
    lat0_deg: float,
    lon0_deg: float,
    alt_m: float = 0.0,
    p: float = 0.5
) -> dict:
    """
    Full pipeline:
      - Geolocate emitter from FDOA-only (Gauss–Newton, alt fixed).
      - Compute position covariance via Σ_f.
      - Convert to EN covariance at solution point.
      - Extract p-level ellipse (e.g. CEP50/95).

    Mirrors geolocate_and_cep_from_tdoa.
    """

    # --- 1) Geolocate from FDOA-only (blind, alt fixed) ---
    lat_est_deg, lon_est_deg, J_est = geolocate_from_fdoa(
        fD21_hat, fD31_hat,
        r_sats, v_sats, f_c_hz,
        lat0_deg, lon0_deg,
        alt_m=alt_m
    )

    # --- 2) Position covariance in (lat,lon) [rad^2] ---
    # IDENTICAL STRUCTURE TO TDOA
    Sigma_latlon = position_covariance_latlon(J_est, Sigma_f)

    # --- 3) Convert to local EN covariance [m^2] ---
    lat_est_rad = np.deg2rad(lat_est_deg)
    Sigma_EN = latlon_cov_to_EN_cov_wgs84(Sigma_latlon, lat_est_rad)

    # --- 4) Extract ellipse at probability p ---
    a, b, angle_deg = ellipse_from_cov_EN(Sigma_EN, p=p)

    return {
        "lat_est_deg": lat_est_deg,
        "lon_est_deg": lon_est_deg,
        "Sigma_latlon": Sigma_latlon,
        "Sigma_EN": Sigma_EN,
        "a": a,
        "b": b,
        "angle_deg": angle_deg,
        "p": p
    }


### 9.3 TDOA (for LFMCWs)

In [None]:
def lfmcw_loopback_refine(
    iq_FF1, iq_FF2, iq_FF3,
    fs,
    lfmcw_row,
    r_sats, v_sats,
    F_LO,
    # initial AOI-centre guess
    lat0_deg, lon0_deg,
    n_iter=4,
    tau_search_us=100.0,
    fd_search_hz=10000.0,
    tau_guard_extra_us=20.0,
    shrink_tau=0.35,
    shrink_fd=0.50,
    shrink_guard=0.70,
    caf_batch_size=256,
):
    """
    Loopback refinement:
      CAF (blind extraction, geometry-centred)
        -> TDOA geolocation
        -> feed back estimated LLA
        -> shrink CAF windows
        -> repeat

    Returns
    -------
    final : dict
        Final CAF outputs + geolocated LLA
    debug : list of dict
        Per-iteration diagnostic info
    """

    # Initial proxy and GN start
    proxy_lla = {"lat": float(lat0_deg), "lon": float(lon0_deg), "alt": 0.0}
    lat_gn, lon_gn = float(lat0_deg), float(lon0_deg)

    debug = []
    final = None

    for it in range(int(n_iter)):

        # --- 1) CAF ---
        res = test_lfmcw_caf_blind_geom(
            iq_FF1, iq_FF2, iq_FF3,
            fs,
            lfmcw_row,
            proxy_lla,
            r_sats, v_sats,
            F_LO,
            tau_guard_extra_us=tau_guard_extra_us,
            tau_search_us=tau_search_us,
            fd_search_hz=fd_search_hz,
            caf_batch_size=caf_batch_size,
        )

        tau21_hat = float(res["tau21_hat"])
        tau31_hat = float(res["tau31_hat"])
        fD21_hat  = float(res["fD21_hat"])
        fD31_hat  = float(res["fD31_hat"])

        # --- 2) Geolocate from TDOA ---
        lat_est_deg, lon_est_deg, J_est = geolocate_from_tdoa(
            tau21_hat, tau31_hat,
            r_sats,
            lat_gn, lon_gn,
            alt_m=0.0,
            max_iter=100,
            tol=1e-11
        )

        # Normalize longitude (safety)
        if lon_est_deg > 180.0:
            lon_est_deg -= 360.0
        elif lon_est_deg < -180.0:
            lon_est_deg += 360.0

        # --- 3) Save debug snapshot ---
        debug.append({
            "iter": it + 1,
            "proxy_lat": proxy_lla["lat"],
            "proxy_lon": proxy_lla["lon"],
            "tau_search_us": tau_search_us,
            "fd_search_hz": fd_search_hz,
            "tau_guard_extra_us": tau_guard_extra_us,
            "tau21_hat": tau21_hat,
            "tau31_hat": tau31_hat,
            "fD21_hat": fD21_hat,
            "fD31_hat": fD31_hat,
            "lat_est_deg": float(lat_est_deg),
            "lon_est_deg": float(lon_est_deg),
            "J_est": J_est,
            "meta": res.get("meta"),
        })

        # --- 4) Prepare next iteration ---
        proxy_lla = {"lat": float(lat_est_deg), "lon": float(lon_est_deg), "alt": 0.0}
        lat_gn, lon_gn = float(lat_est_deg), float(lon_est_deg)

        tau_search_us      *= shrink_tau
        fd_search_hz       *= shrink_fd
        tau_guard_extra_us *= shrink_guard

        final = {
            "tau21_hat": tau21_hat,
            "tau31_hat": tau31_hat,
            "fD21_hat": fD21_hat,
            "fD31_hat": fD31_hat,
            "lat_est_deg": float(lat_est_deg),
            "lon_est_deg": float(lon_est_deg),
            "meta": res.get("meta"),
        }

    return final, debug


### Common Functions (Used by both TDOA and FDOA for CEP Ellipse Derivation)

In [None]:
def position_covariance_latlon(
    J_latlon: np.ndarray,
    Sigma_meas: np.ndarray
) -> np.ndarray:
    """
    Compute covariance of (lat, lon) [rad^2] from measurement covariance Σ_z
    and Jacobian J (d[z]/d[lat,lon]), via Fisher Information inversion:

        Σ_x = (J^T Σ_z^{-1} J)^(-1)

    This works for ANY measurement type:
      - TDOA: Σ_z = Σ_tau
      - FDOA: Σ_z = Σ_fD
      - Joint TDOA+FDOA (if later extended)

    Parameters
    ----------
    J_latlon : (2,2) ndarray
        Jacobian of measurement vector z wrt [lat,lon].
    Sigma_meas : (2,2) ndarray
        Covariance of measurement vector z.

    Returns
    -------
    Sigma_latlon : (2,2) ndarray
        Covariance of [lat, lon] in radians^2.
    """
    Sigma_z_inv = np.linalg.inv(Sigma_meas)
    F = J_latlon.T @ Sigma_z_inv @ J_latlon
    Sigma_latlon = np.linalg.inv(F)
    return Sigma_latlon


In [None]:
def latlon_cov_to_EN_cov(
    Sigma_latlon: np.ndarray,
    lat_rad: float
) -> np.ndarray:
    """
    Convert covariance in (lat,lon) [rad^2] to covariance in local East/North [m^2].

    We approximate locally:
        dE ≈ R_E * cos(lat) * d(lon_rad)
        dN ≈ R_E * d(lat_rad)

    So:
        [dE, dN]^T = G * [dlat, dlon]^T

    Parameters
    ----------
    Sigma_latlon : (2,2) ndarray
        Covariance of [lat, lon] in radians^2.
    lat_rad : float
        Latitude (rad) at which to evaluate the local metric scaling.

    Returns
    -------
    Sigma_EN : (2,2) ndarray
        Covariance in local East/North plane [m^2].
    """
    Re = WGS84_A  # treat Earth as sphere for local metric here

    cos_lat = np.cos(lat_rad)

    # Mapping matrix from [dlat, dlon] to [dE, dN]
    # state ordering: [lat, lon]
    G = np.array([
        [0.0,         Re * cos_lat],   # dE/dlat=0, dE/dlon=Re*cos(lat)
        [Re,          0.0]             # dN/dlat=Re, dN/dlon=0
    ])

    Sigma_EN = G @ Sigma_latlon @ G.T
    return Sigma_EN

def latlon_cov_to_EN_cov_wgs84(
    Sigma_latlon: np.ndarray,
    lat_rad: float
) -> np.ndarray:
    """
    Convert covariance in (lat,lon) [rad^2] to covariance in local East/North [m^2]
    using WGS-84 ellipsoid radii of curvature.

    dE ≈ N(phi) * cos(phi) * dlon
    dN ≈ M(phi) * dlat
    """
    a  = WGS84_A
    e2 = WGS84_E2

    sin_lat = np.sin(lat_rad)
    cos_lat = np.cos(lat_rad)

    denom = np.sqrt(1.0 - e2 * sin_lat**2)

    # Prime vertical radius of curvature
    N_phi = a / denom

    # Meridian radius of curvature
    M_phi = a * (1.0 - e2) / (denom**3)

    # Mapping from [dlat, dlon] to [dE, dN]
    G = np.array([
        [0.0,        N_phi * cos_lat],  # dE/dlon
        [M_phi,      0.0]               # dN/dlat
    ], dtype=float)

    return G @ Sigma_latlon @ G.T


In [None]:
def ellipse_from_cov_EN(
    Sigma_EN: np.ndarray,
    p: float = 0.5
) -> tuple[float, float, float]:
    """
    Extract ellipse semi-axes and orientation from EN covariance.

    For a given probability p (e.g. 0.5 for CEP50, 0.95 for 95% ellipse),
    the ellipse scaling factor is the chi-square quantile with 2 dof:

        k_p = χ²_{2, p}

    Then:
        a = sqrt(k_p * λ1)
        b = sqrt(k_p * λ2)

    where λ1 >= λ2 are eigenvalues of Σ_EN, and a,b are semi-major/minor axes [m].

    Parameters
    ----------
    Sigma_EN : (2,2) ndarray
        Covariance matrix in East/North [m^2].
    p : float
        Confidence level for the ellipse, e.g. 0.5 for CEP50,
        0.95 for 95% probability ellipse.

    Returns
    -------
    a : float
        Semi-major axis [m].
    b : float
        Semi-minor axis [m].
    angle_deg : float
        Orientation of the major axis in degrees, measured CCW from East
        (i.e. EN-plane angle of eigenvector).
    """
    # Eigen-decomposition
    eigvals, eigvecs = np.linalg.eigh(Sigma_EN)

    # Sort eigenvalues (largest first)
    idx = np.argsort(eigvals)[::-1]
    lam1, lam2 = eigvals[idx]
    vec1 = eigvecs[:, idx[0]]

    # Chi-square quantile for 2 dof at probability p
    # Closed form for 2 dof: F(k) = 1 - exp(-k/2)
    # so invert: k_p = -2 ln(1 - p)
    k_p = -2.0 * np.log(1.0 - p)

    a = np.sqrt(k_p * lam1)
    b = np.sqrt(k_p * lam2)

    # Orientation: angle of major-axis eigenvector in EN plane
    angle_rad = np.arctan2(vec1[1], vec1[0])  # vec = [E_component, N_component]
    angle_deg = np.rad2deg(angle_rad)

    return a, b, angle_deg

In [None]:
def cep_from_axes(a, b, p=0.5):
    """
    Compute the circular CEP(p) radius from the ellipse semi-axes a, b.

    CEP(p) = sqrt( -2 (a^2 + b^2) ln(1 - p) )

    Parameters
    ----------
    a, b : floats
        Semi-major and semi-minor axes (meters).
    p : float
        Probability level (e.g., 0.5 for CEP50, 0.9 for CEP90, 0.95 for CEP95)

    Returns
    -------
    cep_radius : float
        Circular-equivalent CEP radius (meters)
    """
    return np.sqrt(-2.0 * (a*a + b*b) * np.log(1.0 - p))

In [None]:
# ===========================================================
#  CEP ellipse containment checker: WITH LLA INPUT
# ===========================================================
from scipy.stats import chi2
def gt_in_cep_ellipse_lla_percentile(lat_est, lon_est,
                                     lat_gt, lon_gt,
                                     a, b, theta_deg,
                                     cep_percent=50,
                                     alt_est=0.0, alt_gt=0.0):
    """
    Check whether the ground truth LLA lies inside the CEP ellipse corresponding to an
    arbitrary CEP percentile (50, 90, 95, 99, ...), centered
    at the estimated LLA. 
    
    Parameters
    ----------
    lat_est, lon_est : float
        Estimated location (deg). Used as ENU origin.

    lat_gt, lon_gt : float
        Ground truth location (deg).

    a, b : float
        Semi-major and semi-minor axes (meters)

    theta_deg : float
        Ellipse orientation angle (degrees), CCW

    alt_est, alt_gt : float
        Altitudes (meters). Default=0.

    Returns
    -------
    d2 : float
        Mahalanobis distance squared.

    inside : bool
        True if GT lies inside ellipse.
    """

    # Convert both points to ECEF
    ecef_est = lla_to_ecef(lat_est, lon_est, alt_est)
    ecef_gt  = lla_to_ecef(lat_gt,  lon_gt,  alt_gt)

    # Convert GT to ENU relative to estimate
    gt_enu = ecef_to_enu(ecef_gt, lat_est, lon_est, alt_est)
    gx, gy = gt_enu[0], gt_enu[1]

    # Rotate ENU axes by -theta (ellipse frame)
    theta = np.deg2rad(theta_deg)
    xr =  gx * np.cos(theta) + gy * np.sin(theta)
    yr = -gx * np.sin(theta) + gy * np.cos(theta)

    # Mahalanobis distance squared
    d2 = (xr / a)**2 + (yr / b)**2

    # Compute χ² threshold for the desired percentile
    p = cep_percent / 100.0
    threshold = chi2.ppf(p, df=2)

    inside = d2 <= threshold
    return d2, inside, threshold
