In [None]:
'''
Make sure your runtime has a GPU enabled.
'''
# For building CUDA
!pip install Ninja

In [None]:
import os
import math

import torch
from torch.nn import functional as F
from torch.utils.cpp_extension import load_inline

cuda_src = '''
template<int BLOCK_SIZE, int D>
__global__
void forward_kernel(
    const float* Q,
    const float* K,
    const float* V,
    const int* block_indices,
    const int Nq,
    const int Nk, 
    const int B,
    const int block_count,
    const float softmax_scale,
    float* output,
    const bool use_causal_mask
) {
    int tx = threadIdx.x; // Specific query being processed
    int blockSize = blockDim.x; // This specifies how big the given block is
    int bx = blockIdx.x; // Specific block being processed
    int by = blockIdx.y; // Specifies the batch/head currently being processed. The input to this kernel should be reshaped such that the first dimension is B * H

    int q_offset = by * Nq * D + bx * blockSize * D + tx * D;

    extern __shared__ float shared_memory[];
    int tile_size = blockSize * D;
    float* shared_q = shared_memory;
    float* shared_k = &shared_memory[tile_size];
    float* shared_v = &shared_memory[2 * tile_size];

    float running_max = -INFINITY;
    float running_sum = 0;

    float* acc[BLOCK_SIZE][D] = {0};
    float* P[BLOCK_SIZE] = {0};

    for(int j = 0; j < block_count; j++){
        for(int x = 0; x < D; x++){ // should try to introduce coalescing into this loop
            shared_k[tx * D + x] = K[by * Nk * D + block_indices[j] * blockSize * D + tx * D + x];
            shared_v[tx * D + x] = V[by * Nk * D + block_indices[j] * blockSize * D + tx * D + x];
        }
        __syncthreads();

        float new_max = running_max;

        for(int i = 0; i < blockSize; i++){
            float dot_product = 0;
            for(int x = 0; x < D; x++){
                dot_product += Q[q_offset + x] * shared_k[i * D + x];
            }
            dot_product *= softmax_scale;
            if(use_causal_mask){
                if(block_indices[j] * blockSize + i > bx * blockSize + tx){
                    dot_product = -INFINITY;
                }
            }
            P[i] = dot_product;
            new_max = fmaxf(new_max, dot_product);
        }

        float alpha = __exp2(new_max - running_max);
        running_sum *= (1 / alpha);
        for(int i = 0; i < blockSize; i++){
            float token_weight = __exp2(P[i] - new_max);
            running_sum += token_weight;
            for(int x = 0; x < D; x++){
                acc[i][x] *= (1 / alpha);
                acc[i][x] += token_weight * shared_v[i * D + x];
            }
        }

        running_max = new_max;
        __syncthreads();
    }

    for(int i = 0; i < blockSize; i++){
        for(int x = 0; x < D; x++){
            output[by * Nq * D + bx * blockSize * D + i * D + x] = acc[i][x] / running_sum;
        }
    }
}

torch::Tensor forward(
    torch::Tensor queries,
    torch::Tensor keys,
    torch::Tensor values,
    torch::Tensor query_blocks,
    int64_t block_size,
    float dropout_p) {
    
    // the first dimensions should be B, H, T, D
    int B = queries.size(0);
    int H = queries.size(1);
    int T = queries.size(2);
    int D = queries.size(3);

    dim3 gridDim(ceil(T * 1.0f / block_size), B * H, 1);
    dim3 blockDim(block_size, 1, 1);

    auto output = torch::zeros_like(queries);

    forward_kernel<block_size, D><<<gridDim, blockDim>>>(
        queries.data_ptr<float>(),
        keys.data_ptr<float>(),
        values.data_ptr<float>(),
        query_blocks.data_ptr<int>(),
        T,
        T,
        B * H,
        query_blocks.size(0),
        1.0 / sqrtf(D),
        output.data_ptr<float>(),
        false
    );

    return output;
}
'''
cpp_src = 'torch::Tensor forward(torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor query_blocks, int64_t block_size);'

build_dir = 'cuda'
if not os.path.exists(build_dir):
    os.mkdir(build_dir)

block_sparse_attention = load_inline(
    name='block_sparse_attention',
    cpp_sources=cpp_src,
    cuda_sources=cuda_src,
    functions=['forward'],
    with_cuda=True,
    extra_cuda_cflags=['-O2'],
    build_directory=f'./{build_dir}'
)

In [None]:
B = 2
H = 4
T = 100
D = 128
block_size = 16

q = torch.randn(B, H, T, D).cuda()
k = torch.randn(B, H, T, D).cuda()
v = torch.randn(B, H, T, D).cuda()
block_indices = torch.randint(0, T // block_size, (B, H, T // block_size, 4)).cuda()

print('=== profiling manual attention ===')

# Our minimal flash attention needs to be faster than this.
def baseline_block_sparse_attention(q, k, v, block_indices, block_size):
    B, H, T, D = q.shape
    O = torch.zeros_like(v)

    for b in range(B):
        for h in range(H):
            bh_output = []
            for query_block_index in range((T + block_size - 1) // block_size):
                query_block = q[b, h, query_block_index * block_size : (query_block_index + 1) * block_size, ...]
                key_blocks = []
                value_blocks = []
                for block_indices in block_indices[b][h][query_block_index]:
                    key_block = k[b, h, block_indices * block_size : (block_indices + 1) * block_size, ...]
                    key_blocks.append(key_block)

                    value_block = v[b, h, block_indices * block_size : (block_indices + 1) * block_size, ...]
                    value_blocks.append(value_block)
                key_block = torch.cat(key_blocks, dim=0)
                value_block = torch.cat(value_blocks, dim=0)

                attention = torch.matmul(query_block, key_block.transpose(-2, -1))
                attention = attention / (D ** 0.5)
                attention = torch.nn.functional.softmax(attention, dim=-1)
                output = torch.matmul(attention, value_block)
                bh_output.append(output)
            bh_output = torch.cat(bh_output, dim=0)
            O[b, h, ...] = bh_output
    return O

with torch.autograd.profiler.profile(use_cuda=True) as prof:
    manual_result = baseline_block_sparse_attention(q, k, v, block_indices, block_size)
print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=10))

print('=== profiling custom cuda block flash attention === ')

with torch.autograd.profiler.profile(use_cuda=True) as prof:
    minimal_result = block_sparse_attention.forward(q, k, v)
print(prof.key_averages().table(sort_by='cuda_time_total', row_limit=10))

print('attn values sanity check:', torch.allclose(minimal_result, manual_result, rtol=0, atol=1e-02))