diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 1b4076df5..3d8522cfa 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -109,6 +109,7 @@ if (NGRAPH_PLAIDML_ENABLE) endif() add_subdirectory(python) +add_subdirectory(python/bfloat16) add_subdirectory(model_level_tests) if (DEFINED NGRAPH_TF_INSTALL_PREFIX) diff --git a/test/python/bfloat16/CMakeLists.txt b/test/python/bfloat16/CMakeLists.txt new file mode 100644 index 000000000..519c682e7 --- /dev/null +++ b/test/python/bfloat16/CMakeLists.txt @@ -0,0 +1,23 @@ +# Copyright 2019 Nervana Systems Inc. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cmake_minimum_required(VERSION 3.4) + +file(GLOB files RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.py") +foreach(file ${files}) + execute_process( + COMMAND ${CMAKE_COMMAND} -E create_symlink + ${CMAKE_CURRENT_SOURCE_DIR}/${file} + ${CMAKE_CURRENT_BINARY_DIR}/${file} +) +endforeach() \ No newline at end of file diff --git a/test/python/bfloat16/test_fusedbatchnorm_training_nchw.py b/test/python/bfloat16/test_fusedbatchnorm_training_nchw.py new file mode 100644 index 000000000..aee482364 --- /dev/null +++ b/test/python/bfloat16/test_fusedbatchnorm_training_nchw.py @@ -0,0 +1,90 @@ +# ============================================================================== +# Copyright 2019 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""nGraph TensorFlow FusedBatchNorm test + +""" +import numpy as np +import tensorflow as tf +import os +import ngraph_bridge +import pytest + +np.random.seed(5) + +# Inputs +scale = [1.0, 0.9, 1.1] +offset = [0.1, 0.2, -.3] +input_shape_nchw = [4, 3, 1, 2] + + +def tf_model(): + x = tf.placeholder(tf.float32, shape=input_shape_nchw) + + # cast the input dtype to bfloat16 + x_c = tf.cast(x, dtype=tf.bfloat16) + + # reshape the inputs to NHWC since TF does not support NCHW + x_t = tf.transpose(x_c, (0, 2, 3, 1)) # shape=[4, 1, 2, 3] + + out_list = tf.nn.fused_batch_norm(x_t, scale, offset, data_format='NHWC') + + # cast the output back to float32 + norm = [tf.cast(i, dtype=tf.float32) for i in out_list] + return norm, x + + +def ng_model(): + x = tf.placeholder(tf.float32, shape=input_shape_nchw) + norm = tf.nn.fused_batch_norm(x, scale, offset, data_format='NCHW') + return norm, x + + +config = tf.ConfigProto( + allow_soft_placement=True, + log_device_placement=False, + inter_op_parallelism_threads=1) + +k_np = np.random.rand(4, 3, 1, 2).astype('f') # NCHW + + +def test_fusedbatchnorm_nchw(): + #Test 1: tf_model TF-native + with tf.Session(config=config) as sess_tf: + ngraph_bridge.disable() + tf_out, in_0 = tf_model() + feed_dict = {in_0: k_np} + tf_outval = sess_tf.run(tf_out, feed_dict=feed_dict) + + #Test 2: model2 with ngraph, NNP backend + with tf.Session(config=config) as sess_ng: + ngraph_bridge.enable() + ngraph_bridge.update_config(config) + os.environ['NGRAPH_TF_DISABLE_DEASSIGN_CLUSTERS'] = '1' + ng_out, in_0 = ng_model() + feed_dict = {in_0: k_np} + ng_outval = sess_ng.run(ng_out, feed_dict=feed_dict) + + # transpose TF output from NHWC to NCHW for comparison with ngraph output + result1_bool = np.allclose( + np.transpose(tf_outval[0], (0, 3, 1, 2)), + ng_outval[0], + rtol=0, + atol=1e-02) + # these TF outputs do not need to be transposed since they have only 1 dimension + result2_bool = np.allclose(tf_outval[1], ng_outval[1], rtol=0, atol=1e-02) + result3_bool = np.allclose(tf_outval[2], ng_outval[2], rtol=0, atol=1e-02) + + assert (result1_bool and result2_bool and result3_bool) diff --git a/test/python/bfloat16/test_fusedbatchnorm_training_nhwc.py b/test/python/bfloat16/test_fusedbatchnorm_training_nhwc.py new file mode 100644 index 000000000..8227896a4 --- /dev/null +++ b/test/python/bfloat16/test_fusedbatchnorm_training_nhwc.py @@ -0,0 +1,81 @@ +# ============================================================================== +# Copyright 2019 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""nGraph TensorFlow FusedBatchNorm test + +""" +import numpy as np +import tensorflow as tf +import os +import ngraph_bridge +import pytest + +np.random.seed(5) + +# Inputs +scale = [1.0, 0.9, 1.1] +offset = [0.1, 0.2, -.3] +input_shape_nhwc = [4, 1, 2, 3] + + +def tf_model(): + x = tf.placeholder(tf.float32, shape=input_shape_nhwc) + + # cast the input dtype to bfloat16 for TF + x_c = tf.cast(x, dtype=tf.bfloat16) + + out_list = tf.nn.fused_batch_norm(x, scale, offset, data_format='NHWC') + + # cast the output dtype back to float32 + norm = [tf.cast(i, dtype=tf.float32) for i in out_list] + return norm, x + + +def ng_model(): + x = tf.placeholder(tf.float32, shape=input_shape_nhwc) + norm = tf.nn.fused_batch_norm(x, scale, offset, data_format='NHWC') + return norm, x + + +config = tf.ConfigProto( + allow_soft_placement=True, + log_device_placement=False, + inter_op_parallelism_threads=1) + +k_np = np.random.rand(4, 1, 2, 3).astype('f') # NHWC + + +def test_fusedbatchnorm_nhwc(): + #Test 1: tf_model TF-native + with tf.Session(config=config) as sess_tf: + ngraph_bridge.disable() + tf_out, in_0 = tf_model() + feed_dict = {in_0: k_np} + tf_outval = sess_tf.run(tf_out, feed_dict=feed_dict) + + #Test 2: model2 with ngraph, NNP backend + with tf.Session(config=config) as sess_ng: + ngraph_bridge.enable() + ngraph_bridge.update_config(config) + os.environ['NGRAPH_TF_DISABLE_DEASSIGN_CLUSTERS'] = '1' + ng_out, in_0 = ng_model() + feed_dict = {in_0: k_np} + ng_outval = sess_ng.run(ng_out, feed_dict=feed_dict) + + result1_bool = np.allclose(tf_outval[0], ng_outval[0], rtol=0, atol=1e-02) + result2_bool = np.allclose(tf_outval[1], ng_outval[1], rtol=0, atol=1e-02) + result3_bool = np.allclose(tf_outval[2], ng_outval[2], rtol=0, atol=1e-02) + + assert (result1_bool and result2_bool and result3_bool) diff --git a/test/python/bfloat16/test_maxpoolbackprop_nchw.py b/test/python/bfloat16/test_maxpoolbackprop_nchw.py new file mode 100644 index 000000000..d62b51f35 --- /dev/null +++ b/test/python/bfloat16/test_maxpoolbackprop_nchw.py @@ -0,0 +1,151 @@ +# ============================================================================== +# Copyright 2019 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""nGraph TensorFlow bridge MaxPoolBackprop operation test + +""" + +# Currently, this test fails with a segmentation fault +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import pytest +import numpy as np +import os + +import tensorflow as tf +from tensorflow.python.ops.gen_nn_ops import max_pool_grad + +import ngraph_bridge + +# Test Ngraph Op MaxPoolBackprop with data format NCHW +# TF Op:MaxPoolGrad + +np.random.seed(5) + +#Inputs +N = 4 +C = 3 +H = 8 +W = 8 + +valid_shape = [4, 3, 3, 3] +same_shape = [4, 3, 4, 4] + +output_nchw = { + "VALID": np.random.rand(*valid_shape).astype('f'), + "SAME": np.random.rand(*same_shape).astype('f') +} +grad_nchw = { + "VALID": np.random.rand(*valid_shape).astype('f'), + "SAME": np.random.rand(*same_shape).astype('f') +} + +stride_nhwc = [1, 2, 2, 1] +ksize_nhwc = [1, 3, 3, 1] + +stride_nchw = [1, 1, 2, 2] +ksize_nchw = [1, 1, 3, 3] + + +# TF graph +def tf_model(padding): + orig_in = tf.placeholder(tf.float32, shape=[N, C, H, W]) + if padding == "VALID": + grad = tf.placeholder(tf.float32, shape=valid_shape) + orig_out = tf.placeholder(tf.float32, shape=valid_shape) + elif padding == "SAME": + grad = tf.placeholder(tf.float32, shape=same_shape) + orig_out = tf.placeholder(tf.float32, shape=same_shape) + + # cast the input dtype to bfloat16 for TF + orig_in_c = tf.cast(orig_in, tf.bfloat16) + orig_out_c = tf.cast(orig_out, tf.bfloat16) + grad_c = tf.cast(grad, tf.bfloat16) + + # transpose to NHWC + orig_in_t = tf.transpose(orig_in_c, (0, 2, 3, 1)) + orig_out_t = tf.transpose(orig_out_c, (0, 2, 3, 1)) + grad_t = tf.transpose(grad_c, (0, 2, 3, 1)) + + out = max_pool_grad( + orig_in_t, + orig_out_t, + grad_t, + ksize_nhwc, + stride_nhwc, + padding=padding, + data_format="NHWC") + + # cast the output dtype back to float32 + output = tf.cast(out, tf.float32) + + # transpose to NCHW + output_nchw = tf.transpose(output, (0, 3, 1, 2)) + return output_nchw, orig_in, orig_out, grad + + +# Ngraph graph +def ng_model(padding): + orig_in = tf.placeholder(tf.float32, shape=[N, C, H, W]) + if padding == "VALID": + grad = tf.placeholder(tf.float32, shape=valid_shape) + orig_out = tf.placeholder(tf.float32, shape=valid_shape) + elif padding == "SAME": + grad = tf.placeholder(tf.float32, shape=same_shape) + orig_out = tf.placeholder(tf.float32, shape=same_shape) + + out = max_pool_grad( + orig_in, + orig_out, + grad, + ksize_nchw, + stride_nchw, + padding=padding, + data_format="NCHW") + return out, orig_in, orig_out, grad + + +config = tf.ConfigProto( + allow_soft_placement=True, + log_device_placement=False, + inter_op_parallelism_threads=1) + +i_np = np.random.rand(N, C, H, W).astype('f') # NHWC + + +@pytest.mark.parametrize("padding", ("VALID", "SAME")) +def test_maxpoolbackprop_nchw(padding): + g_np = grad_nchw[padding] + o_np = output_nchw[padding] + + #Test 1: tf_model TF-native + with tf.Session(config=config) as sess_tf: + ngraph_bridge.disable() + tf_out, orig_in, orig_out, grad = tf_model(padding) + feed_dict = {orig_in: i_np, orig_out: o_np, grad: g_np} + tf_outval = sess_tf.run(tf_out, feed_dict=feed_dict) + + #Test 2: model2 with ngraph, NNP backend + with tf.Session(config=config) as sess_ng: + ngraph_bridge.enable() + ngraph_bridge.update_config(config) + os.environ['NGRAPH_TF_DISABLE_DEASSIGN_CLUSTERS'] = '1' + ng_out, orig_in, orig_out, grad = ng_model(padding) + feed_dict = {orig_in: i_np, orig_out: o_np, grad: g_np} + ng_outval = sess_ng.run(ng_out, feed_dict=feed_dict) + + assert (np.allclose(tf_outval, ng_outval, rtol=0, atol=1e-02)) diff --git a/test/python/bfloat16/test_maxpoolbackprop_nhwc.py b/test/python/bfloat16/test_maxpoolbackprop_nhwc.py new file mode 100644 index 000000000..857ae6e06 --- /dev/null +++ b/test/python/bfloat16/test_maxpoolbackprop_nhwc.py @@ -0,0 +1,139 @@ +# ============================================================================== +# Copyright 2019 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""nGraph TensorFlow bridge MaxPoolBackprop operation test + +""" + +# Currently, this test fails with a segmentation fault +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import pytest +import numpy as np +import os + +import tensorflow as tf +from tensorflow.python.ops.gen_nn_ops import max_pool_grad + +import ngraph_bridge + +# Test Ngraph Op MaxPoolBackprop with data format NHWC +# TF Op:MaxPoolGrad + +np.random.seed(5) + +#Inputs +N = 4 +H = 8 +W = 8 +C = 3 + +valid_shape = [4, 3, 3, 3] +same_shape = [4, 4, 4, 3] + +output_nhwc = { + "VALID": np.random.rand(*valid_shape).astype('f'), + "SAME": np.random.rand(*same_shape).astype('f') +} +grad_nhwc = { + "VALID": np.random.rand(*valid_shape).astype('f'), + "SAME": np.random.rand(*same_shape).astype('f') +} +stride_nhwc = [1, 2, 2, 1] +ksize_nhwc = [1, 3, 3, 1] + + +# TF graph +def tf_model(padding): + orig_in = tf.placeholder(tf.float32, shape=[N, H, W, C]) + if padding == "VALID": + grad = tf.placeholder(tf.float32, shape=valid_shape) + orig_out = tf.placeholder(tf.float32, shape=valid_shape) + elif padding == "SAME": + grad = tf.placeholder(tf.float32, shape=same_shape) + orig_out = tf.placeholder(tf.float32, shape=same_shape) + + # cast the input dtype to bfloat16 for TF + orig_in_c = tf.cast(orig_in, tf.bfloat16) + orig_out_c = tf.cast(orig_out, tf.bfloat16) + grad_c = tf.cast(grad, tf.bfloat16) + + out = max_pool_grad( + orig_in_c, + orig_out_c, + grad_c, + ksize_nhwc, + stride_nhwc, + padding=padding, + data_format="NHWC") + + # cast the output dtype back to float32 + output = tf.cast(out, tf.float32) + return output, orig_in, orig_out, grad + + +# Ngraph graph +def ng_model(padding): + orig_in = tf.placeholder(tf.float32, shape=[N, H, W, C]) + if padding == "VALID": + grad = tf.placeholder(tf.float32, shape=valid_shape) + orig_out = tf.placeholder(tf.float32, shape=valid_shape) + elif padding == "SAME": + grad = tf.placeholder(tf.float32, shape=same_shape) + orig_out = tf.placeholder(tf.float32, shape=same_shape) + + out = max_pool_grad( + orig_in, + orig_out, + grad, + ksize_nhwc, + stride_nhwc, + padding=padding, + data_format="NHWC") + return out, orig_in, orig_out, grad + + +config = tf.ConfigProto( + allow_soft_placement=True, + log_device_placement=False, + inter_op_parallelism_threads=1) + +i_np = np.random.rand(N, H, W, C).astype('f') # NHWC + + +@pytest.mark.parametrize("padding", ("VALID", "SAME")) +def test_maxpoolbackprop_nhwc(padding): + g_np = grad_nhwc[padding] + o_np = output_nhwc[padding] + + #Test 1: tf_model TF-native + with tf.Session(config=config) as sess_tf: + ngraph_bridge.disable() + tf_out, orig_in, orig_out, grad = tf_model(padding) + feed_dict = {orig_in: i_np, orig_out: o_np, grad: g_np} + tf_outval = sess_tf.run(tf_out, feed_dict=feed_dict) + + #Test 2: model2 with ngraph, NNP backend + with tf.Session(config=config) as sess_ng: + ngraph_bridge.enable() + ngraph_bridge.update_config(config) + os.environ['NGRAPH_TF_DISABLE_DEASSIGN_CLUSTERS'] = '1' + ng_out, orig_in, orig_out, grad = ng_model(padding) + feed_dict = {orig_in: i_np, orig_out: o_np, grad: g_np} + ng_outval = sess_ng.run(ng_out, feed_dict=feed_dict) + + assert (np.allclose(tf_outval, ng_outval, rtol=0, atol=1e-02)) diff --git a/test/python/test_fusedbatchnorm.py b/test/python/test_fusedbatchnorm.py index b00cbb9a6..df1fbf534 100644 --- a/test/python/test_fusedbatchnorm.py +++ b/test/python/test_fusedbatchnorm.py @@ -29,63 +29,63 @@ # yes, it works without (tested over 1000 runs) but there's always a chance np.random.seed(5) +NHWC_TO_NCHW = (0, 3, 1, 2) +NCHW_TO_NHWC = (0, 2, 3, 1) + -@pytest.mark.skip(reason="new deviceless mode WIP") class TestFusedBatchNorm(NgraphTest): - x = np.random.rand(64, 3, 10, 8).astype('f') + x = np.random.rand(64, 3, 10, 8).astype('f') #NCHW scale = [1.0, 0.9, 1.1] offset = [0.1, 0.2, -.3] mean = [0.4, 0.5, 0.6] variance = [0.1, 0.2, 0.3] def test_fusedbatchnorm_nchw(self): - with self.device: + + def test_on_ng(sess): norm = tf.nn.fused_batch_norm( self.x, self.scale, self.offset, data_format='NCHW') + return (sess.run(norm)) - with self.session as sess: - result = sess.run(norm) - - with tf.device('/cpu:0'): - x_t = tf.transpose(self.x, (0, 2, 3, 1)) + def test_on_tf(sess): # tensorflow CPU doesn't support NCHW + x_t = tf.transpose(self.x, NCHW_TO_NHWC) # NHWC norm = tf.nn.fused_batch_norm( x_t, self.scale, self.offset, data_format='NHWC') + return (sess.run(norm)) - with self.session as sess: - expected = sess.run(norm) - + expected = self.without_ngraph(test_on_tf) + result = self.with_ngraph(test_on_ng) np.testing.assert_allclose( result[0], - np.transpose(expected[0], (0, 3, 1, 2)), + np.transpose(expected[0], NHWC_TO_NCHW), rtol=0, atol=5e-5) np.testing.assert_allclose(result[1], expected[1], rtol=0, atol=5e-5) np.testing.assert_allclose(result[2], expected[2], rtol=0, atol=5e-5) def test_fusedbatchnorm_nhwc(self): - x_t = tf.transpose(self.x, (0, 2, 3, 1)) + x_t = tf.transpose(self.x, NCHW_TO_NHWC) - with self.device: + def test_on_ng(sess): norm = tf.nn.fused_batch_norm( x_t, self.scale, self.offset, data_format='NHWC') + return (sess.run(norm)) - with self.session as sess: - result = sess.run(norm) - - with tf.device('/cpu:0'): + def test_on_tf(sess): norm = tf.nn.fused_batch_norm( x_t, self.scale, self.offset, data_format='NHWC') + return (sess.run(norm)) - with self.session as sess: - expected = sess.run(norm) - + expected = self.without_ngraph(test_on_tf) + result = self.with_ngraph(test_on_ng) np.testing.assert_allclose(result[0], expected[0], rtol=0, atol=5e-5) np.testing.assert_allclose(result[1], expected[1], rtol=0, atol=5e-5) np.testing.assert_allclose(result[2], expected[2], rtol=0, atol=5e-5) def test_fusedbatchnorm_inference_nchw(self): - with self.device: + + def test_on_ng(sess): norm = tf.nn.fused_batch_norm( self.x, self.scale, @@ -94,12 +94,10 @@ def test_fusedbatchnorm_inference_nchw(self): self.variance, data_format='NCHW', is_training=False) + return (sess.run(norm[0])) - with self.session as sess: - result = sess.run(norm[0]) - - with tf.device('/cpu:0'): - x_t = tf.transpose(self.x, (0, 2, 3, 1)) + def test_on_tf(sess): + x_t = tf.transpose(self.x, NCHW_TO_NHWC) norm = tf.nn.fused_batch_norm( x_t, self.scale, @@ -108,17 +106,17 @@ def test_fusedbatchnorm_inference_nchw(self): self.variance, data_format='NHWC', is_training=False) + return (sess.run(norm[0])) - with self.session as sess: - expected = sess.run(norm[0]) - + expected = self.without_ngraph(test_on_tf) + result = self.with_ngraph(test_on_ng) np.testing.assert_allclose( - result, np.transpose(expected, (0, 3, 1, 2)), rtol=0, atol=5e-5) + result, np.transpose(expected, NHWC_TO_NCHW), rtol=0, atol=5e-5) def test_fusedbatchnorm_inference_nhwc(self): - x_t = tf.transpose(self.x, (0, 2, 3, 1)) + x_t = tf.transpose(self.x, NCHW_TO_NHWC) - with self.device: + def test_on_ng(sess): norm = tf.nn.fused_batch_norm( x_t, self.scale, @@ -127,11 +125,9 @@ def test_fusedbatchnorm_inference_nhwc(self): self.variance, data_format='NHWC', is_training=False) + return (sess.run(norm[0])) - with self.session as sess: - result = sess.run(norm[0]) - - with tf.device('/cpu:0'): + def test_on_tf(sess): norm = tf.nn.fused_batch_norm( x_t, self.scale, @@ -140,8 +136,10 @@ def test_fusedbatchnorm_inference_nhwc(self): self.variance, data_format='NHWC', is_training=False) + return (sess.run(norm[0])) - with self.session as sess: - expected = sess.run(norm[0]) - - np.testing.assert_allclose(result, expected, rtol=0, atol=5e-5) + np.testing.assert_allclose( + self.with_ngraph(test_on_ng), + self.without_ngraph(test_on_tf), + rtol=0, + atol=5e-5) diff --git a/test/python/test_l2loss.py b/test/python/test_l2loss.py index 7b1a41d86..c03a138e4 100644 --- a/test/python/test_l2loss.py +++ b/test/python/test_l2loss.py @@ -26,6 +26,8 @@ from common import NgraphTest +np.random.seed(5) + class TestL2Loss(NgraphTest): diff --git a/test/python/test_maxpoolbackprop.py b/test/python/test_maxpoolbackprop.py index 243224677..5f2ffe535 100644 --- a/test/python/test_maxpoolbackprop.py +++ b/test/python/test_maxpoolbackprop.py @@ -31,73 +31,69 @@ NHWC_TO_NCHW = (0, 3, 1, 2) NCHW_TO_NHWC = (0, 2, 3, 1) +np.random.seed(5) + -@pytest.mark.skip(reason="new deviceless mode WIP") class TestMaxPoolBackpropInput(NgraphTest): + + # NHWC input_nhwc = np.random.rand(128, 224, 224, 3) - input_nchw = np.transpose(input_nhwc, NHWC_TO_NCHW) - output_nhwc = np.random.rand(128, 224, 224, 3) - output_nchw = np.transpose(output_nhwc, NHWC_TO_NCHW) strides_nhwc = ksize_nhwc = [1, 2, 3, 1] - strides_nchw = ksize_nchw = [1, 1, 2, 3] + output_nhwc = { + "VALID": np.random.rand(128, 112, 74, 3), + "SAME": np.random.rand(128, 112, 75, 3) + } grad_nhwc = { "VALID": np.random.rand(128, 112, 74, 3), "SAME": np.random.rand(128, 112, 75, 3) } + + # NCHW + input_nchw = np.transpose(input_nhwc, NHWC_TO_NCHW) + strides_nchw = ksize_nchw = [1, 1, 2, 3] + output_nchw = { + "VALID": np.random.rand(128, 3, 112, 74), + "SAME": np.random.rand(128, 3, 112, 75) + } grad_nchw = { - "VALID": np.transpose(grad_nhwc["VALID"], NHWC_TO_NCHW), - "SAME": np.transpose(grad_nhwc["SAME"], NHWC_TO_NCHW) + "VALID": np.random.rand(128, 3, 112, 74), + "SAME": np.random.rand(128, 3, 112, 75) } @pytest.mark.parametrize("padding", ("VALID", "SAME")) def test_nhwc(self, padding): strides = self.strides_nhwc ksize = self.ksize_nhwc - output = self.output_nhwc - np_nhwc = self.grad_nhwc[padding] + output = self.output_nhwc[padding] + g_nhwc = self.grad_nhwc[padding] if padding == "VALID": grad = tf.placeholder(tf.float32, shape=(128, 112, 74, 3)) elif padding == "SAME": grad = tf.placeholder(tf.float32, shape=(128, 112, 75, 3)) - - with self.device: - a = max_pool_grad( - self.input_nhwc, - output, - grad, - ksize, - strides, - padding=padding, - data_format="NHWC") - with self.session as sess: - result = sess.run(a, feed_dict={grad: np_nhwc}) - - with tf.device('/cpu:0'): - b = max_pool_grad( - self.input_nhwc, - output, - grad, - ksize, - strides, - padding=padding, - data_format="NHWC") - with self.session as sess: - expected = sess.run(b, feed_dict={grad: np_nhwc}) - - np.testing.assert_allclose(result, expected, rtol=5e-7) + out = max_pool_grad( + self.input_nhwc, + output, + grad, + ksize, + strides, + padding=padding, + data_format="NHWC") + sess_fn = lambda sess: sess.run(out, feed_dict={grad: g_nhwc}) + assert (np.allclose( + self.with_ngraph(sess_fn), self.without_ngraph(sess_fn), rtol=5e-7)) @pytest.mark.parametrize("padding", ("VALID", "SAME")) def test_nchw(self, padding): strides = self.strides_nchw ksize = self.ksize_nchw - output = self.output_nchw - np_nchw = self.grad_nchw[padding] + output = self.output_nchw[padding] + g_nchw = self.grad_nchw[padding] if padding == "VALID": grad = tf.placeholder(tf.float32, shape=(128, 3, 112, 74)) elif padding == "SAME": grad = tf.placeholder(tf.float32, shape=(128, 3, 112, 75)) - with self.device: + def test_on_ng(sess): a = max_pool_grad( self.input_nchw, output, @@ -106,27 +102,27 @@ def test_nchw(self, padding): strides, padding=padding, data_format="NCHW") - with self.session as sess: - result = sess.run(a, feed_dict={grad: np_nchw}) + return sess.run(a, feed_dict={grad: g_nchw}) + # To validate on the CPU side we will need to run in NHWC, because the CPU - # implementation of avgpool backprop does not support NCHW. We will + # implementation of maxpool backprop does not support NCHW. We will # transpose on the way in and on the way out - with tf.device('/cpu:0'): - grad = tf.transpose(grad, NCHW_TO_NHWC) - np_nhwc = self.grad_nhwc[padding] - output = self.output_nhwc + def test_on_tf(sess): + grad_t = tf.transpose(grad, NCHW_TO_NHWC) ksize = self.ksize_nhwc strides = self.strides_nhwc + input_t = np.transpose(self.input_nchw, NCHW_TO_NHWC) + output_t = np.transpose(output, NCHW_TO_NHWC) b = max_pool_grad( - self.input_nhwc, - output, - grad, + input_t, + output_t, + grad_t, ksize, strides, padding=padding, data_format="NHWC") b = tf.transpose(b, NHWC_TO_NCHW) - with self.session as sess: - expected = sess.run(b, feed_dict={grad: np_nhwc}) + return sess.run(b, feed_dict={grad: g_nchw}) - np.testing.assert_allclose(result, expected, rtol=5e-7) + assert np.allclose( + self.with_ngraph(test_on_ng), self.without_ngraph(test_on_tf)) diff --git a/test/python/test_relugrad.py b/test/python/test_relugrad.py index 989887f42..9ed7208fd 100644 --- a/test/python/test_relugrad.py +++ b/test/python/test_relugrad.py @@ -21,6 +21,7 @@ from __future__ import print_function import pytest +import numpy as np import tensorflow as tf from tensorflow.python.framework import constant_op @@ -28,46 +29,35 @@ from common import NgraphTest +np.random.seed(5) + -@pytest.mark.skip(reason="new deviceless mode WIP") class TestReluGradOperations(NgraphTest): def test_relugrad_2d(self): - gradients = constant_op.constant( - self.generate_random_numbers(6, 1.0, 10.0), shape=[2, 3]) - features = constant_op.constant( - self.generate_random_numbers(6, 0.0, 100.0), shape=[2, 3]) - - # Run on nGraph - with self.device: - out = relu_grad(gradients, features) - with self.session as sess: - result = sess.run(out) - - # Run on CPU - with self.cpu_device: - out = relu_grad(gradients, features) - with self.session as sess: - expected = sess.run(out) - - assert (result == expected).all() + gradients = tf.placeholder(tf.float32, [2, 3]) + features = tf.placeholder(tf.float32, [2, 3]) + out = relu_grad(gradients, features) + g = np.random.rand(2, 3) + f = np.random.rand(2, 3) + sess_fn = lambda sess: sess.run( + out, feed_dict={ + gradients: g, + features: f + }) + assert (np.allclose( + self.with_ngraph(sess_fn), self.without_ngraph(sess_fn))) def test_relugrad_1d(self): - gradients = constant_op.constant( - self.generate_random_numbers(100, 123.0, 345.0), shape=[100]) - features = constant_op.constant( - self.generate_random_numbers(100, 567.0, 789.0), shape=[100]) - - # Run on nGraph - with self.device: - out = relu_grad(gradients, features) - with self.session as sess: - result = sess.run(out) - - # Run on CPU - with self.cpu_device: - out = relu_grad(gradients, features) - with self.session as sess: - expected = sess.run(out) - - assert (result == expected).all() + gradients = tf.placeholder(tf.float32, [100]) + features = tf.placeholder(tf.float32, [100]) + out = relu_grad(gradients, features) + g = np.random.rand(100) + f = np.random.rand(100) + sess_fn = lambda sess: sess.run( + out, feed_dict={ + gradients: g, + features: f + }) + assert (np.allclose( + self.with_ngraph(sess_fn), self.without_ngraph(sess_fn))) diff --git a/test/python/test_sparse_softmax_cross_entropy_with_logits.py b/test/python/test_sparse_softmax_cross_entropy_with_logits.py index 3097e41ac..9ad5a9a6f 100644 --- a/test/python/test_sparse_softmax_cross_entropy_with_logits.py +++ b/test/python/test_sparse_softmax_cross_entropy_with_logits.py @@ -28,8 +28,9 @@ import numpy as np from common import NgraphTest +np.random.seed(5) + -@pytest.mark.skip(reason="new deviceless mode WIP") class TestSparseSoftmaxCrossEntropyWithLogitsOperations(NgraphTest): def test_sparse_softmax_cross_entropy_with_logits_2d(self): @@ -44,17 +45,11 @@ def test_sparse_softmax_cross_entropy_with_logits_2d(self): self.generate_random_numbers(total_size, 0.0, 1.0), shape=[batch_size, num_classes]) - # Run on CPU - with self.cpu_device: - out_cpu = sparse_softmax_cross_entropy_with_logits(features, labels) - with self.session as sess: - expected = sess.run(out_cpu) + out = sparse_softmax_cross_entropy_with_logits(features, labels) + sess_fn = lambda sess: sess.run(out) - # Run on nGraph - with self.device: - out = sparse_softmax_cross_entropy_with_logits(features, labels) - with self.session as sess: - result = sess.run(out) + expected = self.without_ngraph(sess_fn) + result = self.with_ngraph(sess_fn) - assert np.allclose(result[0], expected[0]) - assert np.allclose(result[1], expected[1]) + assert np.allclose(result[0], expected[0], rtol=0, atol=1e-02) + assert np.allclose(result[1], expected[1], rtol=0, atol=1e-02)