Skip to content

Commit

Permalink
add unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
reuvenp committed Apr 16, 2024
1 parent 78928c8 commit af4f8e0
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 0 deletions.
@@ -0,0 +1,50 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import torch

from tests.pytorch_tests.model_tests.base_pytorch_feature_test import BasePytorchFeatureNetworkTest


class RemoveIdentityNet(torch.nn.Module):
def __init__(self):
super(RemoveIdentityNet, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 3, kernel_size=1, stride=1)
self.identity = torch.nn.Identity()
self.bn1 = torch.nn.BatchNorm2d(3)

def forward(self, x):
x = self.conv1(x)
x = self.identity(x)
x = self.bn1(x)
return x


class RemoveIdentityTest(BasePytorchFeatureNetworkTest):

def __init__(self, unit_test):
super().__init__(unit_test)

def create_networks(self):
return RemoveIdentityNet()

def compare(self,
quantized_model,
float_model,
input_x=None,
quantization_info=None):
for n,m in quantized_model.named_modules():
# make sure identity was removed and bn was folded into the conv
self.unit_test.assertFalse(isinstance(m, torch.nn.Identity) or isinstance(m, torch.nn.BatchNorm2d))

6 changes: 6 additions & 0 deletions tests/pytorch_tests/model_tests/test_feature_models_runner.py
Expand Up @@ -88,9 +88,15 @@
from tests.pytorch_tests.model_tests.feature_models.const_representation_test import ConstRepresentationTest, \
ConstRepresentationMultiInputTest
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
from tests.pytorch_tests.model_tests.feature_models.remove_identity_test import RemoveIdentityTest


class FeatureModelsTestRunner(unittest.TestCase):
def test_remove_identity(self):
"""
This test checks that identity layers are removed from the model.
"""
RemoveIdentityTest(self).run_test()

def test_single_layer_replacement(self):
"""
Expand Down

0 comments on commit af4f8e0

Please sign in to comment.