Skip to content

Commit

Permalink
added DivisionToZeroFP16Resolver.py
Browse files Browse the repository at this point in the history
  • Loading branch information
pavel-esir committed Nov 11, 2021
1 parent 2c7bbf8 commit 85ff1c8
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 6 deletions.
1 change: 1 addition & 0 deletions model-optimizer/automation/package_BOM.txt
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,7 @@ extensions/middle/DeleteControlFlowEdges.py
extensions/middle/DeleteNotExecutable.py
extensions/middle/dequantize_linear_resolver.py
extensions/middle/DilatedConvolution.py
extensions/middle/DivisionToZeroFP16Resolver.py
extensions/middle/EltwiseChecker.py
extensions/middle/EltwiseInputReshape.py
extensions/middle/FakeSplitOutputs.py
Expand Down
111 changes: 111 additions & 0 deletions model-optimizer/extensions/middle/DivisionToZeroFP16Resolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import logging as log

import numpy as np

from mo.graph.graph import Graph, Node
from mo.middle.replacement import MiddleReplacementPattern


class DivisionToZeroFP16ResolverMaximumEps(MiddleReplacementPattern):
"""
"""
enabled = True
# graph_condition = [lambda graph: graph.graph['cmd_params'].data_type == 'FP16']

def run_after(self):
from extensions.middle.fusings import Fusing
return [Fusing]

def run_before(self):
from extensions.middle.L2NormFusing import L2NormToNorm
return [L2NormToNorm]

def pattern(self):
return dict(
nodes=[
('input', dict(kind='data')),
('eps_or_input_data_1', dict(kind='data')), # one of these inputs is eps
('eps_or_input_data_2', dict(kind='data')),
('max', dict(kind='op', op='Maximum')),
('max_data', dict(kind='data')),
('pow_exp', dict(kind='data', value=lambda x: np.all(x < -0) if x is not None else False)),
('pow', dict(kind='op', op='Pow')),
('pow_data', dict(kind='data')),
('multiplicative_inverse', dict(kind='op', op='Mul')),
],
edges=[
('eps_or_input_data_1', 'max'),
('eps_or_input_data_2', 'max'),
('max', 'max_data'),
('max_data', 'pow', {'in': 0}),
('pow_exp', 'pow', {'in': 1}),
('pow', 'pow_data'),
('pow_data', 'multiplicative_inverse'),
('input', 'multiplicative_inverse'),
]
)

def replace_pattern(self, graph: Graph, match: dict):
change_const_value(match['max'])


class DivisionToZeroFP16ResolverAddEpsilon(MiddleReplacementPattern):
"""
"""
enabled = True
# graph_condition = [lambda graph: graph.graph['cmd_params'].data_type == 'FP16']

def run_after(self):
from extensions.middle.fusings import Fusing
return [Fusing]

def run_before(self):
from extensions.middle.L2NormFusing import L2NormToNorm
return [L2NormToNorm]

def pattern(self):
return dict(
nodes=[
('input', dict(kind='data')),
('eps_or_input_data_1', dict(kind='data')), # one of these inputs is eps
('eps_or_input_data_2', dict(kind='data')),
('add', dict(kind='op', op='Add')),
('add_data', dict(kind='data')),
('pow_exp', dict(kind='data', value=lambda x: np.all(x < -0) if x is not None else False)),
('pow', dict(kind='op', op='Pow')),
('pow_data', dict(kind='data')),
('multiplicative_inverse', dict(kind='op', op='Mul')),
],
edges=[
('eps_or_input_data_1', 'add', {'in': 0}),
('eps_or_input_data_2', 'add', {'in': 1}),
('add', 'add_data'),
('add_data', 'pow', {'in': 0}),
('pow_exp', 'pow', {'in': 1}),
('pow', 'pow_data'),
('pow_data', 'multiplicative_inverse'),
('input', 'multiplicative_inverse'),
]
)

def replace_pattern(self, graph: Graph, match: dict):
change_const_value(match['add'])


def change_const_value(node: Node): # node is either max or add
is_port_1_const = node.in_port(1).get_source().node.soft_get('op') == 'Const'
port = 1 if is_port_1_const else 0
const_node = node.in_port(port).get_source().node
const_name = const_node.soft_get('name', const_node.id)

fp16_machine_eps = np.finfo(np.float16).eps
if const_node.value is not None and const_node.value < fp16_machine_eps:
fp16_machine_eps = np.array(fp16_machine_eps, dtype=const_node.value.dtype)
log.error('changing value of constant {} from {} to {} to '
'prevent division to zero'.format(const_name, const_node.value, fp16_machine_eps),
extra={'is_warning': True})
const_node.value = fp16_machine_eps
const_node.out_port(0).data.set_value(fp16_machine_eps)
12 changes: 6 additions & 6 deletions model-optimizer/extensions/middle/L2NormFusing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ class L2NormToNorm(MiddleReplacementPattern):
enabled = True
force_clean_up = True

def run_after(self):
from extensions.middle.pass_separator import PreMiddleStart
return [PreMiddleStart]

def run_before(self):
from extensions.middle.pass_separator import MiddleStart
return [MiddleStart]
from extensions.middle.pass_separator import PostMiddleStart
return [PostMiddleStart]

def run_after(self):
from extensions.middle.DivisionToZeroFP16Resolver import DivisionToZeroFP16ResolverMaximumEps, DivisionToZeroFP16ResolverAddEpsilon
return [DivisionToZeroFP16ResolverMaximumEps, DivisionToZeroFP16ResolverAddEpsilon]

def pattern(self):
return dict(
Expand Down

0 comments on commit 85ff1c8

Please sign in to comment.