In [1]:
import math

import IPython.display
import sympy

sympy.init_printing()

Jm, Am, A0, S0, Sg = sympy.symbols('J_m A_m A_0 S_0 S_g')

def equation(n) :

    T = list()
    J = list()
    A = [sympy.symbols('A_0'),]
    S = [sympy.symbols('S_0'),]
    
    for i in range(n) :
        T.append(sympy.symbols(f"T_{i}"))
        J.append(sympy.symbols(f"J_{i}"))
        A.append( (A[-1] + J[i] * T[i]).simplify() )
        S.append( (S[-1] + A[i] * T[i] + J[i] * T[i]**2 / 2).simplify() )
       
    return T, J, A, S

def dbg(* e_lst) :
    IPython.display.display(
        IPython.display.Math(' = '.join(
            (
                str(e) if isinstance(e, str) else sympy.latex(e)
            ) for e in e_lst ))
    )

In [2]:
def test(jm, am, a0, s0, sg) :
    val = { 'J_m': jm, 'A_m': am, 'A_0': a0, 'S_0': s0, 'S_g': sg }
    cmd, dur, bch = compute(** val)
    
    n = len(dur)
    T, J, A, S = equation(n)
    
    par = dict()
    for i in range(n) :
        par[f"T_{i}"] = dur[i]
        par[f"J_{i}"] = cmd[i] * Jm
        
    def f(e) :
        try :
            return float(e.subs(par).subs(val))
        except AttributeError :
            return e
        
    def s(e) :
        try :
            return e.simplify()
        except AttributeError :
            return e
        
    def n(a, b) :
        return math.isclose(f(a), f(b), rel_tol=1e-5, abs_tol=1e-5)

    dbg("\mathrm{cmd}", [s(i) for i in cmd], [f(i) for i in cmd])
    dbg("\mathrm{dur}", [s(i) for i in dur], [f(i) for i in dur])
    sg_tst = S[-1].subs(par).subs(val).simplify()
    dbg("S_g", S[-1], S[-1].subs(par).simplify(), sg_tst)
    ag_tst = A[-1].subs(par).subs(val).simplify()
    dbg("A_g", A[-1], A[-1].subs(par).simplify(), ag_tst)
    
    print(bch, n(sg_tst, sg), n(ag_tst, 0.0), all((0 < d.subs(val)) for d in dur))

In [30]:
def compute(** val) :
    Ws = 1 if 0.0 <= (Sg - S0).subs(val) else -1
    
    m = 2*Jm*abs(Sg - S0)

    # in 1 step
    Au = sympy.sqrt(m)
    if A0.subs(val) == Ws*Au.subs(val) :
        Ta1 = Au / Jm
        return [ -Ws, ], [ Ta1, ], "A"

    q = sympy.sqrt(2)*Jm

    # in 2 steps
    d = (A0**2 / 2) + Ws*Jm*(Sg - S0)
    dbg('d', d, d.subs(val), float(d.subs(val)))

    A1a = sympy.sqrt(d)
    A1b = -sympy.sqrt(d)
    dbg('A_1^a', A1a, A1a.subs(val), float(A1a.subs(val)))
    dbg('A_1^b', A1b, A1b.subs(val), float(A1b.subs(val)))

    S2a = (-A0 + A1a)/(Jm*k)
    S2b = (-A0 + A1b)/(Jm*k)

    dbg('S_2^a', S2a, S2a.subs(val))
    dbg('S_2^b', S2b, S2b.subs(val))


    Tb2 = sympy.sqrt(A0**2 + m) / q
    Tb1 = Tb2 - Ws*A0/Jm
    Av = abs(A0 + Ws*Tb1*Jm)
    # dbg(Av.subs(val), Am.subs(val))
    #if Av.subs(val) < Am.subs(val) :
    return [ Ws, -Ws ], [ Tb1, Tb2 ], "B"

    Wa = 1 if 0.0 <= (A0).subs(val) else -1

    Sn = Wa*A0**2 / (2*Jm)
    Sd = Sg - S0
    Sc = Sd - Sn

    Wr = 1 if 0.0 <= (Sc * Sn).subs(val) else -1

    # in 3 steps
    Tc1 = Am - Wr*abs(A0) / Jm
    Tc2 = (A0**2 - 2*Am**2 + m) / (2*Am*Jm)
    Tc3 = (Am / Jm)
    return [ Wa*Wr, 0, -Wa*Wr ], [ Tc1, Tc2, Tc3 ], "C"

In [31]:
test(1, 2, 3, -2, 6)

<IPython.core.display.Math object>

TypeError: can't convert complex to float

NameError: name 'val' is not defined

In [27]:
test(1, 1, 1, -sympy.Rational(1, 2), 0)

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

A True True True
