Skip to content

Commit

Permalink
add naive triton kernel for varlen
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuzilin committed Mar 13, 2024
1 parent 7895974 commit 10d992c
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 32 deletions.
20 changes: 12 additions & 8 deletions ring_flash_attn/ring_flash_attn_varlen.py
Expand Up @@ -7,10 +7,19 @@
from .utils import (
RingComm,
update_out_and_lse,
flatten_varlen_lse,
unflatten_varlen_lse,
)

try:
from .triton_utils import (
flatten_varlen_lse,
unflatten_varlen_lse,
)
except:
from .utils import (
flatten_varlen_lse,
unflatten_varlen_lse,
)


def ring_flash_attn_varlen_forward(
process_group,
Expand Down Expand Up @@ -65,12 +74,7 @@ def ring_flash_attn_varlen_forward(
v = next_v

out = out.to(q.dtype)
lse = (
unflatten_varlen_lse(lse, cu_seqlens, max_seqlen)
.squeeze(dim=-1)
.transpose(1, 2)
.contiguous()
)
lse = unflatten_varlen_lse(lse, cu_seqlens, max_seqlen)
return out, lse


Expand Down
137 changes: 137 additions & 0 deletions ring_flash_attn/triton_utils.py
@@ -0,0 +1,137 @@
import torch
import triton
import triton.language as tl


@triton.jit
def flatten_kernel(
# pointers to matrices
OUT,
LSE,
CU_SEQLENS,
# strides
stride_out_nheads,
stride_out_seqlen,
stride_lse_batch,
stride_lse_nheads,
stride_lse_seqlen,
# meta-parameters
BLOCK_M: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
pid_batch = tl.program_id(axis=1)
pid_head = tl.program_id(axis=2)

start_idx = tl.load(CU_SEQLENS + pid_batch)
seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx
LSE = LSE + pid_batch * stride_lse_batch + pid_head * stride_lse_nheads
OUT = OUT + pid_head * stride_out_nheads + start_idx * stride_out_seqlen

rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)

LSE = LSE + rm[:, None] * stride_lse_seqlen
x = tl.load(LSE, mask=rm[:, None] < seqlen, other=0.0)

OUT = OUT + rm[:, None] * stride_out_seqlen
tl.store(OUT, x, mask=rm[:, None] < seqlen)


def flatten_varlen_lse(lse, cu_seqlens):
"""
Arguments:
lse: (batch_size, nheads, max_seqlen)
cu_seqlens: (batch_size + 1,)
Return:
flatten_lse: (nheads, total_seqlen)
"""
total_seqlen = cu_seqlens[-1]
batch_size, nheads, max_seqlen = lse.shape
output = torch.empty((nheads, total_seqlen), dtype=lse.dtype, device=lse.device)

grid = lambda META: (triton.cdiv(max_seqlen, META["BLOCK_M"]), batch_size, nheads)
BLOCK_M = 4

with torch.cuda.device(lse.device.index):
flatten_kernel[grid](
output,
lse,
cu_seqlens,
# strides
output.stride(0),
output.stride(1),
lse.stride(0),
lse.stride(1),
lse.stride(2),
BLOCK_M,
)
return output


@triton.jit
def unflatten_kernel(
# pointers to matrices
OUT,
LSE,
CU_SEQLENS,
# strides
stride_out_batch,
stride_out_nheads,
stride_out_seqlen,
stride_lse_seqlen,
stride_lse_nheads,
# meta-parameters
BLOCK_M: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
pid_batch = tl.program_id(axis=1)
pid_head = tl.program_id(axis=2)

start_idx = tl.load(CU_SEQLENS + pid_batch)
seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx
LSE = LSE + pid_head * stride_lse_nheads + start_idx * stride_lse_seqlen
OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads

rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)

LSE = LSE + rm[:, None] * stride_lse_seqlen
x = tl.load(LSE, mask=rm[:, None] < seqlen, other=0.0)

OUT = OUT + rm[:, None] * stride_out_seqlen
tl.store(OUT, x, mask=rm[:, None] < seqlen)


def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int):
"""
Arguments:
lse: (total_seqlen, nheads, 1)
cu_seqlens: (batch_size + 1,)
max_seqlen: int
Return:
unflatten_lse: (batch_size, nheads, max_seqlen)
"""
lse = lse.unsqueeze(dim=-1)
batch_size = len(cu_seqlens) - 1
nheads = lse.shape[1]
output = torch.empty(
(batch_size, nheads, max_seqlen),
dtype=lse.dtype,
device=lse.device,
)

grid = lambda META: (triton.cdiv(max_seqlen, META["BLOCK_M"]), batch_size, nheads)
BLOCK_M = 4

with torch.cuda.device(lse.device.index):
unflatten_kernel[grid](
output,
lse,
cu_seqlens,
# strides
output.stride(0),
output.stride(1),
output.stride(2),
lse.stride(0),
lse.stride(1),
BLOCK_M,
)
return output
22 changes: 12 additions & 10 deletions ring_flash_attn/utils.py
Expand Up @@ -66,7 +66,7 @@ def unflatten_varlen_lse(lse, cu_seqlens, max_seqlen: int):
for i in range(num_seq):
start, end = cu_seqlens[i], cu_seqlens[i + 1]
new_lse[i, : end - start] = lse[start:end]
return new_lse
return new_lse.squeeze(dim=-1).transpose(1, 2).contiguous()


class RingComm:
Expand All @@ -77,6 +77,13 @@ def __init__(self, process_group: dist.ProcessGroup):
self.world_size = dist.get_world_size(self._process_group)
self._reqs = None

self.send_rank = (self.rank + 1) % self.world_size
self.recv_rank = (self.rank - 1) % self.world_size

if process_group is not None:
self.send_rank = dist.get_global_rank(self._process_group, self.send_rank)
self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank)

def send_recv(
self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None
) -> torch.Tensor:
Expand All @@ -85,15 +92,10 @@ def send_recv(
else:
res = recv_tensor

send_rank = (self.rank + 1) % self.world_size
recv_rank = (self.rank - 1) % self.world_size

if self._process_group is not None:
send_rank = dist.get_global_rank(self._process_group, send_rank)
recv_rank = dist.get_global_rank(self._process_group, recv_rank)

send_op = dist.P2POp(dist.isend, to_send, send_rank, group=self._process_group)
recv_op = dist.P2POp(dist.irecv, res, recv_rank, group=self._process_group)
send_op = dist.P2POp(
dist.isend, to_send, self.send_rank, group=self._process_group
)
recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group)
self._ops.append(send_op)
self._ops.append(recv_op)
return res
Expand Down
30 changes: 16 additions & 14 deletions ring_flash_attn/zigzag_ring_flash_attn_varlen.py
Expand Up @@ -6,10 +6,19 @@
from .utils import (
RingComm,
update_out_and_lse,
flatten_varlen_lse,
unflatten_varlen_lse,
)

try:
from .triton_utils import (
flatten_varlen_lse,
unflatten_varlen_lse,
)
except:
from .utils import (
flatten_varlen_lse,
unflatten_varlen_lse,
)


def get_half_index(cu_seqlens, *, front: bool):
if len(cu_seqlens) == 2:
Expand Down Expand Up @@ -137,12 +146,7 @@ def forward(q, k, v, causal):
v = next_v

out = out.to(q.dtype)
lse = (
unflatten_varlen_lse(lse, cu_seqlens, max_seqlen)
.squeeze(dim=-1)
.transpose(1, 2)
.contiguous()
)
lse = unflatten_varlen_lse(lse, cu_seqlens, max_seqlen)
return out, lse


Expand Down Expand Up @@ -314,9 +318,7 @@ def forward(
q, k, v, out, softmax_lse, cu_seqlens, half_index0, half_index1
)
else:
ctx.save_for_backward(
q, k, v, out, softmax_lse, cu_seqlens
)
ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens)
ctx.half_index0 = half_index0
ctx.half_index1 = half_index1
ctx.max_seqlen = max_seqlen
Expand All @@ -332,9 +334,9 @@ def forward(
@staticmethod
def backward(ctx, dout, *args):
if ctx.is_half_index_tensor:
(
q, k, v, out, softmax_lse, cu_seqlens, half_index0, half_index1
) = ctx.saved_tensors
(q, k, v, out, softmax_lse, cu_seqlens, half_index0, half_index1) = (
ctx.saved_tensors
)
else:
q, k, v, out, softmax_lse, cu_seqlens = ctx.saved_tensors
half_index0 = ctx.half_index0
Expand Down
40 changes: 40 additions & 0 deletions test/test_triton_kernels.py
@@ -0,0 +1,40 @@
import torch
from ring_flash_attn.utils import (
flatten_varlen_lse,
unflatten_varlen_lse,
)
from ring_flash_attn.triton_utils import (
flatten_varlen_lse as triton_flatten_varlen_lse,
unflatten_varlen_lse as triton_unflatten_varlen_lse,
)


if __name__ == "__main__":
device = torch.device("cuda:0")

cu_seqlens = [0, 15, 156, 529]
cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
batch_size = len(cu_seqlens) - 1
max_seqlen = (cu_seqlens_tensor[1:] - cu_seqlens_tensor[:-1]).max().item()
n_head = 5

lse = torch.randn(
(batch_size, n_head, max_seqlen), dtype=torch.float32, device=device
)
flatten_lse = flatten_varlen_lse(lse, cu_seqlens_tensor)
triton_flatten_lse = triton_flatten_varlen_lse(lse, cu_seqlens_tensor)
assert torch.all(flatten_lse == triton_flatten_lse)

flatten_lse = flatten_lse.transpose(-2, -1).unsqueeze(dim=-1)
triton_flatten_lse = triton_flatten_lse.transpose(-2, -1).unsqueeze(dim=-1)

unflatten_lse = unflatten_varlen_lse(flatten_lse, cu_seqlens_tensor, max_seqlen)
triton_unflatten_lse = triton_unflatten_varlen_lse(
triton_flatten_lse, cu_seqlens_tensor, max_seqlen
)

for i in range(batch_size):
seqlen = cu_seqlens[i + 1] - cu_seqlens[i]
assert torch.all(
unflatten_lse[i, :, :seqlen] == triton_unflatten_lse[i, :, :seqlen]
), f"{unflatten_lse[i, :seqlen]} vs {triton_unflatten_lse[i, :seqlen]}"

0 comments on commit 10d992c

Please sign in to comment.