# Assignment 14

I have provided a Python/NumPy implementation of a [PLU Decomposition](https://johnfoster.pge.utexas.edu/numerical-methods-book/LinearAlgebra_LU.html#Python/NumPy-implementation-of-$\mathbf{PLU}$-decomposition) in the course notes.  Because this function makes use of NumPy broadcasting, it's about as fast as they can be in Python, but this efficiency is at the expense of code readability.

We can make this function more readable and maintainable by using an object-oriented approach.  Your assignment is to complete the `LU` class below.  Specifically, you need to implement the `decomp` member function to perform the $\mathbf{PLU}$ decomposition and store the resulting matrices in the class attributes `P, L` and `U`, respectively.  

After that, implement `det` and `inverse` to compute the input matrix's determinant and inverse, respectively, with the $\mathbf{PLU}$ decomposition.  I've already implemented the forward and backward substitution methods and call them from a function called `solve`.  If the argument to `solve` is a one-dimensional NumPy array, then a single solution is returned.  However, if the input to solve is a two-dimensional NumPy array, then each *row* is interpreted as a unique right-hand side vector and a two-dimensional NumPy array is returned with each *row* being a solution vector corresponding to the row of the input argument.  **Hint**: You should be able to implement `inverse` with a single call to `solve`.

Use the `Matrix` class definition that you developed in in [assignment12](https://github.com/PGE310-Students/assignment12).  This class is instantiated as the class attribute objects `P, L` and `U` and allows for indexing operations similar to Python lists and NumPy arrays as well as the row operation functions.  Please use this class and it's member functions to implement your functions when appropriate.

In [1]:
import numpy as np

from assignment12 import Matrix

In [18]:
class LU():
    def __init__(self, A):
        self.n = A.shape[0]
        self.U = Matrix(A.copy())
        self.L = Matrix(np.eye(self.n))
        self.P = Matrix(np.eye(self.n))
        self.number_of_permutations = 0
        self.decomp()
        
    def decomp(self):
        for i in range(self.n):
            k = i
            while self.U[i,i] == 0:
                self.U.row_swap(i, k+1)
                self.P.row_swap(i, k+1)
                self.number_of_permutations += 1
                k += 1
            for j in range(i+1, self.n):
                factor = self.U()[j, i] / self.U()[i,i]
                self.L[j, i] = factor
                self.U.row_combine(j, i, factor)
        return
                

    def forward_substitution(self, b):
        b = np.dot(self.P(), b)
        y = np.zeros_like(b, dtype=np.double)
        y[0] = b[0] / self.L[0, 0]
        for i in range(1, self.n):
            y[i] = (b[i] - np.dot(self.L()[i,:i], y[:i])) / self.L[i,i]
        return y

    def back_substitution(self, y):
        x = np.zeros_like(y, dtype=np.double)
        x[-1] = y[-1] / self.U()[-1, -1]
        for i in range(self.n-2, -1, -1):
            x[i] = (y[i] - np.dot(self.U()[i,i+1:], x[i+1:])) / self.U[i,i]
        return x
    
    def solve(self, b):
        b = np.array(b)
        if len(b.shape) == 1:
            b = b.reshape(1, -1)
        x = np.zeros_like(b, dtype=np.double)
        for i in range(b.shape[0]):
            y = self.forward_substitution(b[i])
            x[:,i] = self.back_substitution(y)
        return x[0] if x.shape[0] == 1 else x
    
    def det(self):
        determinant = np.prod(np.diag(self.U()))
        if self.number_of_permutations % 2 != 0:
            determinant *= -1
        return determinant
    
    def inverse(self):
        I = np.eye(self.n)
        return self.solve(I)

In [16]:
A = np.array([[0, 3, 4], [4, 6, 10], [22, 1, 7]])
sol = LU(A)
sol.inverse()

array([[ 0.5     , -0.265625,  0.09375 ],
       [ 3.      , -1.375   ,  0.25    ],
       [-2.      ,  1.03125 , -0.1875  ]])

In [17]:
np.linalg.inv(A)

array([[ 0.5     , -0.265625,  0.09375 ],
       [ 3.      , -1.375   ,  0.25    ],
       [-2.      ,  1.03125 , -0.1875  ]])