# Matrix Multiplication 

In simple terms, mathematically,  

If **A** is an *m × n* matrix and **B** is an *n × p* matrix:

$$
\mathbf{A} = \begin{pmatrix}
 a_{11} & a_{12} & \cdots & a_{1n} \\
 a_{21} & a_{22} & \cdots & a_{2n} \\
 \vdots & \vdots & \ddots & \vdots \\
 a_{m1} & a_{m2} & \cdots & a_{mn} \\
\end{pmatrix},\quad
\mathbf{B} = \begin{pmatrix}
 b_{11} & b_{12} & \cdots & b_{1p} \\
 b_{21} & b_{22} & \cdots & b_{2p} \\
 \vdots & \vdots & \ddots & \vdots \\
 b_{n1} & b_{n2} & \cdots & b_{np} \\
\end{pmatrix}
$$

Then, the matrix product **C = AB**, where **C** is an *m × p* matrix, is defined as:

$$
\mathbf{C} = \begin{pmatrix}
 c_{11} & c_{12} & \cdots & c_{1p} \\
 c_{21} & c_{22} & \cdots & c_{2p} \\
 \vdots & \vdots & \ddots & \vdots \\
 c_{m1} & c_{m2} & \cdots & c_{mp} \\
\end{pmatrix}
$$

where each element is calculated by:

$$
c_{ij} = \sum_{k=1}^{n} a_{ik} \cdot b_{kj}
$$

for all $$ i = 1, 2, \dots, m $$ and $$ j = 1, 2, \dots, p $$.


In [None]:
#| default_exp matrix_multiply

In [None]:
#| export
from typing import List, Tuple
import numpy as np
import torch

In [None]:
#| exports

def matrix_multiply(a: List[List[float]], # input matrix of size (m, n)
                    b: List[List[float]] # input matrix of size (n, p)
                    ) -> List[List[float]]: # output matrix of size (m, p)
    a = np.array(a)
    b = np.array(b)
    ar, ac = a.shape
    br, bc = b.shape
    if ac != br:
        raise ValueError("Incompatible shapes for matrix multiplication")
    result = np.zeros((ar, bc))
    for i in range(ar):
        for j in range(bc):
            for k in range(ac): # or br works too
                result[i][j] += a[i][k] * b[k][j]
    return result.tolist()