In [7]:
import os
import sys
import numpy as np
from functools import lru_cache
from typing import List, Tuple, Optional
from scipy.stats import ks_2samp, skew
from skimage.filters import gabor_kernel

# ------------------------------------------------------------
# Robust import of audio helpers 
# ------------------------------------------------------------

audio_geometric_endpoint_bands = None  # default: no smoothing if import fails

def _try_import_audio_utils():
    global audio_geometric_endpoint_bands
    try:
        base_dir = os.path.dirname(__file__)  
    except NameError:
        base_dir = os.getcwd()              

    candidates = [
        os.path.normpath(os.path.join(base_dir, "..", "utilities")),
        os.path.normpath(os.path.join(base_dir, "utilities")),
        os.path.normpath(os.path.join(os.getcwd(), "utilities")),
        os.path.normpath(os.path.join(os.getcwd(), "..", "utilities")),
    ]

    for cand in candidates:
        ta_path = os.path.join(cand, "transform_audio.py")
        if os.path.isfile(ta_path) and (cand not in sys.path):
            sys.path.insert(0, cand)

    try:
        from transform_audio import geometric_endpoint_bands as _ge_bands
        audio_geometric_endpoint_bands = _ge_bands
    except Exception as e:
        print("[warn] Could not import ../utilities/transform_audio.py; "
              "geometric endpoint smoothing disabled.\n  details:", e)

_try_import_audio_utils()

# ============================================================
#                Gabor Kernel Generation (cached)
# ============================================================

@lru_cache(maxsize=8192)
def _generate_gabor_kernel_skimage_cached(frequency: float,
                                          wave_number: int,
                                          theta: float,
                                          aspect_ratio: float,
                                          dtype_str: str) -> np.ndarray:
    if frequency <= 0:
        raise ValueError("frequency must be positive (cycles/pixel).")

    sigma_x = float(wave_number) / (4.0 * float(frequency)) # wave_number ≈ 4σ·f
    if sigma_x <= 0:
        raise ValueError("Computed sigma_x must be positive; check wave_number and frequency.")
    sigma_y = float(aspect_ratio) * sigma_x  

    k_complex = gabor_kernel(
        frequency=float(frequency),  # cycles/pixel
        theta=float(theta),
        sigma_x=float(sigma_x),
        sigma_y=float(sigma_y)
    )
    k = np.real(k_complex).astype(np.dtype(dtype_str), copy=False)
    k -= k.mean()
    nrm = np.linalg.norm(k.ravel())
    if nrm > 1e-12:
        k = k / nrm
    return k

def generate_gabor_kernel_skimage(frequency: float,
                                  wave_number: int,
                                  theta: float = 0.0,
                                  aspect_ratio: float = 1.0,
                                  dtype=np.float64) -> np.ndarray:
    """Public wrapper returning a copy to avoid accidental mutation of cached data."""
    arr = _generate_gabor_kernel_skimage_cached(
        float(frequency), int(wave_number), float(theta),
        float(aspect_ratio), np.dtype(dtype).name
    )
    return arr.copy()

# ============================================================
#         Coefficient loading (Gabor → flattened vectors)
# ============================================================

def _standardize_vec(v: np.ndarray, standardize: bool) -> np.ndarray:
    """
    Kernels are zero-mean and L2-normalized; standardize by sqrt(n) so entries
    have ~unit variance across kernels of different sizes.
    """
    if not standardize:
        return v
    n = v.size
    return v if n == 0 else (v * np.sqrt(n))

def _maybe_subsample(v: np.ndarray, sample_n: Optional[int], rng: np.random.Generator) -> np.ndarray:
    if (sample_n is None) or (v.size <= sample_n):
        return v
    return rng.choice(v, size=int(sample_n), replace=False)

def load_coefs_by_freq_gabor(
    frequencies: np.ndarray,
    wave_number: int,
    theta: float = 0.0,
    aspect_ratio: float = 1.0,
    *,
    standardize: bool = True,
    sample_n: Optional[int] = None,
    cache: bool = False,
    random_state: int = 0,
) -> Tuple[List[np.ndarray], np.ndarray]:
    """
    Returns:
      coefs_by_freq: list of flattened, (optionally standardized/subsampled) coeff arrays per freq
      freqs_sorted:  sorted 1D numpy array of the input frequencies
    """
    rng = np.random.default_rng(random_state)
    freqs_sorted = np.array(sorted(map(float, frequencies)))

    cache_key = None
    if cache:
        cache_key = (tuple(np.round(freqs_sorted, 12)), wave_number, theta, aspect_ratio, standardize, sample_n)
        if getattr(load_coefs_by_freq_gabor, 'cache_key', None) == cache_key:
            return load_coefs_by_freq_gabor.cached_coefs, load_coefs_by_freq_gabor.cached_freqs

    coefs_by_freq: List[np.ndarray] = []
    for f in freqs_sorted:
        k = generate_gabor_kernel_skimage(f, wave_number, theta=theta, aspect_ratio=aspect_ratio).ravel()
        k = _standardize_vec(k, standardize=standardize)
        k = _maybe_subsample(k, sample_n=sample_n, rng=rng)
        coefs_by_freq.append(k.astype(np.float64, copy=True))

    if cache:
        load_coefs_by_freq_gabor.cache_key = cache_key
        load_coefs_by_freq_gabor.cached_coefs = coefs_by_freq
        load_coefs_by_freq_gabor.cached_freqs = freqs_sorted

    return coefs_by_freq, freqs_sorted

# ============================================================
#                   Skew Prefilter (OPTIONAL)
# ============================================================

def filter_frequencies_by_skew(
    frequencies,
    wave_number,
    theta: float = 0.0,
    aspect_ratio: float = 1.0,
    tau: float = 0.03,
    sample_n: int = 4096,
    random_state: int = 0
):
    """
    Keep freqs whose kernel coefficients have |skew| <= tau.
    """
    rng = np.random.default_rng(random_state)
    keep, drop = [], []
    for f in map(float, frequencies):
        v = generate_gabor_kernel_skimage(f, wave_number, theta=theta, aspect_ratio=aspect_ratio).ravel()
        if v.size > sample_n:
            v = rng.choice(v, size=sample_n, replace=False)
        (drop if abs(skew(v)) > tau else keep).append(f)
    return keep, drop

# ============================================================
#          KS-recursive grouping 
# ============================================================

def freq_band_groupings_gabor(
    frequencies: np.ndarray,
    wave_number: int,
    theta: float = 0.0,
    aspect_ratio: float = 1.0,
    *,
    ks_threshold: float = 0.05,
    presplit_depth: int = 1,
    max_depth: Optional[int] = None,
    standardize: bool = True,
    sample_n: Optional[int] = None,
    cache: bool = False,
    debug: bool = False,
    random_state: int = 0,
) -> Tuple[List[Tuple[int, int]], np.ndarray]:
    """
    Same decision rule as audio `freq_band_groupings`:
      - Split [left,right) at midpoint
      - KS test on concatenated left vs right coefficient pools
      - If KS statistic < ks_threshold ⇒ treat as homogeneous ⇒ stop splitting
      - Else recurse on both halves

    Returns:
      bands: list of (start_idx, end_idx) on the sorted frequency array (right-exclusive)
      freqs: the sorted frequency array
    """
    coefs_by_freq, freqs = load_coefs_by_freq_gabor(
        frequencies,
        wave_number,
        theta=theta,
        aspect_ratio=aspect_ratio,
        standardize=standardize,
        sample_n=sample_n,
        cache=cache,
        random_state=random_state,
    )
    n = len(coefs_by_freq)

    def helper(left: int, right: int, depth: int) -> List[Tuple[int, int]]:
        if left + 1 == right or (max_depth is not None and depth == max_depth):
            return [(left, right)]
        mid = (left + right) // 2
        if debug:
            print(f'{"  " * depth}[{left},{mid}) vs [{mid},{right}): ', end='')
        if depth >= presplit_depth:
            left_vals = np.concatenate(coefs_by_freq[left:mid])
            right_vals = np.concatenate(coefs_by_freq[mid:right])
            stat, pval = ks_2samp(left_vals, right_vals, alternative='two-sided', mode='auto')
            if debug:
                print(f'KS={stat:.5f}, p={pval:.3g}')
            if stat < ks_threshold:  # homogeneous → merge
                return [(left, right)]
        elif debug:
            print('presplit')
        return helper(left, mid, depth + 1) + helper(mid, right, depth + 1)

    return helper(0, n, 0), freqs

def apply_audio_geometric_endpoint_smoothing(bands: List[Tuple[int,int]],
                                             freqs: np.ndarray) -> List[Tuple[int,int]]:
    """
    Optional post-processing using the audio utility's geometric endpoint smoothing.
    No plotting here.
    """
    if audio_geometric_endpoint_bands is None:
        return bands
    return audio_geometric_endpoint_bands(bands, freqs, visualize=False)

# ============================================================
#                   Print-only runner (counts)
# ============================================================

def run_and_print_band_counts(
    frequencies_to_test: np.ndarray,
    wave_numbers: List[int],
    aspect_ratios: List[float],
    theta: float = 0.0,
    *,
    # Skew prefilter:
    use_skew_prefilter: bool = True,
    skew_tau: float = 0.03,
    skew_sample_n: int = 4096,
    # KS grouping params:
    ks_threshold: float = 0.05,
    presplit_depth: int = 1,
    max_depth: Optional[int] = None,
    standardize: bool = True,
    sample_n: Optional[int] = None,
    smooth_endpoints: bool = False,
    random_state: int = 42
) -> None:
    """
    Executes skew prefilter (optional) + KS-recursive grouping and prints only:
      "WN=<>, AR=<>: bands=<count>"
    """
    freqs_input = np.array(sorted(map(float, frequencies_to_test)))

    print("\n=== Gabor Frequency Banding (counts only) ===")

    for wn in wave_numbers:
        for ar in aspect_ratios:
            work_freqs = freqs_input
            if use_skew_prefilter:
                kept, dropped = filter_frequencies_by_skew(
                    work_freqs, wn, theta=theta, aspect_ratio=ar,
                    tau=skew_tau, sample_n=skew_sample_n, random_state=random_state
                )
                work_freqs = np.array(sorted(kept))
                if work_freqs.size == 0:
                    print(f"WN={wn}, AR={ar}: bands=0 (all freqs dropped by skew)")
                    continue

            bands, freqs = freq_band_groupings_gabor(
                work_freqs,
                wn,
                theta=theta,
                aspect_ratio=ar,
                ks_threshold=ks_threshold,
                presplit_depth=presplit_depth,
                max_depth=max_depth,
                standardize=standardize,
                sample_n=sample_n,
                cache=False,
                debug=False,
                random_state=random_state
            )

            if smooth_endpoints:
                bands = apply_audio_geometric_endpoint_smoothing(bands, freqs)

            print(f"WN={wn}, AR={ar}: bands={len(bands)}")

# ============================================================
#                        __main__
# ============================================================

if __name__ == "__main__":
    # Example grid (ensure 0 < f < 0.5 for Nyquist)
    frequencies_to_test = np.linspace(0.05, 0.40, 350)

    # Wave numbers ≈ desired cycles under ~4σ
    wave_numbers = [1, 2, 3, 4]
    aspect_ratios = [0.5, 1.0]  
    run_and_print_band_counts(
        frequencies_to_test=frequencies_to_test,
        wave_numbers=wave_numbers,
        aspect_ratios=aspect_ratios,
        theta=0.0,
        use_skew_prefilter=False,  # enable skew test
        skew_tau=0.03,
        skew_sample_n=4096,
        ks_threshold=0.05,
        presplit_depth=1,
        max_depth=None,
        standardize=True,
        sample_n=4096,            # None = use all coefficients; subsample for speed
        smooth_endpoints=False,   # optional (requires audio helper)
        random_state=42
    )


[warn] Could not import ../utilities/transform_audio.py; geometric endpoint smoothing disabled.
  details: No module named 'librosa'

=== Gabor Frequency Banding (counts only) ===
WN=1, AR=0.5: bands=341
WN=1, AR=1.0: bands=326
WN=2, AR=0.5: bands=165
WN=2, AR=1.0: bands=100
WN=3, AR=0.5: bands=50
WN=3, AR=1.0: bands=13
WN=4, AR=0.5: bands=18
WN=4, AR=1.0: bands=12
