|
13 | 13 | import torch._dynamo.config
|
14 | 14 | import torch._inductor.config
|
15 | 15 | from loguru import logger
|
| 16 | +from torch.nn.attention import SDPBackend, sdpa_kernel |
16 | 17 | from tqdm import tqdm
|
17 | 18 | from transformers import AutoTokenizer
|
18 | 19 |
|
|
23 | 24 | TextPart,
|
24 | 25 | VQPart,
|
25 | 26 | )
|
26 |
| -from fish_speech.models.text2semantic.llama import BaseModelArgs |
| 27 | +from fish_speech.models.text2semantic.llama import ( |
| 28 | + BaseModelArgs, |
| 29 | + BaseTransformer, |
| 30 | + DualARTransformer, |
| 31 | + NaiveTransformer, |
| 32 | +) |
27 | 33 | from fish_speech.text import clean_text, split_text
|
28 | 34 | from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer
|
29 | 35 |
|
|
36 | 42 | torch._inductor.config.fx_graph_cache = True
|
37 | 43 |
|
38 | 44 |
|
39 |
| -from torch.nn.attention import SDPBackend, sdpa_kernel |
40 |
| - |
41 |
| -from fish_speech.models.text2semantic.llama import ( |
42 |
| - BaseTransformer, |
43 |
| - DualARTransformer, |
44 |
| - NaiveTransformer, |
45 |
| -) |
46 |
| - |
47 |
| - |
48 | 45 | def multinomial_sample_one_no_sync(
|
49 | 46 | probs_sort,
|
50 | 47 | ): # Does multinomial sampling without a cuda synchronization
|
@@ -372,8 +369,12 @@ def decode_n_tokens(
|
372 | 369 | window = previous_tokens[:, i - win_size : i]
|
373 | 370 |
|
374 | 371 | with (
|
375 |
| - torch.backends.cuda.sdp_kernel( |
376 |
| - enable_flash=False, enable_mem_efficient=False, enable_math=True |
| 372 | + sdpa_kernel( |
| 373 | + [ |
| 374 | + SDPBackend.FLASH_ATTENTION, |
| 375 | + SDPBackend.EFFICIENT_ATTENTION, |
| 376 | + SDPBackend.MATH, |
| 377 | + ] |
377 | 378 | )
|
378 | 379 | if torch.cuda.is_available()
|
379 | 380 | else nullcontext()
|
|
0 commit comments