1 file changed +13
-12
lines changed Original file line number Diff line number Diff line change 13
13
import torch ._dynamo .config
14
14
import torch ._inductor .config
15
15
from loguru import logger
16
+ from torch .nn .attention import SDPBackend , sdpa_kernel
16
17
from tqdm import tqdm
17
18
from transformers import AutoTokenizer
18
19
23
24
TextPart ,
24
25
VQPart ,
25
26
)
26
- from fish_speech .models .text2semantic .llama import BaseModelArgs
27
+ from fish_speech .models .text2semantic .llama import (
28
+ BaseModelArgs ,
29
+ BaseTransformer ,
30
+ DualARTransformer ,
31
+ NaiveTransformer ,
32
+ )
27
33
from fish_speech .text import clean_text , split_text
28
34
from fish_speech .tokenizer import IM_END_TOKEN , FishTokenizer
29
35
36
42
torch ._inductor .config .fx_graph_cache = True
37
43
38
44
39
- from torch .nn .attention import SDPBackend , sdpa_kernel
40
-
41
- from fish_speech .models .text2semantic .llama import (
42
- BaseTransformer ,
43
- DualARTransformer ,
44
- NaiveTransformer ,
45
- )
46
-
47
-
48
45
def multinomial_sample_one_no_sync (
49
46
probs_sort ,
50
47
): # Does multinomial sampling without a cuda synchronization
@@ -372,8 +369,12 @@ def decode_n_tokens(
372
369
window = previous_tokens [:, i - win_size : i ]
373
370
374
371
with (
375
- torch .backends .cuda .sdp_kernel (
376
- enable_flash = False , enable_mem_efficient = False , enable_math = True
372
+ sdpa_kernel (
373
+ [
374
+ SDPBackend .FLASH_ATTENTION ,
375
+ SDPBackend .EFFICIENT_ATTENTION ,
376
+ SDPBackend .MATH ,
377
+ ]
377
378
)
378
379
if torch .cuda .is_available ()
379
380
else nullcontext ()
0 commit comments