Skip to content

Commit

Permalink
added unit-tests for DivisionToZeroFP16Resolver.py, some improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
pavel-esir committed Nov 12, 2021
1 parent 85ff1c8 commit 39fb150
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 71 deletions.
99 changes: 30 additions & 69 deletions model-optimizer/extensions/middle/DivisionToZeroFP16Resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,16 @@
from mo.middle.replacement import MiddleReplacementPattern


class DivisionToZeroFP16ResolverMaximumEps(MiddleReplacementPattern):
class DivisionToZeroFP16Resolver(MiddleReplacementPattern):
"""
Patterns input_1/Maximum(input_2, eps) and input_1/Add(input_2, eps) are used
to prevent division to zero. But usually in FP32 networks eps is such
small (e.g. 1e-9, 1e-12, ...) that after casting to FP16 it's collapsed to zero.
This can lead to division to zero if input_2 is also zero.
To prevent that we change eps to FP16 smallest normal value in such patterns.
"""
enabled = True
# graph_condition = [lambda graph: graph.graph['cmd_params'].data_type == 'FP16']
graph_condition = [lambda graph: graph.graph['cmd_params'].compress_fp16]

def run_after(self):
from extensions.middle.fusings import Fusing
Expand All @@ -29,18 +34,18 @@ def pattern(self):
('input', dict(kind='data')),
('eps_or_input_data_1', dict(kind='data')), # one of these inputs is eps
('eps_or_input_data_2', dict(kind='data')),
('max', dict(kind='op', op='Maximum')),
('max_data', dict(kind='data')),
('pow_exp', dict(kind='data', value=lambda x: np.all(x < -0) if x is not None else False)),
('max_or_add', dict(kind='op', op=lambda x: x in ['Maximum', 'Add'])),
('max_or_add_data', dict(kind='data')),
('pow_exp', dict(kind='data', value=lambda x: np.all(x < 0) if x is not None else False)),
('pow', dict(kind='op', op='Pow')),
('pow_data', dict(kind='data')),
('multiplicative_inverse', dict(kind='op', op='Mul')),
],
edges=[
('eps_or_input_data_1', 'max'),
('eps_or_input_data_2', 'max'),
('max', 'max_data'),
('max_data', 'pow', {'in': 0}),
('eps_or_input_data_1', 'max_or_add'),
('eps_or_input_data_2', 'max_or_add'),
('max_or_add', 'max_or_add_data'),
('max_or_add_data', 'pow', {'in': 0}),
('pow_exp', 'pow', {'in': 1}),
('pow', 'pow_data'),
('pow_data', 'multiplicative_inverse'),
Expand All @@ -49,63 +54,19 @@ def pattern(self):
)

def replace_pattern(self, graph: Graph, match: dict):
change_const_value(match['max'])


class DivisionToZeroFP16ResolverAddEpsilon(MiddleReplacementPattern):
"""
"""
enabled = True
# graph_condition = [lambda graph: graph.graph['cmd_params'].data_type == 'FP16']

def run_after(self):
from extensions.middle.fusings import Fusing
return [Fusing]

def run_before(self):
from extensions.middle.L2NormFusing import L2NormToNorm
return [L2NormToNorm]

def pattern(self):
return dict(
nodes=[
('input', dict(kind='data')),
('eps_or_input_data_1', dict(kind='data')), # one of these inputs is eps
('eps_or_input_data_2', dict(kind='data')),
('add', dict(kind='op', op='Add')),
('add_data', dict(kind='data')),
('pow_exp', dict(kind='data', value=lambda x: np.all(x < -0) if x is not None else False)),
('pow', dict(kind='op', op='Pow')),
('pow_data', dict(kind='data')),
('multiplicative_inverse', dict(kind='op', op='Mul')),
],
edges=[
('eps_or_input_data_1', 'add', {'in': 0}),
('eps_or_input_data_2', 'add', {'in': 1}),
('add', 'add_data'),
('add_data', 'pow', {'in': 0}),
('pow_exp', 'pow', {'in': 1}),
('pow', 'pow_data'),
('pow_data', 'multiplicative_inverse'),
('input', 'multiplicative_inverse'),
]
)

def replace_pattern(self, graph: Graph, match: dict):
change_const_value(match['add'])


def change_const_value(node: Node): # node is either max or add
is_port_1_const = node.in_port(1).get_source().node.soft_get('op') == 'Const'
port = 1 if is_port_1_const else 0
const_node = node.in_port(port).get_source().node
const_name = const_node.soft_get('name', const_node.id)

fp16_machine_eps = np.finfo(np.float16).eps
if const_node.value is not None and const_node.value < fp16_machine_eps:
fp16_machine_eps = np.array(fp16_machine_eps, dtype=const_node.value.dtype)
log.error('changing value of constant {} from {} to {} to '
'prevent division to zero'.format(const_name, const_node.value, fp16_machine_eps),
extra={'is_warning': True})
const_node.value = fp16_machine_eps
const_node.out_port(0).data.set_value(fp16_machine_eps)
is_port_1_const = match['max_or_add'].in_port(1).get_source().node.soft_get('op') == 'Const'
port = 1 if is_port_1_const else 0

const_node = match['max_or_add'].in_port(port).get_source().node
value = const_node.value
# we use FP16 smallest normal value, because arithmetic of subnormal values is slower
fp16_smallest_positive = np.finfo(np.float16).tiny

if value is not None and np.all(value < fp16_smallest_positive):
new_eps = np.full_like(value, fp16_smallest_positive)
const_node.out_port(0).data.set_value(new_eps)

const_name = const_node.soft_get('name', const_node.id)
log.error("Changing value of constant '{}' from {} -> {} to "
"prevent division to zero when casted to FP16".format(const_name, value, new_eps),
extra={'is_warning': True})
5 changes: 3 additions & 2 deletions model-optimizer/extensions/middle/L2NormFusing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ def run_before(self):
return [PostMiddleStart]

def run_after(self):
from extensions.middle.DivisionToZeroFP16Resolver import DivisionToZeroFP16ResolverMaximumEps, DivisionToZeroFP16ResolverAddEpsilon
return [DivisionToZeroFP16ResolverMaximumEps, DivisionToZeroFP16ResolverAddEpsilon]
from extensions.middle.DivisionToZeroFP16Resolver import DivisionToZeroFP16Resolver
# because DivisionToZeroFP16Resolver should match to Pow(x, -1)/Div part of L2Norm
return [DivisionToZeroFP16Resolver]

def pattern(self):
return dict(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import unittest

import numpy as np

from extensions.middle.DivisionToZeroFP16Resolver import DivisionToZeroFP16Resolver
from mo.front.common.partial_infer.utils import shape_array
from mo.graph.graph import Node
from unit_tests.utils.graph import build_graph, result, regular_op_with_empty_data, connect, shaped_parameter, \
valued_const_with_data


class ChangeOutputTypeAttributesTests(unittest.TestCase):

def test_division_maximum(self):
self.build_and_test_division_graph(eps=np.array(1e-12), pow_value=np.array(-1), preventing_type='Maximum')

def test_division_add(self):
self.build_and_test_division_graph(eps=np.array(1e-12), pow_value=np.array(-1), preventing_type='Add')

def test_division_arbitrary_negative_pow_1(self):
self.build_and_test_division_graph(eps=np.array(1e-12), pow_value=np.array(-1/2), preventing_type='Add')

def test_division_arbitrary_negative_pow_2(self):
self.build_and_test_division_graph(eps=np.array(1e-12), pow_value=np.array(-0.2), preventing_type='Add')

def test_division_eps_as_array_1(self):
self.build_and_test_division_graph(eps=np.array([1e-12, 1e-12]), pow_value=np.array(-1), preventing_type='Add')

def test_division_eps_as_array_2(self):
self.build_and_test_division_graph(eps=np.array([1e-12]), pow_value=np.array(-1), preventing_type='Add')

# in this case graph should not be changed so eps will be left unchanged 1e-2
# in that case build_and_test_division_graph will raise AssertionError and it's expected
def test_division_graph_not_changed_1(self):
try:
self.build_and_test_division_graph(eps=np.array(1e-2), pow_value=np.array(-1), preventing_type='Maximum')
raise Exception
except AssertionError:
pass

# if at least one value is greater than FP16 smallest normal value
# graph should not be changed so eps will be left unchanged ([1e-2, 1e-12])
# in that case build_and_test_division_graph will raise AssertionError and it's expected
def test_division_graph_not_changed_2(self):
try:
self.build_and_test_division_graph(eps=np.array([1e-2, 1e-12]), pow_value=np.array(-1), preventing_type='Maximum')
raise Exception
except AssertionError:
pass

def build_and_test_division_graph(self, eps, pow_value, preventing_type):
nodes = {
**shaped_parameter('input_1', shape_array((1, 3, 10, 10))),
**shaped_parameter('input_2', shape_array((1, 3, 10, 10))),
**regular_op_with_empty_data(preventing_type, {'type': preventing_type, 'op': preventing_type}),
**regular_op_with_empty_data('negative_pow', {'type': 'Pow', 'op': 'Pow'}),
**regular_op_with_empty_data('mul', {'type': 'Mul', 'op': 'Mul'}),

**valued_const_with_data('negative_pow_const', pow_value),
**valued_const_with_data('eps', eps),
**result('res'),
}

edges = [
*connect('input_2', '0:' + preventing_type),
*connect('eps', '1:' + preventing_type),
*connect(preventing_type, '0:negative_pow'),
*connect('negative_pow_const', '1:negative_pow'),
*connect('negative_pow', '1:mul'),
*connect('input_1', '0:mul'),
*connect('mul', 'res'),
]
graph = build_graph(nodes, edges)
graph.graph['cmd_params'].compress_fp16 = True

DivisionToZeroFP16Resolver().find_and_replace_pattern(graph)

self.assertTrue(np.all(Node(graph, 'eps').value == np.finfo(np.float16).tiny))

def test_l2_norm(self):
nodes = {
**shaped_parameter('input', shape_array((1, 3, 10, 10))),
**regular_op_with_empty_data('square', {'type': 'Pow', 'op': 'Pow'}),
**regular_op_with_empty_data('sum', {'type': 'ReduceSum', 'op': 'ReduceSum'}),
**regular_op_with_empty_data('max', {'type': 'Maximum', 'op': 'Maximum'}),
**regular_op_with_empty_data('rsqrt', {'type': 'Pow', 'op': 'Pow'}),
**regular_op_with_empty_data('l2norm', {'type': 'Mul', 'op': 'Mul'}),

**valued_const_with_data('rsqrt_pow_const', np.array(-1 / 2)),
**valued_const_with_data('square_pow', np.array(2)),
**valued_const_with_data('eps', np.array(1e-12)),
**result('res'),
}

edges = [
*connect('input:0', '0:square'),
*connect('square_pow', '1:square'),
*connect('square', 'sum'),
*connect('sum', '0:max'),
*connect('eps', '1:max'),
*connect('max', '0:rsqrt'),
*connect('rsqrt_pow_const', '1:rsqrt'),
*connect('rsqrt', '0:l2norm'),
*connect('input:0', '1:l2norm', skip_data=True),
*connect('l2norm', 'res'),
]
graph = build_graph(nodes, edges)
graph.graph['cmd_params'].compress_fp16 = True

DivisionToZeroFP16Resolver().find_and_replace_pattern(graph)

self.assertTrue(np.all(Node(graph, 'eps').value == np.finfo(np.float16).tiny))

0 comments on commit 39fb150

Please sign in to comment.