Skip to content

Commit

Permalink
Save torch encodings with 1.0.0 export
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Hsieh <quic_klhsieh@quicinc.com>
  • Loading branch information
quic-klhsieh committed May 22, 2024
1 parent f7f72d8 commit dd13da4
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 9 deletions.
20 changes: 12 additions & 8 deletions TrainingExtensions/torch/src/python/aimet_torch/quantsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -984,26 +984,30 @@ def _export_encodings_to_files(sim_model: torch.nn.Module, path: str, filename_p
'param_encodings': param_encodings,
'excluded_layers': excluded_layer_names}

encodings_dict_pytorch = {'version': quantsim.encoding_version,
'activation_encodings': activation_encodings_torch,
'param_encodings': param_encodings,
'excluded_layers': excluded_layer_names}

if quantizer_args:
encodings_dict_pytorch.update({'quantizer_args': quantizer_args})
encodings_dict_onnx.update({'quantizer_args': quantizer_args})

logger.info("Layers excluded from quantization: %s", excluded_layer_names)

# export weight encodings to output json file
encoding_file_path = os.path.join(path, filename_prefix + '.encodings')
encoding_file_path_pytorch = os.path.join(path, filename_prefix + '_torch' + '.encodings')
save_json_yaml(encoding_file_path, encodings_dict_onnx)
save_json_yaml(encoding_file_path_pytorch, encodings_dict_pytorch)
else:
_export_to_1_0_0(path, filename_prefix, activation_encodings_onnx, param_encodings, tensor_to_quantizer_map,
excluded_layer_names, quantizer_args)

# Export torch.encodings used for saving/loading common to 0.6.1 and 1.0.0 versions
encodings_dict_pytorch = {'version': quantsim.encoding_version,
'activation_encodings': activation_encodings_torch,
'param_encodings': param_encodings,
'excluded_layers': excluded_layer_names}

if quantizer_args:
encodings_dict_pytorch.update({'quantizer_args': quantizer_args})

encoding_file_path_pytorch = os.path.join(path, filename_prefix + '_torch' + '.encodings')
save_json_yaml(encoding_file_path_pytorch, encodings_dict_pytorch)

@staticmethod
def _get_tensor_to_consumer_map(op_to_io_tensor_map: Dict[str, Dict]) -> Dict[str, str]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from aimet_torch.v2.quantsim import QuantizationSimModel
from aimet_torch.v2.quantization.encoding_analyzer import PercentileEncodingAnalyzer
from aimet_torch.v2.quantization.base import QuantizerBase
from aimet_torch.v2.quantization.affine import AffineQuantizerBase
from aimet_torch.v2.quantization.affine import AffineQuantizerBase, GroupedBlockQuantizeDequantize
from aimet_torch.v2.nn import BaseQuantizationMixin
from ..models_ import test_models

Expand Down Expand Up @@ -719,6 +719,62 @@ def test_legacy_load_encodings_to_disabled_quantizer(self, load_encodings_fn):
with pytest.raises(RuntimeError):
load_encodings_fn(qsim, fname)

def test_save_and_load_gbbq(self):
torch.manual_seed(0)
model = test_models.SingleResidualWithAvgPool()
dummy_input = torch.randn(1, 3, 28, 28)
dummy_input_2 = torch.randn(1, 3, 28, 28)
qsim = QuantizationSimModel(model, dummy_input)
qsim.model.fc.param_quantizers['weight'] = GroupedBlockQuantizeDequantize(shape=(10, 6),
bitwidth=4,
symmetric=True,
decompressed_bw=8,
block_size=(1, 12),
block_grouping=(1, 6))
qsim.compute_encodings(lambda m, _: m(dummy_input), None)
out1 = qsim.model(dummy_input)
with tempfile.TemporaryDirectory() as temp_dir:
qsim.save_encodings_to_json(temp_dir, 'saved_encodings')
qsim.export(temp_dir, 'exported_encodings', dummy_input=dummy_input)

with open(os.path.join(temp_dir, 'saved_encodings.json'), 'r') as enc_file:
encodings = json.load(enc_file)

assert len(encodings['param_encodings']['fc.weight']) == 60

with open(os.path.join(temp_dir, 'exported_encodings_torch.encodings'), 'r') as enc_file:
encodings = json.load(enc_file)

assert len(encodings['param_encodings']['fc.weight']) == 60

old_weight = qsim.model.fc.weight
old_max = qsim.model.fc.param_quantizers['weight'].get_max()[0][0]
qsim.model.fc.weight = torch.nn.Parameter(torch.randn(old_weight.shape))
qsim.compute_encodings(lambda m, _: m(dummy_input_2), None)
assert qsim.model.fc.param_quantizers['weight'].get_max()[0][0] != old_max
out2 = qsim.model(dummy_input)

assert not torch.equal(out1, out2)

# Test loading of encodings saved using save_encodings_to_json
qsim.model.fc.weight = old_weight
qsim.load_encodings(os.path.join(temp_dir, 'saved_encodings.json'))

assert qsim.model.fc.param_quantizers['weight'].get_max()[0][0] == old_max
out3 = qsim.model(dummy_input)
assert torch.equal(out1, out3)

qsim.model.fc.weight = torch.nn.Parameter(torch.randn(old_weight.shape))
qsim.compute_encodings(lambda m, _: m(dummy_input_2), None)

# Test loading of encodings from sim.export
qsim.model.fc.weight = old_weight
qsim.load_encodings(os.path.join(temp_dir, 'exported_encodings_torch.encodings'))

out4 = qsim.model(dummy_input)
assert torch.equal(out1, out4)


class TestQuantsimUtilities:

def test_populate_marker_map(self):
Expand Down

0 comments on commit dd13da4

Please sign in to comment.