Skip to content

Commit

Permalink
paged attention
Browse files Browse the repository at this point in the history
  • Loading branch information
flozi00 committed May 22, 2024
1 parent 98ac61c commit 6d41fbf
Show file tree
Hide file tree
Showing 8 changed files with 10 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,7 @@ def forward(
cu_seqlen_prefill,
max_s,
self.softmax_scale,
window_size_left=self.max_past,
)
# Decode
else:
Expand All @@ -438,7 +439,7 @@ def forward(
query,
kv_cache[0],
kv_cache[1],
self.num_key_value_heads,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,19 +281,19 @@ def forward(
)
# Decode
else:
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
paged_attention.attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
self.num_key_value_heads,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
max_s,
)


return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size), adapter_data)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,13 +339,12 @@ def forward(
)
# Decode
else:
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
paged_attention.attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
self.num_key_value_heads,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def forward(
query,
kv_cache[0],
kv_cache[1],
self.num_key_value_heads,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def forward(
query,
kv_cache[0],
kv_cache[1],
self.num_key_value_heads,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def forward(
query,
kv_cache[0],
kv_cache[1],
self.num_key_value_heads,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def forward(
query,
kv_cache[0],
kv_cache[1],
self.num_key_value_heads,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def forward(
query,
kv_cache[0],
kv_cache[1],
self.num_key_value_heads,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
input_lengths,
Expand Down

0 comments on commit 6d41fbf

Please sign in to comment.