## 3.6 Iterative Refinement

In [None]:
import numpy as np
from scripts.lu import backward, forward, lu_pivot

#### Example 3.35 (Iterative Refinement)

We start by defining the linear system and the exact solution using functionality from `numpy`. We will use this later to compare against our computed solution.

In [None]:
A = np.array([[2.3, 1.8, 1], [1.4, 1.1, -0.7], [0.8, 4.3, 2.1]])
b = np.array([1.2, -2.1, 0.6])
x_np = np.linalg.solve(A, b)

Now we implement the iterative refinement approach using mixed precision.

**Implementation 3.7: Iterative Refinement**

In [None]:
def iterative_refinement(A, b, n=3, precision2=np.half):
    A_2 = A.astype(precision2)
    pivot = lu_pivot(A_2)
    d = b.copy()
    x = np.zeros_like(b)
    for i in range(n):
        for p in pivot:
            d[p] = d[[p[1], p[0]]]
        y = forward(A_2, d)
        w = backward(A_2, y)
        x[:] += w
        d[:] = b - A.dot(x)
    return x

To emphasise the effect of using mixed precision, we use single precision for the matrix and vectors, while the LU factorization is computed using half precision.

In [None]:
A_single = A.astype(np.single)
b_single = b.astype(np.single)
x = iterative_refinement(A_single, b_single, n=2, precision2=np.half)

Comparing against the solution computed by `numpy`, we have the relative error

In [None]:
print(np.linalg.norm(x - x_np) / np.linalg.norm(x_np))

This corresponds to about machine precision of single precision arithmetic. We can now try even more. For example, if we allow more post-iteration steps and use double floating point numbers for the vectors, then we see

In [None]:
x = iterative_refinement(A, b, n=5, precision2=np.half)
print(np.linalg.norm(x - x_np) / np.linalg.norm(x_np))

i.e., we see that we can solve the system to machine precision of double precision.