Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PT] change api for get_config and load_from_config #3359

Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -245,6 +245,7 @@ Here is an example of Accuracy Aware Quantization pipeline where model weights a

```python
import nncf
import nncf.torch
import torch
from torchvision import datasets, models

@@ -271,7 +272,7 @@ quantized_model = nncf.quantize(model, calibration_dataset)
# Save quantization modules and the quantized model parameters
checkpoint = {
'state_dict': model.state_dict(),
'nncf_config': model.nncf.get_config(),
'nncf_config': nncf.torch.get_config(model),
... # the rest of the user-defined objects to save
}
torch.save(checkpoint, path_to_checkpoint)
Original file line number Diff line number Diff line change
@@ -75,17 +75,19 @@ ov_quantized_model = ov.convert_model(stripped_model)

The complete information about compression is defined by a compressed model and a NNCF config.
The model characterizes the weights and topology of the network. The NNCF config - how to restore additional modules introduced by NNCF.
The NNCF config can be obtained by `quantized_model.nncf.get_config()` on saving and passed to the
The NNCF config can be obtained by `nncf.torch.get_config` on saving and passed to the
`nncf.torch.load_from_config` helper function to load additional modules from the given NNCF config.
The quantized model saving allows to load quantized modules to the target model in a new python process and
requires only example input for the target module, corresponding NNCF config and the quantized model state dict.

```python
import nncf.torch

# save part
quantized_model = nncf.quantize(model, calibration_dataset)

checkpoint = {
'state_dict':quantized_model.state_dict(),
'nncf_config': quantized_model.nncf.get_config(),
'state_dict': quantized_model.state_dict(),
'nncf_config': nncf.torch.get_config(quantized_model),
...
}
torch.save(checkpoint, path)
4 changes: 2 additions & 2 deletions examples/quantization_aware_training/torch/resnet18/main.py
Original file line number Diff line number Diff line change
@@ -278,11 +278,11 @@ def transform_fn(data_item):
print(f"Train epoch: {epoch}")
train_epoch(train_loader, quantized_model, criterion, optimizer, device=device)
acc1_int8 = validate(val_loader, quantized_model, device)
print(f"Accyracy@1 of INT8 model after {epoch} epoch finetuning: {acc1_int8:.3f}")
print(f"Accuracy@1 of INT8 model after {epoch} epoch finetuning: {acc1_int8:.3f}")
# Save the compression checkpoint for model with the best accuracy metric.
if acc1_int8 > acc1_int8_best:
state_dict = quantized_model.state_dict()
compression_config = quantized_model.nncf.get_config()
compression_config = nncf.torch.get_config(quantized_model)
torch.save(
{
"model_state_dict": state_dict,
1 change: 1 addition & 0 deletions nncf/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -50,6 +50,7 @@
from nncf.torch.model_creation import is_wrapped_model
from nncf.torch.model_creation import wrap_model
from nncf.torch.model_creation import load_from_config
from nncf.torch.model_creation import get_config
from nncf.torch.checkpoint_loading import load_state
from nncf.torch.initialization import register_default_init_args
from nncf.torch.layers import register_module
35 changes: 31 additions & 4 deletions nncf/torch/model_creation.py
Original file line number Diff line number Diff line change
@@ -28,6 +28,8 @@
from nncf.config.extractors import has_input_info_field
from nncf.config.telemetry_extractors import CompressionStartedFromConfig
from nncf.experimental.common.check_feature import is_experimental_torch_tracing_enabled
from nncf.experimental.torch2.function_hook.serialization import get_config as pt2_get_config
from nncf.experimental.torch2.function_hook.serialization import load_from_config as pt2_load_from_config
from nncf.telemetry import tracked_function
from nncf.telemetry.events import NNCF_PT_CATEGORY
from nncf.telemetry.extractors import FunctionCallTelemetryExtractor
@@ -397,18 +399,43 @@ def is_wrapped_model(model: Any) -> bool:
FunctionCallTelemetryExtractor("nncf.torch.load_from_config"),
],
)
def load_from_config(model: torch.nn.Module, config: Dict[str, Any], example_input: Any) -> NNCFNetwork:
def load_from_config(model: Module, config: Dict[str, Any], example_input: Optional[Any] = None) -> Module:
"""
Wraps given model to a NNCFNetwork and recovers additional modules from given NNCFNetwork config.
Wraps given model and recovers additional modules from given config.
Does not recover additional modules weights as they are located in a corresponded state_dict.

:param model: PyTorch model.
:param config: NNCNetwork config.
:param example_input: An example input that will be used for model tracing. A tuple is interpreted
as an example input of a set of non keyword arguments, and a dict as an example input of a set
of keywords arguments.
:return: NNCFNetwork builded from given model with additional modules recovered from given NNCFNetwork config.
of keywords arguments. Required with enabled legacy tracing mode.
:return: Wrapped model with additional modules recovered from given config.
"""
if is_experimental_torch_tracing_enabled():
return pt2_load_from_config(model, config)

if example_input is None:
msg = "The 'example_input' parameter must be specified."
raise nncf.InternalError(msg)

nncf_network = wrap_model(model, example_input, trace_parameters=config[NNCFNetwork.TRACE_PARAMETERS_KEY])
transformation_layout = deserialize_transformations(config)
return PTModelTransformer(nncf_network).transform(transformation_layout)


@tracked_function(
NNCF_PT_CATEGORY,
[
FunctionCallTelemetryExtractor("nncf.torch.get_config"),
],
)
def get_config(model: Module) -> Dict[str, Any]:
"""
Returns the configuration object of the compressed model.

:param model: The compressed model.
:return: The configuration object of the compressed model.
"""
if is_experimental_torch_tracing_enabled():
return pt2_get_config(model)
return model.nncf.get_config()
4 changes: 2 additions & 2 deletions tests/torch2/function_hook/test_serialization.py
Original file line number Diff line number Diff line change
@@ -19,8 +19,8 @@
from nncf.experimental.torch2.function_hook import register_post_function_hook
from nncf.experimental.torch2.function_hook import register_pre_function_hook
from nncf.experimental.torch2.function_hook import wrap_model
from nncf.experimental.torch2.function_hook.serialization import get_config
from nncf.experimental.torch2.function_hook.serialization import load_from_config
from nncf.torch import get_config
from nncf.torch import load_from_config
from tests.torch2.function_hook.helpers import HookWithState
from tests.torch2.function_hook.helpers import SimpleModel