From 95a61d00d18870aac67b7e9f58c66ff0fd33a989 Mon Sep 17 00:00:00 2001 From: Maggie Moss Date: Sat, 25 Oct 2025 18:20:54 -0700 Subject: [PATCH] Fix error suppressions in utils and nn --- torch/nn/attention/bias.py | 2 +- torch/nn/attention/flex_attention.py | 42 +++--- torch/nn/cpp.py | 2 +- torch/nn/functional.py | 130 +++++++++--------- torch/nn/init.py | 2 +- torch/nn/modules/_functions.py | 8 +- torch/nn/modules/container.py | 20 +-- torch/nn/modules/conv.py | 36 ++--- torch/nn/modules/lazy.py | 4 +- torch/nn/modules/linear.py | 12 +- torch/nn/modules/padding.py | 24 ++-- torch/nn/modules/transformer.py | 14 +- torch/nn/modules/utils.py | 2 +- torch/nn/parallel/comm.py | 4 +- torch/nn/parallel/data_parallel.py | 8 +- torch/nn/parallel/distributed.py | 10 +- torch/nn/parallel/scatter_gather.py | 14 +- torch/nn/parameter.py | 20 +-- .../conv_expanded_weights.py | 4 +- .../nn/utils/_expanded_weights/conv_utils.py | 2 +- .../embedding_expanded_weights.py | 4 +- .../expanded_weights_impl.py | 2 +- .../expanded_weights_utils.py | 2 +- .../group_norm_expanded_weights.py | 6 +- .../instance_norm_expanded_weights.py | 4 +- .../layer_norm_expanded_weights.py | 4 +- .../linear_expanded_weights.py | 4 +- torch/nn/utils/_named_member_accessor.py | 2 +- torch/nn/utils/clip_grad.py | 6 +- torch/nn/utils/memory_format.py | 4 +- torch/nn/utils/parametrizations.py | 6 +- torch/nn/utils/parametrize.py | 14 +- torch/nn/utils/prune.py | 8 +- torch/nn/utils/spectral_norm.py | 2 +- torch/utils/_contextlib.py | 2 +- torch/utils/_cxx_pytree.py | 2 +- torch/utils/_debug_mode.py | 2 +- torch/utils/_device.py | 2 +- torch/utils/_functools.py | 4 +- torch/utils/_pytree.py | 18 +-- .../_strobelight/cli_function_profiler.py | 2 +- torch/utils/_sympy/functions.py | 42 +++--- torch/utils/_sympy/numbers.py | 4 +- torch/utils/_sympy/printers.py | 2 +- torch/utils/_sympy/reference.py | 2 +- torch/utils/_sympy/value_ranges.py | 18 +-- .../benchmark/examples/sparse/compare.py | 2 +- torch/utils/benchmark/utils/compile.py | 4 +- torch/utils/benchmark/utils/cpp_jit.py | 2 +- torch/utils/benchmark/utils/sparse_fuzzer.py | 8 +- torch/utils/benchmark/utils/timer.py | 2 +- .../utils/valgrind_wrapper/timer_interface.py | 8 +- torch/utils/checkpoint.py | 6 +- torch/utils/cpp_extension.py | 6 +- torch/utils/data/_utils/collate.py | 2 +- torch/utils/data/_utils/pin_memory.py | 2 +- torch/utils/data/datapipes/_typing.py | 4 +- .../data/datapipes/dataframe/dataframes.py | 28 ++-- .../data/datapipes/dataframe/datapipes.py | 2 +- torch/utils/data/datapipes/datapipe.py | 8 +- torch/utils/data/datapipes/iter/callable.py | 2 +- .../data/datapipes/iter/combinatorics.py | 2 +- torch/utils/data/datapipes/iter/combining.py | 16 +-- torch/utils/data/datapipes/iter/grouping.py | 8 +- torch/utils/data/datapipes/map/callable.py | 2 +- .../utils/data/datapipes/map/combinatorics.py | 4 +- torch/utils/data/datapipes/map/combining.py | 8 +- torch/utils/data/datapipes/utils/common.py | 2 +- torch/utils/data/datapipes/utils/snapshot.py | 2 +- torch/utils/data/distributed.py | 2 +- torch/utils/data/graph.py | 2 +- torch/utils/file_baton.py | 2 +- torch/utils/flop_counter.py | 2 +- torch/utils/hooks.py | 2 +- torch/utils/model_dump/__init__.py | 8 +- torch/utils/tensorboard/_convert_np.py | 4 +- torch/utils/tensorboard/_pytorch_graph.py | 4 +- torch/utils/tensorboard/_utils.py | 6 +- torch/utils/tensorboard/summary.py | 10 +- torch/utils/tensorboard/writer.py | 14 +- torch/utils/viz/_cycles.py | 2 +- 81 files changed, 358 insertions(+), 358 deletions(-) diff --git a/torch/nn/attention/bias.py b/torch/nn/attention/bias.py index fceec1272c16..2a1a97fc756d 100644 --- a/torch/nn/attention/bias.py +++ b/torch/nn/attention/bias.py @@ -153,7 +153,7 @@ def _lower_right(self, device: torch.device) -> torch.Tensor: diagonal=diagonal_offset, ) - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] def _materialize(self, device: Optional[torch.device] = None) -> torch.Tensor: """ Materializes the causal bias into a tensor form. diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py index b68b010ef43d..01f5fe84356b 100644 --- a/torch/nn/attention/flex_attention.py +++ b/torch/nn/attention/flex_attention.py @@ -84,7 +84,7 @@ def _warn_once( _mask_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor] -# pyrefly: ignore # invalid-inheritance +# pyrefly: ignore [invalid-inheritance] class FlexKernelOptions(TypedDict, total=False): """Options for controlling the behavior of FlexAttention kernels. @@ -128,97 +128,97 @@ class FlexKernelOptions(TypedDict, total=False): """ # Performance tuning options - # pyrefly: ignore # invalid-annotation + # pyrefly: ignore [invalid-annotation] num_warps: NotRequired[int] """Number of warps to use in the CUDA kernel. Higher values may improve performance but increase register pressure. Default is determined by autotuning.""" - # pyrefly: ignore # invalid-annotation + # pyrefly: ignore [invalid-annotation] num_stages: NotRequired[int] """Number of pipeline stages in the CUDA kernel. Higher values may improve performance but increase shared memory usage. Default is determined by autotuning.""" - # pyrefly: ignore # invalid-annotation + # pyrefly: ignore [invalid-annotation] BLOCK_M: NotRequired[int] """Thread block size for the sequence length dimension of Q in forward pass. Must be a power of 2. Common values: 16, 32, 64, 128. Default is determined by autotuning.""" - # pyrefly: ignore # invalid-annotation + # pyrefly: ignore [invalid-annotation] BLOCK_N: NotRequired[int] """Thread block size for the sequence length dimension of K/V in forward pass. Must be a power of 2. Common values: 16, 32, 64, 128. Default is determined by autotuning.""" # Backward-specific block sizes (when prefixed with 'bwd_') - # pyrefly: ignore # invalid-annotation + # pyrefly: ignore [invalid-annotation] BLOCK_M1: NotRequired[int] """Thread block size for Q dimension in backward pass. Use as 'bwd_BLOCK_M1'. Default is determined by autotuning.""" - # pyrefly: ignore # invalid-annotation + # pyrefly: ignore [invalid-annotation] BLOCK_N1: NotRequired[int] """Thread block size for K/V dimension in backward pass. Use as 'bwd_BLOCK_N1'. Default is determined by autotuning.""" - # pyrefly: ignore # invalid-annotation + # pyrefly: ignore [invalid-annotation] BLOCK_M2: NotRequired[int] """Thread block size for second Q dimension in backward pass. Use as 'bwd_BLOCK_M2'. Default is determined by autotuning.""" - # pyrefly: ignore # invalid-annotation + # pyrefly: ignore [invalid-annotation] BLOCK_N2: NotRequired[int] """Thread block size for second K/V dimension in backward pass. Use as 'bwd_BLOCK_N2'. Default is determined by autotuning.""" - # pyrefly: ignore # invalid-annotation + # pyrefly: ignore [invalid-annotation] PRESCALE_QK: NotRequired[bool] """Whether to pre-scale QK by 1/sqrt(d) and change of base. This is slightly faster but may have more numerical error. Default: False.""" - # pyrefly: ignore # invalid-annotation + # pyrefly: ignore [invalid-annotation] ROWS_GUARANTEED_SAFE: NotRequired[bool] """If True, guarantees that at least one value in each row is not masked out. Allows skipping safety checks for better performance. Only set this if you are certain your mask guarantees this property. For example, causal attention is guaranteed safe because each query has at least 1 key-value to attend to. Default: False.""" - # pyrefly: ignore # invalid-annotation + # pyrefly: ignore [invalid-annotation] BLOCKS_ARE_CONTIGUOUS: NotRequired[bool] """If True, guarantees that all blocks in the mask are contiguous. Allows optimizing block traversal. For example, causal masks would satisfy this, but prefix_lm + sliding window would not. Default: False.""" - # pyrefly: ignore # invalid-annotation + # pyrefly: ignore [invalid-annotation] WRITE_DQ: NotRequired[bool] """Controls whether gradient scatters are done in the DQ iteration loop of the backward pass. Setting this to False will force this to happen in the DK loop which depending on your specific score_mod and mask_mod might be faster. Default: True.""" - # pyrefly: ignore # invalid-annotation + # pyrefly: ignore [invalid-annotation] FORCE_USE_FLEX_ATTENTION: NotRequired[bool] """If True, forces the use of the flex attention kernel instead of potentially using the more optimized flex-decoding kernel for short sequences. This can be a helpful option for debugging. Default: False.""" - # pyrefly: ignore # invalid-annotation + # pyrefly: ignore [invalid-annotation] USE_TMA: NotRequired[bool] """Whether to use Tensor Memory Accelerator (TMA) on supported hardware. This is experimental and may not work on all hardware, currently specific to NVIDIA GPUs Hopper+. Default: False.""" # ROCm-specific options - # pyrefly: ignore # invalid-annotation + # pyrefly: ignore [invalid-annotation] kpack: NotRequired[int] """ROCm-specific kernel packing parameter.""" - # pyrefly: ignore # invalid-annotation + # pyrefly: ignore [invalid-annotation] matrix_instr_nonkdim: NotRequired[int] """ROCm-specific matrix instruction non-K dimension.""" - # pyrefly: ignore # invalid-annotation + # pyrefly: ignore [invalid-annotation] waves_per_eu: NotRequired[int] """ROCm-specific waves per execution unit.""" - # pyrefly: ignore # invalid-annotation + # pyrefly: ignore [invalid-annotation] force_flash: NotRequired[bool] """ If True, forces use of the cute-dsl flash attention kernel. @@ -644,7 +644,7 @@ def as_tuple(self, flatten: bool = True): block_size = (self.BLOCK_SIZE,) # type: ignore[assignment] seq_lengths = (self.seq_lengths,) # type: ignore[assignment] - # pyrefly: ignore # not-iterable + # pyrefly: ignore [not-iterable] return ( *seq_lengths, self.kv_num_blocks, @@ -817,7 +817,7 @@ def to_dense(self) -> Tensor: partial_dense = _ordered_to_dense(self.kv_num_blocks, self.kv_indices) if self.full_kv_num_blocks is not None: assert self.full_kv_indices is not None - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] return partial_dense | _ordered_to_dense( self.full_kv_num_blocks, self.full_kv_indices ) diff --git a/torch/nn/cpp.py b/torch/nn/cpp.py index 5d01f7f16a4a..e447284ad82b 100644 --- a/torch/nn/cpp.py +++ b/torch/nn/cpp.py @@ -78,7 +78,7 @@ def _apply(self, fn, recurse=True): # nn.Module defines training as a boolean @property # type: ignore[override] - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def training(self): return self.cpp_module.training diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 9f1438d3780c..c562bc63dc47 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -1275,7 +1275,7 @@ def adaptive_max_pool2d_with_indices( output_size, return_indices=return_indices, ) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] output_size = _list_with_default(output_size, input.size()) return torch._C._nn.adaptive_max_pool2d(input, output_size) @@ -1333,7 +1333,7 @@ def adaptive_max_pool3d_with_indices( output_size, return_indices=return_indices, ) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] output_size = _list_with_default(output_size, input.size()) return torch._C._nn.adaptive_max_pool3d(input, output_size) @@ -1392,7 +1392,7 @@ def adaptive_avg_pool2d(input: Tensor, output_size: BroadcastingList2[int]) -> T """ if has_torch_function_unary(input): return handle_torch_function(adaptive_avg_pool2d, (input,), input, output_size) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] _output_size = _list_with_default(output_size, input.size()) return torch._C._nn.adaptive_avg_pool2d(input, _output_size) @@ -1408,7 +1408,7 @@ def adaptive_avg_pool3d(input: Tensor, output_size: BroadcastingList3[int]) -> T """ if has_torch_function_unary(input): return handle_torch_function(adaptive_avg_pool3d, (input,), input, output_size) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] _output_size = _list_with_default(output_size, input.size()) return torch._C._nn.adaptive_avg_pool3d(input, _output_size) @@ -2444,7 +2444,7 @@ def _no_grad_embedding_renorm_( input: Tensor, max_norm: float, norm_type: float, - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] ) -> tuple[Tensor, Tensor]: torch.embedding_renorm_(weight.detach(), input, max_norm, norm_type) @@ -2698,7 +2698,7 @@ def embedding_bag( if not torch.jit.is_scripting() and input.dim() == 2 and input.is_nested: include_last_offset = True - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] offsets = input.offsets() input = input.values().reshape(-1) if per_sample_weights is not None: @@ -2833,7 +2833,7 @@ def batch_norm( eps=eps, ) if training: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] _verify_batch_size(input.size()) return torch.batch_norm( @@ -2889,7 +2889,7 @@ def instance_norm( eps=eps, ) if use_input_stats: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] _verify_spatial_size(input.size()) return torch.instance_norm( input, @@ -3015,13 +3015,13 @@ def local_response_norm( div = input.mul(input) if dim == 3: div = div.unsqueeze(1) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] div = pad(div, (0, 0, size // 2, (size - 1) // 2)) div = avg_pool2d(div, (size, 1), stride=1).squeeze(1) else: sizes = input.size() div = div.view(sizes[0], 1, sizes[1], sizes[2], -1) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] div = pad(div, (0, 0, 0, 0, size // 2, (size - 1) // 2)) div = avg_pool3d(div, (size, 1, 1), stride=1).squeeze(1) div = div.view(sizes) @@ -3173,7 +3173,7 @@ def nll_loss( input, target, weight, - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] _Reduction.get_enum(reduction), ignore_index, ) @@ -3320,7 +3320,7 @@ def gaussian_nll_loss( var.clamp_(min=eps) # Calculate the loss - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] loss = 0.5 * (torch.log(var) + (input - target) ** 2 / var) if full: loss += 0.5 * math.log(2 * math.pi) @@ -3496,7 +3496,7 @@ def cross_entropy( input, target, weight, - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] _Reduction.get_enum(reduction), ignore_index, label_smoothing, @@ -3561,7 +3561,7 @@ def binary_cross_entropy( new_size = _infer_size(target.size(), weight.size()) weight = weight.expand(new_size) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum) @@ -3692,14 +3692,14 @@ def smooth_l1_loss( return torch._C._nn.l1_loss( expanded_input, expanded_target, - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] _Reduction.get_enum(reduction), ) else: return torch._C._nn.smooth_l1_loss( expanded_input, expanded_target, - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] _Reduction.get_enum(reduction), beta, ) @@ -3761,7 +3761,7 @@ def huber_loss( return torch._C._nn.huber_loss( expanded_input, expanded_target, - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] _Reduction.get_enum(reduction), delta, ) @@ -3773,7 +3773,7 @@ def huber_loss( unweighted_loss = torch._C._nn.huber_loss( expanded_input, expanded_target, - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] _Reduction.get_enum("none"), delta, ) @@ -3864,7 +3864,7 @@ def l1_loss( return torch._C._nn.l1_loss( expanded_input, expanded_target, - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] _Reduction.get_enum(reduction), ) @@ -3942,7 +3942,7 @@ def mse_loss( return torch._C._nn.mse_loss( expanded_input, expanded_target, - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] _Reduction.get_enum(reduction), ) @@ -4080,7 +4080,7 @@ def multilabel_margin_loss( reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: reduction_enum = _Reduction.get_enum(reduction) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] return torch._C._nn.multilabel_margin_loss(input, target, reduction_enum) @@ -4122,7 +4122,7 @@ def soft_margin_loss( reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: reduction_enum = _Reduction.get_enum(reduction) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] return torch._C._nn.soft_margin_loss(input, target, reduction_enum) @@ -4292,7 +4292,7 @@ def multi_margin_loss( p, margin, weight, - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] reduction_enum, ) @@ -4439,7 +4439,7 @@ def upsample( # noqa: F811 scale_factor: Optional[float] = None, mode: str = "nearest", align_corners: Optional[bool] = None, - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] ) -> Tensor: # noqa: B950 pass @@ -4451,7 +4451,7 @@ def upsample( # noqa: F811 scale_factor: Optional[float] = None, mode: str = "nearest", align_corners: Optional[bool] = None, - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] ) -> Tensor: # noqa: B950 pass @@ -4554,7 +4554,7 @@ def interpolate( # noqa: F811 align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, antialias: bool = False, - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] ) -> Tensor: # noqa: B950 pass @@ -4568,7 +4568,7 @@ def interpolate( # noqa: F811 align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, antialias: bool = False, - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] ) -> Tensor: # noqa: B950 pass @@ -4582,7 +4582,7 @@ def interpolate( # noqa: F811 align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, antialias: bool = False, - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] ) -> Tensor: # noqa: B950 pass @@ -4596,7 +4596,7 @@ def interpolate( # noqa: F811 align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, antialias: bool = False, - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] ) -> Tensor: pass @@ -4771,7 +4771,7 @@ def interpolate( # noqa: F811 ( torch.floor( ( - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] input.size(i + 2).float() * torch.tensor(scale_factors[i], dtype=torch.float32) ).float() @@ -4796,28 +4796,28 @@ def interpolate( # noqa: F811 ) if input.dim() == 3 and mode == "nearest": - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] return torch._C._nn.upsample_nearest1d(input, output_size, scale_factors) if input.dim() == 4 and mode == "nearest": - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors) if input.dim() == 5 and mode == "nearest": - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] return torch._C._nn.upsample_nearest3d(input, output_size, scale_factors) if input.dim() == 3 and mode == "nearest-exact": - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] return torch._C._nn._upsample_nearest_exact1d(input, output_size, scale_factors) if input.dim() == 4 and mode == "nearest-exact": - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] return torch._C._nn._upsample_nearest_exact2d(input, output_size, scale_factors) if input.dim() == 5 and mode == "nearest-exact": - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] return torch._C._nn._upsample_nearest_exact3d(input, output_size, scale_factors) if input.dim() == 3 and mode == "area": assert output_size is not None - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] return adaptive_avg_pool1d(input, output_size) if input.dim() == 4 and mode == "area": assert output_size is not None @@ -4830,7 +4830,7 @@ def interpolate( # noqa: F811 assert align_corners is not None return torch._C._nn.upsample_linear1d( input, - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] output_size, align_corners, scale_factors, @@ -4840,7 +4840,7 @@ def interpolate( # noqa: F811 if antialias: return torch._C._nn._upsample_bilinear2d_aa( input, - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] output_size, align_corners, scale_factors, @@ -4857,7 +4857,7 @@ def interpolate( # noqa: F811 )._upsample_linear_vec(input, output_size, align_corners, scale_factors) return torch._C._nn.upsample_bilinear2d( input, - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] output_size, align_corners, scale_factors, @@ -4876,7 +4876,7 @@ def interpolate( # noqa: F811 )._upsample_linear_vec(input, output_size, align_corners, scale_factors) return torch._C._nn.upsample_trilinear3d( input, - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] output_size, align_corners, scale_factors, @@ -4886,14 +4886,14 @@ def interpolate( # noqa: F811 if antialias: return torch._C._nn._upsample_bicubic2d_aa( input, - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] output_size, align_corners, scale_factors, ) return torch._C._nn.upsample_bicubic2d( input, - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] output_size, align_corners, scale_factors, @@ -4928,7 +4928,7 @@ def upsample_nearest( # noqa: F811 input: Tensor, size: Optional[int] = None, scale_factor: Optional[float] = None, - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] ) -> Tensor: pass @@ -4938,7 +4938,7 @@ def upsample_nearest( # noqa: F811 input: Tensor, size: Optional[list[int]] = None, scale_factor: Optional[float] = None, - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] ) -> Tensor: pass @@ -4980,7 +4980,7 @@ def upsample_bilinear( # noqa: F811 input: Tensor, size: Optional[int] = None, scale_factor: Optional[float] = None, - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] ) -> Tensor: pass @@ -4990,7 +4990,7 @@ def upsample_bilinear( # noqa: F811 input: Tensor, size: Optional[list[int]] = None, scale_factor: Optional[float] = None, - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] ) -> Tensor: pass @@ -5000,7 +5000,7 @@ def upsample_bilinear( # noqa: F811 input: Tensor, size: Optional[int] = None, scale_factor: Optional[list[float]] = None, - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] ) -> Tensor: pass @@ -5010,7 +5010,7 @@ def upsample_bilinear( # noqa: F811 input: Tensor, size: Optional[list[int]] = None, scale_factor: Optional[list[float]] = None, - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] ) -> Tensor: pass @@ -5817,7 +5817,7 @@ def _in_projection_packed( .squeeze(-2) .contiguous() ) - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] return proj[0], proj[1], proj[2] else: # encoder-decoder attention @@ -5836,7 +5836,7 @@ def _in_projection_packed( .squeeze(-2) .contiguous() ) - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] return (q_proj, kv_proj[0], kv_proj[1]) else: w_q, w_k, w_v = w.chunk(3) @@ -5844,7 +5844,7 @@ def _in_projection_packed( b_q = b_k = b_v = None else: b_q, b_k, b_v = b.chunk(3) - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) @@ -6475,10 +6475,10 @@ def multi_head_attention_forward( k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) if attn_mask is not None: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] attn_mask = pad(attn_mask, (0, 1)) if key_padding_mask is not None: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] key_padding_mask = pad(key_padding_mask, (0, 1)) else: assert bias_k is None @@ -6487,10 +6487,10 @@ def multi_head_attention_forward( # # reshape q, k, v for multihead attention and make them batch first # - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) if static_k is None: - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1) else: # TODO finish disentangling control flow so we don't do in-projections when statics are passed @@ -6502,7 +6502,7 @@ def multi_head_attention_forward( ) k = static_k if static_v is None: - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) else: # TODO finish disentangling control flow so we don't do in-projections when statics are passed @@ -6518,20 +6518,20 @@ def multi_head_attention_forward( if add_zero_attn: zero_attn_shape = (bsz * num_heads, 1, head_dim) k = torch.cat( - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] [k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1, ) v = torch.cat( - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] [v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1, ) if attn_mask is not None: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] attn_mask = pad(attn_mask, (0, 1)) if key_padding_mask is not None: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] key_padding_mask = pad(key_padding_mask, (0, 1)) # update source sequence length after adjustments @@ -6581,7 +6581,7 @@ def multi_head_attention_forward( attn_output = torch.bmm(attn_output_weights, v) attn_output = ( - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim) ) attn_output = linear(attn_output, out_proj_weight, out_proj_bias) @@ -6608,16 +6608,16 @@ def multi_head_attention_forward( attn_mask = attn_mask.view(bsz, num_heads, -1, src_len) q = q.view(bsz, num_heads, tgt_len, head_dim) - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] k = k.view(bsz, num_heads, src_len, head_dim) - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] v = v.view(bsz, num_heads, src_len, head_dim) attn_output = scaled_dot_product_attention( q, k, v, attn_mask, dropout_p, is_causal ) attn_output = ( - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim) ) diff --git a/torch/nn/init.py b/torch/nn/init.py index e033198d4e5e..18358dbabbbf 100644 --- a/torch/nn/init.py +++ b/torch/nn/init.py @@ -500,7 +500,7 @@ def xavier_normal_( def _calculate_correct_fan(tensor: Tensor, mode: _FanMode) -> int: - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] mode = mode.lower() valid_modes = ["fan_in", "fan_out"] if mode not in valid_modes: diff --git a/torch/nn/modules/_functions.py b/torch/nn/modules/_functions.py index 407fcc7e279f..408e6ef42f12 100644 --- a/torch/nn/modules/_functions.py +++ b/torch/nn/modules/_functions.py @@ -6,7 +6,7 @@ class SyncBatchNorm(Function): @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def forward( self, input, @@ -211,7 +211,7 @@ def backward(self, grad_output): class CrossMapLRN2d(Function): @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def forward(ctx, input, size, alpha=1e-4, beta=0.75, k=1): ctx.size = size ctx.alpha = alpha @@ -267,7 +267,7 @@ def forward(ctx, input, size, alpha=1e-4, beta=0.75, k=1): return output @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def backward(ctx, grad_output): input, output = ctx.saved_tensors grad_input = grad_output.new() @@ -309,7 +309,7 @@ def backward(ctx, grad_output): class BackwardHookFunction(torch.autograd.Function): @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def forward(ctx, *args): ctx.mark_non_differentiable(*[arg for arg in args if not arg.requires_grad]) return args diff --git a/torch/nn/modules/container.py b/torch/nn/modules/container.py index 711b8d5c1906..a93843f859a7 100644 --- a/torch/nn/modules/container.py +++ b/torch/nn/modules/container.py @@ -109,7 +109,7 @@ class Sequential(Module): def __init__(self, *args: Module) -> None: ... @overload - # pyrefly: ignore # inconsistent-overload + # pyrefly: ignore [inconsistent-overload] def __init__(self, arg: OrderedDict[str, Module]) -> None: ... def __init__(self, *args): @@ -624,11 +624,11 @@ def update(self, modules: Mapping[str, Module]) -> None: "ModuleDict update sequence element " "#" + str(j) + " should be Iterable; is" + type(m).__name__ ) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] if not len(m) == 2: raise ValueError( "ModuleDict update sequence element " - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] "#" + str(j) + " has length " + str(len(m)) + "; 2 is required" ) # modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)] @@ -687,7 +687,7 @@ def _get_abs_string_index(self, idx): def __getitem__(self, idx: int) -> Any: ... @overload - # pyrefly: ignore # inconsistent-overload + # pyrefly: ignore [inconsistent-overload] def __getitem__(self: T, idx: slice) -> T: ... def __getitem__(self, idx): @@ -773,11 +773,11 @@ def extra_repr(self) -> str: size_str, device_str, ) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] child_lines.append(" (" + str(k) + "): " + parastr) else: child_lines.append( - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] " (" + str(k) + "): Object of type: " + type(p).__name__ ) @@ -985,11 +985,11 @@ def update(self, parameters: Union[Mapping[str, Any], ParameterDict]) -> None: "ParameterDict update sequence element " "#" + str(j) + " should be Iterable; is" + type(p).__name__ ) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] if not len(p) == 2: raise ValueError( "ParameterDict update sequence element " - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] "#" + str(j) + " has length " + str(len(p)) + "; 2 is required" ) # parameters as length-2 list too cumbersome to type, see ModuleDict.update comment @@ -1010,11 +1010,11 @@ def extra_repr(self) -> str: size_str, device_str, ) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] child_lines.append(" (" + str(k) + "): " + parastr) else: child_lines.append( - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] " (" + str(k) + "): Object of type: " + type(p).__name__ ) tmpstr = "\n".join(child_lines) diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index d8af4862697e..f06e38c2abae 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -1514,7 +1514,7 @@ def __init__( dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] super().__init__( 0, 0, @@ -1529,11 +1529,11 @@ def __init__( padding_mode, **factory_kwargs, ) - # pyrefly: ignore # bad-override, bad-argument-type + # pyrefly: ignore [bad-override, bad-argument-type] self.weight = UninitializedParameter(**factory_kwargs) self.out_channels = out_channels if bias: - # pyrefly: ignore # bad-override, bad-argument-type + # pyrefly: ignore [bad-override, bad-argument-type] self.bias = UninitializedParameter(**factory_kwargs) def _get_num_spatial_dims(self) -> int: @@ -1586,7 +1586,7 @@ def __init__( dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] super().__init__( 0, 0, @@ -1601,11 +1601,11 @@ def __init__( padding_mode, **factory_kwargs, ) - # pyrefly: ignore # bad-override, bad-argument-type + # pyrefly: ignore [bad-override, bad-argument-type] self.weight = UninitializedParameter(**factory_kwargs) self.out_channels = out_channels if bias: - # pyrefly: ignore # bad-override, bad-argument-type + # pyrefly: ignore [bad-override, bad-argument-type] self.bias = UninitializedParameter(**factory_kwargs) def _get_num_spatial_dims(self) -> int: @@ -1659,7 +1659,7 @@ def __init__( dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] super().__init__( 0, 0, @@ -1674,11 +1674,11 @@ def __init__( padding_mode, **factory_kwargs, ) - # pyrefly: ignore # bad-override, bad-argument-type + # pyrefly: ignore [bad-override, bad-argument-type] self.weight = UninitializedParameter(**factory_kwargs) self.out_channels = out_channels if bias: - # pyrefly: ignore # bad-override, bad-argument-type + # pyrefly: ignore [bad-override, bad-argument-type] self.bias = UninitializedParameter(**factory_kwargs) def _get_num_spatial_dims(self) -> int: @@ -1730,7 +1730,7 @@ def __init__( dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] super().__init__( 0, 0, @@ -1746,11 +1746,11 @@ def __init__( padding_mode, **factory_kwargs, ) - # pyrefly: ignore # bad-override, bad-argument-type + # pyrefly: ignore [bad-override, bad-argument-type] self.weight = UninitializedParameter(**factory_kwargs) self.out_channels = out_channels if bias: - # pyrefly: ignore # bad-override, bad-argument-type + # pyrefly: ignore [bad-override, bad-argument-type] self.bias = UninitializedParameter(**factory_kwargs) def _get_num_spatial_dims(self) -> int: @@ -1802,7 +1802,7 @@ def __init__( dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] super().__init__( 0, 0, @@ -1818,11 +1818,11 @@ def __init__( padding_mode, **factory_kwargs, ) - # pyrefly: ignore # bad-override, bad-argument-type + # pyrefly: ignore [bad-override, bad-argument-type] self.weight = UninitializedParameter(**factory_kwargs) self.out_channels = out_channels if bias: - # pyrefly: ignore # bad-override, bad-argument-type + # pyrefly: ignore [bad-override, bad-argument-type] self.bias = UninitializedParameter(**factory_kwargs) def _get_num_spatial_dims(self) -> int: @@ -1874,7 +1874,7 @@ def __init__( dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] super().__init__( 0, 0, @@ -1890,11 +1890,11 @@ def __init__( padding_mode, **factory_kwargs, ) - # pyrefly: ignore # bad-override, bad-argument-type + # pyrefly: ignore [bad-override, bad-argument-type] self.weight = UninitializedParameter(**factory_kwargs) self.out_channels = out_channels if bias: - # pyrefly: ignore # bad-override, bad-argument-type + # pyrefly: ignore [bad-override, bad-argument-type] self.bias = UninitializedParameter(**factory_kwargs) def _get_num_spatial_dims(self) -> int: diff --git a/torch/nn/modules/lazy.py b/torch/nn/modules/lazy.py index 1984eb0d0e15..d4c192ee8ce4 100644 --- a/torch/nn/modules/lazy.py +++ b/torch/nn/modules/lazy.py @@ -172,9 +172,9 @@ class LazyModuleMixin: def __init__(self: _LazyProtocol, *args, **kwargs): # Mypy doesn't like this super call in a mixin super().__init__(*args, **kwargs) # type: ignore[misc] - # pyrefly: ignore # read-only + # pyrefly: ignore [read-only] self._load_hook = self._register_load_state_dict_pre_hook(self._lazy_load_hook) - # pyrefly: ignore # read-only + # pyrefly: ignore [read-only] self._initialize_hook = self.register_forward_pre_hook( self._infer_parameters, with_kwargs=True ) diff --git a/torch/nn/modules/linear.py b/torch/nn/modules/linear.py index 0d17e3174615..c58bdcefd0e0 100644 --- a/torch/nn/modules/linear.py +++ b/torch/nn/modules/linear.py @@ -286,7 +286,7 @@ class LazyLinear(LazyModuleMixin, Linear): """ cls_to_become = Linear # type: ignore[assignment] - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] weight: UninitializedParameter bias: UninitializedParameter # type: ignore[assignment] @@ -296,20 +296,20 @@ def __init__( factory_kwargs = {"device": device, "dtype": dtype} # bias is hardcoded to False to avoid creating tensor # that will soon be overwritten. - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] super().__init__(0, 0, False) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] self.weight = UninitializedParameter(**factory_kwargs) self.out_features = out_features if bias: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] self.bias = UninitializedParameter(**factory_kwargs) def reset_parameters(self) -> None: """ Resets parameters based on their initialization used in ``__init__``. """ - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] if not self.has_uninitialized_params() and self.in_features != 0: super().reset_parameters() @@ -317,7 +317,7 @@ def initialize_parameters(self, input) -> None: # type: ignore[override] """ Infers ``in_features`` based on ``input`` and initializes parameters. """ - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] if self.has_uninitialized_params(): with torch.no_grad(): self.in_features = input.shape[-1] diff --git a/torch/nn/modules/padding.py b/torch/nn/modules/padding.py index 2300a498acaa..d5aa1e0d4255 100644 --- a/torch/nn/modules/padding.py +++ b/torch/nn/modules/padding.py @@ -84,7 +84,7 @@ class CircularPad1d(_CircularPadNd): [5., 6., 7., 4., 5., 6., 7., 4.]]]) """ - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] padding: tuple[int, int] def __init__(self, padding: _size_2_t) -> None: @@ -145,7 +145,7 @@ class CircularPad2d(_CircularPadNd): [8., 6., 7., 8., 6.]]]]) """ - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] padding: tuple[int, int, int, int] def __init__(self, padding: _size_4_t) -> None: @@ -196,7 +196,7 @@ class CircularPad3d(_CircularPadNd): >>> output = m(input) """ - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] padding: tuple[int, int, int, int, int, int] def __init__(self, padding: _size_6_t) -> None: @@ -268,7 +268,7 @@ class ConstantPad1d(_ConstantPadNd): [ 3.5000, 3.5000, 3.5000, -3.6372, 0.1182, -1.8652, 3.5000]]]) """ - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] padding: tuple[int, int] def __init__(self, padding: _size_2_t, value: float) -> None: @@ -320,7 +320,7 @@ class ConstantPad2d(_ConstantPadNd): """ __constants__ = ["padding", "value"] - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] padding: tuple[int, int, int, int] def __init__(self, padding: _size_4_t, value: float) -> None: @@ -361,7 +361,7 @@ class ConstantPad3d(_ConstantPadNd): >>> output = m(input) """ - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] padding: tuple[int, int, int, int, int, int] def __init__(self, padding: _size_6_t, value: float) -> None: @@ -415,7 +415,7 @@ class ReflectionPad1d(_ReflectionPadNd): [7., 6., 5., 4., 5., 6., 7., 6.]]]) """ - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] padding: tuple[int, int] def __init__(self, padding: _size_2_t) -> None: @@ -469,7 +469,7 @@ class ReflectionPad2d(_ReflectionPadNd): [7., 6., 7., 8., 7.]]]]) """ - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] padding: tuple[int, int, int, int] def __init__(self, padding: _size_4_t) -> None: @@ -525,7 +525,7 @@ class ReflectionPad3d(_ReflectionPadNd): [1., 0., 1., 0.]]]]]) """ - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] padding: tuple[int, int, int, int, int, int] def __init__(self, padding: _size_6_t) -> None: @@ -579,7 +579,7 @@ class ReplicationPad1d(_ReplicationPadNd): [4., 4., 4., 4., 5., 6., 7., 7.]]]) """ - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] padding: tuple[int, int] def __init__(self, padding: _size_2_t) -> None: @@ -633,7 +633,7 @@ class ReplicationPad2d(_ReplicationPadNd): [6., 6., 7., 8., 8.]]]]) """ - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] padding: tuple[int, int, int, int] def __init__(self, padding: _size_4_t) -> None: @@ -676,7 +676,7 @@ class ReplicationPad3d(_ReplicationPadNd): >>> output = m(input) """ - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] padding: tuple[int, int, int, int, int, int] def __init__(self, padding: _size_6_t) -> None: diff --git a/torch/nn/modules/transformer.py b/torch/nn/modules/transformer.py index d5f489c7c56a..2f69d89b19eb 100644 --- a/torch/nn/modules/transformer.py +++ b/torch/nn/modules/transformer.py @@ -138,7 +138,7 @@ def __init__( d_model, eps=layer_norm_eps, bias=bias, - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] **factory_kwargs, ) self.encoder = TransformerEncoder( @@ -164,7 +164,7 @@ def __init__( d_model, eps=layer_norm_eps, bias=bias, - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] **factory_kwargs, ) self.decoder = TransformerDecoder( @@ -768,9 +768,9 @@ def __init__( self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs) self.norm_first = norm_first - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) self.dropout1 = Dropout(dropout) self.dropout2 = Dropout(dropout) @@ -1062,11 +1062,11 @@ def __init__( self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs) self.norm_first = norm_first - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs) self.dropout1 = Dropout(dropout) self.dropout2 = Dropout(dropout) diff --git a/torch/nn/modules/utils.py b/torch/nn/modules/utils.py index d8d8783b06b4..cfe621983dc2 100644 --- a/torch/nn/modules/utils.py +++ b/torch/nn/modules/utils.py @@ -36,7 +36,7 @@ def _list_with_default(out_size: list[int], defaults: list[int]) -> list[int]: import torch if isinstance(out_size, (int, torch.SymInt)): - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] return out_size if len(defaults) <= len(out_size): raise ValueError(f"Input dimension should be at least {len(out_size) + 1}") diff --git a/torch/nn/parallel/comm.py b/torch/nn/parallel/comm.py index 01ed3030fb84..3df1b4b4eadc 100644 --- a/torch/nn/parallel/comm.py +++ b/torch/nn/parallel/comm.py @@ -43,7 +43,7 @@ def broadcast(tensor, devices=None, *, out=None): devices = [_get_device_index(d) for d in devices] return torch._C._broadcast(tensor, devices) else: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] return torch._C._broadcast_out(tensor, out) @@ -201,7 +201,7 @@ def scatter(tensor, devices=None, chunk_sizes=None, dim=0, streams=None, *, out= """ tensor = _handle_complex(tensor) if out is None: - # pyrefly: ignore # not-iterable + # pyrefly: ignore [not-iterable] devices = [_get_device_index(d) for d in devices] return tuple(torch._C._scatter(tensor, devices, chunk_sizes, dim, streams)) else: diff --git a/torch/nn/parallel/data_parallel.py b/torch/nn/parallel/data_parallel.py index 22cc3044c221..56ad3b8b2015 100644 --- a/torch/nn/parallel/data_parallel.py +++ b/torch/nn/parallel/data_parallel.py @@ -160,7 +160,7 @@ def __init__( self.module = module self.device_ids = [_get_device_index(x, True) for x in device_ids] self.output_device = _get_device_index(output_device, True) - # pyrefly: ignore # read-only + # pyrefly: ignore [read-only] self.src_device_obj = torch.device(device_type, self.device_ids[0]) if device_type == "cuda": @@ -174,7 +174,7 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any: if not self.device_ids: return self.module(*inputs, **kwargs) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] for t in chain(self.module.parameters(), self.module.buffers()): if t.device != self.src_device_obj: raise RuntimeError( @@ -261,10 +261,10 @@ def data_parallel( device_ids = [_get_device_index(x, True) for x in device_ids] output_device = _get_device_index(output_device, True) - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] src_device_obj = torch.device(device_type, device_ids[0]) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] for t in chain(module.parameters(), module.buffers()): if t.device != src_device_obj: raise RuntimeError( diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index 73e0deec5e4c..4444f557f4af 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -241,7 +241,7 @@ class _BufferCommHook: # is completed. class _DDPSink(Function): @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def forward(ctx, ddp_weakref, *inputs): # set_materialize_grads(False) will ensure that None gradients stay as # None and are not filled with zeros. @@ -692,7 +692,7 @@ def __init__( elif process_group is None and device_mesh is None: self.process_group = _get_default_group() elif device_mesh is None: - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self.process_group = process_group else: if device_mesh.ndim != 1: @@ -780,13 +780,13 @@ def __init__( self.device_ids = None self.output_device = None else: - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self.device_ids = [_get_device_index(x, True) for x in device_ids] if output_device is None: output_device = device_ids[0] - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self.output_device = _get_device_index(output_device, True) self.static_graph = False @@ -936,7 +936,7 @@ def __init__( # enabled. self._accum_grad_hooks: list[RemovableHandle] = [] if self._use_python_reducer: - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] torch._inductor.config._fuse_ddp_communication = True torch._inductor.config._fuse_ddp_bucket_size = bucket_cap_mb # Directly adding this to the trace rule will disturb the users diff --git a/torch/nn/parallel/scatter_gather.py b/torch/nn/parallel/scatter_gather.py index cb167b80b809..a2917bddd032 100644 --- a/torch/nn/parallel/scatter_gather.py +++ b/torch/nn/parallel/scatter_gather.py @@ -56,16 +56,16 @@ def scatter_map(obj): if isinstance(obj, torch.Tensor): return Scatter.apply(target_gpus, None, dim, obj) if _is_namedtuple(obj): - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] return [type(obj)(*args) for args in zip(*map(scatter_map, obj))] if isinstance(obj, tuple) and len(obj) > 0: - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] return list(zip(*map(scatter_map, obj))) if isinstance(obj, list) and len(obj) > 0: - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] return [list(i) for i in zip(*map(scatter_map, obj))] if isinstance(obj, dict) and len(obj) > 0: - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] return [type(obj)(i) for i in zip(*map(scatter_map, obj.items()))] return [obj for _ in target_gpus] @@ -127,12 +127,12 @@ def gather_map(outputs): if isinstance(out, dict): if not all(len(out) == len(d) for d in outputs): raise ValueError("All dicts must have the same number of keys") - # pyrefly: ignore # not-callable + # pyrefly: ignore [not-callable] return type(out)((k, gather_map([d[k] for d in outputs])) for k in out) if _is_namedtuple(out): - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] return type(out)._make(map(gather_map, zip(*outputs))) - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] return type(out)(map(gather_map, zip(*outputs))) # Recursive function calls like this create reference cycles. diff --git a/torch/nn/parameter.py b/torch/nn/parameter.py index 39758f3efd15..c03c85f48fc3 100644 --- a/torch/nn/parameter.py +++ b/torch/nn/parameter.py @@ -81,7 +81,7 @@ def __deepcopy__(self, memo): memo[id(self)] = result return result - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def __repr__(self): return "Parameter containing:\n" + super().__repr__() @@ -144,7 +144,7 @@ def materialize(self, shape, device=None, dtype=None): if dtype is None: dtype = self.data.dtype self.data = torch.empty(shape, device=device, dtype=dtype) - # pyrefly: ignore # bad-override, missing-attribute + # pyrefly: ignore [bad-override, missing-attribute] self.__class__ = self.cls_to_become @property @@ -168,7 +168,7 @@ def __repr__(self): def __reduce_ex__(self, proto): # See Note [Don't serialize hooks] - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] return (self.__class__, (self.requires_grad,)) @classmethod @@ -178,7 +178,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): if func in cls._allowed_methods or func.__class__.__name__ == "method-wrapper": if kwargs is None: kwargs = {} - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] return super().__torch_function__(func, types, args, kwargs) raise ValueError( f"Attempted to use an uninitialized parameter in {func}. " @@ -220,7 +220,7 @@ class UninitializedParameter(UninitializedTensorMixin, Parameter): def __new__(cls, requires_grad=True, device=None, dtype=None) -> None: factory_kwargs = {"device": device, "dtype": dtype} data = torch.empty(0, **factory_kwargs) - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] return torch.Tensor._make_subclass(cls, data, requires_grad) def __deepcopy__(self, memo): @@ -266,9 +266,9 @@ def __new__(cls, data=None, *, persistent=True): data = torch.empty(0) t = data.detach().requires_grad_(data.requires_grad) - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] t.persistent = persistent - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] t._is_buffer = True return t @@ -299,9 +299,9 @@ def __new__( factory_kwargs = {"device": device, "dtype": dtype} data = torch.empty(0, **factory_kwargs) ret = torch.Tensor._make_subclass(cls, data, requires_grad) - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] ret.persistent = persistent - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] ret._is_buffer = True - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] return ret diff --git a/torch/nn/utils/_expanded_weights/conv_expanded_weights.py b/torch/nn/utils/_expanded_weights/conv_expanded_weights.py index dba0cd27132d..da7d8f3dfabb 100644 --- a/torch/nn/utils/_expanded_weights/conv_expanded_weights.py +++ b/torch/nn/utils/_expanded_weights/conv_expanded_weights.py @@ -24,7 +24,7 @@ @implements_per_sample_grads(F.conv3d) class ConvPerSampleGrad(torch.autograd.Function): @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def forward( ctx: Any, kwarg_names: list[str], @@ -57,7 +57,7 @@ def forward( f"unbatched input of dim {input.dim()}, expected input of dim {batched_dim_size}" ) - # pyrefly: ignore # invalid-type-var + # pyrefly: ignore [invalid-type-var] ctx.conv_fn = conv_fn ctx.batch_size = orig_input.shape[0] diff --git a/torch/nn/utils/_expanded_weights/conv_utils.py b/torch/nn/utils/_expanded_weights/conv_utils.py index 463d7efb6467..e755362a4f20 100644 --- a/torch/nn/utils/_expanded_weights/conv_utils.py +++ b/torch/nn/utils/_expanded_weights/conv_utils.py @@ -237,7 +237,7 @@ def conv_unfold_weight_grad_sample( # n=batch_sz; o=num_out_channels; p=(num_in_channels/groups)*kernel_sz weight_grad_sample = torch.einsum("noq,npq->nop", grad_output, input) # rearrange the above tensor and extract diagonals. - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] weight_grad_sample = weight_grad_sample.view( n, groups, diff --git a/torch/nn/utils/_expanded_weights/embedding_expanded_weights.py b/torch/nn/utils/_expanded_weights/embedding_expanded_weights.py index e1c9dc04d8cf..3b4f0ce46b95 100644 --- a/torch/nn/utils/_expanded_weights/embedding_expanded_weights.py +++ b/torch/nn/utils/_expanded_weights/embedding_expanded_weights.py @@ -14,7 +14,7 @@ @implements_per_sample_grads(F.embedding) class EmbeddingPerSampleGrad(torch.autograd.Function): @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def forward( ctx: Any, kwarg_names: list[str], _: Any, *expanded_args_and_kwargs: Any ) -> torch.Tensor: @@ -35,7 +35,7 @@ def forward( return output @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def backward( ctx: Any, grad_output: torch.Tensor ) -> tuple[Optional[torch.Tensor], ...]: diff --git a/torch/nn/utils/_expanded_weights/expanded_weights_impl.py b/torch/nn/utils/_expanded_weights/expanded_weights_impl.py index dd6c6107fe22..78ceec1c785f 100644 --- a/torch/nn/utils/_expanded_weights/expanded_weights_impl.py +++ b/torch/nn/utils/_expanded_weights/expanded_weights_impl.py @@ -131,7 +131,7 @@ def __torch_function__(cls, func, _, args=(), kwargs=None): # in aten, choosing the input or data variants is done by parsing logic. This mimics some of that decomp_opts = expanded_weights_rnn_decomps[func] use_input_variant = isinstance( - # pyrefly: ignore # index-error + # pyrefly: ignore [index-error] args[2], list, ) # data variant uses a list here diff --git a/torch/nn/utils/_expanded_weights/expanded_weights_utils.py b/torch/nn/utils/_expanded_weights/expanded_weights_utils.py index 5f99e468767d..b3f674d3233d 100644 --- a/torch/nn/utils/_expanded_weights/expanded_weights_utils.py +++ b/torch/nn/utils/_expanded_weights/expanded_weights_utils.py @@ -8,7 +8,7 @@ def is_batch_first(expanded_args_and_kwargs): batch_first = None - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] for arg in expanded_args_and_kwargs: if not isinstance(arg, ExpandedWeight): continue diff --git a/torch/nn/utils/_expanded_weights/group_norm_expanded_weights.py b/torch/nn/utils/_expanded_weights/group_norm_expanded_weights.py index 1439593408c8..9ddf60e0a54e 100644 --- a/torch/nn/utils/_expanded_weights/group_norm_expanded_weights.py +++ b/torch/nn/utils/_expanded_weights/group_norm_expanded_weights.py @@ -18,7 +18,7 @@ @implements_per_sample_grads(F.group_norm) class GroupNormPerSampleGrad(torch.autograd.Function): @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs): expanded_args, expanded_kwargs = standard_kwargs( kwarg_names, expanded_args_and_kwargs @@ -47,7 +47,7 @@ def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs): return output @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def backward(ctx, grad_output): input, num_groups = ctx.input, ctx.num_groups weight, bias, eps = ctx.weight, ctx.bias, ctx.eps @@ -97,7 +97,7 @@ def backward(ctx, grad_output): weight, lambda _: torch.einsum( "ni...->ni", - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] F.group_norm(input, num_groups, eps=eps) * grad_output, ), ) diff --git a/torch/nn/utils/_expanded_weights/instance_norm_expanded_weights.py b/torch/nn/utils/_expanded_weights/instance_norm_expanded_weights.py index 7f7fc02dc905..613ce90431b8 100644 --- a/torch/nn/utils/_expanded_weights/instance_norm_expanded_weights.py +++ b/torch/nn/utils/_expanded_weights/instance_norm_expanded_weights.py @@ -17,7 +17,7 @@ @implements_per_sample_grads(F.instance_norm) class InstanceNormPerSampleGrad(torch.autograd.Function): @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs): instance_norm = partial(torch.instance_norm, cudnn_enabled=True) expanded_args, expanded_kwargs = standard_kwargs( @@ -37,7 +37,7 @@ def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs): return output @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def backward(ctx, grad_output): input, running_mean, running_var = ctx.input, ctx.running_mean, ctx.running_var weight, bias, eps = ctx.weight, ctx.bias, ctx.eps diff --git a/torch/nn/utils/_expanded_weights/layer_norm_expanded_weights.py b/torch/nn/utils/_expanded_weights/layer_norm_expanded_weights.py index a53ee8a52dab..ff5b5a61e7f5 100644 --- a/torch/nn/utils/_expanded_weights/layer_norm_expanded_weights.py +++ b/torch/nn/utils/_expanded_weights/layer_norm_expanded_weights.py @@ -17,7 +17,7 @@ @implements_per_sample_grads(F.layer_norm) class LayerNormPerSampleGrad(torch.autograd.Function): @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs): expanded_args, expanded_kwargs = standard_kwargs( kwarg_names, expanded_args_and_kwargs @@ -43,7 +43,7 @@ def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs): return output @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def backward(ctx, grad_output): def weight_per_sample_grad(weight): return sum_over_all_but_batch_and_last_n( diff --git a/torch/nn/utils/_expanded_weights/linear_expanded_weights.py b/torch/nn/utils/_expanded_weights/linear_expanded_weights.py index e617c79bb1c4..80903782db18 100644 --- a/torch/nn/utils/_expanded_weights/linear_expanded_weights.py +++ b/torch/nn/utils/_expanded_weights/linear_expanded_weights.py @@ -16,7 +16,7 @@ @implements_per_sample_grads(F.linear) class LinearPerSampleGrad(torch.autograd.Function): @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def forward(ctx, _, __, *expanded_args_and_kwargs): if len(expanded_args_and_kwargs[0].shape) <= 1: raise RuntimeError( @@ -36,7 +36,7 @@ def forward(ctx, _, __, *expanded_args_and_kwargs): return output @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def backward(ctx, grad_output): input, weight = ctx.args bias = ctx.kwargs["bias"] diff --git a/torch/nn/utils/_named_member_accessor.py b/torch/nn/utils/_named_member_accessor.py index 111a24ec1863..e815265fec63 100644 --- a/torch/nn/utils/_named_member_accessor.py +++ b/torch/nn/utils/_named_member_accessor.py @@ -77,7 +77,7 @@ def swap_tensor( setattr(module, name, tensor) elif hasattr(module, name): delattr(module, name) - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] return orig_tensor diff --git a/torch/nn/utils/clip_grad.py b/torch/nn/utils/clip_grad.py index 42cf898bfdf0..99c2abe4e56c 100644 --- a/torch/nn/utils/clip_grad.py +++ b/torch/nn/utils/clip_grad.py @@ -41,11 +41,11 @@ def _no_grad(func: Callable[_P, _R]) -> Callable[_P, _R]: def _no_grad_wrapper(*args, **kwargs): with torch.no_grad(): - # pyrefly: ignore # invalid-param-spec + # pyrefly: ignore [invalid-param-spec] return func(*args, **kwargs) functools.update_wrapper(_no_grad_wrapper, func) - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] return _no_grad_wrapper @@ -283,7 +283,7 @@ def clip_grad_value_( clip_value = float(clip_value) grads = [p.grad for p in parameters if p.grad is not None] - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] grouped_grads = _group_tensors_by_device_and_dtype([grads]) for (device, _), ([grads], _) in grouped_grads.items(): diff --git a/torch/nn/utils/memory_format.py b/torch/nn/utils/memory_format.py index 757b0bb272c8..06eb55a02572 100644 --- a/torch/nn/utils/memory_format.py +++ b/torch/nn/utils/memory_format.py @@ -84,7 +84,7 @@ def convert_conv2d_weight_memory_format( ) for child in module.children(): convert_conv2d_weight_memory_format(child, memory_format) - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] return module @@ -164,7 +164,7 @@ def convert_conv3d_weight_memory_format( ) for child in module.children(): convert_conv3d_weight_memory_format(child, memory_format) - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] return module diff --git a/torch/nn/utils/parametrizations.py b/torch/nn/utils/parametrizations.py index 5a48b690cfe0..7706be61e39f 100644 --- a/torch/nn/utils/parametrizations.py +++ b/torch/nn/utils/parametrizations.py @@ -98,7 +98,7 @@ def forward(self, X: torch.Tensor) -> torch.Tensor: ) # Q is now orthogonal (or unitary) of size (..., n, n) if n != k: - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] Q = Q[..., :k] # Q is now the size of the X (albeit perhaps transposed) else: @@ -111,10 +111,10 @@ def forward(self, X: torch.Tensor) -> torch.Tensor: Q = Q * X.diagonal(dim1=-2, dim2=-1).int().unsqueeze(-2) if hasattr(self, "base"): - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] Q = self.base @ Q if transposed: - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] Q = Q.mT return Q # type: ignore[possibly-undefined] diff --git a/torch/nn/utils/parametrize.py b/torch/nn/utils/parametrize.py index ed298dece3ac..88eeb3aaf50c 100644 --- a/torch/nn/utils/parametrize.py +++ b/torch/nn/utils/parametrize.py @@ -179,28 +179,28 @@ def __init__( # Register the tensor(s) if self.is_tensor: - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] if original.dtype != new.dtype: raise ValueError( "When `right_inverse` outputs one tensor, it may not change the dtype.\n" f"original.dtype: {original.dtype}\n" - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] f"right_inverse(original).dtype: {new.dtype}" ) - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] if original.device != new.device: raise ValueError( "When `right_inverse` outputs one tensor, it may not change the device.\n" f"original.device: {original.device}\n" - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] f"right_inverse(original).device: {new.device}" ) # Set the original to original so that the user does not need to re-register the parameter # manually in the optimiser with torch.no_grad(): - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] _maybe_set(original, new) _register_parameter_or_buffer(self, "original", original) else: @@ -401,7 +401,7 @@ def get_parametrized(self) -> Tensor: if torch.jit.is_scripting(): raise RuntimeError("Parametrization is not working with scripting.") parametrization = self.parametrizations[tensor_name] - # pyrefly: ignore # redundant-condition + # pyrefly: ignore [redundant-condition] if _cache_enabled: if torch.jit.is_scripting(): # Scripting @@ -701,7 +701,7 @@ def remove_parametrizations( # Fetch the original tensor assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy parametrizations = module.parametrizations[tensor_name] - # pyrefly: ignore # invalid-argument + # pyrefly: ignore [invalid-argument] if parametrizations.is_tensor: original = parametrizations.original assert isinstance(original, torch.Tensor), "is_tensor promised us a Tensor" diff --git a/torch/nn/utils/prune.py b/torch/nn/utils/prune.py index aa0d5c2e7248..99a1439ec5c8 100644 --- a/torch/nn/utils/prune.py +++ b/torch/nn/utils/prune.py @@ -274,11 +274,11 @@ def __init__(self, *args): if not isinstance(args, Iterable): # only 1 item self._tensor_name = args._tensor_name self.add_pruning_method(args) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] elif len(args) == 1: # only 1 item in a tuple - # pyrefly: ignore # index-error + # pyrefly: ignore [index-error] self._tensor_name = args[0]._tensor_name - # pyrefly: ignore # index-error + # pyrefly: ignore [index-error] self.add_pruning_method(args[0]) else: # manual construction from list or other iterable (or no args) for method in args: @@ -1100,7 +1100,7 @@ def global_unstructured(parameters, pruning_method, importance_scores=None, **kw # flatten importance scores to consider them all at once in global pruning relevant_importance_scores = torch.nn.utils.parameters_to_vector( - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] [ importance_scores.get((module, name), getattr(module, name)) for (module, name) in parameters diff --git a/torch/nn/utils/spectral_norm.py b/torch/nn/utils/spectral_norm.py index 9cf39cc5bda7..d40e3a35e55e 100644 --- a/torch/nn/utils/spectral_norm.py +++ b/torch/nn/utils/spectral_norm.py @@ -332,7 +332,7 @@ def spectral_norm( else: dim = 0 SpectralNorm.apply(module, name, n_power_iterations, dim, eps) - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] return module diff --git a/torch/utils/_contextlib.py b/torch/utils/_contextlib.py index 65e0674f3d48..a21d1bcb0f21 100644 --- a/torch/utils/_contextlib.py +++ b/torch/utils/_contextlib.py @@ -117,7 +117,7 @@ def ctx_factory(): @functools.wraps(func) def decorate_context(*args, **kwargs): - # pyrefly: ignore # bad-context-manager + # pyrefly: ignore [bad-context-manager] with ctx_factory(): return func(*args, **kwargs) diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index a0865b0c9bd7..80dc1776bc35 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -1074,7 +1074,7 @@ def key_get(obj: Any, kp: KeyPath) -> Any: with python_pytree._NODE_REGISTRY_LOCK: - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] python_pytree._cxx_pytree_imported = True args, kwargs = (), {} # type: ignore[var-annotated] for args, kwargs in python_pytree._cxx_pytree_pending_imports: diff --git a/torch/utils/_debug_mode.py b/torch/utils/_debug_mode.py index 4251c983a420..7c9d1a850a46 100644 --- a/torch/utils/_debug_mode.py +++ b/torch/utils/_debug_mode.py @@ -262,7 +262,7 @@ def __enter__(self): self.module_tracker.__enter__() # type: ignore[attribute, union-attr] return self - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def __exit__(self, *args): super().__exit__(*args) if self.record_nn_module: diff --git a/torch/utils/_device.py b/torch/utils/_device.py index 8a2f409c728c..bc6072517d5b 100644 --- a/torch/utils/_device.py +++ b/torch/utils/_device.py @@ -60,7 +60,7 @@ def _device_constructors(): # NB: This is directly called from C++ in torch/csrc/Device.cpp class DeviceContext(TorchFunctionMode): def __init__(self, device): - # pyrefly: ignore # read-only + # pyrefly: ignore [read-only] self.device = torch.device(device) def __enter__(self): diff --git a/torch/utils/_functools.py b/torch/utils/_functools.py index 37f0a1d17a22..ac8f0d3f111f 100644 --- a/torch/utils/_functools.py +++ b/torch/utils/_functools.py @@ -35,12 +35,12 @@ def wrap(self: _C, *args: _P.args, **kwargs: _P.kwargs) -> _T: if not (cache := getattr(self, cache_name, None)): cache = {} setattr(self, cache_name, cache) - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] cached_value = cache.get(args, _cache_sentinel) if cached_value is not _cache_sentinel: return cached_value value = f(self, *args, **kwargs) - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] cache[args] = value return value diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index 759e0e611384..2ed1ba60a593 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -708,7 +708,7 @@ def __init_subclass__(cls) -> NoReturn: def __new__( cls: type[Self], sequence: Iterable[_T_co], - # pyrefly: ignore # bad-function-definition + # pyrefly: ignore [bad-function-definition] dict: dict[str, Any] = ..., ) -> Self: raise NotImplementedError @@ -755,7 +755,7 @@ def _tuple_flatten_with_keys( d: tuple[T, ...], ) -> tuple[list[tuple[KeyEntry, T]], Context]: values, context = _tuple_flatten(d) - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] return [(SequenceKey(i), v) for i, v in enumerate(values)], context @@ -769,7 +769,7 @@ def _list_flatten(d: list[T]) -> tuple[list[T], Context]: def _list_flatten_with_keys(d: list[T]) -> tuple[list[tuple[KeyEntry, T]], Context]: values, context = _list_flatten(d) - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] return [(SequenceKey(i), v) for i, v in enumerate(values)], context @@ -785,7 +785,7 @@ def _dict_flatten_with_keys( d: dict[Any, T], ) -> tuple[list[tuple[KeyEntry, T]], Context]: values, context = _dict_flatten(d) - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] return [(MappingKey(k), v) for k, v in zip(context, values)], context @@ -801,7 +801,7 @@ def _namedtuple_flatten_with_keys( d: NamedTuple, ) -> tuple[list[tuple[KeyEntry, Any]], Context]: values, context = _namedtuple_flatten(d) - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] return ( [(GetAttrKey(field), v) for field, v in zip(context._fields, values)], context, @@ -851,7 +851,7 @@ def _ordereddict_flatten_with_keys( d: OrderedDict[Any, T], ) -> tuple[list[tuple[KeyEntry, T]], Context]: values, context = _ordereddict_flatten(d) - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] return [(MappingKey(k), v) for k, v in zip(context, values)], context @@ -876,7 +876,7 @@ def _defaultdict_flatten_with_keys( ) -> tuple[list[tuple[KeyEntry, T]], Context]: values, context = _defaultdict_flatten(d) _, dict_context = context - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] return [(MappingKey(k), v) for k, v in zip(dict_context, values)], context @@ -925,7 +925,7 @@ def _deque_flatten_with_keys( d: deque[T], ) -> tuple[list[tuple[KeyEntry, T]], Context]: values, context = _deque_flatten(d) - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] return [(SequenceKey(i), v) for i, v in enumerate(values)], context @@ -1827,7 +1827,7 @@ def enum_object_hook(obj: dict[str, Any]) -> Union[Enum, dict[str, Any]]: for attr in classname.split("."): enum_cls = getattr(enum_cls, attr) enum_cls = cast(type[Enum], enum_cls) - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] return enum_cls[obj["name"]] return obj diff --git a/torch/utils/_strobelight/cli_function_profiler.py b/torch/utils/_strobelight/cli_function_profiler.py index 7825f784e2f3..024cd93b3578 100644 --- a/torch/utils/_strobelight/cli_function_profiler.py +++ b/torch/utils/_strobelight/cli_function_profiler.py @@ -305,7 +305,7 @@ def strobelight_inner( ) -> Callable[_P, Optional[_R]]: @functools.wraps(work_function) def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] return profiler.profile(work_function, *args, **kwargs) return wrapper_function diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index d7f65dd0c16e..d152b719bcde 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -105,7 +105,7 @@ def _keep_float( ) -> Callable[[Unpack[_Ts]], Union[_T, sympy.Float]]: @functools.wraps(f) def inner(*args: Unpack[_Ts]) -> Union[_T, sympy.Float]: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] r: Union[_T, sympy.Float] = f(*args) if any(isinstance(a, sympy.Float) for a in args) and not isinstance( r, sympy.Float @@ -113,7 +113,7 @@ def inner(*args: Unpack[_Ts]) -> Union[_T, sympy.Float]: r = sympy.Float(float(r)) return r - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] return inner @@ -200,12 +200,12 @@ class FloorDiv(sympy.Function): @property def base(self) -> sympy.Basic: - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] return self.args[0] @property def divisor(self) -> sympy.Basic: - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] return self.args[1] def _sympystr(self, printer: sympy.printing.StrPrinter) -> str: @@ -374,7 +374,7 @@ def eval( return None def _eval_is_nonnegative(self) -> Optional[bool]: - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] p, q = self.args[:2] return fuzzy_eq(p.is_nonnegative, q.is_nonnegative) # type: ignore[attr-defined] @@ -455,7 +455,7 @@ def eval(cls, p: sympy.Expr, q: sympy.Expr) -> Optional[sympy.Expr]: # - floor(p / q) = 0 # - p % q = p - floor(p / q) * q = p less = p < q - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] if less.is_Boolean and bool(less) and r.is_positive: return p @@ -472,11 +472,11 @@ def _eval_is_nonpositive(self) -> Optional[bool]: return True if self.args[1].is_negative else None # type: ignore[attr-defined] def _ccode(self, printer): - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] p = printer.parenthesize(self.args[0], PRECEDENCE["Atom"] - 0.5) - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] q = printer.parenthesize(self.args[1], PRECEDENCE["Atom"] - 0.5) - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] abs_q = str(q) if self.args[1].is_positive else f"abs({q})" return f"({p} % {q}) < 0 ? {p} % {q} + {abs_q} : {p} % {q}" @@ -559,7 +559,7 @@ def eval(cls, number): return sympy.Integer(math.ceil(float(number))) def _ccode(self, printer): - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] number = printer.parenthesize(self.args[0], self.args[0].precedence - 0.5) return f"ceil({number})" @@ -830,7 +830,7 @@ def do(ai, a): if not cond: return ai.func(*[do(i, a) for i in ai.args], evaluate=False) if isinstance(ai, cls): - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] return ai.func(*[do(i, a) for i in ai.args if i != a], evaluate=False) return a @@ -1008,7 +1008,7 @@ def _eval_is_nonnegative(self): # type:ignore[override] return fuzzy_or(a.is_nonnegative for a in self.args) # type: ignore[attr-defined] def _eval_is_negative(self): # type:ignore[override] - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] return fuzzy_and(a.is_negative for a in self.args) @@ -1027,7 +1027,7 @@ def _eval_is_nonnegative(self): # type:ignore[override] return fuzzy_and(a.is_nonnegative for a in self.args) # type: ignore[attr-defined] def _eval_is_negative(self): # type:ignore[override] - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] return fuzzy_or(a.is_negative for a in self.args) @@ -1165,9 +1165,9 @@ def eval(cls, base, divisor): return sympy.Float(int(base) / int(divisor)) def _ccode(self, printer): - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] base = printer.parenthesize(self.args[0], PRECEDENCE["Atom"] - 0.5) - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] divisor = printer.parenthesize(self.args[1], PRECEDENCE["Atom"] - 0.5) return f"((int){base}/(int){divisor})" @@ -1331,16 +1331,16 @@ class Identity(sympy.Function): precedence = 10 def __repr__(self): # type: ignore[override] - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] return f"Identity({self.args[0]})" def _sympystr(self, printer): """Controls how sympy's StrPrinter prints this""" - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] return f"({printer.doprint(self.args[0])})" def _eval_is_real(self): - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] return self.args[0].is_real def _eval_is_integer(self): @@ -1348,15 +1348,15 @@ def _eval_is_integer(self): def _eval_expand_identity(self, **hints): # Removes the identity op. - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] return self.args[0] def __int__(self) -> int: - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] return int(self.args[0]) def __float__(self) -> float: - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] return float(self.args[0]) diff --git a/torch/utils/_sympy/numbers.py b/torch/utils/_sympy/numbers.py index 01aee8b29f10..f675de25ad8a 100644 --- a/torch/utils/_sympy/numbers.py +++ b/torch/utils/_sympy/numbers.py @@ -9,7 +9,7 @@ from sympy.core.singleton import S, Singleton -# pyrefly: ignore # invalid-inheritance +# pyrefly: ignore [invalid-inheritance] class IntInfinity(Number, metaclass=Singleton): r"""Positive integer infinite quantity. @@ -204,7 +204,7 @@ def ceiling(self): int_oo = S.IntInfinity -# pyrefly: ignore # invalid-inheritance +# pyrefly: ignore [invalid-inheritance] class NegativeIntInfinity(Number, metaclass=Singleton): """Negative integer infinite quantity. diff --git a/torch/utils/_sympy/printers.py b/torch/utils/_sympy/printers.py index 475eed67c381..526443577b3f 100644 --- a/torch/utils/_sympy/printers.py +++ b/torch/utils/_sympy/printers.py @@ -66,7 +66,7 @@ def _print_Float(self, expr: sympy.Expr) -> str: # NB: this pow by natural, you should never have used builtin sympy.pow # for FloatPow, and a symbolic exponent should be PowByNatural. These # means exp is guaranteed to be integer. - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def _print_Pow(self, expr: sympy.Expr) -> str: base, exp = expr.args if exp != int(exp): diff --git a/torch/utils/_sympy/reference.py b/torch/utils/_sympy/reference.py index 9012f80cfc6e..c3a3878f3c8c 100644 --- a/torch/utils/_sympy/reference.py +++ b/torch/utils/_sympy/reference.py @@ -176,7 +176,7 @@ def sqrt(x): @staticmethod def pow(a, b): - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] return _keep_float(FloatPow)(a, b) @staticmethod diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index b0a99dd4887c..ef7c1696480b 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -126,9 +126,9 @@ def vr_is_expr(vr: ValueRanges[_T]) -> TypeGuard[ValueRanges[sympy.Expr]]: class ValueRanges(Generic[_T]): if TYPE_CHECKING: # ruff doesn't understand circular references but mypy does - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] ExprVR = ValueRanges[sympy.Expr] # noqa: F821 - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] BoolVR = ValueRanges[SympyBoolean] # noqa: F821 AllVR = Union[ExprVR, BoolVR] @@ -484,7 +484,7 @@ def constant(value, dtype): @staticmethod def to_dtype(a, dtype, src_dtype=None): if dtype == torch.float64: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] return ValueRanges.increasing_map(a, ToFloat) elif dtype == torch.bool: return ValueRanges.unknown_bool() @@ -494,7 +494,7 @@ def to_dtype(a, dtype, src_dtype=None): @staticmethod def trunc_to_int(a, dtype): - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] return ValueRanges.increasing_map(a, TruncToInt) @staticmethod @@ -652,7 +652,7 @@ def int_truediv(a, b): return ValueRanges.coordinatewise_monotone_map( a, b, - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] _keep_float(IntTrueDiv), ) @@ -668,7 +668,7 @@ def truediv(a, b): return ValueRanges.coordinatewise_monotone_map( a, b, - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] _keep_float(FloatTrueDiv), ) @@ -748,7 +748,7 @@ def pow_by_natural(cls, a, b): # We should know that b >= 0 but we may have forgotten this fact due # to replacements, so don't assert it, but DO clamp it to prevent # degenerate problems - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] return ValueRanges.coordinatewise_increasing_map( a, b & ValueRanges(0, int_oo), PowByNatural ) @@ -915,7 +915,7 @@ def round_decimal(cls, number, ndigits): @classmethod def round_to_int(cls, number, dtype): - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] return ValueRanges.increasing_map(number, RoundToInt) # It's used in some models on symints @@ -1032,7 +1032,7 @@ def atan(x): @staticmethod def trunc(x): - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] return ValueRanges.increasing_map(x, TruncToFloat) diff --git a/torch/utils/benchmark/examples/sparse/compare.py b/torch/utils/benchmark/examples/sparse/compare.py index 91e30e68054a..fa00fb1818cd 100644 --- a/torch/utils/benchmark/examples/sparse/compare.py +++ b/torch/utils/benchmark/examples/sparse/compare.py @@ -63,7 +63,7 @@ def generate_coo_data(size, sparse_dim, nnz, dtype, device): indices = torch.rand(sparse_dim, nnz, device=device) indices.mul_(torch.tensor(size[:sparse_dim]).unsqueeze(1).to(indices)) indices = indices.to(torch.long) - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] values = torch.rand([nnz, ], dtype=dtype, device=device) return indices, values diff --git a/torch/utils/benchmark/utils/compile.py b/torch/utils/benchmark/utils/compile.py index 9127b14c99b3..777120c81105 100644 --- a/torch/utils/benchmark/utils/compile.py +++ b/torch/utils/benchmark/utils/compile.py @@ -170,7 +170,7 @@ def bench_all( _disable_tensor_cores() table.append([ ("Training" if optimizer else "Inference"), - # pyrefly: ignore # redundant-condition + # pyrefly: ignore [redundant-condition] backend if backend else "-", mode if mode is not None else "-", f"{compilation_time} ms " if compilation_time else "-", @@ -191,5 +191,5 @@ def bench_all( ]) - # pyrefly: ignore # not-callable + # pyrefly: ignore [not-callable] return tabulate(table, headers=field_names, tablefmt="github") diff --git a/torch/utils/benchmark/utils/cpp_jit.py b/torch/utils/benchmark/utils/cpp_jit.py index 27699d9ee21e..969eb6abb695 100644 --- a/torch/utils/benchmark/utils/cpp_jit.py +++ b/torch/utils/benchmark/utils/cpp_jit.py @@ -35,7 +35,7 @@ def _get_build_root() -> str: global _BUILD_ROOT if _BUILD_ROOT is None: _BUILD_ROOT = _make_temp_dir(prefix="benchmark_utils_jit_build") - # pyrefly: ignore # missing-argument + # pyrefly: ignore [missing-argument] atexit.register(shutil.rmtree, _BUILD_ROOT) return _BUILD_ROOT diff --git a/torch/utils/benchmark/utils/sparse_fuzzer.py b/torch/utils/benchmark/utils/sparse_fuzzer.py index 735b40c3b5e4..cd84900c5b43 100644 --- a/torch/utils/benchmark/utils/sparse_fuzzer.py +++ b/torch/utils/benchmark/utils/sparse_fuzzer.py @@ -92,7 +92,7 @@ def sparse_tensor_constructor(size, dtype, sparse_dim, nnz, is_coalesced): return x def _make_tensor(self, params, state): - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] size, _, _ = self._get_size_and_steps(params) density = params['density'] nnz = math.ceil(sum(size) * density) @@ -102,10 +102,10 @@ def _make_tensor(self, params, state): is_coalesced = params['coalesced'] sparse_dim = params['sparse_dim'] if self._sparse_dim else len(size) sparse_dim = min(sparse_dim, len(size)) - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] tensor = self.sparse_tensor_constructor(size, self._dtype, sparse_dim, nnz, is_coalesced) - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] if self._cuda: tensor = tensor.cuda() sparse_dim = tensor.sparse_dim() @@ -121,7 +121,7 @@ def _make_tensor(self, params, state): "sparse_dim": sparse_dim, "dense_dim": dense_dim, "is_hybrid": is_hybrid, - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] "dtype": str(self._dtype), } return tensor, properties diff --git a/torch/utils/benchmark/utils/timer.py b/torch/utils/benchmark/utils/timer.py index acd9e5f96205..3dc17edeb796 100644 --- a/torch/utils/benchmark/utils/timer.py +++ b/torch/utils/benchmark/utils/timer.py @@ -234,7 +234,7 @@ def __init__( setup = textwrap.dedent(setup) setup = (setup[1:] if setup and setup[0] == "\n" else setup).rstrip() - # pyrefly: ignore # bad-instantiation + # pyrefly: ignore [bad-instantiation] self._timer = self._timer_cls( stmt=stmt, setup=setup, diff --git a/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py b/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py index e80416482271..9080f8272160 100644 --- a/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py +++ b/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py @@ -449,13 +449,13 @@ def construct(self) -> str: load_lines = [] for name, wrapped_value in self._globals.items(): if wrapped_value.setup is not None: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] load_lines.append(textwrap.dedent(wrapped_value.setup)) if wrapped_value.serialization == Serialization.PICKLE: path = os.path.join(self._data_dir, f"{name}.pkl") load_lines.append( - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] f"with open({repr(path)}, 'rb') as f:\n {name} = pickle.load(f)") with open(path, "wb") as f: pickle.dump(wrapped_value.value, f) @@ -465,13 +465,13 @@ def construct(self) -> str: # TODO: Figure out if we can use torch.serialization.add_safe_globals here # Using weights_only=False after the change in # https://dev-discuss.pytorch.org/t/bc-breaking-change-torch-load-is-being-flipped-to-use-weights-only-true-by-default-in-the-nightlies-after-137602/2573 - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] load_lines.append(f"{name} = torch.load({repr(path)}, weights_only=False)") torch.save(wrapped_value.value, path) elif wrapped_value.serialization == Serialization.TORCH_JIT: path = os.path.join(self._data_dir, f"{name}.pt") - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] load_lines.append(f"{name} = torch.jit.load({repr(path)})") with open(path, "wb") as f: torch.jit.save(wrapped_value.value, f) # type: ignore[no-untyped-call] diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py index 005314377929..d3c41b8fb9e7 100644 --- a/torch/utils/checkpoint.py +++ b/torch/utils/checkpoint.py @@ -222,7 +222,7 @@ def _get_autocast_kwargs(device_type="cuda"): class CheckpointFunction(torch.autograd.Function): @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def forward(ctx, run_function, preserve_rng_state, *args): check_backward_validity(args) ctx.run_function = run_function @@ -785,7 +785,7 @@ def __init__(self): class _NoopSaveInputs(torch.autograd.Function): @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def forward(*args): return torch.empty((0,)) @@ -1008,7 +1008,7 @@ def get_context_manager(self): def logging_mode(): with LoggingTensorMode(), \ capture_logs(True, python_tb=True, script_tb=True, cpp_tb=cpp_tb) as logs_and_tb: - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self.logs, self.tbs = logs_and_tb yield logs_and_tb return logging_mode() diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index 2fa5eda7fffa..b069167cc6b5 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -788,7 +788,7 @@ def unix_wrap_ninja_compile(sources, # Use absolute path for output_dir so that the object file paths # (`objects`) get generated with absolute paths. - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] output_dir = os.path.abspath(output_dir) # See Note [Absolute include_dirs] @@ -979,7 +979,7 @@ def win_wrap_ninja_compile(sources, is_standalone=False): if not self.compiler.initialized: self.compiler.initialize() - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] output_dir = os.path.abspath(output_dir) # Note [Absolute include_dirs] @@ -2573,7 +2573,7 @@ def _get_num_workers(verbose: bool) -> Optional[int]: def _get_vc_env(vc_arch: str) -> dict[str, str]: try: from setuptools import distutils # type: ignore[attr-defined] - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] return distutils._msvccompiler._get_vc_env(vc_arch) except AttributeError: try: diff --git a/torch/utils/data/_utils/collate.py b/torch/utils/data/_utils/collate.py index b9a04644f331..efe50ba22e8e 100644 --- a/torch/utils/data/_utils/collate.py +++ b/torch/utils/data/_utils/collate.py @@ -204,7 +204,7 @@ def collate( # check to make sure that the elements in batch have consistent size it = iter(batch) elem_size = len(next(it)) - # pyrefly: ignore # not-iterable + # pyrefly: ignore [not-iterable] if not all(len(elem) == elem_size for elem in it): raise RuntimeError("each element in list of batch should be of equal size") transposed = list(zip(*batch)) # It may be accessed twice, so we use a list. diff --git a/torch/utils/data/_utils/pin_memory.py b/torch/utils/data/_utils/pin_memory.py index c0a9416c45fe..223962fc04ba 100644 --- a/torch/utils/data/_utils/pin_memory.py +++ b/torch/utils/data/_utils/pin_memory.py @@ -70,7 +70,7 @@ def pin_memory(data, device=None): return clone else: return type(data)( - # pyrefly: ignore # bad-argument-count + # pyrefly: ignore [bad-argument-count] {k: pin_memory(sample, device) for k, sample in data.items()} ) # type: ignore[call-arg] except TypeError: diff --git a/torch/utils/data/datapipes/_typing.py b/torch/utils/data/datapipes/_typing.py index c8972b005dd9..32777cfd01d3 100644 --- a/torch/utils/data/datapipes/_typing.py +++ b/torch/utils/data/datapipes/_typing.py @@ -265,7 +265,7 @@ def issubtype_of_instance(self, other): # Default type for DataPipe without annotation _T_co = TypeVar("_T_co", covariant=True) -# pyrefly: ignore # invalid-annotation +# pyrefly: ignore [invalid-annotation] _DEFAULT_TYPE = _DataPipeType(Generic[_T_co]) @@ -284,7 +284,7 @@ def __new__(cls, name, bases, namespace, **kwargs): return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload] # TODO: the statements below are not reachable by design as there is a bug and typing is low priority for now. - # pyrefly: ignore # no-access + # pyrefly: ignore [no-access] cls.__origin__ = None if "type" in namespace: return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload] diff --git a/torch/utils/data/datapipes/dataframe/dataframes.py b/torch/utils/data/datapipes/dataframe/dataframes.py index f5a4ebaf2703..8908721bccd7 100644 --- a/torch/utils/data/datapipes/dataframe/dataframes.py +++ b/torch/utils/data/datapipes/dataframe/dataframes.py @@ -80,7 +80,7 @@ def __str__(self): def _ops_str(self): res = "" - # pyrefly: ignore # not-iterable + # pyrefly: ignore [not-iterable] for op in self.ctx["operations"]: if len(res) > 0: res += "\n" @@ -90,7 +90,7 @@ def _ops_str(self): def __getstate__(self): # TODO(VitalyFedyunin): Currently can't pickle (why?) self.ctx["schema_df"] = None - # pyrefly: ignore # not-iterable + # pyrefly: ignore [not-iterable] for var in self.ctx["variables"]: var.calculated_value = None state = {} @@ -114,13 +114,13 @@ def __getitem__(self, key): return CaptureGetItem(self, key, ctx=self.ctx) def __setitem__(self, key, value): - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] self.ctx["operations"].append(CaptureSetItem(self, key, value, ctx=self.ctx)) def __add__(self, add_val): res = CaptureAdd(self, add_val, ctx=self.ctx) var = CaptureVariable(res, ctx=self.ctx) - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] self.ctx["operations"].append( CaptureVariableAssign(variable=var, value=res, ctx=self.ctx) ) @@ -129,7 +129,7 @@ def __add__(self, add_val): def __sub__(self, add_val): res = CaptureSub(self, add_val, ctx=self.ctx) var = CaptureVariable(res, ctx=self.ctx) - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] self.ctx["operations"].append( CaptureVariableAssign(variable=var, value=res, ctx=self.ctx) ) @@ -139,19 +139,19 @@ def __mul__(self, add_val): res = CaptureMul(self, add_val, ctx=self.ctx) var = CaptureVariable(res, ctx=self.ctx) t = CaptureVariableAssign(variable=var, value=res, ctx=self.ctx) - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] self.ctx["operations"].append(t) return var def _is_context_empty(self): - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] return len(self.ctx["operations"]) == 0 and len(self.ctx["variables"]) == 0 def apply_ops_2(self, dataframe): # TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer) - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] self.ctx["variables"][0].calculated_value = dataframe - # pyrefly: ignore # not-iterable + # pyrefly: ignore [not-iterable] for op in self.ctx["operations"]: op.execute() @@ -184,7 +184,7 @@ def __call__(self, *args, **kwargs): res = CaptureCall(self, ctx=self.ctx, args=args, kwargs=kwargs) var = CaptureVariable(None, ctx=self.ctx) t = CaptureVariableAssign(ctx=self.ctx, variable=var, value=res) - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] self.ctx["operations"].append(t) return var @@ -283,9 +283,9 @@ def execute(self): def apply_ops(self, dataframe): # TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer) - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] self.ctx["variables"][0].calculated_value = dataframe - # pyrefly: ignore # not-iterable + # pyrefly: ignore [not-iterable] for op in self.ctx["operations"]: op.execute() return self.calculated_value @@ -385,7 +385,7 @@ def get_val(capture): class CaptureInitial(CaptureVariable): def __init__(self, schema_df=None): - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] new_ctx: dict[str, list[Any]] = { "operations": [], "variables": [], @@ -401,7 +401,7 @@ class CaptureDataFrame(CaptureInitial): class CaptureDataFrameWithDataPipeOps(CaptureDataFrame): def as_datapipe(self): - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] return DataFrameTracedOps(self.ctx["variables"][0].source_datapipe, self) def raw_iterator(self): diff --git a/torch/utils/data/datapipes/dataframe/datapipes.py b/torch/utils/data/datapipes/dataframe/datapipes.py index 0c1b416e99c2..edb08d77a81d 100644 --- a/torch/utils/data/datapipes/dataframe/datapipes.py +++ b/torch/utils/data/datapipes/dataframe/datapipes.py @@ -92,7 +92,7 @@ def __iter__(self): size = None all_buffer = [] filter_res = [] - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] for df in self.source_datapipe: if size is None: size = len(df.index) diff --git a/torch/utils/data/datapipes/datapipe.py b/torch/utils/data/datapipes/datapipe.py index 22e324e0ae2c..f0811ac81b61 100644 --- a/torch/utils/data/datapipes/datapipe.py +++ b/torch/utils/data/datapipes/datapipe.py @@ -135,7 +135,7 @@ class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta): _fast_forward_iterator: Optional[Iterator] = None def __iter__(self) -> Iterator[_T_co]: - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] return self def __getattr__(self, attribute_name): @@ -380,7 +380,7 @@ def __getstate__(self): value = pickle.dumps(self._datapipe) except Exception: if HAS_DILL: - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] value = dill.dumps(self._datapipe) use_dill = True else: @@ -390,7 +390,7 @@ def __getstate__(self): def __setstate__(self, state): value, use_dill = state if use_dill: - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] self._datapipe = dill.loads(value) else: self._datapipe = pickle.loads(value) @@ -407,7 +407,7 @@ def __len__(self): class _IterDataPipeSerializationWrapper(_DataPipeSerializationWrapper, IterDataPipe): def __init__(self, datapipe: IterDataPipe[_T_co]): super().__init__(datapipe) - # pyrefly: ignore # invalid-type-var + # pyrefly: ignore [invalid-type-var] self._datapipe_iter: Optional[Iterator[_T_co]] = None def __iter__(self) -> "_IterDataPipeSerializationWrapper": diff --git a/torch/utils/data/datapipes/iter/callable.py b/torch/utils/data/datapipes/iter/callable.py index bfff0d19f4cf..1ce1c9c07196 100644 --- a/torch/utils/data/datapipes/iter/callable.py +++ b/torch/utils/data/datapipes/iter/callable.py @@ -118,7 +118,7 @@ def _apply_fn(self, data): for idx in sorted(self.input_col[1:], reverse=True): del data[idx] else: - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] data[self.input_col] = res else: if self.output_col == -1: diff --git a/torch/utils/data/datapipes/iter/combinatorics.py b/torch/utils/data/datapipes/iter/combinatorics.py index bd10ff2a6785..ff76e995f0ad 100644 --- a/torch/utils/data/datapipes/iter/combinatorics.py +++ b/torch/utils/data/datapipes/iter/combinatorics.py @@ -43,7 +43,7 @@ def __init__( "Sampler class requires input datapipe implemented `__len__`" ) super().__init__() - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self.datapipe = datapipe self.sampler_args = () if sampler_args is None else sampler_args self.sampler_kwargs = {} if sampler_kwargs is None else sampler_kwargs diff --git a/torch/utils/data/datapipes/iter/combining.py b/torch/utils/data/datapipes/iter/combining.py index 22f27327b2ee..b6dda4552c24 100644 --- a/torch/utils/data/datapipes/iter/combining.py +++ b/torch/utils/data/datapipes/iter/combining.py @@ -59,7 +59,7 @@ def __iter__(self) -> Iterator: def __len__(self) -> int: if all(isinstance(dp, Sized) for dp in self.datapipes): - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] return sum(len(dp) for dp in self.datapipes) else: raise TypeError(f"{type(self).__name__} instance doesn't have valid length") @@ -180,7 +180,7 @@ def __init__( self._child_stop: list[bool] = [True for _ in range(num_instances)] def __len__(self): - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] return len(self.main_datapipe) def get_next_element_by_instance(self, instance_id: int): @@ -240,7 +240,7 @@ def is_every_instance_exhausted(self) -> bool: return self.end_ptr is not None and all(self._child_stop) def get_length_by_instance(self, instance_id: int) -> int: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] return len(self.main_datapipe) def reset(self) -> None: @@ -327,7 +327,7 @@ def __init__(self, main_datapipe: IterDataPipe, instance_id: int): if not isinstance(main_datapipe, _ContainerTemplate): raise AssertionError("main_datapipe must implement _ContainerTemplate") - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self.main_datapipe: IterDataPipe = main_datapipe self.instance_id = instance_id @@ -454,7 +454,7 @@ def __init__( drop_none: bool, buffer_size: int, ): - # pyrefly: ignore # invalid-type-var + # pyrefly: ignore [invalid-type-var] self.main_datapipe = datapipe self._datapipe_iterator: Optional[Iterator[Any]] = None self.num_instances = num_instances @@ -466,9 +466,9 @@ def __init__( UserWarning, ) self.current_buffer_usage = 0 - # pyrefly: ignore # invalid-type-var + # pyrefly: ignore [invalid-type-var] self.child_buffers: list[deque[_T_co]] = [deque() for _ in range(num_instances)] - # pyrefly: ignore # invalid-type-var + # pyrefly: ignore [invalid-type-var] self.classifier_fn = classifier_fn self.drop_none = drop_none self.main_datapipe_exhausted = False @@ -706,7 +706,7 @@ def __iter__(self) -> Iterator[tuple[_T_co]]: def __len__(self) -> int: if all(isinstance(dp, Sized) for dp in self.datapipes): - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] return min(len(dp) for dp in self.datapipes) else: raise TypeError(f"{type(self).__name__} instance doesn't have valid length") diff --git a/torch/utils/data/datapipes/iter/grouping.py b/torch/utils/data/datapipes/iter/grouping.py index 9bd6ab7f819d..865feb9953e3 100644 --- a/torch/utils/data/datapipes/iter/grouping.py +++ b/torch/utils/data/datapipes/iter/grouping.py @@ -204,9 +204,9 @@ def __init__( drop_remaining: bool = False, ): _check_unpickable_fn(group_key_fn) - # pyrefly: ignore # invalid-type-var + # pyrefly: ignore [invalid-type-var] self.datapipe = datapipe - # pyrefly: ignore # invalid-type-var + # pyrefly: ignore [invalid-type-var] self.group_key_fn = group_key_fn self.keep_key = keep_key @@ -218,14 +218,14 @@ def __init__( if group_size is not None and buffer_size is not None: if not (0 < group_size <= buffer_size): raise AssertionError("group_size must be > 0 and <= buffer_size") - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self.guaranteed_group_size = group_size if guaranteed_group_size is not None: if group_size is None or not (0 < guaranteed_group_size <= group_size): raise AssertionError( "guaranteed_group_size must be > 0 and <= group_size and group_size must be set" ) - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self.guaranteed_group_size = guaranteed_group_size self.drop_remaining = drop_remaining self.wrapper_class = DataChunk diff --git a/torch/utils/data/datapipes/map/callable.py b/torch/utils/data/datapipes/map/callable.py index 983ef41748d7..3696d34b2a81 100644 --- a/torch/utils/data/datapipes/map/callable.py +++ b/torch/utils/data/datapipes/map/callable.py @@ -60,7 +60,7 @@ def __init__( self.fn = fn # type: ignore[assignment] def __len__(self) -> int: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] return len(self.datapipe) def __getitem__(self, index) -> _T_co: diff --git a/torch/utils/data/datapipes/map/combinatorics.py b/torch/utils/data/datapipes/map/combinatorics.py index b49619c12fd7..4876ce3fd1cb 100644 --- a/torch/utils/data/datapipes/map/combinatorics.py +++ b/torch/utils/data/datapipes/map/combinatorics.py @@ -64,7 +64,7 @@ def __init__( ) -> None: super().__init__() self.datapipe = datapipe - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] self.indices = list(range(len(datapipe))) if indices is None else indices self._enabled = True self._seed = None @@ -96,7 +96,7 @@ def reset(self) -> None: self._shuffled_indices = self._rng.sample(self.indices, len(self.indices)) def __len__(self) -> int: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] return len(self.datapipe) def __getstate__(self): diff --git a/torch/utils/data/datapipes/map/combining.py b/torch/utils/data/datapipes/map/combining.py index b4cb1add714f..21a412ff9160 100644 --- a/torch/utils/data/datapipes/map/combining.py +++ b/torch/utils/data/datapipes/map/combining.py @@ -49,16 +49,16 @@ def __init__(self, *datapipes: MapDataPipe): def __getitem__(self, index) -> _T_co: # type: ignore[type-var] offset = 0 for dp in self.datapipes: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] if index - offset < len(dp): return dp[index - offset] else: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] offset += len(dp) raise IndexError(f"Index {index} is out of range.") def __len__(self) -> int: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] return sum(len(dp) for dp in self.datapipes) @@ -105,5 +105,5 @@ def __getitem__(self, index) -> tuple[_T_co, ...]: return tuple(res) def __len__(self) -> int: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] return min(len(dp) for dp in self.datapipes) diff --git a/torch/utils/data/datapipes/utils/common.py b/torch/utils/data/datapipes/utils/common.py index 6edcee5e35b2..2390434c3ef5 100644 --- a/torch/utils/data/datapipes/utils/common.py +++ b/torch/utils/data/datapipes/utils/common.py @@ -196,7 +196,7 @@ def onerror(err: OSError): if match_masks(fname, masks): yield path else: - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] for path, dirs, files in os.walk(root, onerror=onerror): if abspath: path = os.path.abspath(path) diff --git a/torch/utils/data/datapipes/utils/snapshot.py b/torch/utils/data/datapipes/utils/snapshot.py index 5d0f1c0dc84d..42aec1aa308a 100644 --- a/torch/utils/data/datapipes/utils/snapshot.py +++ b/torch/utils/data/datapipes/utils/snapshot.py @@ -43,7 +43,7 @@ def _simple_graph_snapshot_restoration( # simple fast-forwarding. Therefore, we need to call `reset` twice, because if `SnapshotState` is `Restored`, # the first reset will not actually reset. datapipe.reset() # This ensures `SnapshotState` is `Iterating` by this point, even if it was `Restored`. - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] apply_random_seed(datapipe, rng) remainder = n_iterations diff --git a/torch/utils/data/distributed.py b/torch/utils/data/distributed.py index a7f8b61beabe..b2f4eb04e8e2 100644 --- a/torch/utils/data/distributed.py +++ b/torch/utils/data/distributed.py @@ -137,7 +137,7 @@ def __iter__(self) -> Iterator[_T_co]: f"Number of subsampled indices ({len(indices)}) does not match num_samples ({self.num_samples})" ) - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] return iter(indices) def __len__(self) -> int: diff --git a/torch/utils/data/graph.py b/torch/utils/data/graph.py index 63ac99c49268..a08421f9b68d 100644 --- a/torch/utils/data/graph.py +++ b/torch/utils/data/graph.py @@ -72,7 +72,7 @@ def reduce_hook(obj): p.dump(scan_obj) except (pickle.PickleError, AttributeError, TypeError): if dill_available(): - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] d.dump(scan_obj) else: raise diff --git a/torch/utils/file_baton.py b/torch/utils/file_baton.py index b493441db23a..c7ce437ab9bf 100644 --- a/torch/utils/file_baton.py +++ b/torch/utils/file_baton.py @@ -31,7 +31,7 @@ def try_acquire(self): True if the file could be created, else False. """ try: - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self.fd = os.open(self.lock_file_path, os.O_CREAT | os.O_EXCL) return True except FileExistsError: diff --git a/torch/utils/flop_counter.py b/torch/utils/flop_counter.py index c665bb634c5f..8de220f58dd8 100644 --- a/torch/utils/flop_counter.py +++ b/torch/utils/flop_counter.py @@ -149,7 +149,7 @@ def conv_flop_count( @register_flop_formula([aten.convolution, aten._convolution, aten.cudnn_convolution, aten._slow_conv2d_forward]) def conv_flop(x_shape, w_shape, _bias, _stride, _padding, _dilation, transposed, *args, out_shape=None, **kwargs) -> int: """Count flops for convolution.""" - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed) diff --git a/torch/utils/hooks.py b/torch/utils/hooks.py index a4431c8cc349..9ee3dbe18e9a 100644 --- a/torch/utils/hooks.py +++ b/torch/utils/hooks.py @@ -145,7 +145,7 @@ def hook(grad_input, _): res = out - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self.grad_outputs = None return self._unpack_none(self.input_tensors_index, res) diff --git a/torch/utils/model_dump/__init__.py b/torch/utils/model_dump/__init__.py index 253301b31121..9b39c303ac39 100644 --- a/torch/utils/model_dump/__init__.py +++ b/torch/utils/model_dump/__init__.py @@ -237,7 +237,7 @@ def get_model_info( with zipfile.ZipFile(path_or_file) as zf: path_prefix = None zip_files = [] - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] for zi in zf.infolist(): prefix = re.sub("/.*", "", zi.filename) if path_prefix is None: @@ -392,12 +392,12 @@ def get_inline_skeleton(): import importlib.resources - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] skeleton = importlib.resources.read_text(__package__, "skeleton.html") - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] js_code = importlib.resources.read_text(__package__, "code.js") for js_module in ["preact", "htm"]: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] js_lib = importlib.resources.read_binary(__package__, f"{js_module}.mjs") js_url = "data:application/javascript," + urllib.parse.quote(js_lib) js_code = js_code.replace(f"https://unpkg.com/{js_module}?module", js_url) diff --git a/torch/utils/tensorboard/_convert_np.py b/torch/utils/tensorboard/_convert_np.py index 21290a8b0ced..f0e8910580de 100644 --- a/torch/utils/tensorboard/_convert_np.py +++ b/torch/utils/tensorboard/_convert_np.py @@ -31,7 +31,7 @@ def make_np(x: torch.Tensor) -> np.ndarray: def _prepare_pytorch(x: torch.Tensor) -> np.ndarray: if x.dtype == torch.bfloat16: x = x.to(torch.float16) - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] x = x.detach().cpu().numpy() - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] return x diff --git a/torch/utils/tensorboard/_pytorch_graph.py b/torch/utils/tensorboard/_pytorch_graph.py index 31ae14919315..b3ef6a468dca 100644 --- a/torch/utils/tensorboard/_pytorch_graph.py +++ b/torch/utils/tensorboard/_pytorch_graph.py @@ -188,7 +188,7 @@ def populate_namespace_from_OP_to_IO(self): for key, node in self.nodes_io.items(): if type(node) is NodeBase: - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] self.unique_name_to_scoped_name[key] = node.scope + "/" + node.debugName if hasattr(node, "input_or_output"): self.unique_name_to_scoped_name[key] = ( @@ -199,7 +199,7 @@ def populate_namespace_from_OP_to_IO(self): self.unique_name_to_scoped_name[key] = node.scope + "/" + node.debugName if node.scope == "" and self.shallowest_scope_name: self.unique_name_to_scoped_name[node.debugName] = ( - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] self.shallowest_scope_name + "/" + node.debugName ) diff --git a/torch/utils/tensorboard/_utils.py b/torch/utils/tensorboard/_utils.py index 6c44576d4cb7..ac06b8c3986f 100644 --- a/torch/utils/tensorboard/_utils.py +++ b/torch/utils/tensorboard/_utils.py @@ -57,14 +57,14 @@ def is_power2(num): return num != 0 and ((num & (num - 1)) == 0) # pad to nearest power of 2, all at once - # pyrefly: ignore # index-error + # pyrefly: ignore [index-error] if not is_power2(V.shape[0]): - # pyrefly: ignore # index-error + # pyrefly: ignore [index-error] len_addition = int(2 ** V.shape[0].bit_length() - V.shape[0]) V = np.concatenate((V, np.zeros(shape=(len_addition, t, c, h, w))), axis=0) n_rows = 2 ** ((b.bit_length() - 1) // 2) - # pyrefly: ignore # index-error + # pyrefly: ignore [index-error] n_cols = V.shape[0] // n_rows V = np.reshape(V, newshape=(n_rows, n_cols, t, c, h, w)) diff --git a/torch/utils/tensorboard/summary.py b/torch/utils/tensorboard/summary.py index e9322279c963..ae3b6a7a19a5 100644 --- a/torch/utils/tensorboard/summary.py +++ b/torch/utils/tensorboard/summary.py @@ -498,7 +498,7 @@ def make_histogram(values, bins, max_bins=None): subsampling = num_bins // max_bins subsampling_remainder = num_bins % subsampling if subsampling_remainder != 0: - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] counts = np.pad( counts, pad_width=[[0, subsampling - subsampling_remainder]], @@ -838,21 +838,21 @@ def compute_curve(labels, predictions, num_thresholds=None, weights=None): weights = 1.0 # Compute bins of true positives and false positives. - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] bucket_indices = np.int32(np.floor(predictions * (num_thresholds - 1))) float_labels = labels.astype(np.float64) - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] histogram_range = (0, num_thresholds - 1) tp_buckets, _ = np.histogram( bucket_indices, - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] bins=num_thresholds, range=histogram_range, weights=float_labels * weights, ) fp_buckets, _ = np.histogram( bucket_indices, - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] bins=num_thresholds, range=histogram_range, weights=(1.0 - float_labels) * weights, diff --git a/torch/utils/tensorboard/writer.py b/torch/utils/tensorboard/writer.py index 51646362bceb..e100ddb179f6 100644 --- a/torch/utils/tensorboard/writer.py +++ b/torch/utils/tensorboard/writer.py @@ -254,9 +254,9 @@ def __init__( buckets = [] neg_buckets = [] while v < 1e20: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] buckets.append(v) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] neg_buckets.append(-v) v *= 1.1 self.default_bins = neg_buckets[::-1] + [0] + buckets @@ -264,19 +264,19 @@ def __init__( def _get_file_writer(self): """Return the default FileWriter instance. Recreates it if closed.""" if self.all_writers is None or self.file_writer is None: - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self.file_writer = FileWriter( self.log_dir, self.max_queue, self.flush_secs, self.filename_suffix ) - # pyrefly: ignore # bad-assignment, missing-attribute + # pyrefly: ignore [bad-assignment, missing-attribute] self.all_writers = {self.file_writer.get_logdir(): self.file_writer} if self.purge_step is not None: most_recent_step = self.purge_step - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] self.file_writer.add_event( Event(step=most_recent_step, file_version="brain.Event:2") ) - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] self.file_writer.add_event( Event( step=most_recent_step, @@ -1207,7 +1207,7 @@ def close(self): for writer in self.all_writers.values(): writer.flush() writer.close() - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self.file_writer = self.all_writers = None def __enter__(self): diff --git a/torch/utils/viz/_cycles.py b/torch/utils/viz/_cycles.py index f18225d62859..7a63977b861a 100644 --- a/torch/utils/viz/_cycles.py +++ b/torch/utils/viz/_cycles.py @@ -461,7 +461,7 @@ def to_html(nodes): if n.context is None: continue s = _listener_template.format(id=str(i + 1), stack=escape(f'{n.label}:\n{n.context}')) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] listeners.append(s) dot = to_dot(nodes) return _template.replace('$DOT', repr(dot)).replace('$LISTENERS', '\n'.join(listeners))