In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt

In [None]:
def function_for_roots(x):
    a = 1.01
    b = -3.04
    c = 2.07
    return a * x**2 + b*x + c

In [None]:
def check_initial_values(f, x_min, x_max,tol):
    
    # check our initial guesses
    y_min = f(x_min)
    y_max = f(x_max)
    
    # check that x_min and x_max contain a zero crossing
    if (y_min*y_max>=0.0):
        print("No zero crossing found in the range = ", x_min, x_max)
        s = "f(%f) = %f, f(%f) = %f" % (x_min,y_min,x_max,y_max)
        print(s)
        return 0
    
    # if x_min is a root then return flag == 1
    if np.fabs(y_min)<tol:
        return 1

    # if x_max is a root, then return flag == 2
    if np.fabs(y_max)<tol:
        return 2
    
    # if we rech this piont the bracket is valid 
    # and we will return 3
    return 3

In [None]:
def bisection_root_finding(f, x_min_start, x_max_start, tol):
    
    
    x_min = x_min_start
    x_max = x_max_start
    x_mid = 0.0
    
    y_min = f(x_min)
    y_max = f(x_max)
    y_mid = 0.0
    
    imax = 10000
    i = 0
    
    flag = check_initial_values(f, x_min, x_max, tol)
    if flag == 0:
        print("Error in bisection_root_finding")
        raise ValueError("Initial values invalid", x_min, x_max)

    elif flag == 1:
        return x_min
    elif flag == 2: 
        return x_max

    # if we reach here. then we need to conduct the search

    flag = 1

    # enter while loop

    while flag:

        x_mid = 0.5*(x_min + x_max)
        y_mid = f(x_mid)

        # check if x mid is a root
        if np.fabs(y_mid) < tol:
            flag = 0

        else:
            # x mid isn't a root

            # if the product of thefuntion at midpoint 
            # and at one of the end points is greater than
            # zero, replaze this end point
            if f(x_min) * f(x_mid) > 0:
                x_min = x_mid
            else:
                # replace x max with x mid
                x_max = x_mid


        print(x_min,f(x_min),x_max,f(x_max))

        # count the iteration

        i+= 1

        # if we have exceeded the max num of
        # iterations, then exit

        if i > imax:
            print("Exceeded the max num of iterations = ", i)
            s = "Min bracket f(%f) = %15.14e" % (x_min, f(x_min))
            print(s)
            s = "Max bracket f(%f) = %15.14e" % (x_max, f(x_max))
            print(s)
            s = "Mid bracket f(%f) = %15.14e" % (x_mid, f(x_mid))
            print(s)

            raise StopIteration("Stopping iterations after ", i)
    
    return x_mid, i # root, iterations

In [None]:
# left root
x_min = 0.0
x_max = 1.4
tolerance = 1.0e-6

# print the inital guess
print(x_min,function_for_roots(x_min))
print(x_max,function_for_roots(x_max))

x_root1, iterations = bisection_root_finding(function_for_roots, x_min,x_max,tolerance)
y_root1 = function_for_roots(x_root1)

# creat \n to make reading data clearer
print()

# right root
x_min2 = 1.6
x_max2 = 3.0

x_root2, iter2 = bisection_root_finding(function_for_roots, x_min2,x_max2,tolerance)
y_root2 = function_for_roots(x_root2)

s = "Left root found with y(%f) = %f" % (x_root1,y_root1)
print(s)
print("Left root took %i iterations" % (iterations))

# make data clearer
print()

s2 = "Right root found with y(%f) = %f" % (x_root2,y_root2)
print(s2)
print("Right root took %i iterations" % (iter2))

In [None]:
x = np.arange(0, 3, step=3/1000)

y = 1.01 * x**2 - 3.04*x + 2.07
plt.axhline(y=0, color='k', linestyle='-', lw=1)

plt.plot([x_min], [function_for_roots(x_min)], marker='o', markersize=5, color="green")
plt.plot([x_max], [function_for_roots(x_max)], marker='o', markersize=5, color="green")
plt.plot([x_root1], [y_root1], marker='o', markersize=5, color="red")
plt.plot(x, y)

# right root
plt.plot([x_min2], [function_for_roots(x_min2)], marker='o', markersize=5, color="green")
plt.plot([x_max2], [function_for_roots(x_max2)], marker='o', markersize=5, color="green")
plt.plot([x_root2], [y_root2], marker='o', markersize=5, color="red")

plt.xlim(0,3)
plt.ylim(-0.5, 2.1)
plt.show()