## Matrix Multiplication

### Row-wise multiplication:
$$
\text{Each row } i: \quad \mathbf{C}_{i,:} = \mathbf{A}_{i,:} \cdot \mathbf{B}
$$

### Column-wise multiplication:
$$
\text{Each column } j: \quad \mathbf{C}_{:,j} = \mathbf{A} \cdot \mathbf{B}_{:,j}
$$

---

## Block Matrix Multiplication

$$
\mathbf{C} = \mathbf{A} \cdot \mathbf{B}
$$

---

## Matrix Inverse

$$
\mathbf{A}^{-1} \cdot \mathbf{A} = \mathbf{I}
$$

---

## Gauss-Jordan Elimination

Step-by-step process:

1. Normalize pivot rows:
   $$
   A_{i,:} \leftarrow \frac{A_{i,:}}{A_{ii}}, \quad b_i \leftarrow \frac{b_i}{A_{ii}}
   $$

2. Eliminate other rows:
   $$
   A_{j,:} \leftarrow A_{j,:} - A_{ji} A_{i,:}, \quad b_j \leftarrow b_j - A_{ji} b_i
   $$

Final solution:
$$
x = \begin{bmatrix}
x_1 \\
x_2 \\
x_3
\end{bmatrix}
$$


# Row-wise Multiplication

In [1]:
import jax.numpy as jnp

def matmul_rowwise(A, B):
    rows, cols = A.shape[0], B.shape[1]
    result = jnp.zeros((A.shape[0], B.shape[1]))
    for i in range(rows):
        row = A[i]
        result = result.at[i].set(row @ B)
    return result


# Column-wise Multiplication

In [2]:
def matmul_columnwise(A, B):
    rows, cols = A.shape[0], B.shape[1]
    result = jnp.zeros((A.shape[0], B.shape[1]))
    for j in range(cols):
        col = B[:, j]
        result = result.at[:, j].set(A @ col)
    return result


# Block Matrix Multiplication

In [3]:
def block_matrix_multiplication(A, B):
    return A @ B  # Standard matrix multiplication handles blocks

# Matrix Inverse

In [4]:
import jax.numpy as jnp

def matrix_inverse(A):
    A = A.copy()
    n = A.shape[0]
    I = jnp.eye(n)

    # Augmented matrix [A | I]
    aug = jnp.hstack([A, I])

    print("Initial Augmented Matrix [A | I]:")
    print(aug, "\n")

    for i in range(n):
        pivot = aug[i, i]
        # Normalize the pivot row
        aug = aug.at[i].set(aug[i] / pivot)
        print(f"Step {i+1}: Normalize row {i} by dividing by pivot {pivot}:")
        print(aug, "\n")

        # Eliminate all other elements in current column
        for j in range(n):
            if j != i:
                factor = aug[j, i]
                aug = aug.at[j].set(aug[j] - factor * aug[i])
                print(f"Eliminate element at position ({j},{i}) using row {i}:")
                print(aug, "\n")

    # The right half of aug is the inverse
    A_inv = aug[:, n:]
    return A_inv


# Gauss-Jordan Elimination (with Step-by-Step Print)

In [5]:
def gauss_jordan_elimination(A, b):
    A = A.copy()
    b = b.copy()
    n = A.shape[0]

    print("Initial Augmented Matrix [A | b]:")
    print(jnp.hstack([A, b.reshape(-1,1)]), "\n")

    # Forward Elimination + Normalize pivot and eliminate all others
    for i in range(n):
        pivot = A[i, i]
        A = A.at[i].set(A[i] / pivot)
        b = b.at[i].set(b[i] / pivot)
        print(f"Normalize row {i}:")
        print(jnp.hstack([A, b.reshape(-1,1)]), "\n")

        for j in range(n):
            if j != i:
                factor = A[j, i]
                A = A.at[j].set(A[j] - factor * A[i])
                b = b.at[j].set(b[j] - factor * b[i])
                print(f"Eliminate element A[{j},{i}] using row {i}:")
                print(jnp.hstack([A, b.reshape(-1,1)]), "\n")

    return b  # Final solution


In [6]:
# Example matrices
A = jnp.array([[2.0, 1.0, -1.0],
               [-3.0, -1.0, 2.0],
               [-2.0, 1.0, 2.0]])

b = jnp.array([8.0, -11.0, -3.0])

print("=== Gauss-Jordan Elimination ===")
solution = gauss_jordan_elimination(A, b)
print("Final Solution x:", solution)

print("\n=== Matrix Multiplication ===")
A2 = jnp.array([[1.0, 2.0], [3.0, 4.0]])
B2 = jnp.array([[5.0, 6.0], [7.0, 8.0]])

rowwise = matmul_rowwise(A2, B2)
columnwise = matmul_columnwise(A2, B2)
block_mult = block_matrix_multiplication(A2, B2)
print("\n=== Matrix Inverse ===")
inverse = matrix_inverse(A2)
print("Inverse of A2:\n", inverse)

print("Row-wise multiplication result:\n", rowwise)
print("Column-wise multiplication result:\n", columnwise)
print("Block matrix multiplication result:\n", block_mult)



=== Gauss-Jordan Elimination ===
Initial Augmented Matrix [A | b]:
[[  2.   1.  -1.   8.]
 [ -3.  -1.   2. -11.]
 [ -2.   1.   2.  -3.]] 

Normalize row 0:
[[  1.    0.5  -0.5   4. ]
 [ -3.   -1.    2.  -11. ]
 [ -2.    1.    2.   -3. ]] 

Eliminate element A[1,0] using row 0:
[[ 1.   0.5 -0.5  4. ]
 [ 0.   0.5  0.5  1. ]
 [-2.   1.   2.  -3. ]] 

Eliminate element A[2,0] using row 0:
[[ 1.   0.5 -0.5  4. ]
 [ 0.   0.5  0.5  1. ]
 [ 0.   2.   1.   5. ]] 

Normalize row 1:
[[ 1.   0.5 -0.5  4. ]
 [ 0.   1.   1.   2. ]
 [ 0.   2.   1.   5. ]] 

Eliminate element A[0,1] using row 1:
[[ 1.  0. -1.  3.]
 [ 0.  1.  1.  2.]
 [ 0.  2.  1.  5.]] 

Eliminate element A[2,1] using row 1:
[[ 1.  0. -1.  3.]
 [ 0.  1.  1.  2.]
 [ 0.  0. -1.  1.]] 

Normalize row 2:
[[ 1.  0. -1.  3.]
 [ 0.  1.  1.  2.]
 [-0. -0.  1. -1.]] 

Eliminate element A[0,2] using row 2:
[[ 1.  0.  0.  2.]
 [ 0.  1.  1.  2.]
 [-0. -0.  1. -1.]] 

Eliminate element A[1,2] using row 2:
[[ 1.  0.  0.  2.]
 [ 0.  1.  0.  3.]
 [-0