Skip to content

Commit b3c39d3

Browse files
authoredMar 17, 2025
[fix]:fix future warning: sdp_kernel
fix future warning: sdp_kernel
1 parent b167642 commit b3c39d3

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed
 

‎fish_speech/models/text2semantic/inference.py

+13-12
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import torch._dynamo.config
1414
import torch._inductor.config
1515
from loguru import logger
16+
from torch.nn.attention import SDPBackend, sdpa_kernel
1617
from tqdm import tqdm
1718
from transformers import AutoTokenizer
1819

@@ -23,7 +24,12 @@
2324
TextPart,
2425
VQPart,
2526
)
26-
from fish_speech.models.text2semantic.llama import BaseModelArgs
27+
from fish_speech.models.text2semantic.llama import (
28+
BaseModelArgs,
29+
BaseTransformer,
30+
DualARTransformer,
31+
NaiveTransformer,
32+
)
2733
from fish_speech.text import clean_text, split_text
2834
from fish_speech.tokenizer import IM_END_TOKEN, FishTokenizer
2935

@@ -36,15 +42,6 @@
3642
torch._inductor.config.fx_graph_cache = True
3743

3844

39-
from torch.nn.attention import SDPBackend, sdpa_kernel
40-
41-
from fish_speech.models.text2semantic.llama import (
42-
BaseTransformer,
43-
DualARTransformer,
44-
NaiveTransformer,
45-
)
46-
47-
4845
def multinomial_sample_one_no_sync(
4946
probs_sort,
5047
): # Does multinomial sampling without a cuda synchronization
@@ -372,8 +369,12 @@ def decode_n_tokens(
372369
window = previous_tokens[:, i - win_size : i]
373370

374371
with (
375-
torch.backends.cuda.sdp_kernel(
376-
enable_flash=False, enable_mem_efficient=False, enable_math=True
372+
sdpa_kernel(
373+
[
374+
SDPBackend.FLASH_ATTENTION,
375+
SDPBackend.EFFICIENT_ATTENTION,
376+
SDPBackend.MATH,
377+
]
377378
)
378379
if torch.cuda.is_available()
379380
else nullcontext()

0 commit comments

Comments
 (0)
Failed to load comments.