In [1]:
import numpy as np
import matplotlib.pyplot as plt
from numpy.linalg import solve, norm
from numpy import sin, cos, pi, arctan

In [23]:
dt = np.float32

In [2]:
def solve2(a, b, c):
    if a == 0:
        if b == 0:
            return -1  # any x is a solution
        return [-c / b, -c / b]

    D = b*b - 4*a*c

    if D > 0:
        x1 = (-b + np.sqrt(D)) / (2*a)
        x2 = (-b - np.sqrt(D)) / (2*a)
        return [x1, x2]
    elif D < 0:
        x1 = complex(-b, np.sqrt(-D)) / (2*a)
        x2 = complex(-b, -np.sqrt(-D)) / (2*a)
        return [x1, x2]
    else:
        return [-b / (2*a), -b / (2*a)]

In [3]:
def mypow(x, p):
    if x >= 0:
        return x ** p
    return -((-x) ** p)

In [10]:
def solve3(a, b, c, d):
    print type(a)
#     a = float(a); b = float(b); c = float(c); d = float(d)
    # reduce to the form y^3 + py + q = 0
    # substitution: x = y - b/3a
    p = c/a - b*b/(3*a*a)
    q = 2*b*b*b/(27*a*a*a) - b*c/(3*a*a) + d/a

    Q = p*p*p/27 + q*q/4
    # Q > 0 - one real root and two complex conjugated roots
    # Q = 0 - one single real root and one double real root, or,
	#         if p = q = 0, then one triple real root
	# Q < 0 - three real roots

    if Q >= 0:
        alpha   = mypow(-q/2 + np.sqrt(Q), 1.0/3)
        beta    = mypow(-q/2 - np.sqrt(Q), 1.0/3)
    else:
        alpha   = complex(-q/2, np.sqrt(-Q)) ** (1.0/3)
        beta    = complex(-q/2, -np.sqrt(-Q)) ** (1.0/3)

    x1 = alpha + beta - b/(3*a)
    x2 = complex(-(alpha+beta)/2, (alpha-beta)*np.sqrt(3)/2) - b/(3*a)
    x3 = complex(-(alpha+beta)/2, -(alpha-beta)*np.sqrt(3)/2) - b/(3*a)
    return [x1, x2, x3]

In [29]:
def solve4(a, b, c, d, e):
    print type(a)
#     a = float(a); b = float(b); c = float(c); d = float(d); e = float(e)
    b /= a; c /= a; d /= a; e /= a;
    a = b; b = c; c = d; d = e;

    # reduce to the form y^4 + p*y^2 + q*y + r = 0
    p = b - 3*a*a/8
    q = a*a*a/8 - a*b/2 + c
    r = - 3*a*a*a*a/256 + a*a*b/16 - c*a/4 + d

    # obtain cubic resolvent A*s^3 + B*s^2 + C*s + D = 0
    A = dt(2)
    B = -p
    C = -2.0*r
    D = r*p - q*q/4
    s1, s2, s3 = solve3(A, B, C, D)

    s = 0
    if np.real(s1) > p/2:
        s = s1
    elif np.real(s2) > p/2:
        s = s2
    elif np.real(s3) > p/2:
        s = s3

    a1 = dt(1); b1 = -np.sqrt(2*s-p); c1 = q/(2*np.sqrt(2*s-p)) + s
    a2 = dt(1); b2 = np.sqrt(2*s-p); c2 = -q/(2*np.sqrt(2*s-p)) + s

    x1, x2 = solve2(a1, b1, c1)
    x1 -= a/4
    x2 -= a/4
    x3, x4 = solve2(a2, b2, c2)
    x3 -= a/4
    x4 -= a/4
    return [x1, x2, x3, x4]

In [30]:
dt = np.float32
a = dt(1.0)
b = dt(6.49999237)
c = dt(13.2862368)
d = dt(9.05017757)
e = dt(-1.74342331e-005)
r = solve4(a, b, c, d, e)
print type(r[0]), r

<type 'numpy.float32'>
<type 'numpy.float32'>
<type 'complex'> [(-1.5914067359802744+0.4423873476428892j), (-1.5914067359802744-0.4423873476428892j), 2.0768199777698726e-06, -3.3171809754648978]


In [31]:
dt = np.float64
a = dt(1.0)
b = dt(6.49999237)
c = dt(13.2862368)
d = dt(9.05017757)
e = dt(-1.74342331e-005)
print solve4(a, b, c, d, e)

<type 'numpy.float64'>
<type 'numpy.float64'>
[(-1.591406724322575+0.4423868713392655j), (-1.591406724322575-0.4423868713392655j), 1.9259133507976145e-06, -3.3171808472682009]
