<a href="https://colab.research.google.com/github/tanvisharma/AIChip_Paper_List/blob/master/matmul.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# PyTorch Matrix Multiplication Demo
In this notebook, we will compare several ways to perform matrix multiplication using PyTorch:

1. **Naive Python for-loops** (for demonstration purposes, typically very slow)
2. **torch.matmul()**
3. **torch.mm()**
4. **torch.bmm()** (batched matrix multiplication)
5. **@ operator** (syntactic sugar for `torch.matmul`)
6. **torch.einsum()**

In [1]:
import torch
import time
print(f"PyTorch version: {torch.__version__}")

PyTorch version: 2.5.1+cu121


## 1. Naive Matrix Multiplication (Python for loops)
This is just to illustrate the concept. You wouldn't use this in practice in Python/PyTorch because it is much slower than built-in tensor operations.

In [2]:
def naive_matrix_multiplication(A, B):
    """Compute matrix multiplication A x B using Python for loops."""
    # A: shape (n, k)
    # B: shape (k, m)
    n = A.shape[0]
    k = A.shape[1]
    m = B.shape[1]
    C = torch.zeros(n, m, dtype=A.dtype)

    for i in range(n):
        for j in range(m):
            s = 0
            for z in range(k):
                s += A[i, z] * B[z, j]
            C[i, j] = s

    return C

# Example sizes
n, k, m = 4, 5, 3
A = torch.randn(n, k)
B = torch.randn(k, m)

C_naive = naive_matrix_multiplication(A, B)
print("Naive result\n", C_naive)

Naive result
 tensor([[ 7.2291, -1.0633,  1.8808],
        [ 2.7907, -1.6227, -0.3606],
        [-1.1496,  0.6015, -0.9736],
        [ 1.6018, -0.3654,  0.5207]])


## 2. torch.matmul()
`torch.matmul` is the most general-purpose matrix multiplication function in PyTorch. It can handle 1D, 2D, or higher-dimensional tensors (batch dimensions).

In [3]:
C_matmul = torch.matmul(A, B)
print("torch.matmul result\n", C_matmul)

# Let's compare with naive result
difference = torch.sum(torch.abs(C_naive - C_matmul))
print("Difference between naive and torch.matmul:", difference.item())

torch.matmul result
 tensor([[ 7.2291, -1.0633,  1.8808],
        [ 2.7907, -1.6227, -0.3606],
        [-1.1496,  0.6015, -0.9736],
        [ 1.6018, -0.3654,  0.5207]])
Difference between naive and torch.matmul: 1.4007091522216797e-06


## 3. torch.mm()
`torch.mm` only works for 2D matrices (no broadcasting, no batch dimensions). It's slightly more specialized than `torch.matmul`.

In [4]:
C_mm = torch.mm(A, B)
print("torch.mm result\n", C_mm)

difference_mm = torch.sum(torch.abs(C_mm - C_matmul))
print("Difference between torch.mm and torch.matmul:", difference_mm.item())

torch.mm result
 tensor([[ 7.2291, -1.0633,  1.8808],
        [ 2.7907, -1.6227, -0.3606],
        [-1.1496,  0.6015, -0.9736],
        [ 1.6018, -0.3654,  0.5207]])
Difference between torch.mm and torch.matmul: 0.0


## 4. torch.bmm()
`torch.bmm` performs a batch matrix-matrix product. It expects 3D tensors of the form `(batch_size, n, m)`.

We'll create a batched version of our matrices and multiply them.

In [5]:
batch_size = 2
n, k, m = 3, 4, 5

# Create batched tensors: shape: (batch_size, n, k), (batch_size, k, m)
A_batched = torch.randn(batch_size, n, k)
B_batched = torch.randn(batch_size, k, m)

# bmm: (batch_size, n, k) x (batch_size, k, m) -> (batch_size, n, m)
C_batched = torch.bmm(A_batched, B_batched)
print("Shape of C_batched:", C_batched.shape)
print("C_batched\n", C_batched)

Shape of C_batched: torch.Size([2, 3, 5])
C_batched
 tensor([[[-1.3760,  1.9335, -0.0937, -2.3668, -4.4848],
         [-0.5382, -1.2826,  1.7685, -1.5827,  4.2771],
         [ 1.0298, -0.2790, -0.3781,  1.8991,  1.9159]],

        [[ 0.0708, -1.3645,  0.0829,  0.9073,  1.0164],
         [-1.6293, -2.5508, -3.6004, -1.9828,  0.8957],
         [-0.9734, -3.0769,  1.3874,  3.3635, -0.1442]]])


## 5. @ Operator
The `@` operator in PyTorch is syntactic sugar for `torch.matmul()` when applied to PyTorch tensors.

In [6]:
# We'll reuse A and B
C_at = A @ B  # same as torch.matmul(A, B)
print("@ operator result\n", C_at)

difference_at = torch.sum(torch.abs(C_matmul - C_at))
print("Difference between @ operator and torch.matmul:", difference_at.item())

@ operator result
 tensor([[ 7.2291, -1.0633,  1.8808],
        [ 2.7907, -1.6227, -0.3606],
        [-1.1496,  0.6015, -0.9736],
        [ 1.6018, -0.3654,  0.5207]])
Difference between @ operator and torch.matmul: 0.0


## 6. torch.einsum()
`torch.einsum` is very flexible. For a standard matrix multiplication of two 2D matrices, we can use the Einstein summation notation `ij,jk->ik`.

In [7]:
C_einsum = torch.einsum('ij,jk->ik', A, B)
print("torch.einsum result\n", C_einsum)

difference_einsum = torch.sum(torch.abs(C_matmul - C_einsum))
print("Difference between torch.matmul and torch.einsum:", difference_einsum.item())

torch.einsum result
 tensor([[ 7.2291, -1.0633,  1.8808],
        [ 2.7907, -1.6227, -0.3606],
        [-1.1496,  0.6015, -0.9736],
        [ 1.6018, -0.3654,  0.5207]])
Difference between torch.matmul and torch.einsum: 1.4007091522216797e-06


## Performance Comparison
We can do a quick time comparison using `%timeit` or simple timing.

In [13]:
# Let's do a larger matrix for timing tests
n, k, m = 200, 200, 200
A_large = torch.randn(n, k)
B_large = torch.randn(k, m)



In [14]:
# 1. Naive for loops (warning: can be slow for large sizes)
start = time.time()
C_naive_large = naive_matrix_multiplication(A_large, B_large)
end = time.time()
print(f"Naive multiplication took: {end - start:.4f} seconds")

Naive multiplication took: 124.4101 seconds


In [15]:
# 2. torch.matmul()
start = time.time()
C_matmul_large = torch.matmul(A_large, B_large)
end = time.time()
print(f"torch.matmul took: {end - start:.6f} seconds")

torch.matmul took: 0.001682 seconds


In [16]:
# 3. torch.mm()
start = time.time()
C_mm_large = torch.mm(A_large, B_large)
end = time.time()
print(f"torch.mm took: {end - start:.6f} seconds")

torch.mm took: 0.001086 seconds


In [17]:
# 4. @ operator
start = time.time()
C_at_large = A_large @ B_large
end = time.time()
print(f"@ operator took: {end - start:.6f} seconds")

@ operator took: 0.001489 seconds


In [18]:
# 5. torch.einsum()
start = time.time()
C_einsum_large = torch.einsum('ij,jk->ik', A_large, B_large)
end = time.time()
print(f"torch.einsum took: {end - start:.6f} seconds")

torch.einsum took: 0.001968 seconds


We can see that naive for-loop implementation is significantly slower than any of the built-in tensor operations. Among the PyTorch operations (`matmul`, `mm`, `@`, `einsum`), you will typically find them close in performance, with potential small variations depending on backend optimizations.