In [1]:
import torch
import torch.utils.cpp_extension
from pathlib import Path

In [2]:
matrix = torch.randn(4, 8, 32, 32768, device='cuda', dtype=torch.float32)
small = torch.tensor([[1, 2, 3]], device='cuda', dtype=torch.float32)

In [3]:
%%timeit
torch.nn.functional.softmax(matrix, dim=-1)
torch.cuda.synchronize()

128 μs ± 108 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [4]:
def naive_pytorch_softmax(x, dim=None):
    m = torch.max(x, dim=dim, keepdim=True)[0]
    e = torch.exp(x - m)
    s = torch.sum(e, dim=dim, keepdim=True)
    return e / s

In [5]:
torch.testing.assert_close(naive_pytorch_softmax(matrix, dim=-1), torch.nn.functional.softmax(matrix, dim=-1))

In [6]:
%%timeit
naive_pytorch_softmax(matrix, dim=-1)
torch.cuda.synchronize()

371 μs ± 321 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [7]:
cuda_begin = r'''
#include <torch/extension.h>
#include <stdio.h>
#include <c10/cuda/CUDAException.h>

#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

inline unsigned int cdiv(unsigned int a, unsigned int b) { return (a + b - 1) / b;}
'''

In [9]:
softmax1_cu = Path('softmax1.cu').read_text()
cuda_src1 = cuda_begin + "\n" + softmax1_cu
cpp_src1 = """
torch::Tensor softmax1(const torch::Tensor& x);
"""
my1 = torch.utils.cpp_extension.load_inline(
    "my1", cpp_src1, cuda_src1,
    functions=['softmax1'], extra_cuda_cflags=['--ptxas-options=-v'], verbose=True,
)

Using /home/belevich/.cache/torch_extensions/py312_cu126 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/belevich/.cache/torch_extensions/py312_cu126/my1/build.ninja...
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
Building extension module my1...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


ninja: no work to do.


Loading extension module my1...


In [10]:
torch.testing.assert_close(my1.softmax1(matrix), torch.nn.functional.softmax(matrix, dim=-1))

In [11]:
%%timeit
my1.softmax1(matrix)
torch.cuda.synchronize()

83.5 ms ± 92.9 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [13]:
softmax2_cu = Path('softmax2.cu').read_text()
cuda_src2 = cuda_begin + "\n" + softmax2_cu
cpp_src2 = """
torch::Tensor softmax2(const torch::Tensor& x);
"""
my2 = torch.utils.cpp_extension.load_inline(
    "my2", cpp_src2, cuda_src2,
    functions=['softmax2'], extra_cuda_cflags=['--ptxas-options=-v'], verbose=True,
)

Using /home/belevich/.cache/torch_extensions/py312_cu126 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/belevich/.cache/torch_extensions/py312_cu126/my2/build.ninja...
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
Building extension module my2...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


ninja: no work to do.


Loading extension module my2...


In [14]:
torch.testing.assert_close(my2.softmax2(matrix), torch.nn.functional.softmax(matrix, dim=-1))

In [15]:
%%timeit
my2.softmax2(matrix)
torch.cuda.synchronize()

136 μs ± 67.8 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [17]:
softmax3_cu = Path('softmax3.cu').read_text()
cuda_src3 = cuda_begin + "\n" + softmax3_cu
cpp_src3 = """
torch::Tensor softmax3(const torch::Tensor& x);
"""
my3 = torch.utils.cpp_extension.load_inline(
    "my3", cpp_src3, cuda_src3,
    functions=['softmax3'], extra_cuda_cflags=['--ptxas-options=-v'], verbose=True,
)

Using /home/belevich/.cache/torch_extensions/py312_cu126 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/belevich/.cache/torch_extensions/py312_cu126/my3/build.ninja...
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
Building extension module my3...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


ninja: no work to do.


Loading extension module my3...


In [18]:
torch.testing.assert_close(my3.softmax3(matrix), torch.nn.functional.softmax(matrix, dim=-1))

In [19]:
%%timeit
my3.softmax3(matrix)
torch.cuda.synchronize()

138 μs ± 20.1 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [34]:
softmax2_shfl_cu = Path('softmax2_shfl.cu').read_text()
cuda_src2_shfl = cuda_begin + "\n" + softmax2_shfl_cu
cpp_src2_shfl = """
torch::Tensor softmax2_shfl(const torch::Tensor& x);
"""
my2_shfl = torch.utils.cpp_extension.load_inline(
    "my2_shfl", cpp_src2_shfl, cuda_src2_shfl,
    functions=['softmax2_shfl'], extra_cuda_cflags=['--ptxas-options=-v'], verbose=True,
)

Using /home/belevich/.cache/torch_extensions/py312_cu126 as PyTorch extensions root...
The input conditions for extension module my2_shfl have changed. Bumping to version 6 and re-building as my2_shfl_v6...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/belevich/.cache/torch_extensions/py312_cu126/my2_shfl/build.ninja...
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
Building extension module my2_shfl_v6...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


[1/3] c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=my2_shfl_v6 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1016\" -isystem /fsxl/belevich/miniconda3/lib/python3.12/site-packages/torch/include -isystem /fsxl/belevich/miniconda3/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda/include -isystem /fsxl/belevich/miniconda3/include/python3.12 -D_GLIBCXX_USE_CXX11_ABI=1 -fPIC -std=c++17 -c /home/belevich/.cache/torch_extensions/py312_cu126/my2_shfl/main.cpp -o main.o 
[2/3] /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output cuda.cuda.o.d -DTORCH_EXTENSION_NAME=my2_shfl_v6 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1016\" -isystem /fsxl/belevich/miniconda3/lib/python3.12/site-packages/torch/include -isystem /fsxl/belevich/miniconda3/lib/python3.12/site-pa

Loading extension module my2_shfl_v6...


In [35]:
torch.testing.assert_close(my2_shfl.softmax2_shfl(matrix), torch.nn.functional.softmax(matrix, dim=-1))

In [36]:
%%timeit
my2_shfl.softmax2_shfl(matrix)
torch.cuda.synchronize()

130 μs ± 40.2 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [39]:
softmax3_shfl_cu = Path('softmax3_shfl.cu').read_text()
cuda_src3_shfl = cuda_begin + "\n" + softmax3_shfl_cu
cpp_src3_shfl = """
torch::Tensor softmax3_shfl(const torch::Tensor& x);
"""
my3_shfl = torch.utils.cpp_extension.load_inline(
    "my3_shfl", cpp_src3_shfl, cuda_src3_shfl,
    functions=['softmax3_shfl'], extra_cuda_cflags=['--ptxas-options=-v'], verbose=True,
)

Using /home/belevich/.cache/torch_extensions/py312_cu126 as PyTorch extensions root...
The input conditions for extension module my3_shfl have changed. Bumping to version 2 and re-building as my3_shfl_v2...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/belevich/.cache/torch_extensions/py312_cu126/my3_shfl/build.ninja...
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
Building extension module my3_shfl_v2...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)


[1/3] c++ -MMD -MF main.o.d -DTORCH_EXTENSION_NAME=my3_shfl_v2 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1016\" -isystem /fsxl/belevich/miniconda3/lib/python3.12/site-packages/torch/include -isystem /fsxl/belevich/miniconda3/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -isystem /usr/local/cuda/include -isystem /fsxl/belevich/miniconda3/include/python3.12 -D_GLIBCXX_USE_CXX11_ABI=1 -fPIC -std=c++17 -c /home/belevich/.cache/torch_extensions/py312_cu126/my3_shfl/main.cpp -o main.o 
[2/3] /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output cuda.cuda.o.d -DTORCH_EXTENSION_NAME=my3_shfl_v2 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1016\" -isystem /fsxl/belevich/miniconda3/lib/python3.12/site-packages/torch/include -isystem /fsxl/belevich/miniconda3/lib/python3.12/site-pa

Loading extension module my3_shfl_v2...


In [40]:
torch.testing.assert_close(my3_shfl.softmax3_shfl(matrix), torch.nn.functional.softmax(matrix, dim=-1))

In [41]:
%%timeit
my3_shfl.softmax3_shfl(matrix)
torch.cuda.synchronize()

130 μs ± 42.7 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
