Skip to content

Commit

Permalink
switch default IMX500 TPC to v2
Browse files Browse the repository at this point in the history
  • Loading branch information
elad-c committed Apr 16, 2024
1 parent 932a8cb commit 553ce03
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 13 deletions.
Expand Up @@ -13,12 +13,12 @@
# limitations under the License.
# ==============================================================================
from model_compression_toolkit.constants import FOUND_TF, FOUND_TORCH
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tp_model import get_tp_model, generate_tp_model, \
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2.tp_model import get_tp_model, generate_tp_model, \
get_op_quantization_configs
if FOUND_TF:
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tpc_keras import get_keras_tpc as get_keras_tpc_latest
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tpc_keras import generate_keras_tpc
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2.tpc_keras import get_keras_tpc as get_keras_tpc_latest
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2.tpc_keras import generate_keras_tpc
if FOUND_TORCH:
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tpc_pytorch import get_pytorch_tpc as \
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2.tpc_pytorch import get_pytorch_tpc as \
get_pytorch_tpc_latest
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v1.tpc_pytorch import generate_pytorch_tpc
from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v2.tpc_pytorch import generate_pytorch_tpc
Expand Up @@ -19,6 +19,7 @@

import model_compression_toolkit as mct
from sony_custom_layers.keras.object_detection.ssd_post_process import SSDPostProcess
from mct_quantizers.keras.metadata import MetadataLayer

keras = tf.keras
layers = keras.layers
Expand Down Expand Up @@ -56,7 +57,8 @@ def test_custom_layer(self):
target_resource_utilization=mct.core.ResourceUtilization(weights_memory=6000))

# verify the custom layer is in the quantized model
self.assertTrue(isinstance(q_model.layers[-1], SSDPostProcess), 'Custom layer should be in the quantized model')
last_model_layer_index = -2 if isinstance(q_model.layers[-1], MetadataLayer) else -1
self.assertTrue(isinstance(q_model.layers[last_model_layer_index], SSDPostProcess), 'Custom layer should be in the quantized model')
# verify mixed-precision
self.assertTrue(any([q_model.layers[2].weights_quantizers['kernel'].num_bits < 8,
q_model.layers[4].weights_quantizers['kernel'].num_bits < 8]))
Expand Up @@ -28,6 +28,7 @@

import numpy as np
from tests.common_tests.helpers.tensors_compare import cosine_similarity
from mct_quantizers.keras.metadata import MetadataLayer

keras = tf.keras
layers = keras.layers
Expand Down Expand Up @@ -73,7 +74,7 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=
self.unit_test.assertFalse(isinstance(l.layer, Functional) or isinstance(l.layer, Sequential))
else:
self.unit_test.assertFalse(isinstance(l, Functional) or isinstance(l, Sequential))
num_layers = 8
num_layers = 8 + int(isinstance(quantized_model.layers[-1], MetadataLayer))
num_fq_layers = 7
self.unit_test.assertTrue(len(quantized_model.layers) == (num_layers+num_fq_layers))
y = float_model.predict(input_x)
Expand Down
Expand Up @@ -25,6 +25,7 @@
import model_compression_toolkit as mct
import tensorflow as tf
from tests.keras_tests.feature_networks_tests.base_keras_feature_test import BaseKerasFeatureNetworkTest
from mct_quantizers.keras.metadata import MetadataLayer


import numpy as np
Expand Down Expand Up @@ -88,6 +89,7 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=
num_layers = 8
else:
num_layers = 5
num_layers = num_layers + int(isinstance(quantized_model.layers[-1], MetadataLayer))
self.unit_test.assertTrue(len(quantized_model.layers) == num_layers)
y = float_model.predict(input_x)
y_hat = quantized_model.predict(input_x)
Expand Down
10 changes: 9 additions & 1 deletion tests/keras_tests/non_parallel_tests/test_keras_tp_model.py
Expand Up @@ -265,20 +265,28 @@ def rep_data():

def test_get_keras_supported_version(self):
tpc = mct.get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL) # Latest
self.assertTrue(tpc.version == 'v1')
self.assertTrue(tpc.version == 'v2')

tpc = mct.get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL, 'v1_pot')
self.assertTrue(tpc.version == 'v1_pot')
tpc = mct.get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL, 'v1_lut')
self.assertTrue(tpc.version == 'v1_lut')
tpc = mct.get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL, 'v1')
self.assertTrue(tpc.version == 'v1')
tpc = mct.get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL, 'v2_lut')
self.assertTrue(tpc.version == 'v2_lut')
tpc = mct.get_target_platform_capabilities(TENSORFLOW, DEFAULT_TP_MODEL, 'v2')
self.assertTrue(tpc.version == 'v2')

tpc = mct.get_target_platform_capabilities(TENSORFLOW, IMX500_TP_MODEL, "v1")
self.assertTrue(tpc.version == 'v1')
tpc = mct.get_target_platform_capabilities(TENSORFLOW, IMX500_TP_MODEL, "v2")
self.assertTrue(tpc.version == 'v2')

tpc = mct.get_target_platform_capabilities(TENSORFLOW, IMX500_TP_MODEL, "v1_lut")
self.assertTrue(tpc.version == 'v1_lut')
tpc = mct.get_target_platform_capabilities(TENSORFLOW, IMX500_TP_MODEL, "v2_lut")
self.assertTrue(tpc.version == 'v2_lut')

tpc = mct.get_target_platform_capabilities(TENSORFLOW, IMX500_TP_MODEL, "v1_pot")
self.assertTrue(tpc.version == 'v1_pot')
Expand Down
Expand Up @@ -72,8 +72,8 @@ def test_adding_holder_instead_quantize_wrapper(self):
activation_quantization_holders_in_model = [m[1] for m in gptq_model.named_modules() if isinstance(m[1], PytorchActivationQuantizationHolder)]
# the last module should be an activation quantization holder
self.assertTrue(isinstance(last_module, PytorchActivationQuantizationHolder))
# check that 4 activation quantization holders where generated
self.assertTrue(len(activation_quantization_holders_in_model) == 4)
# check that 3 activation quantization holders where generated
self.assertTrue(len(activation_quantization_holders_in_model) == 3)
for a in activation_quantization_holders_in_model:
self.assertTrue(isinstance(a.activation_holder_quantizer, ActivationPOTInferableQuantizer))
for name, module in gptq_model.named_modules():
Expand Down Expand Up @@ -101,8 +101,8 @@ def test_adding_holders_after_reuse(self):
last_module = list(gptq_model.named_modules())[-1][1]
# the last module should be an activation quantization holder
self.assertTrue(isinstance(last_module, PytorchActivationQuantizationHolder))
# check that 4 activation quantization holders where generated
self.assertTrue(len(activation_quantization_holders_in_model) == 4)
# check that 3 activation quantization holders where generated
self.assertTrue(len(activation_quantization_holders_in_model) == 3)
for a in activation_quantization_holders_in_model:
self.assertTrue(isinstance(a.activation_holder_quantizer, ActivationPOTInferableQuantizer))
for name, module in gptq_model.named_modules():
Expand Down
7 changes: 6 additions & 1 deletion tests/pytorch_tests/function_tests/test_pytorch_tp_model.py
Expand Up @@ -252,12 +252,17 @@ def rep_data():

def test_get_pytorch_supported_version(self):
tpc = mct.get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL) # Latest
self.assertTrue(tpc.version == 'v1')
self.assertTrue(tpc.version == 'v2')

tpc = mct.get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL, 'v1')
self.assertTrue(tpc.version == 'v1')
tpc = mct.get_target_platform_capabilities(PYTORCH, DEFAULT_TP_MODEL, 'v2')
self.assertTrue(tpc.version == 'v2')

tpc = mct.get_target_platform_capabilities(PYTORCH, IMX500_TP_MODEL, "v1")
self.assertTrue(tpc.version == 'v1')
tpc = mct.get_target_platform_capabilities(PYTORCH, IMX500_TP_MODEL, "v2")
self.assertTrue(tpc.version == 'v2')

tpc = mct.get_target_platform_capabilities(PYTORCH, TFLITE_TP_MODEL, "v1")
self.assertTrue(tpc.version == 'v1')
Expand Down

0 comments on commit 553ce03

Please sign in to comment.