Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ if (NGRAPH_PLAIDML_ENABLE)
endif()

add_subdirectory(python)
add_subdirectory(python/bfloat16)
add_subdirectory(model_level_tests)

if (DEFINED NGRAPH_TF_INSTALL_PREFIX)
Expand Down
23 changes: 23 additions & 0 deletions test/python/bfloat16/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2019 Nervana Systems Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

cmake_minimum_required(VERSION 3.4)

file(GLOB files RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.py")
foreach(file ${files})
execute_process(
COMMAND ${CMAKE_COMMAND} -E create_symlink
${CMAKE_CURRENT_SOURCE_DIR}/${file}
${CMAKE_CURRENT_BINARY_DIR}/${file}
)
endforeach()
90 changes: 90 additions & 0 deletions test/python/bfloat16/test_fusedbatchnorm_training_nchw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# ==============================================================================
# Copyright 2019 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""nGraph TensorFlow FusedBatchNorm test

"""
import numpy as np
import tensorflow as tf
import os
import ngraph_bridge
import pytest

np.random.seed(5)

# Inputs
scale = [1.0, 0.9, 1.1]
offset = [0.1, 0.2, -.3]
input_shape_nchw = [4, 3, 1, 2]


def tf_model():
x = tf.placeholder(tf.float32, shape=input_shape_nchw)

# cast the input dtype to bfloat16
x_c = tf.cast(x, dtype=tf.bfloat16)

# reshape the inputs to NHWC since TF does not support NCHW
x_t = tf.transpose(x_c, (0, 2, 3, 1)) # shape=[4, 1, 2, 3]

out_list = tf.nn.fused_batch_norm(x_t, scale, offset, data_format='NHWC')

# cast the output back to float32
norm = [tf.cast(i, dtype=tf.float32) for i in out_list]
return norm, x


def ng_model():
x = tf.placeholder(tf.float32, shape=input_shape_nchw)
norm = tf.nn.fused_batch_norm(x, scale, offset, data_format='NCHW')
return norm, x


config = tf.ConfigProto(
allow_soft_placement=True,
log_device_placement=False,
inter_op_parallelism_threads=1)

k_np = np.random.rand(4, 3, 1, 2).astype('f') # NCHW


def test_fusedbatchnorm_nchw():
#Test 1: tf_model TF-native
with tf.Session(config=config) as sess_tf:
ngraph_bridge.disable()
tf_out, in_0 = tf_model()
feed_dict = {in_0: k_np}
tf_outval = sess_tf.run(tf_out, feed_dict=feed_dict)

#Test 2: model2 with ngraph, NNP backend
with tf.Session(config=config) as sess_ng:
ngraph_bridge.enable()
ngraph_bridge.update_config(config)
os.environ['NGRAPH_TF_DISABLE_DEASSIGN_CLUSTERS'] = '1'
ng_out, in_0 = ng_model()
feed_dict = {in_0: k_np}
ng_outval = sess_ng.run(ng_out, feed_dict=feed_dict)

# transpose TF output from NHWC to NCHW for comparison with ngraph output
result1_bool = np.allclose(
np.transpose(tf_outval[0], (0, 3, 1, 2)),
ng_outval[0],
rtol=0,
atol=1e-02)
# these TF outputs do not need to be transposed since they have only 1 dimension
result2_bool = np.allclose(tf_outval[1], ng_outval[1], rtol=0, atol=1e-02)
result3_bool = np.allclose(tf_outval[2], ng_outval[2], rtol=0, atol=1e-02)

assert (result1_bool and result2_bool and result3_bool)
81 changes: 81 additions & 0 deletions test/python/bfloat16/test_fusedbatchnorm_training_nhwc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# ==============================================================================
# Copyright 2019 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""nGraph TensorFlow FusedBatchNorm test

"""
import numpy as np
import tensorflow as tf
import os
import ngraph_bridge
import pytest

np.random.seed(5)

# Inputs
scale = [1.0, 0.9, 1.1]
offset = [0.1, 0.2, -.3]
input_shape_nhwc = [4, 1, 2, 3]


def tf_model():
x = tf.placeholder(tf.float32, shape=input_shape_nhwc)

# cast the input dtype to bfloat16 for TF
x_c = tf.cast(x, dtype=tf.bfloat16)

out_list = tf.nn.fused_batch_norm(x, scale, offset, data_format='NHWC')

# cast the output dtype back to float32
norm = [tf.cast(i, dtype=tf.float32) for i in out_list]
return norm, x


def ng_model():
x = tf.placeholder(tf.float32, shape=input_shape_nhwc)
norm = tf.nn.fused_batch_norm(x, scale, offset, data_format='NHWC')
return norm, x


config = tf.ConfigProto(
allow_soft_placement=True,
log_device_placement=False,
inter_op_parallelism_threads=1)

k_np = np.random.rand(4, 1, 2, 3).astype('f') # NHWC


def test_fusedbatchnorm_nhwc():
#Test 1: tf_model TF-native
with tf.Session(config=config) as sess_tf:
ngraph_bridge.disable()
tf_out, in_0 = tf_model()
feed_dict = {in_0: k_np}
tf_outval = sess_tf.run(tf_out, feed_dict=feed_dict)

#Test 2: model2 with ngraph, NNP backend
with tf.Session(config=config) as sess_ng:
ngraph_bridge.enable()
ngraph_bridge.update_config(config)
os.environ['NGRAPH_TF_DISABLE_DEASSIGN_CLUSTERS'] = '1'
ng_out, in_0 = ng_model()
feed_dict = {in_0: k_np}
ng_outval = sess_ng.run(ng_out, feed_dict=feed_dict)

result1_bool = np.allclose(tf_outval[0], ng_outval[0], rtol=0, atol=1e-02)
result2_bool = np.allclose(tf_outval[1], ng_outval[1], rtol=0, atol=1e-02)
result3_bool = np.allclose(tf_outval[2], ng_outval[2], rtol=0, atol=1e-02)

assert (result1_bool and result2_bool and result3_bool)
151 changes: 151 additions & 0 deletions test/python/bfloat16/test_maxpoolbackprop_nchw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# ==============================================================================
# Copyright 2019 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""nGraph TensorFlow bridge MaxPoolBackprop operation test

"""

# Currently, this test fails with a segmentation fault
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import pytest
import numpy as np
import os

import tensorflow as tf
from tensorflow.python.ops.gen_nn_ops import max_pool_grad

import ngraph_bridge

# Test Ngraph Op MaxPoolBackprop with data format NCHW
# TF Op:MaxPoolGrad

np.random.seed(5)

#Inputs
N = 4
C = 3
H = 8
W = 8

valid_shape = [4, 3, 3, 3]
same_shape = [4, 3, 4, 4]

output_nchw = {
"VALID": np.random.rand(*valid_shape).astype('f'),
"SAME": np.random.rand(*same_shape).astype('f')
}
grad_nchw = {
"VALID": np.random.rand(*valid_shape).astype('f'),
"SAME": np.random.rand(*same_shape).astype('f')
}

stride_nhwc = [1, 2, 2, 1]
ksize_nhwc = [1, 3, 3, 1]

stride_nchw = [1, 1, 2, 2]
ksize_nchw = [1, 1, 3, 3]


# TF graph
def tf_model(padding):
orig_in = tf.placeholder(tf.float32, shape=[N, C, H, W])
if padding == "VALID":
grad = tf.placeholder(tf.float32, shape=valid_shape)
orig_out = tf.placeholder(tf.float32, shape=valid_shape)
elif padding == "SAME":
grad = tf.placeholder(tf.float32, shape=same_shape)
orig_out = tf.placeholder(tf.float32, shape=same_shape)

# cast the input dtype to bfloat16 for TF
orig_in_c = tf.cast(orig_in, tf.bfloat16)
orig_out_c = tf.cast(orig_out, tf.bfloat16)
grad_c = tf.cast(grad, tf.bfloat16)

# transpose to NHWC
orig_in_t = tf.transpose(orig_in_c, (0, 2, 3, 1))
orig_out_t = tf.transpose(orig_out_c, (0, 2, 3, 1))
grad_t = tf.transpose(grad_c, (0, 2, 3, 1))

out = max_pool_grad(
orig_in_t,
orig_out_t,
grad_t,
ksize_nhwc,
stride_nhwc,
padding=padding,
data_format="NHWC")

# cast the output dtype back to float32
output = tf.cast(out, tf.float32)

# transpose to NCHW
output_nchw = tf.transpose(output, (0, 3, 1, 2))
return output_nchw, orig_in, orig_out, grad


# Ngraph graph
def ng_model(padding):
orig_in = tf.placeholder(tf.float32, shape=[N, C, H, W])
if padding == "VALID":
grad = tf.placeholder(tf.float32, shape=valid_shape)
orig_out = tf.placeholder(tf.float32, shape=valid_shape)
elif padding == "SAME":
grad = tf.placeholder(tf.float32, shape=same_shape)
orig_out = tf.placeholder(tf.float32, shape=same_shape)

out = max_pool_grad(
orig_in,
orig_out,
grad,
ksize_nchw,
stride_nchw,
padding=padding,
data_format="NCHW")
return out, orig_in, orig_out, grad


config = tf.ConfigProto(
allow_soft_placement=True,
log_device_placement=False,
inter_op_parallelism_threads=1)

i_np = np.random.rand(N, C, H, W).astype('f') # NHWC


@pytest.mark.parametrize("padding", ("VALID", "SAME"))
def test_maxpoolbackprop_nchw(padding):
g_np = grad_nchw[padding]
o_np = output_nchw[padding]

#Test 1: tf_model TF-native
with tf.Session(config=config) as sess_tf:
ngraph_bridge.disable()
tf_out, orig_in, orig_out, grad = tf_model(padding)
feed_dict = {orig_in: i_np, orig_out: o_np, grad: g_np}
tf_outval = sess_tf.run(tf_out, feed_dict=feed_dict)

#Test 2: model2 with ngraph, NNP backend
with tf.Session(config=config) as sess_ng:
ngraph_bridge.enable()
ngraph_bridge.update_config(config)
os.environ['NGRAPH_TF_DISABLE_DEASSIGN_CLUSTERS'] = '1'
ng_out, orig_in, orig_out, grad = ng_model(padding)
feed_dict = {orig_in: i_np, orig_out: o_np, grad: g_np}
ng_outval = sess_ng.run(ng_out, feed_dict=feed_dict)

assert (np.allclose(tf_outval, ng_outval, rtol=0, atol=1e-02))
Loading