Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torch/nn/attention/bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def _lower_right(self, device: torch.device) -> torch.Tensor:
diagonal=diagonal_offset,
)

# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
def _materialize(self, device: Optional[torch.device] = None) -> torch.Tensor:
"""
Materializes the causal bias into a tensor form.
Expand Down
42 changes: 21 additions & 21 deletions torch/nn/attention/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _warn_once(
_mask_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor]


# pyrefly: ignore # invalid-inheritance
# pyrefly: ignore [invalid-inheritance]
class FlexKernelOptions(TypedDict, total=False):
"""Options for controlling the behavior of FlexAttention kernels.

Expand Down Expand Up @@ -128,97 +128,97 @@ class FlexKernelOptions(TypedDict, total=False):
"""

# Performance tuning options
# pyrefly: ignore # invalid-annotation
# pyrefly: ignore [invalid-annotation]
num_warps: NotRequired[int]
"""Number of warps to use in the CUDA kernel. Higher values may improve performance
but increase register pressure. Default is determined by autotuning."""

# pyrefly: ignore # invalid-annotation
# pyrefly: ignore [invalid-annotation]
num_stages: NotRequired[int]
"""Number of pipeline stages in the CUDA kernel. Higher values may improve performance
but increase shared memory usage. Default is determined by autotuning."""

# pyrefly: ignore # invalid-annotation
# pyrefly: ignore [invalid-annotation]
BLOCK_M: NotRequired[int]
"""Thread block size for the sequence length dimension of Q in forward pass.
Must be a power of 2. Common values: 16, 32, 64, 128. Default is determined by autotuning."""

# pyrefly: ignore # invalid-annotation
# pyrefly: ignore [invalid-annotation]
BLOCK_N: NotRequired[int]
"""Thread block size for the sequence length dimension of K/V in forward pass.
Must be a power of 2. Common values: 16, 32, 64, 128. Default is determined by autotuning."""

# Backward-specific block sizes (when prefixed with 'bwd_')
# pyrefly: ignore # invalid-annotation
# pyrefly: ignore [invalid-annotation]
BLOCK_M1: NotRequired[int]
"""Thread block size for Q dimension in backward pass. Use as 'bwd_BLOCK_M1'.
Default is determined by autotuning."""

# pyrefly: ignore # invalid-annotation
# pyrefly: ignore [invalid-annotation]
BLOCK_N1: NotRequired[int]
"""Thread block size for K/V dimension in backward pass. Use as 'bwd_BLOCK_N1'.
Default is determined by autotuning."""

# pyrefly: ignore # invalid-annotation
# pyrefly: ignore [invalid-annotation]
BLOCK_M2: NotRequired[int]
"""Thread block size for second Q dimension in backward pass. Use as 'bwd_BLOCK_M2'.
Default is determined by autotuning."""

# pyrefly: ignore # invalid-annotation
# pyrefly: ignore [invalid-annotation]
BLOCK_N2: NotRequired[int]
"""Thread block size for second K/V dimension in backward pass. Use as 'bwd_BLOCK_N2'.
Default is determined by autotuning."""

# pyrefly: ignore # invalid-annotation
# pyrefly: ignore [invalid-annotation]
PRESCALE_QK: NotRequired[bool]
"""Whether to pre-scale QK by 1/sqrt(d) and change of base. This is slightly faster but
may have more numerical error. Default: False."""

# pyrefly: ignore # invalid-annotation
# pyrefly: ignore [invalid-annotation]
ROWS_GUARANTEED_SAFE: NotRequired[bool]
"""If True, guarantees that at least one value in each row is not masked out.
Allows skipping safety checks for better performance. Only set this if you are certain
your mask guarantees this property. For example, causal attention is guaranteed safe
because each query has at least 1 key-value to attend to. Default: False."""

# pyrefly: ignore # invalid-annotation
# pyrefly: ignore [invalid-annotation]
BLOCKS_ARE_CONTIGUOUS: NotRequired[bool]
"""If True, guarantees that all blocks in the mask are contiguous.
Allows optimizing block traversal. For example, causal masks would satisfy this,
but prefix_lm + sliding window would not. Default: False."""

# pyrefly: ignore # invalid-annotation
# pyrefly: ignore [invalid-annotation]
WRITE_DQ: NotRequired[bool]
"""Controls whether gradient scatters are done in the DQ iteration loop of the backward pass.
Setting this to False will force this to happen in the DK loop which depending on your
specific score_mod and mask_mod might be faster. Default: True."""

# pyrefly: ignore # invalid-annotation
# pyrefly: ignore [invalid-annotation]
FORCE_USE_FLEX_ATTENTION: NotRequired[bool]
"""If True, forces the use of the flex attention kernel instead of potentially using
the more optimized flex-decoding kernel for short sequences. This can be a helpful
option for debugging. Default: False."""

# pyrefly: ignore # invalid-annotation
# pyrefly: ignore [invalid-annotation]
USE_TMA: NotRequired[bool]
"""Whether to use Tensor Memory Accelerator (TMA) on supported hardware.
This is experimental and may not work on all hardware, currently specific
to NVIDIA GPUs Hopper+. Default: False."""

# ROCm-specific options
# pyrefly: ignore # invalid-annotation
# pyrefly: ignore [invalid-annotation]
kpack: NotRequired[int]
"""ROCm-specific kernel packing parameter."""

# pyrefly: ignore # invalid-annotation
# pyrefly: ignore [invalid-annotation]
matrix_instr_nonkdim: NotRequired[int]
"""ROCm-specific matrix instruction non-K dimension."""

# pyrefly: ignore # invalid-annotation
# pyrefly: ignore [invalid-annotation]
waves_per_eu: NotRequired[int]
"""ROCm-specific waves per execution unit."""

# pyrefly: ignore # invalid-annotation
# pyrefly: ignore [invalid-annotation]
force_flash: NotRequired[bool]
""" If True, forces use of the cute-dsl flash attention kernel.

Expand Down Expand Up @@ -644,7 +644,7 @@ def as_tuple(self, flatten: bool = True):
block_size = (self.BLOCK_SIZE,) # type: ignore[assignment]
seq_lengths = (self.seq_lengths,) # type: ignore[assignment]

# pyrefly: ignore # not-iterable
# pyrefly: ignore [not-iterable]
return (
*seq_lengths,
self.kv_num_blocks,
Expand Down Expand Up @@ -817,7 +817,7 @@ def to_dense(self) -> Tensor:
partial_dense = _ordered_to_dense(self.kv_num_blocks, self.kv_indices)
if self.full_kv_num_blocks is not None:
assert self.full_kv_indices is not None
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return partial_dense | _ordered_to_dense(
self.full_kv_num_blocks, self.full_kv_indices
)
Expand Down
2 changes: 1 addition & 1 deletion torch/nn/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _apply(self, fn, recurse=True):

# nn.Module defines training as a boolean
@property # type: ignore[override]
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def training(self):
return self.cpp_module.training

Expand Down
Loading
Loading