In [None]:
import numpy as np

# Tridiaggonal matrix solver (Thomas algorithm)

A Tridiagonal matrix is matrix in which all elements are zero except for 3 bands (vectors): the diagonal, the sub-diagonal (below diagonal) and super-diagonal (above diagonal). This algorithm is used to solve the system of equation $Ax=b$ to obtain the unknow vector $x$ using Gauss elimination. In this algorithm $A$ is only available based on those 3 vectors because it is not preferred to stores all zero elements of the matrix. By totally disregarding the zero elements in the matrix (which constitute most of the element in tridiagonal systems) the algorithm can result in significant savings in terms of computations. 

In [None]:
def tridiagonal(sub,diag,sup,rhs):   #function definition
    """
    Solves a tridiagonal system of equations using Gauss Elimination method
    
    tridiagonal(sub,diag,sup,rhs)
    Solves a tridiagonal system of equations Ax=b to obtain x using Gauss Elimination method
    The coefficient matrix A is defined based on its diagonal, sub-diagonal (below diaginal) 
    and super-diagonal (above diaginal) vectors. 
    All input arguments sub,diag,sup and rhs as must have the same length. 
    The output vector x has the same length as these vectors.
    Input:
        sub: sub-diagonal vector of the matrix. The first element of this vector must be zero
        diag: diagonal vector of the matrix
        sub: super-diagonal vector of the matrix. The last element of this vector must be zero
        rhs: the right-hand side vector
    Output
        x: the unknown vector
    """
    
    n = diag.shape[0] # length of vectors
    # check if vector lengths are equal, if not show an error message
    assert sub.shape[0] == n and  \
           sup.shape[0] == n and  \
           rhs.shape[0] == n      \
           , "all vector lengths must be equal"
    # check if the 1st element of sub ector sf zero, if not show an error message       
    assert sub[0] == 0 , "the first element of sub must be zero"
    # check if the last element of sup ector is zero, if not show an error message       
    assert sup[-1] == 0 , "the last element of sup must be zero"  
    # make local copies of vectors so that the original copies are not over-written
    sub1 = sub.copy()
    diag1 = diag.copy()
    sup1 = sup.copy()
    rhs1 = rhs.copy()
    x = np.zeros(n)   #initiate x vector of size n filled with zeros
    # forward elimination
    for i in range(1,n):
        fac = - sub1[i] / diag1[i-1]  #Factor f for all subsequent rows 
        diag1[i] = diag1[i] + fac * sup1[i-1] #using pivot element to eliminate the sub vector
        rhs1[i] = rhs1[i] + fac * rhs1[i-1]   #using pivot element to modify the rhs vector
    # at this stage an upper triangle matrix results    
    # back substitution
    x[n-1] = rhs1[n-1] / diag1[n-1] #calculate x on the last row
    for i in range(n-2,-1,-1):    #calculate x for row i from n-2 down to 0
        x[i] = (rhs1[i] - sup1[i] * x[i+1]) / diag1[i]     
    return x   #returns x

## Example
Solve the system $Ax=b$ with $A=\begin{bmatrix}
1 & 11 & 0 & 0\\
5 & 2 & 2 & 0 \\
0 & 3 & 3 & 1 \\
0 &0 &9&4
\end{bmatrix}$ and the vector $b=\begin{bmatrix}
2\\
8\\
7\\
6
\end{bmatrix}$ using the tridiagonal matrix algorithm. Verify your results by showing that $Ax=b$ is satisfied. 

In [None]:
# constructing the bands (vectors) of the coefficient matrix
sub = np.array([0.,5.,3.,9.])  #sub-diagonal vector. Note that the first element is zero
diag = np.array([1.,2.,3.,4.]) #diagonal vector
sup = np.array([11.,2.,1.,0.]) #super-diagonal vector. Note that the last element is zero   
rhs = np.array([2.,8.,7.,6.])  # the right-hand side vector
x = tridiagonal(sub, diag, sup, rhs)
print('solution: x=', x)
# constructing A from the bands for verification purpose
n = diag.shape[0] 
A = np.zeros((n,n))
# construct A based on sub, diag and sup vectors
print()
print('... Constructing A for verification ...')
for i in range(0,n):
        A[i,i] = diag[i]
        if i != 0: A[i,i-1] = sub[i]
        if i != n-1: A[i,i+1] = sup[i]
print('A=',A) 
print('Verification: Ax=',np.dot(A,x),'= rhs')

solution: x= [ -1.00546448   0.27322404   6.24043716 -12.54098361]

... Constructing A for verification ...
A= [[ 1. 11.  0.  0.]
 [ 5.  2.  2.  0.]
 [ 0.  3.  3.  1.]
 [ 0.  0.  9.  4.]]
Verification: Ax= [2. 8. 7. 6.] = rhs


As shown, $Ax=b$ is satisfied which indicates that $x$ is the solution to the system.

## Exercise
Solve the system $Ax=b$ with $A=\begin{bmatrix}
1 & 5 & 0 & 0 & 0\\
3 & -4 & 3 & 0 & 0 \\
0 & 18 & 8 & 3 & 0 \\
0 &0 & 9 & 7 & 6 \\
0 &0 & 0 & -5 & 3 
\end{bmatrix}$ and the vector $b=\begin{bmatrix}
3\\
5\\
9\\
8 \\
15
\end{bmatrix}$ using the tridiagonal matrix algorithm. Verify your results by showing that $Ax=b$ is satisfied. 