-
Notifications
You must be signed in to change notification settings - Fork 41
/
test_fused_chunk.py
125 lines (108 loc) · 4.02 KB
/
test_fused_chunk.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# -*- coding: utf-8 -*-
import torch
import triton
import triton.language as tl
@triton.jit
def attention_fwd_kernel(
q,
k,
v,
h,
o,
s_qh,
s_qt,
s_qd,
s_hh,
s_ht,
T,
scale,
BT: tl.constexpr,
BD: tl.constexpr,
NT: tl.constexpr,
STORE: tl.constexpr,
IFCOND: tl.constexpr
):
i_bh = tl.program_id(0)
# [BD, BD]
b_h = tl.zeros([BD, BD], dtype=tl.float32)
for i in range(0, tl.cdiv(T, BT)):
p_q = tl.make_block_ptr(q + i_bh * s_qh, (T, BD), (s_qt, s_qd), (i * BT, 0), (BT, BD), (1, 0))
p_k = tl.make_block_ptr(k + i_bh * s_qh, (BD, T), (s_qd, s_qt), (0, i * BT), (BD, BT), (0, 1))
p_v = tl.make_block_ptr(v + i_bh * s_qh, (T, BD), (s_qt, s_qd), (i * BT, 0), (BT, BD), (1, 0))
p_h = tl.make_block_ptr(h + i_bh * s_hh, (NT * BD, BD), (s_ht, s_qd), (i * BD, 0), (BD, BD), (1, 0))
p_o = tl.make_block_ptr(o + i_bh * s_qh, (T, BD), (s_qt, s_qd), (i * BT, 0), (BT, BD), (1, 0))
if STORE:
tl.store(p_h, b_h.to(p_h.dtype.element_ty))
# [BT, BD]
b_q = tl.load(p_q)
b_q = (b_q * scale).to(b_q.dtype)
# [BD, BT]
b_k = tl.load(p_k)
# [BT, BD]
b_v = tl.load(p_v)
# [BT, BT]
b_s = tl.dot(b_q, b_k, allow_tf32=False)
# [BT, BD]
b_o = tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)
if IFCOND:
if i == 0:
b_h = tl.dot(b_k, b_v, allow_tf32=False)
else:
b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)
b_h += tl.dot(b_k, b_v, allow_tf32=False)
else:
b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)
b_h += tl.dot(b_k, b_v, allow_tf32=False)
tl.store(p_o, b_o.to(p_o.dtype.element_ty))
class AttentionFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, store=False, ifcond=False):
batch_size, n_heads, seq_len, d_head = q.shape
scale = d_head ** -0.5
BD = q.shape[-1]
BT = 32
NT = triton.cdiv(seq_len, BT)
num_stages = 3 if d_head <= 64 else 2
num_warps = 4
h = q.new_empty(batch_size, n_heads, NT * BD, BD)
o = torch.empty_like(q)
grid = (batch_size * n_heads,)
attention_fwd_kernel[grid](
q, k, v, h, o,
q.stride(1), q.stride(2), q.stride(3), h.stride(1), h.stride(2),
seq_len, scale,
BT=BT, BD=BD, NT=NT, STORE=store, IFCOND=ifcond,
num_warps=num_warps,
num_stages=num_stages
)
return o
def sizeof_fmt(num, suffix='B'):
for unit in ('', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi'):
if abs(num) < 1024.0:
return f'{num:3.1f}{unit}{suffix}'
num /= 1024.0
return f'{num:.1f}Yi{suffix}'
if __name__ == '__main__':
B, H, T, D = 2, 8, 1024, 128
dtype = torch.float
torch.manual_seed(42)
# [batch_size, n_heads, seq_len, d_head]
q = torch.randn((B, H, T, D), dtype=dtype, device='cuda')
k = torch.randn((B, H, T, D), dtype=dtype, device='cuda')
v = torch.randn((B, H, T, D), dtype=dtype, device='cuda')
ref = AttentionFunction.apply(q, k, v)
infos = torch.cuda.get_device_properties(q)
def fmt(x):
if isinstance(x, (float, torch.Tensor)):
return f"{x:>16.2f}"
return f"{str(x):>16}"
print(f'{fmt("GPU Type")}{fmt("Memory")}{fmt("Cores")}\n'
f"{fmt(infos.name)}{fmt(sizeof_fmt(infos.total_memory))}{fmt(infos.multi_processor_count)}")
print(f'{"DTYPE":>16}{"STORE":>16}{"INIT CHECK":>16}{"DIFF":>16}{"PASSED":>16}')
for dtype in (torch.float, torch.bfloat16):
q, k, v = q.clone().to(dtype), k.clone().to(dtype), v.clone().to(dtype)
for store in [False, True]:
for check in [False, True]:
tri = AttentionFunction.apply(q, k, v, store, check)
diff = (ref - tri).abs().max()
print(f"{fmt(q.dtype)}{fmt(store)}{fmt(check)}{fmt(diff)}{fmt(bool(diff < 1))}")