# Omega computation helpers

Utilities for computing vector resonant relaxation precession rates using the `vrr_Omegas` module.

In [None]:
import numpy as np
import pandas as pd
from concurrent.futures import ProcessPoolExecutor
from typing import Any, Dict, Iterable, Mapping, Optional, Sequence, Union

from vrr_Omegas import (
    Orbit,
    OrbitPair,
    ExactSeriesEvaluator,
    AsymptoticEvaluator,
    AsymptoticWithCorrectionsEvaluator,
)


In [None]:
_DEFAULT_METHOD_ORDER = ("exact", "asymptotic", "hybrid")


def _normalize_methods(methods: Any) -> Sequence[str]:
    """Return a tuple of unique method names derived from ``methods``."""

    if methods is None:
        return _DEFAULT_METHOD_ORDER[:1]

    if isinstance(methods, str):
        method_list = [methods]
    else:
        method_list = list(methods)

    normalized = []
    for name in method_list:
        key = str(name).strip().lower()
        if key == "all":
            return _DEFAULT_METHOD_ORDER
        if key in _DEFAULT_METHOD_ORDER and key not in normalized:
            normalized.append(key)
    if not normalized:
        raise ValueError("No valid Omega evaluation methods were provided.")
    return tuple(normalized)


def _extract_vector(row: Mapping[str, Any]) -> Optional[np.ndarray]:
    """Return the angular-momentum vector from ``row`` when available."""

    components = []
    for axis in ("Lx", "Ly", "Lz"):
        if axis not in row:
            return None
        value = row.get(axis)
        if value is None:
            return None
        try:
            value_float = float(value)
        except (TypeError, ValueError):
            return None
        if not np.isfinite(value_float):
            return None
        components.append(value_float)
    return np.asarray(components, dtype=float)


def _build_orbit(row: Mapping[str, Any], G: float, default_central_mass: float) -> Orbit:
    """Construct an :class:`Orbit` instance from a row of orbital data."""

    kwargs: Dict[str, Any] = {
        "a": float(row["a"]),
        "e": float(row["e"]),
        "m": float(row["m"]),
        "G": float(row.get("G", G)),
        "M_central": float(row.get("M_central", default_central_mass)),
    }
    vector = _extract_vector(row)
    if vector is not None:
        kwargs.update({"Lx": vector[0], "Ly": vector[1], "Lz": vector[2]})
    return Orbit(**kwargs)


def _build_evaluators(
    methods: Sequence[str],
    max_ell: int,
    method_options: Optional[Mapping[str, Mapping[str, Any]]] = None,
) -> Dict[str, Any]:
    """Instantiate evaluators for the requested methods."""

    options = {name: dict(value) for name, value in (method_options or {}).items()}
    evaluators: Dict[str, Any] = {}
    for method in methods:
        if method == "exact":
            config = options.get(method, {})
            ell_max = int(config.pop("ell_max", max_ell))
            ell_max = max(2, ell_max)
            evaluators[method] = ExactSeriesEvaluator(ell_max=ell_max, **config)
        elif method == "asymptotic":
            config = options.get(method, {})
            ell_max = int(config.pop("ell_max", max_ell))
            ell_max = max(2, ell_max)
            evaluators[method] = AsymptoticEvaluator(ell_max=ell_max, **config)
        elif method == "hybrid":
            config = options.get(method, {})
            ell_max = int(config.pop("ell_max", max_ell))
            ell_max = max(2, ell_max)
            lmax_correction = int(config.pop("lmax_correction", min(ell_max, 4)))
            lmax_correction = max(2, lmax_correction)
            evaluators[method] = AsymptoticWithCorrectionsEvaluator(
                ell_max=ell_max, lmax_correction=lmax_correction, **config
            )
        else:
            raise ValueError(f"Unsupported evaluation method: {method}")
    return evaluators


def compute_omega_for_star(args: tuple[Any, ...]) -> Dict[str, Any]:
    """Compute Omega vectors for a single star against components."""

    (
        star_i,
        stars_j_components,
        G,
        max_ell,
        component_names,
        methods,
        method_options,
    ) = args

    result: Dict[str, Any] = {"ind": star_i["ind"]}
    methods = _normalize_methods(methods)
    evaluators = _build_evaluators(methods, max_ell, method_options)

    orbit_i = _build_orbit(star_i, G, float(star_i.get("M_central", 1.0)))
    L_i_vec = orbit_i.angular_momentum_vector
    if L_i_vec is None:
        zeros = np.zeros(3, dtype=float)
        for comp in component_names:
            for method in methods:
                for axis, value in zip(("x", "y", "z"), zeros):
                    result[f"Omega_{comp}_{method}_{axis}"] = value
        return result

    norm_L_i = float(np.linalg.norm(L_i_vec))
    if norm_L_i <= 0.0:
        zeros = np.zeros(3, dtype=float)
        for comp in component_names:
            for method in methods:
                for axis, value in zip(("x", "y", "z"), zeros):
                    result[f"Omega_{comp}_{method}_{axis}"] = value
        return result

    default_central_mass = float(star_i.get("M_central", 1.0))

    for stars_j, comp_name in zip(stars_j_components, component_names):
        component_totals = {method: np.zeros(3, dtype=float) for method in methods}

        for _, star_j in stars_j.iterrows():
            if star_i["ind"] == star_j.get("ind", object()):
                continue

            orbit_j = _build_orbit(star_j, G, default_central_mass)
            L_j_vec = orbit_j.angular_momentum_vector
            if L_j_vec is None:
                continue

            norm_L_j = float(np.linalg.norm(L_j_vec))
            if norm_L_j <= 0.0:
                continue

            cos_theta = float(np.dot(L_i_vec, L_j_vec) / (norm_L_i * norm_L_j))
            cos_theta = float(np.clip(cos_theta, -1.0, 1.0))

            pair = OrbitPair(orbit_i, orbit_j, cos_theta)

            for method, evaluator in evaluators.items():
                interaction = evaluator.evaluate_pair(pair)
                omega_val = interaction.omega
                omega_vec = (
                    np.asarray(omega_val, dtype=float)
                    if isinstance(omega_val, np.ndarray)
                    else np.asarray(pair.omega_from_scalar(float(omega_val)), dtype=float)
                )
                component_totals[method] += omega_vec

        for method, vector in component_totals.items():
            for axis, value in zip(("x", "y", "z"), vector):
                result[f"Omega_{comp_name}_{method}_{axis}"] = float(value)

    return result


def compute_Omega_parallel(
    stars_i: pd.DataFrame,
    stars_j: Union[pd.DataFrame, Sequence[pd.DataFrame]],
    *,
    G: float = 1.0,
    max_ell: int = 100,
    components: Optional[Sequence[str]] = None,
    methods: Any = None,
    method_options: Optional[Mapping[str, Mapping[str, Any]]] = None,
) -> pd.DataFrame:
    """Compute Omega vectors in parallel for each star in ``stars_i``."""

    stars_i_copy = stars_i.copy()
    if "ind" not in stars_i_copy.columns:
        stars_i_copy["ind"] = range(len(stars_i_copy))

    if isinstance(stars_j, pd.DataFrame):
        stars_j_list = [stars_j.copy()]
        if components is None:
            component_names = ["comp_0"]
        else:
            component_names = list(components)
    else:
        stars_j_list = [df.copy() for df in stars_j]
        if components is None:
            component_names = [f"comp_{i}" for i in range(len(stars_j_list))]
        else:
            component_names = list(components)

    if len(component_names) != len(stars_j_list):
        raise ValueError("Number of component names must match stars_j entries.")

    for df in stars_j_list:
        if "ind" not in df.columns:
            df["ind"] = range(len(df))

    methods = _normalize_methods(methods)

    args_iterable: Iterable[tuple[Any, ...]] = [
        (
            star_i,
            stars_j_list,
            G,
            max_ell,
            component_names,
            methods,
            method_options,
        )
        for star_i in stars_i_copy.to_dict("records")
    ]

    with ProcessPoolExecutor() as executor:
        results = list(executor.map(compute_omega_for_star, args_iterable))

    omega_df = pd.DataFrame(results)
    merged = pd.merge(stars_i_copy, omega_df, on="ind", how="left")
    merged.set_index("ind", inplace=True)
    return merged
