In [1]:
import torch
import numpy as np
from torch.utils.cpp_extension import load_inline
from torch.profiler import profile, record_function, ProfilerActivity

In [10]:
def trace_handler(prof):
    print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
    prof.export_chrome_trace("tmp/test_trace_" + str(prof.step_num) + ".json")

def profile_func(func, *tensors, trace_handler=trace_handler):
        
    """ In this example with wait=1, warmup=1, active=2, repeat=1, profiler will skip the first step/iteration,
        start warming up on the second, record the third and the forth iterations, after which the trace will become available
        and on_trace_ready (when set) is called; the cycle repeats starting with the next step """
    with torch.profiler.profile(
        activities=[
            torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA,
        ],
        schedule=torch.profiler.schedule(wait=1, warmup=1, active=2, repeat=1), on_trace_ready=trace_handler
        # on_trace_ready=torch.profiler.tensorboard_trace_handler('./log')
        # used when outputting for tensorboard
        ) as p:
            for iter in range(10):
                func(*tensors)
                # send a signal to the profiler that the next iteration has started
                p.step()

In [3]:
cuda_source = '''
__global__ void square_matrix_kernel(const float* matrix, float* result, int width, int height) {
    int row = blockIdx.y * blockDim.y + threadIdx.y;
    int col = blockIdx.x * blockDim.x + threadIdx.x;

    if (row < height && col < width) {
        int idx = row * width + col;
        result[idx] = matrix[idx] * matrix[idx];
    }
}

torch::Tensor square_matrix(torch::Tensor matrix) {
    const auto height = matrix.size(0);
    const auto width = matrix.size(1);

    auto result = torch::empty_like(matrix);

    dim3 threads_per_block(16, 16);
    dim3 number_of_blocks((width + threads_per_block.x - 1) / threads_per_block.x,
                          (height + threads_per_block.y - 1) / threads_per_block.y);

    square_matrix_kernel<<<number_of_blocks, threads_per_block>>>(
        matrix.data_ptr<float>(), result.data_ptr<float>(), width, height);

    return result;
    }
'''

cpp_source = "torch::Tensor square_matrix(torch::Tensor matrix);"

In [None]:
square_matrix_extension = load_inline(
    name='square_matrix_extension',
    cpp_sources=cpp_source,
    cuda_sources=cuda_source,
    functions=['square_matrix'],
    with_cuda=True,
    extra_cuda_cflags=["-O2"],
    build_directory='tmp',
    # extra_cuda_cflags=['--expt-relaxed-constexpr']
)


In [5]:
a = torch.tensor([[1., 2., 3.], [4., 5., 6.]], device='cuda')
print(square_matrix_extension.square_matrix(a))

tensor([[ 1.,  4.,  9.],
        [16., 25., 36.]], device='cuda:0')


### Matrix mul row

In [6]:
cuda_source = """
__global__ void matrix_mul_kernel(const float* matrix_a, const float* matrix_b, float* result, 
                                  int width){
	
	int row = threadIdx.x;
 
	// Check for thread out of bound access
	if (row < width){
        for (int col = 0; col < width; ++col) {
            float sum = 0;	
            for (int k = 0; k < width; ++k){
    			sum += matrix_a[row * width + k] * matrix_b[k * width + col];
    		}
    		result[row * width + col] = sum;
    	}
    }
}

torch::Tensor matrix_mul(torch::Tensor matrix_a, torch::Tensor matrix_b){

    const auto width = matrix_a.size(0);
    auto result =  torch::empty_like(matrix_a);

    matrix_mul_kernel<<<1, width>>>(
        matrix_a.data_ptr<float>(), matrix_b.data_ptr<float>(), result.data_ptr<float>(), width); // assume square matrix
    return result;
}
"""

cpp_source = "torch::Tensor matrix_mul(torch::Tensor matrix_a, torch::Tensor matrix_b);"

In [7]:
matrix_mul_extension = load_inline(
    name='matrix_mul_extension',
    cpp_sources=cpp_source,
    cuda_sources=cuda_source,
    functions=['matrix_mul'],
    with_cuda=True,
    extra_cuda_cflags=["-O2"],
    build_directory='tmp',
    # extra_cuda_cflags=['--expt-relaxed-constexpr']
)


In [8]:
a = torch.tensor([[1., 2., 3], [4., 5., 6.], [7., 8., 9.]], device='cuda')
b = torch.tensor([[1., 2., 3], [4., 5., 6.], [7., 8., 9.]], device='cuda')
print(matrix_mul_extension.matrix_mul(a, b))

tensor([[ 30.,  36.,  42.],
        [ 66.,  81.,  96.],
        [102., 126., 150.]], device='cuda:0')


In [12]:
# profile_func(matrix_mul_extension.matrix_mul, *[a,b])

In [31]:
a @ b

tensor([[ 30.,  36.,  42.],
        [ 66.,  81.,  96.],
        [102., 126., 150.]], device='cuda:0')

### Matrix mul column

In [38]:
cuda_source = """
__global__ void matrix_mul_kernel(const float* matrix_a, const float* matrix_b, float* result, 
                                  int width){
	
	int col = threadIdx.x;
 
	// Check for thread out of bound access
	if (col < width){
        for (int row = 0; row < width; ++row) {
            float sum = 0;	
            for (int k = 0; k < width; ++k){
    			sum += matrix_a[row * width + k] * matrix_b[k * width + col];
    		}
    		result[row * width + col] = sum;
    	}
    }
}

torch::Tensor matrix_mul(torch::Tensor matrix_a, torch::Tensor matrix_b){

    const auto width = matrix_a.size(0);
    auto result =  torch::empty_like(matrix_a);

    matrix_mul_kernel<<<1, width>>>(
        matrix_a.data_ptr<float>(), matrix_b.data_ptr<float>(), result.data_ptr<float>(), width); // assume square matrix
    return result;
}
"""

cpp_source = "torch::Tensor matrix_mul(torch::Tensor matrix_a, torch::Tensor matrix_b);"

In [39]:
matrix_mul_extension_col = load_inline(
    name='matrix_mul_extension_col',
    cpp_sources=cpp_source,
    cuda_sources=cuda_source,
    functions=['matrix_mul'],
    with_cuda=True,
    extra_cuda_cflags=["-O2"],
    build_directory='tmp',
    # extra_cuda_cflags=['--expt-relaxed-constexpr']
)


In [40]:
a = torch.tensor([[1., 2., 3], [4., 5., 6.], [7., 8., 9.]], device='cuda')
b = torch.tensor([[1., 2., 3], [4., 5., 6.], [7., 8., 9.]], device='cuda')
print(matrix_mul_extension_col.matrix_mul(a, b))

tensor([[ 30.,  36.,  42.],
        [ 66.,  81.,  96.],
        [102., 126., 150.]], device='cuda:0')


In [36]:
%%timeit
matrix_mul_extension_col.matrix_mul(a, b)

5.87 µs ± 37.1 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [None]:
# profile_func(matrix_mul_extension.matrix_mul, *[a,b])

In [37]:
a @ b

tensor([[ 30.,  36.,  42.],
        [ 66.,  81.,  96.],
        [102., 126., 150.]], device='cuda:0')

### Matrix vec mul

In [54]:
cuda_source = """
__global__ void matrix_vec_mul_kernel(const float* matrix_b, const float* vec_c, float* vec_a, int width){
	
	int idx = threadIdx.x;
 
	// Check for thread out of bound access
	if (idx < width){
		float sum = 0;
		for (int col = 0; col < width; ++col) {
			sum += matrix_b[idx * width + col] * vec_c[col];
		}
	    vec_a[idx] = sum;
    }
}

torch::Tensor matrix_vec_mul(torch::Tensor matrix_b, torch::Tensor vec_c){

    const auto width = matrix_b.size(0);
    auto vec_a =  torch::empty_like(vec_c);

    matrix_vec_mul_kernel<<<1, width>>>(
        matrix_b.data_ptr<float>(), vec_c.data_ptr<float>(), vec_a.data_ptr<float>(), width); // assume square matrix
    return vec_a;
}
"""

cpp_source = "torch::Tensor matrix_vec_mul(torch::Tensor matrix_b, torch::Tensor vec_c);"

In [55]:
matrix_vec_mul_extension_col = load_inline(
    name='matrix_vec_mul_extension_col',
    cpp_sources=cpp_source,
    cuda_sources=cuda_source,
    functions=['matrix_vec_mul'],
    with_cuda=True,
    extra_cuda_cflags=["-O2"],
    build_directory='tmp',
    # extra_cuda_cflags=['--expt-relaxed-constexpr']
)


In [56]:
a = torch.tensor([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]], device='cuda')
b = torch.tensor([[1.], [4.], [7]], device='cuda')
print(matrix_vec_mul_extension_col.matrix_vec_mul(a, b))

tensor([[ 30.],
        [ 66.],
        [102.]], device='cuda:0')


In [57]:
%%timeit
matrix_vec_mul_extension_col.matrix_vec_mul(a, b)

5.24 µs ± 40.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [53]:
a @ b

tensor([[ 30.],
        [ 66.],
        [102.]], device='cuda:0')

In [None]:
# profile_func(matrix_vec_mul_extension_col.matrix_vec_mul, *[a,b])