# EE 508 HW 3 Part 1: Kernel Design

Your task in this Colab notebook is to fill out the sections that are specified by **TODO** (please search the keyword `TODO` to make sure you do not miss any).

Install the `Ninja` package in Colab used for building PyTorch kernels and import all required packages.

In [1]:
!pip install Ninja

Collecting Ninja
  Downloading ninja-1.11.1.3-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl.metadata (5.3 kB)
Downloading ninja-1.11.1.3-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (422 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/422.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━[0m [32m307.2/422.9 kB[0m [31m9.6 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m422.9/422.9 kB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: Ninja
Successfully installed Ninja-1.11.1.3


In [2]:
import torch
import torch.nn.functional as F

from torch.utils.cpp_extension import load_inline
import time

No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'


## Im2col Algorithm


`Im2col` is a method used in CNNs to transform the input data and filters into a format that allows the convolution to be expressed as a matrix multiplication. This transformation can simplify the implementation of convolution and leverage highly optimized matrix multiplication routines such as BLAS library.

In the class and discussion, we have covered the `im2col` algorithm for 2D input with 2D filter. In this section, we extend and implement the `im2col` algorithm for 4D input with 4D filters.

### **TODO 1**:

Implement the `im2col` operation to transform the 4D input and filter tensors into two 2D matrices. After matrix multiplication, reshape the result back into a 4D tensor to simulate the convolution operation.

In [3]:
def im2col_conv2d(input, filter, stride=1, padding=0):
    """
    Perform a convolution operation using im2col transformation.

    Parameters:
    - input: Tensor of shape (N, C, H, W) -> Input image batch
    - filter: Tensor of shape (K, C, KH, KW) -> Convolution filters
    - stride: Stride of the convolution
    - padding: Padding around the input

    Returns:
    - output_im2col: Tensor of shape (N, K, OH, OW) -> Convolution output
    """

    # Extract dimensions
    N, C, H, W = input.shape
    K, _, KH, KW = filter.shape

    # Pad the input if necessary
    if padding > 0:
        input = F.pad(input, (padding, padding, padding, padding))

    # Compute output height and width
    OH = (H + 2 * padding - KH) // stride + 1
    OW = (W + 2 * padding - KW) // stride + 1

    # Initialize toeplitz tensor (im2col representation)
    toeplitz = torch.zeros((C*KH*KW, N*OH*OW))

    # Perform im2col transformation and store in toeplitz tensor
    col = 0
    for i in range(N):
            for j in range(OH):
                for k in range(OW):
                        block = input[i, :, j*stride:j*stride+KH, k*stride:k*stride+KW].flatten()
                        toeplitz[:, col] = block
                        col += 1


    # Reshape filter and toeplitz tensor for matrix multiplication
    filter = filter.reshape(K,C*KH*KW)

    # Compute output by matrix multiplication
    output_im2col = torch.matmul(filter,toeplitz)

    # Reshape output
    output_im2col = output_im2col.reshape([K,N,OH,OW]).transpose(0,1)

    return output_im2col

Let's perform some tests. The results returned by `im2col_conv2d` function should match the results returned by PyTorch.

In [4]:
# Test case 1 with stride=1 and padding=0
N, C, H, W = 4, 3, 5, 5  # Input dimensions
K, _, KH, KW = 2, C, 3, 3  # Filter dimensions

# Create random input and filter tensors
torch.manual_seed(508)
input_tensor = torch.randn(N, C, H, W)
filter_tensor = torch.randn(K, C, KH, KW)

# Perform convolution using im2col
my_results = im2col_conv2d(input_tensor, filter_tensor, stride=1, padding=0)

# Perform convolution using PyTorch
pt_results = F.conv2d(input_tensor, filter_tensor, stride=1, padding=0)

# Compare results
print("Results are matched:", torch.allclose(my_results, pt_results, atol=1e-6))

Results are matched: True


In [5]:
# Test case 2 with stride=2 and padding=1
# Perform convolution using im2col
my_results = im2col_conv2d(input_tensor, filter_tensor, stride=2, padding=1)

# Perform convolution using PyTorch
pt_results = F.conv2d(input_tensor, filter_tensor, stride=2, padding=1)

# Compare results
print("Results are matched:", torch.allclose(my_results, pt_results, atol=1e-6))

Results are matched: True


## PyTorch C++ Extension

## Reordered Matrix Multiplication
Matrix multiplication is one of the most critical operations in neural networks. While PyTorch provides a highly optimized `matmul` operator, re-implementing it from scratch can provide deeper insights into performance tuning. In this section, we will demonstrate how to implement a naive matrix multiplication (ijk ordered) kernel on CPU backend and integrate it with PyTorch.

The code below demonstrates how to build PyTorch C++ extensions using just-in-time (JIT) compilation. The JIT compilation mechanism provides you with a way of compiling and loading your extensions on the fly using PyTorch's API `torch.utils.cpp_extension.load()` or `torch.utils.cpp_extension.load_inline()`.

* `torch.utils.cpp_extension.load()` requires writing the C++ source code to a file and loading it from the filesystem.
* `torch.utils.cpp_extension.load_inline()` functions similarly but takes the source code as a string rather than a file, which is the approach we will use.

In [6]:
cpp_source = """
torch::Tensor my_ijk_matmul(torch::Tensor a, torch::Tensor b) {
    TORCH_CHECK(a.dim() == 2 && b.dim() == 2, "Inputs must be 2D tensors.");
    TORCH_CHECK(a.dtype() == torch::kFloat32 && b.dtype() == torch::kFloat32, "Inputs must be of type float32.");
    TORCH_CHECK(a.size(1) == b.size(0), "Inner dimensions must match.");

    int m = a.size(0);
    int n = a.size(1);
    int p = b.size(1);

    // Extract raw pointers to input tensors
    const float* a_ptr = a.data_ptr<float>();
    const float* b_ptr = b.data_ptr<float>();

    // Define output tensor using torch::zeros (initialize with 0)
    torch::Tensor output = torch::zeros({m, p}, a.options());
    float* output_ptr = output.data_ptr<float>();

    // Perform ijk-ordered matrix multiplication (output-stationary)
    for (int i = 0; i < m; ++i) {
        for (int j = 0; j < p; ++j) {
            float sum = 0.0;  // Accumulator for C(i, j)
            for (int k = 0; k < n; ++k) {
                sum += a_ptr[i * n + k] * b_ptr[k * p + j];
            }
            output_ptr[i * p + j] = sum;
        }
    }

    return output;
}
"""

# Load the extension
my_kernel_lib = load_inline(
    name="cpp_extension",
    cpp_sources=cpp_source,
    functions="my_ijk_matmul",
    extra_cflags=['-O2'],
)

The kernel library has been loaded. Before testing it, we define some helper functions first.

In [7]:
def check_correctness(func):
    """
    Check correctness of the custom kernel.
    """
    # Define input tensors
    torch.manual_seed(508)
    A = torch.randn(32, 32, dtype=torch.float32)
    B = torch.randn(32, 32, dtype=torch.float32)

    # Perform matrix multiplication using the custom kernel
    my_results = func(A, B)

    # Perform matrix multiplication using PyTorch
    pt_results = torch.matmul(A, B)

    # Compare results
    if torch.allclose(my_results, pt_results, atol=1e-6):
        print("Results match!")
    else:
        print("Results do not match!")
        print(f"my_results: {my_results}")
        print(f"pt_results: {pt_results}")


def benchmark_matmul(func, num_runs=3, warmup_runs=3, size=1024):
    """
    Benchmark a matrix multiplication function and compare it with PyTorch's matmul.

    Parameters:
    - func: The custom function to benchmark.
    - num_runs: Number of timed executions for measurement.
    - warmup_runs: Number of warm-up executions.
    - size: Matrix dimension (size x size).

    Returns:
    - None (prints benchmark results)
    """

    # Define input tensors
    torch.manual_seed(508)
    A = torch.randn(size, size, dtype=torch.float32)
    B = torch.randn(size, size, dtype=torch.float32)

    def measure_flops(kernel_func):
        """Helper function to measure FLOPs per second for a given function."""

        # Warm-up phase
        for _ in range(warmup_runs):
            kernel_func(A, B)

        # Measure execution time over multiple runs
        start_time = time.time()
        for _ in range(num_runs):
            kernel_func(A, B)
        end_time = time.time()

        # Compute average time per run
        avg_time = (end_time - start_time) / num_runs

        # Estimate FLOPs: 2 * (m * n * p) for standard matrix multiplication
        flops = 2 * size * size * size
        flops_per_sec = flops / avg_time

        return flops_per_sec, avg_time

    # Benchmark the custom kernel
    print("Benchmarking custom kernel...")
    flops_per_sec, avg_time = measure_flops(func)
    print(f"My kernel GFLOPs per second: {(flops_per_sec * 1e-9):.5f}, Average time: {avg_time:.5f} sec")

    # Benchmark PyTorch's matmul
    print("\nBenchmarking PyTorch matmul...")
    flops_per_sec, avg_time = measure_flops(torch.matmul)
    print(f"PyTorch GFLOPs per second: {(flops_per_sec * 1e-9):.5f}, Average time: {avg_time:.5f} sec")

Let's compare the results returned by our customized kernel implmentation with the correct results returned by PyTorch.

In [8]:
check_correctness(my_kernel_lib.my_ijk_matmul)

Results match!


Benchmark the performance. This will take less one minute to finish, and you will find out that this implementation is about 25 times slower than PyTorch's built-in kernel!

In [9]:
benchmark_matmul(my_kernel_lib.my_ijk_matmul)

Benchmarking custom kernel...
My kernel GFLOPs per second: 0.42848, Average time: 5.01189 sec

Benchmarking PyTorch matmul...
PyTorch GFLOPs per second: 88.71087, Average time: 0.02421 sec


### **TODO 2:**
Implement **jki** ordered matmul kernel using the template below.

In [12]:
cpp_source = """
torch::Tensor my_jki_matmul(torch::Tensor a, torch::Tensor b) {
    TORCH_CHECK(a.dim() == 2 && b.dim() == 2, "Inputs must be 2D tensors.");
    TORCH_CHECK(a.dtype() == torch::kFloat32 && b.dtype() == torch::kFloat32, "Inputs must be of type float32.");
    TORCH_CHECK(a.size(1) == b.size(0), "Inner dimensions must match.");

    int m = a.size(0);
    int n = a.size(1);
    int p = b.size(1);

    // Extract raw pointers to input tensors
    const float* a_ptr = a.data_ptr<float>();
    const float* b_ptr = b.data_ptr<float>();

    // Define output tensor using torch::zeros (initialize with 0)
    torch::Tensor output = torch::zeros({m, p}, a.options());
    float* output_ptr = output.data_ptr<float>();

    // Perform jki-ordered matrix multiplication (B-stationary)
    for (int j = 0; j < p; ++j) {
        for (int k = 0; k < n; ++k) {
            float b_jk = b_ptr[k * p + j];
            for (int i = 0; i < m; ++i) {
                output_ptr[i * p + j] += a_ptr[i * n + k] * b_jk;
            }
        }
    }

    return output;
}
"""


# Load the extension
my_kernel_lib = load_inline(
    name="cpp_extension",
    cpp_sources=cpp_source,
    functions="my_jki_matmul",
    extra_cflags=['-O2'],
)

We check correctness first and then benchmark it as previous example.

In [13]:
check_correctness(my_kernel_lib.my_jki_matmul)
benchmark_matmul(my_kernel_lib.my_jki_matmul)

Results match!
Benchmarking custom kernel...
My kernel GFLOPs per second: 0.13617, Average time: 15.77076 sec

Benchmarking PyTorch matmul...
PyTorch GFLOPs per second: 88.11324, Average time: 0.02437 sec


### **TODO 3:**
Implement **ikj** ordered matmul kernel using the template below.

In [14]:
cpp_source = """
torch::Tensor my_ikj_matmul(torch::Tensor a, torch::Tensor b) {
    TORCH_CHECK(a.dim() == 2 && b.dim() == 2, "Inputs must be 2D tensors.");
    TORCH_CHECK(a.dtype() == torch::kFloat32 && b.dtype() == torch::kFloat32, "Inputs must be of type float32.");
    TORCH_CHECK(a.size(1) == b.size(0), "Inner dimensions must match.");

    int m = a.size(0);
    int n = a.size(1);
    int p = b.size(1);

    // Extract raw pointers to input tensors
    const float* a_ptr = a.data_ptr<float>();
    const float* b_ptr = b.data_ptr<float>();

    // Define output tensor using torch::zeros (initialize with 0)
    torch::Tensor output = torch::zeros({m, p}, a.options());
    float* output_ptr = output.data_ptr<float>();

    // Perform ikj-ordered matrix multiplication (A-stationary)
    for (int i = 0; i < m; i++) {
        for (int k = 0; k < n; k++) {
            float a_ik = a_ptr[i * n + k];
            for (int j = 0; j < p; j++) {
                output_ptr[i * p + j] += a_ik * b_ptr[k * p + j];
            }
        }
    }

    return output;
}
"""

# Load the extension
my_kernel_lib = load_inline(
    name="cpp_extension",
    cpp_sources=cpp_source,
    functions="my_ikj_matmul",
    extra_cflags=['-O2'],
)

In [None]:
check_correctness(my_kernel_lib.my_ikj_matmul)
benchmark_matmul(my_kernel_lib.my_ikj_matmul)

### **TODO 4:**
Blocked matrix multiplication, also known as tiled matrix multiplication, can improve the temporal locality of inner loops. The general idea of blocking is to organize the data structures in a program into large chunks called blocks. (In this context, “block” refers to an application-level chunk of data, not to a cache block.) The program is structured so that it loads a chunk into the L1 cache, does all the reads and writes that it needs to on that chunk, then discards the chunk, loads in the next chunk, and so on.

Implement **blocked ikj** ordered matmul kernel, with block size of 16, using the template below.

In [15]:
cpp_source = """
#define BLOCK_SIZE 16

torch::Tensor my_blocked_ikj_matmul(torch::Tensor a, torch::Tensor b) {
    TORCH_CHECK(a.dim() == 2 && b.dim() == 2, "Inputs must be 2D tensors.");
    TORCH_CHECK(a.dtype() == torch::kFloat32 && b.dtype() == torch::kFloat32, "Inputs must be of type float32.");
    TORCH_CHECK(a.size(1) == b.size(0), "Inner dimensions must match.");

    int m = a.size(0);
    int n = a.size(1);
    int p = b.size(1);

    // Extract raw pointers to input tensors
    const float* a_ptr = a.data_ptr<float>();
    const float* b_ptr = b.data_ptr<float>();

    // Define output tensor using torch::zeros (initialize with 0)
    torch::Tensor output = torch::zeros({m, p}, a.options());
    float* output_ptr = output.data_ptr<float>();

    // Blocked Matrix Multiplication
    for (int i = 0; i < m; i += BLOCK_SIZE) {
        for (int j = 0; j < p; j += BLOCK_SIZE) {
            for (int k = 0; k < n; k += BLOCK_SIZE){
                for (int ii = i; ii < std::min(i + BLOCK_SIZE, m); ii++) {
                    for (int kk = k; kk < std::min(k + BLOCK_SIZE, n); kk++) {
                        float a_ik = a_ptr[ii * n + kk];
                        for (int jj = j; jj < std::min(j + BLOCK_SIZE, p); jj++) {
                            output_ptr[ii * p + jj] += a_ik * b_ptr[kk * p + jj];
                        }
                    }
                }
            }
        }
    }

    return output;
}
"""

# Load the extension
my_kernel_lib = load_inline(
    name="cpp_extension",
    cpp_sources=cpp_source,
    functions="my_blocked_ikj_matmul",
    extra_cflags=['-O2'],
)

In [16]:
check_correctness(my_kernel_lib.my_blocked_ikj_matmul)
benchmark_matmul(my_kernel_lib.my_blocked_ikj_matmul)

Results match!
Benchmarking custom kernel...
My kernel GFLOPs per second: 4.14006, Average time: 0.51871 sec

Benchmarking PyTorch matmul...
PyTorch GFLOPs per second: 89.60158, Average time: 0.02397 sec


## Data Parallel with SIMD

Intel AVX (Advanced Vector Extensions) instructions are Single Instruction Multiple Data (SIMD) instructions that can process 8 single precision or 4 double precision floating-point operands in a single instruction.

### **TODO 5:**

Add Intel AVX SIMD (8xFP32) to implement **non-blocked ikj** ordered matmul using the template below.

In [17]:
cpp_source = """
#include <immintrin.h>  // AVX intrinsics

torch::Tensor my_avx_ikj_matmul(torch::Tensor a, torch::Tensor b) {
    TORCH_CHECK(a.dim() == 2 && b.dim() == 2, "Inputs must be 2D tensors.");
    TORCH_CHECK(a.dtype() == torch::kFloat32 && b.dtype() == torch::kFloat32, "Inputs must be of type float32.");
    TORCH_CHECK(a.size(1) == b.size(0), "Inner dimensions must match.");

    int m = a.size(0);
    int n = a.size(1);
    int p = b.size(1);

    // Extract raw pointers to input tensors
    const float* a_ptr = a.data_ptr<float>();
    const float* b_ptr = b.data_ptr<float>();

    // Define output tensor using torch::zeros (initialize with 0)
    torch::Tensor output = torch::zeros({m, p}, a.options());
    float* output_ptr = output.data_ptr<float>();

    // Perform ikj-ordered matrix multiplication with AVX
    for (int i = 0; i < m; i++) {
        for (int k = 0; k < n; k++) {
            float a_ik = a_ptr[i * n + k];
            __m256 a_vec = _mm256_set1_ps(a_ik);

            int j = 0;
            for (; j <= p - 8; j += 8) {
                __m256 b_vec = _mm256_loadu_ps(&b_ptr[k * p + j]);
                __m256 c_vec = _mm256_loadu_ps(&output_ptr[i * p + j]);
                c_vec = _mm256_fmadd_ps(a_vec, b_vec, c_vec);
                _mm256_storeu_ps(&output_ptr[i * p + j], c_vec);
            }

            for (; j < p; j++) {
                output_ptr[i * p + j] += a_ik * b_ptr[k * p + j];
            }
        }
    }

    return output;
}
"""

# Load the extension
my_kernel_lib = load_inline(
    name="cpp_extension",
    cpp_sources=cpp_source,
    functions="my_avx_ikj_matmul",
    extra_cflags=["-mavx", "-mfma", "-O2"],
)

In [18]:
check_correctness(my_kernel_lib.my_avx_ikj_matmul)
benchmark_matmul(my_kernel_lib.my_avx_ikj_matmul)

Results match!
Benchmarking custom kernel...
My kernel GFLOPs per second: 20.96646, Average time: 0.10242 sec

Benchmarking PyTorch matmul...
PyTorch GFLOPs per second: 92.47290, Average time: 0.02322 sec


## Multi-threading
The `<thread>` library is a part of the C++11 standard library that provides classes and functions to manage threads. Using this library, you can create, manage, and synchronize threads directly from C++ code, which enables concurrent execution paths within your applications.

### **TODO 6:**
To further enhance the performance of the previous AVX implementation, employ multithreading by allocating two threads to execute the matrix multiplication concurrently. Each thread is assigned to process a distinct portion of matrix A: one thread handles the upper half, and the other handles the lower half.

In [None]:
cpp_source = """
#include <immintrin.h>  // AVX intrinsics
#include <torch/extension.h>
#include <vector>
#include <thread>

void thread_worker(const float* a_ptr, const float* b_ptr, float* output_ptr, int m_start, int m_end, int n, int p) {


}

torch::Tensor my_mt_avx_ikj_matmul(torch::Tensor a, torch::Tensor b) {
    TORCH_CHECK(a.dim() == 2 && b.dim() == 2, "Inputs must be 2D tensors.");
    TORCH_CHECK(a.dtype() == torch::kFloat32 && b.dtype() == torch::kFloat32, "Inputs must be of type float32.");
    TORCH_CHECK(a.size(1) == b.size(0), "Inner dimensions must match.");

    int m = a.size(0);
    int n = a.size(1);
    int p = b.size(1);

    // Extract raw pointers to input tensors
    const float* a_ptr = a.data_ptr<float>();
    const float* b_ptr = b.data_ptr<float>();

    // Define output tensor using torch::zeros (initialize with 0)
    torch::Tensor output = torch::zeros({m, p}, a.options());
    float* output_ptr = output.data_ptr<float>();

    // Create two threads to parallelize over rows of A
    int mid = m / 2;

    std::thread t1(thread_worker, a_ptr, b_ptr, output_ptr, 0, mid, n, p);
    std::thread t2(thread_worker, a_ptr, b_ptr, output_ptr, mid, m, n, p);

    // Join threads
    t1.join();
    t2.join();

    return output;
}
"""

# Load the extension
my_kernel_lib = load_inline(
    name="cpp_extension",
    cpp_sources=cpp_source,
    functions="my_mt_avx_ikj_matmul",
    extra_cflags=["-mavx", "-mfma", "-O2"],
)

In [None]:
check_correctness(my_kernel_lib.my_mt_avx_ikj_matmul)
benchmark_matmul(my_kernel_lib.my_mt_avx_ikj_matmul)