Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 21 additions & 11 deletions ngraph_bridge/ngraph_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,32 +128,34 @@ Status ValuesFromConstNode(const NodeDef& node,
n_elements *= shape.dim(i).size();
}
values->resize(n_elements);
auto val_lastsaved = (T)0; // cast
for (auto i = 0; i < n_elements; i++) {
auto& tensor = node.attr().at("value").tensor();
auto dt = node.attr().at("dtype").type();
int64 val_size = 0;
auto val_i = (T)0; // cast
switch (dt) {
// TODO(amprocte/NGRAPH-2502): there are more element types to support
// here
case DT_INT32:
(*values)[i] = (tensor.int_val_size() == 1 ? tensor.int_val()[0]
: tensor.int_val()[i]);
val_size = tensor.int_val_size();
if (val_size > i) val_i = tensor.int_val()[i];
break;
case DT_INT64:
(*values)[i] = (tensor.int64_val_size() == 1 ? tensor.int64_val()[0]
: tensor.int64_val()[i]);
val_size = tensor.int64_val_size();
if (val_size > i) val_i = tensor.int64_val()[i];
break;
case DT_FLOAT:
(*values)[i] = (tensor.float_val_size() == 1 ? tensor.float_val()[0]
: tensor.float_val()[i]);
val_size = tensor.float_val_size();
if (val_size > i) val_i = tensor.float_val()[i];
break;
case DT_BOOL:
(*values)[i] = (tensor.bool_val_size() == 1 ? tensor.bool_val()[0]
: tensor.bool_val()[i]);
val_size = tensor.bool_val_size();
if (val_size > i) val_i = tensor.bool_val()[i];
break;
case DT_DOUBLE:
(*values)[i] =
(tensor.double_val_size() == 1 ? tensor.double_val()[0]
: tensor.double_val()[i]);
val_size = tensor.double_val_size();
if (val_size > i) val_i = tensor.double_val()[i];
break;
default:
NGRAPH_VLOG(0)
Expand All @@ -165,6 +167,14 @@ Status ValuesFromConstNode(const NodeDef& node,
DataType_Name(dt),
" on an empty tensor");
}
if (val_size == 0) {
return errors::InvalidArgument("Empty values vector");
} else if (i < val_size) {
(*values)[i] = val_i;
val_lastsaved = val_i;
} else {
(*values)[i] = val_lastsaved;
}
}
} else {
values->resize(tensor_content_size / sizeof(VecT));
Expand Down
95 changes: 95 additions & 0 deletions test/python/test_const.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# ==============================================================================
# Copyright 2018-2019 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""nGraph TensorFlow bridge Const operation test

"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import pytest

import tensorflow as tf
import os

from common import NgraphTest

# Uncomment for debugging; Also add -s in command like, e.g.
# (venv-tf-py3) [build_cmake]$
# NGRAPH_TF_LOG_PLACEMENT=1 NGRAPH_TF_VLOG_LEVEL=6 pytest -s -k test_const_scalarval ../test/python/test_const.py
import logging
logging.basicConfig(level=logging.DEBUG)


class TestConstOperations(NgraphTest):

def test_const_listvals(self):
zz = tf.constant([1, 2, 3, 4, 5, 6], dtype=float, shape=[2, 3])

def run_test(sess):
return sess.run(zz)

assert (
self.with_ngraph(run_test) == self.without_ngraph(run_test)).all()

def test_const_listvals_2(self):
zz = tf.constant([[1, 2, 3], [4, 5, 6]], dtype=float, shape=[2, 3])

def run_test(sess):
return sess.run(zz)

assert (
self.with_ngraph(run_test) == self.without_ngraph(run_test)).all()

def test_const_scalarval(self):
zz = tf.constant(-3, dtype=float, shape=[2, 3])

def run_test(sess):
return sess.run(zz)

assert (
self.with_ngraph(run_test) == self.without_ngraph(run_test)).all()

def test_const_lastfill(self):
zz = tf.constant([1, 2], dtype=float, shape=[2, 3])

def run_test(sess):
return sess.run(zz)

assert (
self.with_ngraph(run_test) == self.without_ngraph(run_test)).all()

def test_const_empty(self):
log = logging.getLogger('test_const_empty')
zz = tf.constant([], dtype=float, shape=[2, 3])

def run_test(sess):
log.debug('Invoking sess.run(zz)')
return sess.run(zz)

# Ideally we want same behavior for both TF & NG, but for now we are deviating,
# NGraph will throw error, but TF will fill in zeros
# assert (
# self.with_ngraph(run_test) == self.without_ngraph(run_test)).all()

# Test to see that exception is raised in NG
try:
# This test is expected to fail currently
res = self.with_ngraph(run_test)
assert False, 'Failed, expected test to raise error'
except:
log.debug('Passed, expected NG to raise error...')
assert True