Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
217 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
import torch | ||
import triton | ||
import triton.language as tl | ||
|
||
|
||
@triton.jit | ||
def flatten_kernel( | ||
# pointers to matrices | ||
OUT, | ||
LSE, | ||
CU_SEQLENS, | ||
# strides | ||
stride_out_nheads, | ||
stride_out_seqlen, | ||
stride_lse_batch, | ||
stride_lse_nheads, | ||
stride_lse_seqlen, | ||
# meta-parameters | ||
BLOCK_M: tl.constexpr, | ||
): | ||
pid_m = tl.program_id(axis=0) | ||
pid_batch = tl.program_id(axis=1) | ||
pid_head = tl.program_id(axis=2) | ||
|
||
start_idx = tl.load(CU_SEQLENS + pid_batch) | ||
seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx | ||
LSE = LSE + pid_batch * stride_lse_batch + pid_head * stride_lse_nheads | ||
OUT = OUT + pid_head * stride_out_nheads + start_idx * stride_out_seqlen | ||
|
||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) | ||
|
||
LSE = LSE + rm[:, None] * stride_lse_seqlen | ||
x = tl.load(LSE, mask=rm[:, None] < seqlen, other=0.0) | ||
|
||
OUT = OUT + rm[:, None] * stride_out_seqlen | ||
tl.store(OUT, x, mask=rm[:, None] < seqlen) | ||
|
||
|
||
def flatten_varlen_lse(lse, cu_seqlens): | ||
""" | ||
Arguments: | ||
lse: (batch_size, nheads, max_seqlen) | ||
cu_seqlens: (batch_size + 1,) | ||
Return: | ||
flatten_lse: (nheads, total_seqlen) | ||
""" | ||
total_seqlen = cu_seqlens[-1] | ||
batch_size, nheads, max_seqlen = lse.shape | ||
output = torch.empty((nheads, total_seqlen), dtype=lse.dtype, device=lse.device) | ||
|
||
grid = lambda META: (triton.cdiv(max_seqlen, META["BLOCK_M"]), batch_size, nheads) | ||
BLOCK_M = 4 | ||
|
||
with torch.cuda.device(lse.device.index): | ||
flatten_kernel[grid]( | ||
output, | ||
lse, | ||
cu_seqlens, | ||
# strides | ||
output.stride(0), | ||
output.stride(1), | ||
lse.stride(0), | ||
lse.stride(1), | ||
lse.stride(2), | ||
BLOCK_M, | ||
) | ||
return output | ||
|
||
|
||
@triton.jit | ||
def unflatten_kernel( | ||
# pointers to matrices | ||
OUT, | ||
LSE, | ||
CU_SEQLENS, | ||
# strides | ||
stride_out_batch, | ||
stride_out_nheads, | ||
stride_out_seqlen, | ||
stride_lse_seqlen, | ||
stride_lse_nheads, | ||
# meta-parameters | ||
BLOCK_M: tl.constexpr, | ||
): | ||
pid_m = tl.program_id(axis=0) | ||
pid_batch = tl.program_id(axis=1) | ||
pid_head = tl.program_id(axis=2) | ||
|
||
start_idx = tl.load(CU_SEQLENS + pid_batch) | ||
seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx | ||
LSE = LSE + pid_head * stride_lse_nheads + start_idx * stride_lse_seqlen | ||
OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads | ||
|
||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) | ||
|
||
LSE = LSE + rm[:, None] * stride_lse_seqlen | ||
x = tl.load(LSE, mask=rm[:, None] < seqlen, other=0.0) | ||
|
||
OUT = OUT + rm[:, None] * stride_out_seqlen | ||
tl.store(OUT, x, mask=rm[:, None] < seqlen) | ||
|
||
|
||
def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int): | ||
""" | ||
Arguments: | ||
lse: (total_seqlen, nheads, 1) | ||
cu_seqlens: (batch_size + 1,) | ||
max_seqlen: int | ||
Return: | ||
unflatten_lse: (batch_size, nheads, max_seqlen) | ||
""" | ||
lse = lse.unsqueeze(dim=-1) | ||
batch_size = len(cu_seqlens) - 1 | ||
nheads = lse.shape[1] | ||
output = torch.empty( | ||
(batch_size, nheads, max_seqlen), | ||
dtype=lse.dtype, | ||
device=lse.device, | ||
) | ||
|
||
grid = lambda META: (triton.cdiv(max_seqlen, META["BLOCK_M"]), batch_size, nheads) | ||
BLOCK_M = 4 | ||
|
||
with torch.cuda.device(lse.device.index): | ||
unflatten_kernel[grid]( | ||
output, | ||
lse, | ||
cu_seqlens, | ||
# strides | ||
output.stride(0), | ||
output.stride(1), | ||
output.stride(2), | ||
lse.stride(0), | ||
lse.stride(1), | ||
BLOCK_M, | ||
) | ||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import torch | ||
from ring_flash_attn.utils import ( | ||
flatten_varlen_lse, | ||
unflatten_varlen_lse, | ||
) | ||
from ring_flash_attn.triton_utils import ( | ||
flatten_varlen_lse as triton_flatten_varlen_lse, | ||
unflatten_varlen_lse as triton_unflatten_varlen_lse, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
device = torch.device("cuda:0") | ||
|
||
cu_seqlens = [0, 15, 156, 529] | ||
cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) | ||
batch_size = len(cu_seqlens) - 1 | ||
max_seqlen = (cu_seqlens_tensor[1:] - cu_seqlens_tensor[:-1]).max().item() | ||
n_head = 5 | ||
|
||
lse = torch.randn( | ||
(batch_size, n_head, max_seqlen), dtype=torch.float32, device=device | ||
) | ||
flatten_lse = flatten_varlen_lse(lse, cu_seqlens_tensor) | ||
triton_flatten_lse = triton_flatten_varlen_lse(lse, cu_seqlens_tensor) | ||
assert torch.all(flatten_lse == triton_flatten_lse) | ||
|
||
flatten_lse = flatten_lse.transpose(-2, -1).unsqueeze(dim=-1) | ||
triton_flatten_lse = triton_flatten_lse.transpose(-2, -1).unsqueeze(dim=-1) | ||
|
||
unflatten_lse = unflatten_varlen_lse(flatten_lse, cu_seqlens_tensor, max_seqlen) | ||
triton_unflatten_lse = triton_unflatten_varlen_lse( | ||
triton_flatten_lse, cu_seqlens_tensor, max_seqlen | ||
) | ||
|
||
for i in range(batch_size): | ||
seqlen = cu_seqlens[i + 1] - cu_seqlens[i] | ||
assert torch.all( | ||
unflatten_lse[i, :, :seqlen] == triton_unflatten_lse[i, :, :seqlen] | ||
), f"{unflatten_lse[i, :seqlen]} vs {triton_unflatten_lse[i, :seqlen]}" |