Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into const_quantization
Browse files Browse the repository at this point in the history
  • Loading branch information
elad-c committed Apr 17, 2024
2 parents 077f37c + 6cbffa7 commit 168ff53
Show file tree
Hide file tree
Showing 12 changed files with 261 additions and 13 deletions.
@@ -0,0 +1,48 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from model_compression_toolkit.core.common.graph.base_graph import Graph
from model_compression_toolkit.core.common.graph.base_node import BaseNode


def remove_identity_node(graph: Graph,
node: BaseNode) -> Graph:
"""
The method to perform the substitution of the identity node by
reconnecting its input directly to its output, effectively removing the node
from the graph.
Args:
graph: The current graph of operations where the node resides.
node: The specific `BaseNode` that is matched to be an Identity operation.
Returns:
Graph: The updated graph after removing the identity node.
"""
# Retrieve the predecessor nodes of the identity node.
prev_identity_nodes = graph.get_prev_nodes(node)
# Ensure there is exactly one predecessor; otherwise, do nothing.
if len(prev_identity_nodes) != 1:
return graph

# Reconnect the output edges of the identity node to its predecessor,
# effectively bypassing the identity node.
graph.reconnect_out_edges(current_node=node, new_node=prev_identity_nodes[0])
# Remove the edge from the predecessor to the identity node.
graph.remove_edge(prev_identity_nodes[0], node)
# Remove the identity node from the graph.
graph.remove_node(node_to_remove=node)

return graph
@@ -0,0 +1,51 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import keras
import tensorflow as tf

from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
from model_compression_toolkit.core import common
from model_compression_toolkit.core.common.graph.base_graph import Graph
from model_compression_toolkit.core.common.graph.base_node import BaseNode
from model_compression_toolkit.core.common.substitutions.remove_identity import remove_identity_node


class RemoveIdentity(common.BaseSubstitution):
"""
Remove Identity layers from the graph.
"""

def __init__(self):
nodes = NodeOperationMatcher(keras.layers.Identity) | NodeOperationMatcher(tf.identity)
super().__init__(matcher_instance=nodes)

def substitute(self,
graph: Graph,
node: BaseNode) -> Graph:
"""
The method to perform the substitution of the identity keras node by
reconnecting its input directly to its output, effectively removing the node
from the graph.
Args:
graph: The current graph of operations where the node resides.
node: The specific `BaseNode` that is matched to be an Identity operation.
Returns:
Graph: The updated graph after removing the identity node.
"""
return remove_identity_node(graph, node)

4 changes: 3 additions & 1 deletion model_compression_toolkit/core/keras/keras_implementation.py
Expand Up @@ -22,6 +22,7 @@

from model_compression_toolkit.constants import HESSIAN_NUM_ITERATIONS
from model_compression_toolkit.core.common.hessian import TraceHessianRequest, HessianMode, HessianInfoService
from model_compression_toolkit.core.keras.graph_substitutions.substitutions.remove_identity import RemoveIdentity
from model_compression_toolkit.core.keras.hessian.activation_trace_hessian_calculator_keras import \
ActivationTraceHessianCalculatorKeras
from model_compression_toolkit.core.keras.hessian.weights_trace_hessian_calculator_keras import WeightsTraceHessianCalculatorKeras
Expand Down Expand Up @@ -246,7 +247,8 @@ def get_substitutions_prepare_graph(self, fw_info: FrameworkInfo = None) -> List
MatmulToDenseSubstitution(),
MultiHeadAttentionDecomposition(),
ActivationDecomposition(),
DwconvToConv()]
DwconvToConv(),
RemoveIdentity()]

def get_substitutions_pre_statistics_collection(self, quant_config: QuantizationConfig) -> \
List[common.BaseSubstitution]:
Expand Down
@@ -0,0 +1,50 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import torch

from model_compression_toolkit.core.common.substitutions.remove_identity import remove_identity_node
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
from model_compression_toolkit.core import common
from model_compression_toolkit.core.common.graph.base_graph import Graph
from model_compression_toolkit.core.common.graph.base_node import BaseNode


class RemoveIdentity(common.BaseSubstitution):
"""
Remove `torch.nn.Identity` layers from the graph.
"""

def __init__(self):
nodes = NodeOperationMatcher(torch.nn.Identity)
super().__init__(matcher_instance=nodes)

def substitute(self,
graph: Graph,
node: BaseNode) -> Graph:
"""
The method to perform the substitution of the `torch.nn.Identity` node by
reconnecting its input directly to its output, effectively removing the node
from the graph.
Args:
graph: The current graph of operations where the node resides.
node: The specific `BaseNode` that is matched to be an Identity operation.
Returns:
Graph: The updated graph after removing the identity node.
"""
return remove_identity_node(graph, node)


Expand Up @@ -58,6 +58,7 @@
FunctionalConvSubstitution
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.relu_bound_to_power_of_2 import \
ReLUBoundToPowerOfTwo
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.remove_identity import RemoveIdentity
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.reshape_with_static_shapes import \
ReshapeWithStaticShapes
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.residual_collapsing import \
Expand Down Expand Up @@ -238,7 +239,8 @@ def get_substitutions_prepare_graph(self, fw_info: FrameworkInfo = None) -> List
PermuteCallMethod(),
FunctionalConvSubstitution(fw_info),
FunctionalBatchNorm(),
FunctionalLayerNorm()]
FunctionalLayerNorm(),
RemoveIdentity()]

def get_substitutions_pre_statistics_collection(self,
quant_config: QuantizationConfig
Expand Down
@@ -0,0 +1,35 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from tests.keras_tests.feature_networks_tests.base_keras_feature_test import BaseKerasFeatureNetworkTest
import keras
import tensorflow as tf

class RemoveIdentityTest(BaseKerasFeatureNetworkTest):
def __init__(self, unit_test):
super().__init__(unit_test)

def create_networks(self):
inputs = keras.layers.Input(shape=self.get_input_shapes()[0][1:])
x = keras.layers.Conv2D(3, 3)(inputs)
x = keras.layers.Identity()(x)
x = tf.identity(x)
outputs = keras.layers.BatchNormalization()(x)
return keras.Model(inputs=inputs, outputs=outputs)

def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
# Make sure identity and bn layers are not in the final model.
# there should be 4 layers: input, input_quantizer, conv, conv_quantizer
self.unit_test.assertTrue(len(quantized_model.layers)==4)

Expand Up @@ -99,6 +99,7 @@
QuantizationAwareTrainingQuantizerHolderTest
from tests.keras_tests.feature_networks_tests.feature_networks.relu_replacement_test import ReluReplacementTest, \
SingleReluReplacementTest, ReluReplacementWithAddBiasTest
from tests.keras_tests.feature_networks_tests.feature_networks.remove_identity_test import RemoveIdentityTest
from tests.keras_tests.feature_networks_tests.feature_networks.residual_collapsing_test import ResidualCollapsingTest1, \
ResidualCollapsingTest2
from tests.keras_tests.feature_networks_tests.feature_networks.reused_layer_mixed_precision_test import \
Expand Down Expand Up @@ -140,7 +141,10 @@


class FeatureNetworkTest(unittest.TestCase):


def test_remove_identity(self):
RemoveIdentityTest(self).run_test()

def test_per_tensor_weight_quantization(self):
PerTensorWeightQuantizationTest(self).run_test()

Expand Down
@@ -1,3 +1,5 @@
import copy

import unittest
import torch
from mct_quantizers import PytorchActivationQuantizationHolder, PytorchQuantizationWrapper
Expand All @@ -23,11 +25,9 @@ def __init__(self, num_channels=3, kernel_size=1):
super(BasicModel, self).__init__()
self.conv1 = Conv2d(num_channels, num_channels, kernel_size=kernel_size, bias=False)
self.conv2 = Conv2d(num_channels, num_channels, kernel_size=kernel_size, bias=False)
self.identity = torch.nn.Identity()

def forward(self, inp):
x = self.conv1(inp)
x = self.identity(x)
x = self.conv2(x)
return x

Expand All @@ -51,11 +51,9 @@ class ReuseModel(torch.nn.Module):
def __init__(self, num_channels=3, kernel_size=1):
super(ReuseModel, self).__init__()
self.conv = Conv2d(num_channels, num_channels, kernel_size=kernel_size, bias=False)
self.identity = torch.nn.Identity()

def forward(self, inp):
x = self.conv(inp)
x = self.identity(x)
x = self.conv(x)
return x

Expand All @@ -73,7 +71,7 @@ def test_adding_holder_instead_quantize_wrapper(self):
# the last module should be an activation quantization holder
self.assertTrue(isinstance(last_module, PytorchActivationQuantizationHolder))
# check that 4 activation quantization holders where generated
self.assertTrue(len(activation_quantization_holders_in_model) == 4)
self.assertTrue(len(activation_quantization_holders_in_model) == 3)
for a in activation_quantization_holders_in_model:
self.assertTrue(isinstance(a.activation_holder_quantizer, ActivationPOTInferableQuantizer))
for name, module in gptq_model.named_modules():
Expand Down Expand Up @@ -102,7 +100,7 @@ def test_adding_holders_after_reuse(self):
# the last module should be an activation quantization holder
self.assertTrue(isinstance(last_module, PytorchActivationQuantizationHolder))
# check that 4 activation quantization holders where generated
self.assertTrue(len(activation_quantization_holders_in_model) == 4)
self.assertTrue(len(activation_quantization_holders_in_model) == 3)
for a in activation_quantization_holders_in_model:
self.assertTrue(isinstance(a.activation_holder_quantizer, ActivationPOTInferableQuantizer))
for name, module in gptq_model.named_modules():
Expand All @@ -115,13 +113,16 @@ def test_adding_holders_after_reuse(self):

def _get_gptq_model(self, input_shape, in_model):
pytorch_impl = GPTQPytorchImplemantation()
qc = copy.deepcopy(mct.core.DEFAULTCONFIG)
qc.linear_collapsing = False
graph = prepare_graph_with_quantization_parameters(in_model,
pytorch_impl,
DEFAULT_PYTORCH_INFO,
representative_dataset,
generate_pytorch_tpc,
[1] + input_shape,
mixed_precision_enabled=False)
mixed_precision_enabled=False,
qc=qc)
graph = set_bit_widths(mixed_precision_enable=False,
graph=graph)
trainer = PytorchGPTQTrainer(graph,
Expand Down
Expand Up @@ -28,7 +28,7 @@ def __init__(self):

def forward(self, x):
x1 = self.conv1(x)
x2 = self.identity(x1)
x2 = torch.relu(x1)
x3 = self.conv2(x2)
x4 = torch.relu(x3)
return x, x1, x2, x3, x4
Expand Down
Expand Up @@ -38,13 +38,12 @@ def __init__(self):
self.conv3 = Conv2d(3, 3, kernel_size=1, stride=1)
self.conv4 = Conv2d(3, 3, kernel_size=1, stride=1)
self.relu2 = ReLU()
self.identity = torch.nn.Identity()

def forward(self, inp):
x = self.conv1(inp)
x = self.relu1(x)
x = self.conv2(x)
x = self.identity(x)
x = relu(x)
x = self.conv3(x)
x = relu6(x)
x = self.conv4(x)
Expand Down
@@ -0,0 +1,50 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import torch

from tests.pytorch_tests.model_tests.base_pytorch_feature_test import BasePytorchFeatureNetworkTest


class RemoveIdentityNet(torch.nn.Module):
def __init__(self):
super(RemoveIdentityNet, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 3, kernel_size=1, stride=1)
self.identity = torch.nn.Identity()
self.bn1 = torch.nn.BatchNorm2d(3)

def forward(self, x):
x = self.conv1(x)
x = self.identity(x)
x = self.bn1(x)
return x


class RemoveIdentityTest(BasePytorchFeatureNetworkTest):

def __init__(self, unit_test):
super().__init__(unit_test)

def create_networks(self):
return RemoveIdentityNet()

def compare(self,
quantized_model,
float_model,
input_x=None,
quantization_info=None):
for n,m in quantized_model.named_modules():
# make sure identity was removed and bn was folded into the conv
self.unit_test.assertFalse(isinstance(m, torch.nn.Identity) or isinstance(m, torch.nn.BatchNorm2d))

0 comments on commit 168ff53

Please sign in to comment.