# Single-head scaled dot product attention (SDPA) using the CUTLASS GEMM  API

Simple sandbox example (single batch, single head, row-major). Significant potential for optimization <br>
-> Tiling (FlashAttention, ThunderKittens, ...).

Input sequences $\, Q, K, V \in \mathbb{R}^{N \times d_k}$ where $N$ is the sequence length and $d_k$ is key (and query) vector dimension per head. \\

GEMM 1 : $S = Q K^\top \in \mathbb{R}^{N \times N}$ <br>
Scaling + Row-softmax : $P = \mathrm{softmax}\,\left(\, S / \sqrt{d_k} \, \right) \in \mathbb{R}^{N \times N} $ <br>
GEMM 2 : $O = P\,V \in \mathbb{R}^{N \times d_k}$



In [None]:
!nvcc --version
!nvidia-smi --query-gpu=compute_cap --format=csv,noheader | awk -F. '{printf "\nCompute capability : sm_%d%d\n",$1,$2}'

In [None]:
%pip install nvcc4jupyter

In [None]:
%load_ext nvcc4jupyter

In [None]:
!git clone https://github.com/NVIDIA/cutlass.git
%env CUTLASS_PATH=/content/cutlass
%env CUTLASS_INCLUDE=/content/cutlass/include

In [None]:
%%cuda_group_save --name "sdpa.cu" --group "sdpa"

#include <cmath>
#include <cfloat>
#include <cstdio>
#include <vector>
#include <random>
#include <limits>

#include <cuda_runtime.h>
#include <math_constants.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/layout/matrix.h>
#include <cutlass/numeric_types.h>
#include <cutlass/array.h>


inline cudaError_t CUDA_CHECK(cudaError_t result) {
    if (result != cudaSuccess) {
      std::fprintf(stderr, "CUDA error at %s:%d -> %s\n", __FILE__, __LINE__, \
              cudaGetErrorString(result) );
      std::exit(EXIT_FAILURE);
    }
    return result;
}

template <typename T> struct ComputeType { using type = T; };
template <> struct ComputeType<float> { using type = float; };
template <> struct ComputeType<double> { using type = double; };
template <> struct ComputeType<cutlass::half_t> { using type = float; };
template <> struct ComputeType<cutlass::bfloat16_t> { using type = float; };

template <typename T> __device__ inline T neg_inf() {}

template <> __device__ inline float neg_inf<float>()
{ return -CUDART_INF_F; }

template <> __device__ inline double neg_inf<double>()
{ return -CUDART_INF; };

template <> __device__ inline cutlass::half_t
  neg_inf<cutlass::half_t>() { return cutlass::half_t(-65504.0f); }

template <> __device__ inline cutlass::bfloat16_t
  neg_inf<cutlass::bfloat16_t>() {
      return cutlass::bfloat16_t(-3.38953139e38f); }




// Naive device transpose for row-major matrices
template <typename T>
__global__ void transpose_rowmajor(const T* __restrict__ in,
                                         T* __restrict__ out,
                                         int rows_in,
                                         int cols_in) {

  int r = blockIdx.y * blockDim.y + threadIdx.y;
  int c = blockIdx.x * blockDim.x + threadIdx.x;
  if (r < rows_in && c < cols_in) {
    // in[r, c] -> out[c, r]
    out[c * rows_in + r] = in[r * cols_in + c];
  }

}


// Naive device per-row softmax: one thread handles one row
// (ok for small N)
template <typename T>
__global__ void row_softmax(T* __restrict__ scores,
                            int N,
                            float inv_sqrt_dk) {

  using ComputeT = typename ComputeType<T>::type;

  int row = blockIdx.x * blockDim.x + threadIdx.x;
  if (row >= N) return;

  // Scale, then compute max for numerical stability
  // Note that softmax(x) = softmax(x + c), c a constant
  // We use c = max(x)
  ComputeT max_val = neg_inf<ComputeT>();
  T* row_ptr = scores + row * N;
  for (int j = 0; j < N; ++j) {
    ComputeT row_val =
      static_cast<ComputeT>(row_ptr[j]) * static_cast<ComputeT>(inv_sqrt_dk);
    row_ptr[j] = static_cast<T>(row_val);
    if (row_val > max_val) max_val = row_val;
  }

  // Exponentiate and sum
  ComputeT sum = static_cast<ComputeT>(0.0);
  for (int j = 0; j < N; ++j) {
    ComputeT row_val = static_cast<ComputeT>(row_ptr[j]);
    sum += std::exp(row_val - max_val);
  }

  // Normalize
  for (int j = 0; j < N; ++j) {
    ComputeT row_val = static_cast<ComputeT>(row_ptr[j]);
    ComputeT softmax_val = std::exp(row_val - max_val) / sum;
    row_ptr[j] = static_cast<T>(softmax_val);
  }

}


int main() {

  // Problem sizes (single batch, single head)
  constexpr int N = 128;   // sequence length
  constexpr int D = 64;    // d_k

  // row major matrix layout everywhere
  //using Element = float;
  using Element = cutlass::half_t;
  constexpr auto soEl = sizeof(Element);

  const float inv_sqrt_dk = 1.0f / std::sqrt((float)D);


  using Layout = cutlass::layout::RowMajor;

  using Architecture = cutlass::arch::Sm75; // for Google Colab

  // Host buffers (row-major)
  std::vector<Element> hQ(N*D), hK(N*D), hV(N*D);

  // Init with random values
  std::random_device rd;
  std::mt19937 rng(rd());
  std::uniform_real_distribution<float> dist(-0.1f, 0.1f);
  for (auto* arr : {&hQ, &hK, &hV}) {
    for (auto& x : *arr) {
        x = static_cast<Element>(dist(rng));
    }
  }

  // Device buffers
  Element *dQ = nullptr, *dK = nullptr, *dKt = nullptr, *dV = nullptr;
  Element *dScores = nullptr, *dOut = nullptr;

  CUDA_CHECK(cudaMalloc(&dQ,      N * D * soEl));
  CUDA_CHECK(cudaMalloc(&dK,      N * D * soEl));
  CUDA_CHECK(cudaMalloc(&dKt,     D * N * soEl));  // K^T
  CUDA_CHECK(cudaMalloc(&dV,      N * D * soEl));
  CUDA_CHECK(cudaMalloc(&dScores, N * N * soEl));  // Q K^T
  CUDA_CHECK(cudaMalloc(&dOut,    N * D * soEl));

  CUDA_CHECK(cudaMemcpy(dQ, hQ.data(), N * D * soEl, cudaMemcpyHostToDevice));
  CUDA_CHECK(cudaMemcpy(dK, hK.data(), N * D * soEl, cudaMemcpyHostToDevice));
  CUDA_CHECK(cudaMemcpy(dV, hV.data(), N * D * soEl, cudaMemcpyHostToDevice));

  // Compute K^T on device (row-major transpose)
  dim3 blockT(16, 16);
  dim3 gridT((D + blockT.x - 1) / blockT.x, (N + blockT.y - 1) / blockT.y);

  transpose_rowmajor<<<gridT, blockT>>>(dK, dKt, /*rows_in=*/N, /*cols_in=*/D);
  CUDA_CHECK(cudaGetLastError());

  // CUTLASS GEMM definitions
  // Gemm computes: D = alpha * A * B + beta * C
  using Gemm = cutlass::gemm::device::Gemm<
      Element, Layout,   // A
      Element, Layout,   // B
      Element, Layout    // C/D
  //    float,
  //    Architecture
  >;

  // GEMM 1: Scores = Q [N x D] * K^T [D x N] -> [N x N]
  {

    Element alpha = static_cast<Element>(1.0);
    Element beta  = static_cast<Element>(0.0);

    int lda = D;  // row-major leading dimension = columns
    int ldb = N;
    int ldc = N;

    Gemm gemm_op;

    Gemm::Arguments args(
      {N, N, D},       // GEMM problem dimensions
      {dQ,  lda},      // A
      {dKt, ldb},      // B
      {dScores, ldc},  // C
      {dScores, ldc},  // D (in-place OK when beta=0)
      {alpha, beta}
    );

    auto status = gemm_op(args);

    if (status != cutlass::Status::kSuccess) {
      std::fprintf(stderr, "GEMM 1 launch failed: %d\n", int(status));
      return EXIT_FAILURE;
    }
  }


  // Row-wise softmax with scaling by 1/sqrt(d_k)
  {
    int threads = 128;
    int blocks = (N + threads - 1) / threads;
    row_softmax<<<blocks, threads>>>(dScores, N, inv_sqrt_dk);
    CUDA_CHECK(cudaGetLastError());
  }


  // GEMM 2: Out = P [N x N] * V [N x D] -> [N x D]
  {

    Element alpha = static_cast<Element>(1.0);
    Element beta  = static_cast<Element>(0.0);

    int lda = N;  // P leading dimension (columns)
    int ldb = D;  // V leading dimension (columns)
    int ldc = D;  // Out leading dimension (columns)

    Gemm gemm_op;

    Gemm::Arguments args(
      {N, D, N},       // GEMM problem dimensions
      {dScores, lda},  // A = P
      {dV,      ldb},  // B = V
      {dOut,    ldc},  // C
      {dOut,    ldc},  // D
      {alpha, beta}
    );

    auto status = gemm_op(args);

    if (status != cutlass::Status::kSuccess) {
      std::fprintf(stderr, "GEMM 2 launch failed: %d\n", int(status));
      return EXIT_FAILURE;
    }
  }

  CUDA_CHECK(cudaDeviceSynchronize());


  // Copy back result
  std::vector<Element> hOut(N*D);
  CUDA_CHECK(cudaMemcpy(hOut.data(), dOut, N * D * soEl, \
                        cudaMemcpyDeviceToHost));

  std::printf("O[0, 0..9]: ");
  for (int j = 0; j < 9 && j < D; ++j) std::printf("%.6f ", \
                        static_cast<float>(hOut[j]));
  std::printf("\n");


  // Cleanup
  cudaFree(dQ); cudaFree(dK); cudaFree(dKt); cudaFree(dV);
  cudaFree(dScores); cudaFree(dOut);

  return EXIT_SUCCESS;
}


In [None]:
%cuda_group_run --group "sdpa" --compiler-args "--include-path /content/cutlass/include --gpu-architecture sm_75"

To build directly with nvcc, localise source file folder of nvcc4jupyter (see nvcc4jupyter ouput above)

In [None]:
!ls -al /tmp/tmpp7f7xysa/sdpa

In [None]:
!nvcc -O3 -arch=sm_75 -I$CUTLASS_INCLUDE -o sdpa sdpa.cu

In [None]:
!./sdpa