Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable FunctionHookMode tracing by default #3338

Draft
wants to merge 13 commits into
base: develop
Choose a base branch
from
Original file line number Diff line number Diff line change
@@ -75,7 +75,7 @@ ov_quantized_model = ov.convert_model(stripped_model)

The complete information about compression is defined by a compressed model and a NNCF config.
The model characterizes the weights and topology of the network. The NNCF config - how to restore additional modules introduced by NNCF.
The NNCF config can be obtained by `quantized_model.nncf.get_config()` on saving and passed to the
The NNCF config can be obtained by `nncf.torch.get_config(quantized_model)` on saving and passed to the
`nncf.torch.load_from_config` helper function to load additional modules from the given NNCF config.
The quantized model saving allows to load quantized modules to the target model in a new python process and
requires only example input for the target module, corresponding NNCF config and the quantized model state dict.
@@ -84,8 +84,8 @@ requires only example input for the target module, corresponding NNCF config and
# save part
quantized_model = nncf.quantize(model, calibration_dataset)
checkpoint = {
'state_dict':quantized_model.state_dict(),
'nncf_config': quantized_model.nncf.get_config(),
'state_dict': quantized_model.state_dict(),
'nncf_config': nncf.torch.get_config(quantized_model),
...
}
torch.save(checkpoint, path)
@@ -96,7 +96,7 @@ resuming_checkpoint = torch.load(path)
nncf_config = resuming_checkpoint['nncf_config']
state_dict = resuming_checkpoint['state_dict']

quantized_model = nncf.torch.load_from_config(model, nncf_config, dummy_input)
quantized_model = nncf.torch.load_from_config(model, nncf_config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update the main Readme:

  • using nncf.torch.get_config() / nncf.torch.load_from_config() in Training-Time Quantization
  • add note that create_compressed API is deprecated in Training-Time Compression.

OpenVINO documentation should be also update regarding the changes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using nncf.torch.get_config() / nncf.torch.load_from_config() in Training-Time Quantization

done

add note that create_compressed API is deprecated in Training-Time Compression.

we have already have deprecation message in create_compressed sine previous release, what a reason to duplicate it in readme now?

quantized_model.load_state_dict(state_dict)
```

Original file line number Diff line number Diff line change
@@ -25,7 +25,6 @@
from PIL import Image
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from torchvision.models.detection.ssd import SSD
from torchvision.models.detection.ssd import GeneralizedRCNNTransform
from torchvision.models.detection.anchor_utils import DefaultBoxGenerator
from rich.progress import track
from functools import partial
@@ -141,7 +140,6 @@ def main():
model.eval()

# Disable NNCF tracing for some methods in order for the model to be properly traced by NNCF
disable_tracing(GeneralizedRCNNTransform.normalize)
disable_tracing(SSD.postprocess_detections)
disable_tracing(DefaultBoxGenerator.forward)

10 changes: 5 additions & 5 deletions examples/quantization_aware_training/torch/resnet18/main.py
Original file line number Diff line number Diff line change
@@ -34,6 +34,8 @@
import nncf
import nncf.torch
from nncf.common.utils.helpers import create_table
from nncf.torch import get_config
from nncf.torch import load_from_config

warnings.filterwarnings("ignore", category=TracerWarning)
warnings.filterwarnings("ignore", category=UserWarning)
@@ -278,11 +280,11 @@ def transform_fn(data_item):
print(f"Train epoch: {epoch}")
train_epoch(train_loader, quantized_model, criterion, optimizer, device=device)
acc1_int8 = validate(val_loader, quantized_model, device)
print(f"Accyracy@1 of INT8 model after {epoch} epoch finetuning: {acc1_int8:.3f}")
print(f"Accuracy@1 of INT8 model after {epoch} epoch finetuning: {acc1_int8:.3f}")
# Save the compression checkpoint for model with the best accuracy metric.
if acc1_int8 > acc1_int8_best:
state_dict = quantized_model.state_dict()
compression_config = quantized_model.nncf.get_config()
compression_config = get_config(quantized_model)
torch.save(
{
"model_state_dict": state_dict,
@@ -294,9 +296,7 @@ def transform_fn(data_item):

# Load quantization modules and parameters from best checkpoint to the source model.
ckpt = torch.load(ROOT / BEST_CKPT_NAME, weights_only=False)
quantized_model = nncf.torch.load_from_config(
deepcopy(model), ckpt["compression_config"], torch.ones((1, 3, IMAGE_SIZE, IMAGE_SIZE)).to(device)
)
quantized_model = load_from_config(deepcopy(model), ckpt["compression_config"])
quantized_model.load_state_dict(ckpt["model_state_dict"])

# Evaluate on validation set after Quantization-Aware Training (QAT case).
10 changes: 5 additions & 5 deletions nncf/common/factory.py
Original file line number Diff line number Diff line change
@@ -20,7 +20,7 @@
from nncf.common.utils.backend import BackendType
from nncf.common.utils.backend import get_backend
from nncf.data.dataset import Dataset
from nncf.experimental.common.check_feature import is_experimental_torch_tracing_enabled
from nncf.experimental.common.check_feature import is_torch_tracing_by_torch_function_mode

TModel = TypeVar("TModel")

@@ -90,13 +90,13 @@ def create(model: TModel, inplace: bool = False) -> ModelTransformer[Any]:
from nncf.openvino.graph.model_transformer import OVModelTransformer

return OVModelTransformer(cast(Model, model), inplace=inplace)
if model_backend == BackendType.TORCH and is_experimental_torch_tracing_enabled():
if model_backend == BackendType.TORCH and is_torch_tracing_by_torch_function_mode():
from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper
from nncf.experimental.torch2.model_transformer import PT2ModelTransformer

return PT2ModelTransformer(cast(GraphModelWrapper, model))

if model_backend == BackendType.TORCH and not is_experimental_torch_tracing_enabled():
if model_backend == BackendType.TORCH and not is_torch_tracing_by_torch_function_mode():
from nncf.torch.model_transformer import PTModelTransformer
from nncf.torch.nncf_network import NNCFNetwork

@@ -191,11 +191,11 @@ def create(model: TModel, dataset: Dataset) -> aggregator.StatisticsAggregator:
from nncf.openvino.statistics.aggregator import OVStatisticsAggregator

return OVStatisticsAggregator(dataset)
if model_backend == BackendType.TORCH and not is_experimental_torch_tracing_enabled():
if model_backend == BackendType.TORCH and not is_torch_tracing_by_torch_function_mode():
from nncf.torch.statistics.aggregator import PTStatisticsAggregator

return PTStatisticsAggregator(dataset)
if model_backend == BackendType.TORCH and is_experimental_torch_tracing_enabled():
if model_backend == BackendType.TORCH and is_torch_tracing_by_torch_function_mode():
from nncf.experimental.torch2.statistics.aggregator import PT2StatisticsAggregator

return PT2StatisticsAggregator(dataset)
4 changes: 2 additions & 2 deletions nncf/common/utils/backend.py
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@
from packaging import version

import nncf
from nncf.experimental.common.check_feature import is_experimental_torch_tracing_enabled
from nncf.experimental.common.check_feature import is_torch_tracing_by_torch_function_mode

try:
import openvino # type: ignore # noqa: F401
@@ -58,7 +58,7 @@ def is_torch_model(model: Any) -> bool:

from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper

if is_experimental_torch_tracing_enabled():
if is_torch_tracing_by_torch_function_mode():
return isinstance(model, (GraphModelWrapper, torch.nn.Module)) and not isinstance(model, torch.fx.GraphModule)

return not isinstance(model, torch.fx.GraphModule) and isinstance(model, torch.nn.Module)
11 changes: 7 additions & 4 deletions nncf/experimental/common/check_feature.py
Original file line number Diff line number Diff line change
@@ -12,10 +12,13 @@
import os


def is_experimental_torch_tracing_enabled() -> bool:
def is_torch_tracing_by_torch_function_mode() -> bool:
"""
Checks if experimental torch tracing is enabled by environment variable NNCF_EXPERIMENTAL_TORCH_TRACING.
Checks if legacy torch tracing is enabled by environment variable NNCF_TORCH_LEGACY_TRACING.

:return: True if experimental torch tracing is enabled, False otherwise.
True - will wrap model by NNCFNetwork and patch function in torch namespace.
False - will use FunctionHookMode without patching torch namespace.

:return: True if legacy torch tracing is enabled, False otherwise.
"""
return os.getenv("NNCF_EXPERIMENTAL_TORCH_TRACING") is not None
return os.getenv("NNCF_TORCH_LEGACY_TRACING", "").lower() not in ["1", "on", "true"]
15 changes: 15 additions & 0 deletions nncf/experimental/torch2/function_hook/hook_executor_mode.py
Original file line number Diff line number Diff line change
@@ -28,6 +28,7 @@
from torch import Tensor
from torch import nn
from torch.overrides import TorchFunctionMode
from torch.overrides import _get_current_function_mode_stack

from nncf.common.logging import nncf_logger as logger
from nncf.experimental.torch2.function_hook.handle_inner_functions import get_handle_inner_function
@@ -517,3 +518,17 @@ def disable(self) -> Iterator[None]:
self.enabled = False
yield
self.enabled = ret


@contextmanager
def disable_function_hook_mode() -> Iterator[None]:
"""
Temporarily disables the function tracing and execution hooks within a context.
"""
enabled_modes = _get_current_function_mode_stack() # type: ignore[no-untyped-call]
state = {(mode, mode.enabled) for mode in enabled_modes if isinstance(mode, FunctionHookMode)}
for mode, _ in state:
mode.enabled = False
yield
for mode, enabled in state:
mode.enabled = enabled
11 changes: 11 additions & 0 deletions nncf/experimental/torch2/function_hook/serialization.py
Original file line number Diff line number Diff line change
@@ -15,10 +15,13 @@
from torch import nn

import nncf
from nncf.common.logging import nncf_logger
from nncf.experimental.torch2.function_hook.wrapper import get_hook_storage
from nncf.experimental.torch2.function_hook.wrapper import wrap_model
from nncf.torch.layer_utils import COMPRESSION_MODULES
from nncf.torch.layer_utils import StatefulModuleInterface
from nncf.torch.utils import get_model_device
from nncf.torch.utils import is_multidevice

COMPRESSION_STATE_ATTR = "compression_state"
TModel = TypeVar("TModel", bound=nn.Module)
@@ -101,11 +104,19 @@ def load_from_config(model: TModel, config: Dict[str, Any]) -> TModel:
:return: The compressed model.
"""
wrapped_model = wrap_model(model)

device = None
if not is_multidevice(wrapped_model):
device = get_model_device(wrapped_model)
else:
nncf_logger.warning("Model is on multiple devices. Cannot determine device for loaded modules.")

hook_storage = get_hook_storage(wrapped_model)
transformation_commands = cast(List[S_COMMAND], config[COMPRESSION_STATE_ATTR])
for command in transformation_commands:
module_cls = COMPRESSION_MODULES.get(command["module_cls_name"])
module = module_cls.from_config(command["module_config"])
module.to(device)
Comment on lines +108 to +119
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I encountered the same problem with the model on various devices:
https://github.com/openvinotoolkit/nncf/pull/3341/files#diff-f257219eb5631518ca7ec37bb9858c77491049dc999ca14d845a46517f263900R135-R136

Is there a reliable method to match the compression module from the insertion command to the target layer and its parameters? With this matching, it would be possible to determine the device.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently we have limited support for multidevice models.
For not multi device model it works.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Multidevice case is very popular for LLMs. @AlexanderDokuchaev, do you have any idea how torch2 will support it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not look at it. Currently multidevice support for new tracing works similar like for current.

for target_name in command["hook_names_in_model"]:
hook_type, hook_key, hook_id = target_name.split(".")
storage_dict = getattr(hook_storage, hook_type)
9 changes: 9 additions & 0 deletions nncf/experimental/torch2/model_transformer.py
Original file line number Diff line number Diff line change
@@ -28,6 +28,8 @@
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.model_graph_manager import set_const_data
from nncf.torch.model_graph_manager import update_fused_bias
from nncf.torch.utils import get_model_device
from nncf.torch.utils import is_multidevice

TRANSFORMATION_PAIRS = Tuple[Tuple[Type[Any], Callable[[GraphModelWrapper, List[Any]], GraphModelWrapper]], ...]

@@ -84,11 +86,18 @@ def _apply_insertion_transformations(
:param wrapped_model: Model to apply transformations.
:param command: Insertion transformation command.
"""
device = None
if not is_multidevice(self._model.model):
device = get_model_device(self._model.model)

for command in transformations:
target_points = command.target_points
hook_module = command.hook_module
handle_storage = command.handle_storage

if device is not None:
hook_module.to(device)

for target_point in target_points:
handle = insert_hook(wrapped_model.model, hook_module, target_point)
if handle_storage is not None:
8 changes: 4 additions & 4 deletions nncf/quantization/algorithms/min_max/torch_backend.py
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@
from nncf.common.graph.transformations.commands import TransformationCommand
from nncf.common.hardware.config import HWConfig
from nncf.common.quantization.structs import QuantizerConfig
from nncf.experimental.common.check_feature import is_experimental_torch_tracing_enabled
from nncf.experimental.common.check_feature import is_torch_tracing_by_torch_function_mode
from nncf.experimental.common.tensor_statistics.collectors import REDUCERS_MAP
from nncf.experimental.common.tensor_statistics.collectors import TensorReducerBase
from nncf.experimental.torch2.commands import PT2InsertionCommand
@@ -144,7 +144,7 @@ def target_point(target_type: TargetType, target_node_name: str, port_id: int) -
if NNCFGraphNodeType.INPUT_NODE in target_node_name or target_type == TargetType.POST_LAYER_OPERATION:
input_port_id = None
if (
not is_experimental_torch_tracing_enabled()
not is_torch_tracing_by_torch_function_mode()
and target_type in PTMinMaxAlgoBackend.TARGET_TYPE_TO_PT_INS_TYPE_MAP
):
target_type = PTMinMaxAlgoBackend.TARGET_TYPE_TO_PT_INS_TYPE_MAP[target_type]
@@ -271,7 +271,7 @@ def create_quantizer_insertion_command(
quantizer = PTMinMaxAlgoBackend._create_quantizer(
quantizer_config, scale_shape, parameters, target_point.target_type
)
if is_experimental_torch_tracing_enabled():
if is_torch_tracing_by_torch_function_mode():
return PT2InsertionCommand(target_points=[target_point], hook_module=quantizer)

return create_quantizer_insertion_command(target_point, quantizer)
@@ -290,7 +290,7 @@ def create_unified_scales_quantizers_insertion_commands(
quantizer = PTMinMaxAlgoBackend._create_quantizer(
quantizer_config, scale_shape, parameters, target_points[0].target_type
)
if is_experimental_torch_tracing_enabled():
if is_torch_tracing_by_torch_function_mode():
return [PT2InsertionCommand(target_points=target_points, hook_module=quantizer)]
return [create_shared_quantizer_insertion_command(target_points, quantizer)]

8 changes: 4 additions & 4 deletions nncf/quantization/algorithms/smooth_quant/torch_backend.py
Original file line number Diff line number Diff line change
@@ -20,7 +20,7 @@
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.quantization.quantizer_propagation.structs import QuantizationTrait
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
from nncf.experimental.common.check_feature import is_experimental_torch_tracing_enabled
from nncf.experimental.common.check_feature import is_torch_tracing_by_torch_function_mode
from nncf.experimental.common.tensor_statistics.collectors import AbsMaxReducer
from nncf.experimental.common.tensor_statistics.collectors import MaxAggregator
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
@@ -136,7 +136,7 @@ def get_weight_value(node_with_weight: NNCFNode, model: NNCFNetwork, nncf_graph:
def weight_update_command(
node_with_weight: NNCFNode, nncf_graph: NNCFGraph, weight_value: torch.Tensor
) -> PTWeightUpdateCommand:
if is_experimental_torch_tracing_enabled():
if is_torch_tracing_by_torch_function_mode():
weight_node = get_const_node(node_with_weight, node_with_weight.metatype.weight_port_ids[0], nncf_graph)
return PT2ConstUpdateCommand(weight_node, weight_value)
return create_command_to_update_weight(node_with_weight, weight_value)
@@ -157,7 +157,7 @@ def scale_insertion_command(
sq_multiply = SQMultiply(scale_value.shape)
sq_multiply.scale = scale_value

if is_experimental_torch_tracing_enabled():
if is_torch_tracing_by_torch_function_mode():
return PT2InsertionCommand(target_points=target_points, hook_module=sq_multiply)
return PTSharedFnInsertionCommand(target_points, sq_multiply, scale_node_name)

@@ -175,7 +175,7 @@ def get_weight_channel_axis(node: NNCFNode) -> int:

@staticmethod
def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
if is_experimental_torch_tracing_enabled():
if is_torch_tracing_by_torch_function_mode():
weight_node = get_const_node(node, node.metatype.weight_port_ids[0], nncf_graph)
output_edges = nncf_graph.get_next_nodes(weight_node)
return len(output_edges) > 1
Original file line number Diff line number Diff line change
@@ -26,7 +26,7 @@
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.common.quantization.structs import QuantizationScheme
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
from nncf.experimental.common.check_feature import is_experimental_torch_tracing_enabled
from nncf.experimental.common.check_feature import is_torch_tracing_by_torch_function_mode
from nncf.experimental.common.tensor_statistics.collectors import MaxVarianceReducer
from nncf.experimental.common.tensor_statistics.collectors import MeanAbsMaxReducer
from nncf.experimental.common.tensor_statistics.collectors import MeanAggregator
@@ -232,7 +232,7 @@ def get_weight_shape(node_with_weight: NNCFNode, weight_port_id: int, graph: NNC
def set_weight(
self, node_with_weight: NNCFNode, weight_port_id: int, model: torch.nn.Module, graph: NNCFGraph, weight: Tensor
):
if is_experimental_torch_tracing_enabled():
if is_torch_tracing_by_torch_function_mode():
weight_node = get_const_node(node_with_weight, weight_port_id, graph)
module_name, weight_attr_name = split_const_name(weight_node.layer_attributes.name)
module = get_module_by_name(module_name, model.model)
@@ -431,7 +431,7 @@ def get_dq_insertion_command(
weight.requires_grad = False
weight.data = packed_tensor

if is_experimental_torch_tracing_enabled():
if is_torch_tracing_by_torch_function_mode():
return PT2InsertionCommand(
[
PTTargetPoint(
@@ -542,7 +542,7 @@ def scale_insertion_command(
sq_multiply = SQMultiply(scale.shape)
sq_multiply.scale = scale

if is_experimental_torch_tracing_enabled():
if is_torch_tracing_by_torch_function_mode():
return PT2InsertionCommand(target_points, sq_multiply)
scale_node_name = f"{source_node.node_name}/awq_mul"
return PTSharedFnInsertionCommand(target_points, sq_multiply, scale_node_name)
4 changes: 2 additions & 2 deletions nncf/quantization/quantize_model.py
Original file line number Diff line number Diff line change
@@ -21,7 +21,7 @@
from nncf.common.utils.backend import BackendType
from nncf.common.utils.backend import get_backend
from nncf.data import Dataset
from nncf.experimental.common.check_feature import is_experimental_torch_tracing_enabled
from nncf.experimental.common.check_feature import is_torch_tracing_by_torch_function_mode
from nncf.parameters import BackupMode
from nncf.parameters import CompressionFormat
from nncf.parameters import CompressWeightsMode
@@ -232,7 +232,7 @@ def quantize(
)

if backend == BackendType.TORCH:
if is_experimental_torch_tracing_enabled():
if is_torch_tracing_by_torch_function_mode():
from nncf.experimental.torch2.quantization.quantize_model import quantize_impl
else:
from nncf.torch.quantization.quantize_model import quantize_impl
1 change: 1 addition & 0 deletions nncf/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -50,6 +50,7 @@
from nncf.torch.model_creation import is_wrapped_model
from nncf.torch.model_creation import wrap_model
from nncf.torch.model_creation import load_from_config
from nncf.experimental.torch2.function_hook.serialization import get_config
from nncf.torch.checkpoint_loading import load_state
from nncf.torch.initialization import register_default_init_args
from nncf.torch.layers import register_module
Loading
Oops, something went wrong.