Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torch/ao/ns/fx/pattern_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def get_reversed_fusions() -> List[Tuple[NSFusionType, int]]:
results.append((new_pattern, default_base_op_idx)) # type: ignore[arg-type]


# After this point, results countains values such as
# After this point, results contains values such as
# [..., ((torch.nn.Relu, torch.nn.Conv2d), 0), ...]

# Patterns for matching fp16 emulation are not specified in the quantization
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def hook(module, input):

def squash_mask(self, attach_sparsify_hook=True, **kwargs):
"""
Unregisters aggreagate hook that was applied earlier and registers sparsification hooks if
Unregisters aggregate hook that was applied earlier and registers sparsification hooks if
attach_sparsify_hook = True.
"""
for name, configs in self.data_groups.items():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def step(self):
elif self.data_sparsifier._step_count < 1: # type: ignore[attr-defined]
warnings.warn("Detected call of `scheduler.step()` before `data_sparsifier.step()`. "
"You have to make sure you run the data_sparsifier.step() BEFORE any "
"calls to the scheduer.step().", UserWarning)
"calls to the scheduler.step().", UserWarning)
self._step_count += 1

class _enable_get_sp_call:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def update_mask(self, name, data, sparsity_level,
else:
data_norm = (data * data).squeeze() # square every element for L2

if len(data_norm.shape) > 2: # only supports 2 dimenstional data at the moment
if len(data_norm.shape) > 2: # only supports 2 dimensional data at the moment
raise ValueError("only supports 2-D at the moment")

elif len(data_norm.shape) == 1: # in case the data is bias (or 1D)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

def _attach_model_to_data_sparsifier(module, data_sparsifier, config=None):
"""Attaches a data sparsifier to all the layers of the module.
Essentialy, loop over all the weight parameters in the module and
Essentially, loop over all the weight parameters in the module and
attach it to the data sparsifier.
Note::
The '.' in the layer names are replaced with '_' (refer to _get_valid_name() below)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def prune(self) -> None:
continue

first_module = modules.get(node.target)
# check if first module exists and has apropriate parameterization, otherwise skip
# check if first module exists and has appropriate parameterization, otherwise skip
if (
first_module is not None
and parametrize.is_parametrized(first_module)
Expand Down
4 changes: 2 additions & 2 deletions torch/ao/pruning/_experimental/pruner/lstm_saliency_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ class LSTMSaliencyPruner(BaseStructuredSparsifier):
- weight_ih_l{k}
- weight_hh_l{k}

These tensors pack the weights for the 4 linear layers together for efficency.
These tensors pack the weights for the 4 linear layers together for efficiency.

[W_ii | W_if | W_ig | W_io]

Pruning this tensor directly will lead to weights being misassigned when unpacked.
To ensure that each packed linear layer is pruned the same amount:
1. We split the packed weight into the 4 constitutient linear parts
1. We split the packed weight into the 4 constituent linear parts
2. Update the mask for each individual piece using saliency individually

This applies to both weight_ih_l{k} and weight_hh_l{k}.
Expand Down
8 changes: 4 additions & 4 deletions torch/ao/pruning/_experimental/pruner/prune_functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
Collection of conversion functions for linear / conv2d structured pruning
Also contains utilities for bias propogation
Also contains utilities for bias propagation
"""
from typing import cast, Optional, Callable, Tuple

Expand All @@ -10,7 +10,7 @@
from torch.nn.utils.parametrize import ParametrizationList
from .parametrization import FakeStructuredSparsity, BiasHook

# BIAS PROPOGATION
# BIAS PROPAGATION
def _remove_bias_handles(module: nn.Module) -> None:
if hasattr(module, "_forward_hooks"):
bias_hooks = []
Expand Down Expand Up @@ -86,15 +86,15 @@ def _prune_module_bias(module: nn.Module, mask: Tensor) -> None:

def _propogate_module_bias(module: nn.Module, mask: Tensor) -> Optional[Tensor]:
r"""
In the case that we need to propogate biases, this function will return the biases we need
In the case that we need to propagate biases, this function will return the biases we need
"""
# set current module bias
if module.bias is not None:
module.bias = nn.Parameter(cast(Tensor, module.bias)[mask])
elif getattr(module, "_bias", None) is not None:
module.bias = nn.Parameter(cast(Tensor, module._bias)[mask])

# get pruned biases to propogate to subsequent layer
# get pruned biases to propagate to subsequent layer
if getattr(module, "_bias", None) is not None:
pruned_biases = cast(Tensor, module._bias)[~mask]
else:
Expand Down
2 changes: 1 addition & 1 deletion torch/ao/pruning/scheduler/base_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def step(self, epoch=None):
elif self.sparsifier._step_count < 1: # type: ignore[attr-defined]
warnings.warn("Detected call of `scheduler.step()` before `sparsifier.step()`. "
"You have to make sure you run the sparsifier.step() BEFORE any "
"calls to the scheduer.step().", UserWarning)
"calls to the scheduler.step().", UserWarning)
self._step_count += 1

class _enable_get_sl_call:
Expand Down
2 changes: 1 addition & 1 deletion torch/ao/pruning/sparsifier/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def get_arg_info_from_tensor_fqn(model: nn.Module, tensor_fqn: str) -> Dict[str,
# Parametrizations
class FakeSparsity(nn.Module):
r"""Parametrization for the weights. Should be attached to the 'weight' or
any other parmeter that requires a mask applied to it.
any other parameter that requires a mask applied to it.

Note::

Expand Down
2 changes: 1 addition & 1 deletion torch/ao/quantization/_equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def converged(curr_modules, prev_modules, threshold=1e-4):
being less than the given threshold

Takes two dictionaries mapping names to modules, the set of names for each dictionary
should be the same, looping over the set of names, for each name take the differnce
should be the same, looping over the set of names, for each name take the difference
between the associated modules in each dictionary

'''
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

__all__: List[str] = []

# TODO: rename to be more explict, e.g. qat_conv_relu
# TODO: rename to be more explicit, e.g. qat_conv_relu
_ConvMetadata = namedtuple(
"_ConvMetadata",
["root", "transpose", "bn", "reference", "transpose_reference",
Expand Down
2 changes: 1 addition & 1 deletion torch/ao/quantization/fuse_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch.nn as nn

from torch.ao.quantization.fuser_method_mappings import get_fuser_method
# for backward compatiblity
# for backward compatibility
from torch.ao.quantization.fuser_method_mappings import fuse_conv_bn # noqa: F401
from torch.ao.quantization.fuser_method_mappings import fuse_conv_bn_relu # noqa: F401
from torch.nn.utils.parametrize import type_before_parametrizations
Expand Down
2 changes: 1 addition & 1 deletion torch/ao/quantization/fuser_method_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def _get_valid_patterns(op_pattern):
def get_fuser_method_new(
op_pattern: Pattern,
fuser_method_mapping: Dict[Pattern, Union[nn.Sequential, Callable]]):
""" This will be made defult after we deprecate the get_fuser_method
""" This will be made default after we deprecate the get_fuser_method
Would like to implement this first and have a separate PR for deprecation
"""
op_patterns = _get_valid_patterns(op_pattern)
Expand Down
16 changes: 8 additions & 8 deletions torch/ao/quantization/fx/_decomposed.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,14 @@ def quantize_per_tensor_tensor(
Same as `quantize_per_tensor` but scale and zero_point are Scalar Tensor instead of
scalar values
"""
assert zero_point.numel() == 1, f"Exepecting zero_point tensor to be one element, but received : {zero_point.numel()}"
assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}"
assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
assert scale.numel() == 1, f"Expecting scale tensor to be one element, but received : {scale.numel()}"
return quantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype)

@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor", "Meta")
def quantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype):
assert zero_point.numel() == 1, f"Exepecting zero_point tensor to be one element, but received : {zero_point.numel()}"
assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}"
assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
assert scale.numel() == 1, f"Expecting scale tensor to be one element, but received : {scale.numel()}"
assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
return torch.empty_like(input, dtype=dtype)
Expand Down Expand Up @@ -161,14 +161,14 @@ def dequantize_per_tensor_tensor(
Same as `dequantize_per_tensor` but scale and zero_point are Scalar Tensor instead of
scalar values
"""
assert zero_point.numel() == 1, f"Exepecting zero_point tensor to be one element, but received : {zero_point.numel()}"
assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}"
assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
assert scale.numel() == 1, f"Expecting scale tensor to be one element, but received : {scale.numel()}"
return dequantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype)

@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor", "Meta")
def dequantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype):
assert zero_point.numel() == 1, f"Exepecting zero_point tensor to be one element, but received : {zero_point.numel()}"
assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}"
assert zero_point.numel() == 1, f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
assert scale.numel() == 1, f"Expecting scale tensor to be one element, but received : {scale.numel()}"
assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}"
if dtype in [torch.uint8, torch.int8, torch.int32]:
return torch.empty_like(input, dtype=torch.float32)
Expand Down
4 changes: 2 additions & 2 deletions torch/ao/quantization/fx/_equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,11 +280,11 @@ def get_op_node_and_weight_eq_obs(
""" Gets the following weight equalization observer. There should always
exist a weight equalization observer after an input equalization observer.

Returns the operation node that follows the input equalizatoin observer node
Returns the operation node that follows the input equalization observer node
and the weight equalization observer
"""

# Find the op node that comes directly after the input equaliation observer
# Find the op node that comes directly after the input equalization observer
op_node = None
for user in input_eq_obs_node.users.keys():
if node_supports_equalization(user, modules):
Expand Down
4 changes: 2 additions & 2 deletions torch/ao/quantization/fx/_lower_to_native_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,7 +755,7 @@ def _lower_weight_only_weighted_ref_module(model: GraphModule):
# TODO: maybe define a WeightedWeightOnlyQuantizedModule
q_module = q_class.from_reference(ref_module) # type: ignore[union-attr]

# replace reference moduel with dynamically quantized module
# replace reference module with dynamically quantized module
parent_name, module_name = _parent_name(ref_node.target)
setattr(named_modules[parent_name], module_name, q_module)

Expand Down Expand Up @@ -981,7 +981,7 @@ def _lower_quantized_binary_op(
assert bop_node.target in QBIN_OP_MAPPING
binop_to_qbinop = QBIN_OP_MAPPING if relu_node is None else QBIN_RELU_OP_MAPPING
qbin_op = binop_to_qbinop[bop_node.target]
# prepare the args for quantized bianry op
# prepare the args for quantized binary op
# (x, y)
qop_node_args = list(bop_node.args)
# (x, y, scale, zero_point)
Expand Down
Loading