File tree Expand file tree Collapse file tree 4 files changed +24
-6
lines changed
Expand file tree Collapse file tree 4 files changed +24
-6
lines changed Original file line number Diff line number Diff line change @@ -36,8 +36,13 @@ class FullAttention(nn.Module):
3636 scale (float): Scaling factor for attention scores.
3737 attention_dropout (float): Dropout rate for attention scores.
3838 output_attention (bool): Whether to output attention weights.
39- use_efficient_attention (bool): Whether to use torch's native efficient
40- scaled dot product attention implementation.
39+ use_efficient_attention (bool): Whether to use PyTorch's native,
40+ optimized Scaled Dot Product Attention implementation which can
41+ reduce computation time and memory consumption for longer sequences.
42+ PyTorch automatically selects the optimal backend (FlashAttention-2,
43+ Memory-Efficient Attention, or their own C++ implementation) based
44+ on user's input properties, hardware capabilities, and build
45+ configuration.
4146 """
4247
4348 def __init__ (
Original file line number Diff line number Diff line change @@ -121,7 +121,11 @@ def __init__(
121121 ('relu' or 'gelu').
122122 use_efficient_attention (bool, optional): If set to True, will use
123123 PyTorch's native, optimized Scaled Dot Product Attention
124- Implementation.
124+ implementation which can reduce computation time and memory
125+ consumption for longer sequences. PyTorch automatically selects the
126+ optimal backend (FlashAttention-2, Memory-Efficient Attention, or
127+ their own C++ implementation) based on user's input properties,
128+ hardware capabilities, and build configuration.
125129 patch_length (int, optional): Length of each non-overlapping patch for
126130 endogenous variable tokenization.
127131 use_norm (bool, optional): Whether to apply normalization to input data.
Original file line number Diff line number Diff line change @@ -59,7 +59,11 @@ class TimeXer(TslibBaseModel):
5959 Activation function to use in the feed-forward network. Common choices are 'relu', 'gelu', etc.
6060 use_efficient_attention: bool, default=False
6161 If set to True, will use PyTorch's native, optimized Scaled Dot Product
62- Attention Implementation.
62+ Attention implementation which can reduce computation time and memory
63+ consumption for longer sequences. PyTorch automatically selects the
64+ optimal backend (FlashAttention-2, Memory-Efficient Attention, or their
65+ own C++ implementation) based on user's input properties, hardware
66+ capabilities, and build configuration.
6367 endogenous_vars: Optional[list[str]], default=None
6468 List of endogenous variable names to be used in the model. If None, all historical values
6569 for the target variable are used.
Original file line number Diff line number Diff line change @@ -37,8 +37,13 @@ class FullAttention(nn.Module):
3737 scale (float): Scaling factor for attention scores.
3838 attention_dropout (float): Dropout rate for attention scores.
3939 output_attention (bool): Whether to output attention weights.
40- use_efficient_attention (bool): Whether to use torch's native efficient
41- scaled dot product attention implementation.
40+ use_efficient_attention (bool): Whether to use PyTorch's native,
41+ optimized Scaled Dot Product Attention implementation which can
42+ reduce computation time and memory consumption for longer sequences.
43+ PyTorch automatically selects the optimal backend (FlashAttention-2,
44+ Memory-Efficient Attention, or their own C++ implementation) based
45+ on user's input properties, hardware capabilities, and build
46+ configuration.
4247 """
4348
4449 def __init__ (
You can’t perform that action at this time.
0 commit comments