In [2]:
import torch
import time
import os

In [3]:
# function for naive matrix multiplication (single-threaded)
def func_matmul(A, B):
    batch_size, size, _ = A.shape
    _, _, size_b = B.shape

    # initialize result matrix C with zeros
    C = torch.zeros(batch_size, size, size_b, device=A.device)

    # loop through each batch
    for b in range(batch_size):
        for i in range(size):
            for j in range(size_b):
                C[b, i, j] = 0  # initialize the result for the position (i, j)
                for k in range(size):
                    C[b, i, j] += A[b, i, k] * B[b, k, j]

    return C

In [4]:
# function for single-threaded matrix multiplication
def single_threaded_matmul(batch_size, size, device):
    # create random square matrices A and B
    A = torch.randn(batch_size, size, size, device=device)
    B = torch.randn(batch_size, size, size, device=device)

    # perform matrix multiplication
    start_time = time.time()
    C = func_matmul(A, B)
    end_time = time.time()

    return end_time - start_time

In [5]:
# function for multi-threaded matrix multiplication
def multi_threaded_matmul(batch_size, size, num_threads, device):
    # set the number of threads for parallelization
    torch.set_num_threads(num_threads)

    A = torch.randn(batch_size, size, size, device=device)
    B = torch.randn(batch_size, size, size, device=device)

    start_time = time.time()
    C = func_matmul(A, B)
    end_time = time.time()

    return end_time - start_time

In [6]:
# main function to run experiments with varying batch sizes
def run_experiment():
    batch_sizes = [1, 10, 50, 100, 500]  # batch sizes to test
    matrix_size = 64  # size of the square matrix

    num_threads_list = [2, 4]  # number of threads to test for parallelization

    # force computation to run on CPU
    device = torch.device('cpu')

    # for each batch size, test single-threaded vs parallelized
    for batch_size in batch_sizes:
        print(f"\nBatch Size: {batch_size}")

        # single-threaded test
        single_time = single_threaded_matmul(batch_size, matrix_size, device)
        print(f"Single-threaded time: {single_time:.4f} seconds")

        # parallelized test for different numbers of threads
        for num_threads in num_threads_list:
            parallel_time = multi_threaded_matmul(batch_size, matrix_size, num_threads, device)
            print(f"Multi-threaded ({num_threads} threads) time: {parallel_time:.4f} seconds")

        print("-" * 50)

In [None]:
if __name__ == "__main__":
  run_experiment()


Batch Size: 1
