## LU Factorization

Given a matrix $A$, we want to decompose:
$$
A = L U
$$

### ✅ Method 1: Inverse of Product

At each step:
1. Eliminate below pivot:
   $$
   U_{j,:} \leftarrow U_{j,:} - \frac{U_{ji}}{U_{ii}} U_{i,:}
   $$
   and store the factor in $L_{j,i}$.

Final result:
$$
A = L U
$$

Check:
$$
A^{-1} = U^{-1} L^{-1}
$$

---

### ✅ Method 2: Transpose of Product

1. Factorize $A^\top = U^\top L^\top$ (similar to forward elimination).
2. Then transpose back:
   $$
   A = L U
   $$

Final result:
$$
A = L U
$$


## Method 1

In [1]:
import jax.numpy as jnp

def lu_factorization_inverse_method(A):
    A = A.copy()
    n = A.shape[0]

    L = jnp.eye(n)
    U = A.copy()

    print("Initial Matrix A:")
    print(A, "\n")

    # Forward Elimination to get U and L
    for i in range(n):
        for j in range(i+1, n):
            factor = U[j, i] / U[i, i]
            U = U.at[j].set(U[j] - factor * U[i])
            L = L.at[j, i].set(factor)

            print(f"Eliminating element U[{j},{i}] using factor {factor}:")
            print("U:")
            print(U)
            print("L:")
            print(L, "\n")

    print("Final L:")
    print(L)
    print("Final U:")
    print(U)

    # Check A = L @ U
    print("\nReconstructed A = L @ U:")
    print(L @ U)

    # Inverse check (optional)
    A_inv = jnp.linalg.inv(A)
    L_inv = jnp.linalg.inv(L)
    U_inv = jnp.linalg.inv(U)
    print("\nCheck A^{-1} = U^{-1} @ L^{-1}:")
    print(U_inv @ L_inv)
    print("\nDirect A^{-1}:")
    print(A_inv)

    return L, U


## Method 2

In [2]:
def lu_factorization_transpose_method(A):
    print("Factorizing A^T instead of A")

    A_T = A.T
    print("\nMatrix A^T:")
    print(A_T, "\n")

    L_T = jnp.eye(A.shape[0])
    U_T = A_T.copy()

    for i in range(A.shape[0]):
        for j in range(i+1, A.shape[0]):
            factor = U_T[j, i] / U_T[i, i]
            U_T = U_T.at[j].set(U_T[j] - factor * U_T[i])
            L_T = L_T.at[j, i].set(factor)

            print(f"Eliminating element U_T[{j},{i}] using factor {factor}:")
            print("U_T:")
            print(U_T)
            print("L_T:")
            print(L_T, "\n")

    U = L_T.T
    L = U_T.T

    print("Final L (from U_T^T):")
    print(L)
    print("Final U (from L_T^T):")
    print(U)

    # Check A = L @ U
    print("\nReconstructed A = L @ U:")
    print(L @ U)

    return L, U


In [3]:
A = jnp.array([[4.0, 3.0],
               [6.0, 3.0]])

print("\n=== LU Factorization (Inverse of Product Method) ===")
L1, U1 = lu_factorization_inverse_method(A)

print("\n=== LU Factorization (Transpose of Product Method) ===")
L2, U2 = lu_factorization_transpose_method(A)



=== LU Factorization (Inverse of Product Method) ===
Initial Matrix A:
[[4. 3.]
 [6. 3.]] 

Eliminating element U[1,0] using factor 1.5:
U:
[[ 4.   3. ]
 [ 0.  -1.5]]
L:
[[1.  0. ]
 [1.5 1. ]] 

Final L:
[[1.  0. ]
 [1.5 1. ]]
Final U:
[[ 4.   3. ]
 [ 0.  -1.5]]

Reconstructed A = L @ U:
[[4. 3.]
 [6. 3.]]

Check A^{-1} = U^{-1} @ L^{-1}:
[[-0.5        0.5      ]
 [ 1.        -0.6666667]]

Direct A^{-1}:
[[-0.5        0.5      ]
 [ 1.        -0.6666667]]

=== LU Factorization (Transpose of Product Method) ===
Factorizing A^T instead of A

Matrix A^T:
[[4. 6.]
 [3. 3.]] 

Eliminating element U_T[1,0] using factor 0.75:
U_T:
[[ 4.   6. ]
 [ 0.  -1.5]]
L_T:
[[1.   0.  ]
 [0.75 1.  ]] 

Final L (from U_T^T):
[[ 4.   0. ]
 [ 6.  -1.5]]
Final U (from L_T^T):
[[1.   0.75]
 [0.   1.  ]]

Reconstructed A = L @ U:
[[4. 3.]
 [6. 3.]]
