## Gaussian Elimination (Step-by-Step)

We solve the system $A x = b$ using **Gaussian Elimination**.

### Step 1: Form the augmented matrix
$$
[A | b]
$$

### Step 2: Normalize pivot row and eliminate lower rows

For each pivot $A_{ii}$:
- Normalize row $i$:  
  $$ A_{i,:} \leftarrow \frac{A_{i,:}}{A_{ii}} $$
  $$ b_i \leftarrow \frac{b_i}{A_{ii}} $$

- For rows $j = i+1$ to $n-1$, eliminate $A_{ji}$:  
  $$ A_{j,:} \leftarrow A_{j,:} - A_{ji} \cdot A_{i,:} $$
  $$ b_j \leftarrow b_j - A_{ji} \cdot b_i $$

### Step 3: Back Substitution

Compute solution starting from bottom row:
$$
x_i = b_i - \sum_{j=i+1}^{n-1} A_{ij} x_j
$$

---

The computed solution is:
$$
x = \begin{bmatrix}
x_1 \\
x_2 \\
x_3
\end{bmatrix}
$$


In [1]:
import jax.numpy as jnp

def gaussian_elimination(A, b):
    A = A.copy()
    b = b.copy()
    n = A.shape[0]

    print("Initial Augmented Matrix [A | b]:")
    print(jnp.hstack([A, b.reshape(-1,1)]), "\n")

    # Forward Elimination
    for i in range(n):
        # Make pivot = 1 by dividing row by pivot element
        pivot = A[i, i]
        A = A.at[i].set(A[i] / pivot)
        b = b.at[i].set(b[i] / pivot)

        print(f"Step {i+1}: Normalize row {i}:")
        print(jnp.hstack([A, b.reshape(-1,1)]), "\n")

        # Eliminate below
        for j in range(i+1, n):
            factor = A[j, i]
            A = A.at[j].set(A[j] - factor * A[i])
            b = b.at[j].set(b[j] - factor * b[i])

            print(f"Eliminate row {j} using row {i}:")
            print(jnp.hstack([A, b.reshape(-1,1)]), "\n")

    # Back Substitution
    x = jnp.zeros(n)
    for i in reversed(range(n)):
        x = x.at[i].set(b[i] - jnp.sum(A[i, i+1:] * x[i+1:]))

    return x


In [2]:
# Example system: 3x3
A = jnp.array([[2.0, 1.0, -1.0],
               [-3.0, -1.0, 2.0],
               [-2.0, 1.0, 2.0]])

b = jnp.array([8.0, -11.0, -3.0])

solution = gaussian_elimination(A, b)
print("Final Solution x:", solution)

# Verify Ax ≈ b
print("\nVerification A @ x:", A @ solution)
print("Original b:", b)


Initial Augmented Matrix [A | b]:
[[  2.   1.  -1.   8.]
 [ -3.  -1.   2. -11.]
 [ -2.   1.   2.  -3.]] 

Step 1: Normalize row 0:
[[  1.    0.5  -0.5   4. ]
 [ -3.   -1.    2.  -11. ]
 [ -2.    1.    2.   -3. ]] 

Eliminate row 1 using row 0:
[[ 1.   0.5 -0.5  4. ]
 [ 0.   0.5  0.5  1. ]
 [-2.   1.   2.  -3. ]] 

Eliminate row 2 using row 0:
[[ 1.   0.5 -0.5  4. ]
 [ 0.   0.5  0.5  1. ]
 [ 0.   2.   1.   5. ]] 

Step 2: Normalize row 1:
[[ 1.   0.5 -0.5  4. ]
 [ 0.   1.   1.   2. ]
 [ 0.   2.   1.   5. ]] 

Eliminate row 2 using row 1:
[[ 1.   0.5 -0.5  4. ]
 [ 0.   1.   1.   2. ]
 [ 0.   0.  -1.   1. ]] 

Step 3: Normalize row 2:
[[ 1.   0.5 -0.5  4. ]
 [ 0.   1.   1.   2. ]
 [-0.  -0.   1.  -1. ]] 

Final Solution x: [ 2.  3. -1.]

Verification A @ x: [  8. -11.  -3.]
Original b: [  8. -11.  -3.]
