In [None]:
# Dullin–Tong style multi-stage planner (dynamic + geometric phase)
# -----------------------------------------------------------------
# Modes:
#   "twisting"  ->  sin(theta) = sin(alpha)*dn(p*t | m)
#   "wobbling"  ->  sin(theta) = sin(alpha)*cn(q*t | m)
#
# Period spec:
#   {"type":"N","N":int>0}  ->  T = N*K(m)   (so p=2/N, q=4/N)
#   {"type":"T","T":float>0} -> numeric T set; p=2K/T or q=4K/T
#
# Inputs per stage:  A,B,C (A>B>C), M (|angular momentum|), target Δφ, Δψ,
#                    mode, period_spec
# Solved per stage:  alpha, m, p|q, duration τ, energy E
# Planner outputs:   dynamic twist Σ(2E/M τ), geometric twist (−Ω),
#                    total twist about M, and the solved stage parameters.
#
# Dependencies: mpmath, numpy

import mpmath as mp
import numpy as np
mp.mp.dps = 70

# Special functions
RF  = mp.elliprf      # Carlson RF(x,y,z)
PiC = mp.ellippi      # complete Π(n|m)
K   = mp.ellipk       # complete K(m)

# ---------- Per-stage inverse solves (from your Δφ, Δψ) ----------

def _equations_dn(alpha, m, A, C, M, Delta_phi, Delta_psi, period_spec):
    """Two residuals for the dn stage; zero when (alpha,m) satisfies your Δφ,Δψ."""
    A = mp.mpf(A); C = mp.mpf(C); M = mp.mpf(M)
    Delta_phi = mp.mpf(Delta_phi); Delta_psi = mp.mpf(Delta_psi)
    D = 1/C - 1/A
    t = mp.tan(alpha); n = -m*(t*t)
    if period_spec["type"] == "N":
        N = mp.mpf(period_spec["N"])
        rhs_phi = (Delta_phi * A) / (M * N)
        rhs_psi = (Delta_psi) / (M * N * D)
    else:  # "T"
        T = mp.mpf(period_spec["T"]); Km = K(m)
        rhs_phi = (Delta_phi * A * Km) / (M * T)
        rhs_psi = (Delta_psi * Km) / (M * T * D)
    r1 = PiC(n, m) - rhs_phi
    r2 = mp.cos(alpha) * RF(0, 1 + m*(t*t), 1 - m) - rhs_psi
    return r1, r2

def _equations_cn(alpha, m, A, C, M, Delta_phi, Delta_psi, period_spec):
    """Two residuals for the cn stage; zero when (alpha,m) satisfies your Δφ,Δψ."""
    A = mp.mpf(A); C = mp.mpf(C); M = mp.mpf(M)
    Delta_phi = mp.mpf(Delta_phi); Delta_psi = mp.mpf(Delta_psi)
    D = 1/C - 1/A
    s  = mp.sin(alpha); c2 = mp.cos(alpha)**2
    n = - (s*s) / (1 - s*s)    # = -tan^2(alpha)
    if period_spec["type"] == "N":
        N = mp.mpf(period_spec["N"])
        rhs_phi = (Delta_phi * A) / (M * N)
        rhs_psi = (Delta_psi) / (M * N * D)
    else:  # "T"
        T = mp.mpf(period_spec["T"]); Km = K(m)
        rhs_phi = (Delta_phi * A * Km) / (M * T)
        rhs_psi = (Delta_psi * Km) / (M * T * D)
    r1 = PiC(n, m) - rhs_phi
    r2 = mp.cos(alpha) * RF(0, 1/c2, 1 - m) - rhs_psi
    return r1, r2

def solve_stage(A, B, C, M, Delta_phi, Delta_psi, mode, period_spec,
                alpha_guess=0.9, m_guess=0.5):
    """
    Solve one rigid stage: (alpha,m) then p|q, duration τ, and energy E.
    Returns a dict with keys: mode, alpha, m, p|q, tau, A,B,C,M,E.
    """
    A = float(A); B = float(B); C = float(C); M = float(M)

    if mode == "twisting":   # dn case
        eqfun = lambda a, m: _equations_dn(a, m, A, C, M, Delta_phi, Delta_psi, period_spec)
    elif mode == "wobbling": # cn case
        eqfun = lambda a, m: _equations_cn(a, m, A, C, M, Delta_phi, Delta_psi, period_spec)
    else:
        raise ValueError("mode must be 'twisting' (dn) or 'wobbling' (cn)")

    alpha, m = mp.findroot(lambda a, mm: eqfun(a, mm), (alpha_guess, m_guess))
    Km = K(m)

    if mode == "twisting":
        if period_spec["type"] == "N":
            N = mp.mpf(period_spec["N"]); p = mp.mpf(2)/N; T = N*Km
        else:
            T = mp.mpf(period_spec["T"]); p = (2*Km)/T
        # Energy from Ω^2 = ((A-B)(M^2 - 2 E C))/(A B C) with Ω=p
        E = (M**2 - (p**2 * A * B * C)/(A - B)) / (2*C)
        tau = 2*Km/p
        return {"mode":"twisting","alpha":float(alpha),"m":float(m),"p":float(p),
                "tau":float(tau),"A":A,"B":B,"C":C,"M":M,"E":float(E)}
    else:
        if period_spec["type"] == "N":
            N = mp.mpf(period_spec["N"]); q = mp.mpf(4)/N; T = N*Km
        else:
            T = mp.mpf(period_spec["T"]); q = (4*Km)/T
        # Same Ω–E relation near the A-axis: Ω=q
        E = (M**2 - (q**2 * A * B * C)/(A - B)) / (2*C)
        tau = 4*Km/q
        return {"mode":"wobbling","alpha":float(alpha),"m":float(m),"q":float(q),
                "tau":float(tau),"A":A,"B":B,"C":C,"M":M,"E":float(E)}

# ---------- Simulate L̂(t) per stage and compute the geometric phase ----------

def _omega_components_from_EM(A,B,C,E,M,m,Omega,u):
    """
    Body-frame solution (libration near A-axis):
      ω1 = a1*cn(Ωt), ω2 = a2*sn(Ωt), ω3 = a3*dn(Ωt),
    with amplitudes from (E,M; A,B,C).  (This covers both dn/cn "modes".)
    """
    I1, I2, I3 = C, B, A
    a1 = mp.sqrt((2*E*I3 - M**2)/(I1*(I3 - I1)))
    a2 = mp.sqrt((2*E*I3 - M**2)/(I2*(I3 - I2)))
    a3 = mp.sqrt((M**2 - 2*E*I1)/(I3*(I3 - I1)))
    return (a1*mp.ellipfun('cn', u, m),
            a2*mp.ellipfun('sn', u, m),
            a3*mp.ellipfun('dn', u, m))

def simulate_Lhat_path(stage, num_points=800):
    """Return L̂(t) over one cycle of a stage as an (N×3) numpy array of unit vectors."""
    A,B,C = stage["A"], stage["B"], stage["C"]
    M,E,m = stage["M"], stage["E"], stage["m"]
    if stage["mode"] == "twisting":
        Omega = stage["p"]; period_u = 2*K(m)   # u ∈ [0,2K]
    else:
        Omega = stage["q"]; period_u = 4*K(m)   # u ∈ [0,4K]
    tgrid = np.linspace(0.0, float(period_u/Omega), num_points)
    Ls = np.empty((num_points,3), dtype=float)
    for i, tt in enumerate(tgrid):
        u = Omega*tt
        w1,w2,w3 = _omega_components_from_EM(A,B,C,E,M,m,Omega,u)
        L = np.array([A*float(w1), B*float(w2), C*float(w3)], dtype=float)
        Ls[i,:] = L / np.linalg.norm(L)
    return Ls

def _spherical_triangle_area(u,v,w):
    # Robust signed area: 2*atan2(det(u,v,w), 1 + u·v + v·w + w·u)
    det = np.dot(u, np.cross(v, w))
    denom = 1.0 + np.dot(u,v) + np.dot(v,w) + np.dot(w,u)
    return 2.0*np.arctan2(det, denom)

def spherical_polyline_area(vertices):
    """
    Solid angle of a closed geodesic polygon approximating the path.
    We triangulate w.r.t. the mean direction (safe if path avoids antipodal wrap).
    """
    V = np.asarray(vertices, dtype=float)
    if np.linalg.norm(V[0] - V[-1]) > 1e-12:
        V = np.vstack([V, V[0]])
    r0 = V[:-1].mean(axis=0); r0 /= np.linalg.norm(r0)
    area = 0.0
    for i in range(len(V)-1):
        area += _spherical_triangle_area(r0, V[i], V[i+1])
    return area  # signed; typically in (-4π,4π)

def plan_sequence(stages):
    """
    Sum stage dynamic twist and add geometric phase (−Ω) to get total twist about M.
    Returns:
      - dynamic_twist        = Σ_j (2E_j/M) τ_j
      - geometric_twist      = −Ω
      - total_twist_about_M  = dynamic_twist + geometric_twist
      - Omega_solid_angle    = Ω
      - stage_params         = the list you passed in (echo)
    """
    dyn = sum( (2*st["E"]/st["M"]) * st["tau"] for st in stages )

    # Build the concatenated L̂ path across stages (include shape-change jumps naturally)
    verts = []
    for st in stages:
        Ls = simulate_Lhat_path(st, num_points=600)
        if len(verts) == 0: verts.append(Ls[0])
        verts.extend(Ls[1:])
    Omega = spherical_polyline_area(np.array(verts))
    geo = - Omega

    return {
        "dynamic_twist": float(dyn),
        "geometric_twist": float(geo),
        "total_twist_about_M": float(dyn + geo),
        "Omega_solid_angle": float(Omega),
        "stage_params": stages
    }

# ------------------------------ Example ------------------------------
# if __name__ == "__main__":
#     # Two-stage example (instantaneous shape change between stages)
#     A1,B1,C1 = 3.0, 2.0, 1.0
#     A2,B2,C2 = 2.6, 1.7, 0.9      # different “shape” (arms in/out)
#     M = 1.1
#
#     # Target per-stage angle accumulations (dynamic, your convention), one cycle each
#     Dphi1, Dpsi1 = 0.6, 0.4
#     Dphi2, Dpsi2 = 0.6, 0.4
#
#     st1 = solve_stage(A1,B1,C1,M, Dphi1,Dpsi1,
#                       mode="twisting", period_spec={"type":"N","N":5},
#                       alpha_guess=0.9, m_guess=0.6)
#     st2 = solve_stage(A2,B2,C2,M, Dphi2,Dpsi2,
#                       mode="wobbling", period_spec={"type":"T","T":10.0},
#                       alpha_guess=0.9, m_guess=0.5)
#
#     result = plan_sequence([st1, st2])
#     print("Stage 1:", st1)
#     print("Stage 2:", st2)
#     print("\nDynamic twist (Σ 2E/M·τ):", result["dynamic_twist"])
#     print("Geometric phase twist (−Ω):", result["geometric_twist"])
#     print("Total twist about M:", result["total_twist_about_M"])
