From 921a6fb78a2038cca9dd501be4e0f9c88fda3ef0 Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Thu, 14 Apr 2022 11:16:19 +0800 Subject: [PATCH] Cherry pick final state ops (#41755) * [Yaml]add exp yaml (#41217) * add exp yaml * add exp api in test case * add determinant yaml * fix exp op unittest * change test class name * modify api name * compacted with raw api * fix det api * add python_api * add test eager for determinant op * [Yaml] Add assign yaml (#41428) * add assign yaml * add assign api * add assign backward api * add assign * add assign yaml * add assign * assign yaml * add assign raw kernel and use assign_raw in yaml * merge develop branch * add missing python_api * exchange assign and assign_raw kernel name (#41625) * exchange assign and assign_raw kernel name * fix register error * [Yaml]add gaussian_random yaml and test case (#41312) * add guassian random yaml * add gaussian_random yaml and test case * fix error modify of full yaml * import in_dygraph_mode * import _in_legacy_dygraph * add place arg in api * import __current_expected_place * fix test_egr_python_api failed case * add test case * add cast for NormalInitializer * fix test error * fix test error * rm unsed check code * fix test error in test_initializer_nn * modify by review * [Phi]fix split error when sections has 0 size and add test case (#41708) * fix split error when sections has 0 size and add test case * fix test case --- paddle/fluid/operators/strided_memcpy.h | 2 +- paddle/phi/kernels/assign_kernel.cc | 30 +++++-- paddle/phi/kernels/assign_kernel.h | 11 ++- paddle/phi/ops/compat/assign_sig.cc | 4 +- .../fluid/dygraph/varbase_patch_methods.py | 5 +- python/paddle/fluid/initializer.py | 79 ++++++++++++++++--- python/paddle/fluid/layers/nn.py | 11 ++- python/paddle/fluid/layers/tensor.py | 15 ++-- .../tests/unittests/test_activation_op.py | 3 +- .../fluid/tests/unittests/test_assign_op.py | 12 ++- .../tests/unittests/test_determinant_op.py | 10 ++- .../tests/unittests/test_egr_python_api.py | 3 - .../unittests/test_gaussian_random_op.py | 11 +++ .../fluid/tests/unittests/test_initializer.py | 64 ++++++++++++++- .../tests/unittests/test_initializer_nn.py | 2 +- .../fluid/tests/unittests/test_split_op.py | 19 +++++ python/paddle/tensor/linalg.py | 5 +- python/paddle/tensor/random.py | 10 ++- python/paddle/utils/code_gen/api.yaml | 41 ++++++++++ python/paddle/utils/code_gen/backward.yaml | 30 +++++++ 20 files changed, 320 insertions(+), 47 deletions(-) diff --git a/paddle/fluid/operators/strided_memcpy.h b/paddle/fluid/operators/strided_memcpy.h index af29aac6b9052..90cf4128aae94 100644 --- a/paddle/fluid/operators/strided_memcpy.h +++ b/paddle/fluid/operators/strided_memcpy.h @@ -134,7 +134,7 @@ inline void StridedMemcpyWithAxis0( for (size_t i = 0; i < outputs->size(); ++i) { auto out_stride = stride_numel(shape_refer[i]->dims()); auto out = outputs->at(i); - if (out != nullptr) { + if (out != nullptr && out->initialized()) { StridedNumelCopyWithAxis(dev_ctx, axis, out->data(), out_stride, input.data() + input_offset, in_stride, out_stride[axis]); diff --git a/paddle/phi/kernels/assign_kernel.cc b/paddle/phi/kernels/assign_kernel.cc index a330227fcfafd..720ebb5b78c9a 100644 --- a/paddle/phi/kernels/assign_kernel.cc +++ b/paddle/phi/kernels/assign_kernel.cc @@ -24,14 +24,21 @@ namespace phi { template void AssignKernel(const Context& dev_ctx, - paddle::optional x, + const DenseTensor& x, DenseTensor* out) { - if (x.get_ptr()) { - if (!x.is_initialized()) { + Copy(dev_ctx, x, x.place(), false, out); +} + +template +void AssignRawKernel(const Context& dev_ctx, + paddle::optional x, + DenseTensor* out) { + if (x) { + if (!x->IsInitialized()) { return; } auto& x_tensor = *x.get_ptr(); - Copy(dev_ctx, x_tensor, x_tensor.place(), false, out); + AssignKernel(dev_ctx, x_tensor, out); } } @@ -105,7 +112,13 @@ void AssignValueKernel(const Context& dev_ctx, } // namespace phi PD_REGISTER_GENERAL_KERNEL( - assign, CPU, ALL_LAYOUT, phi::AssignKernel, ALL_DTYPE) { + assign, CPU, ALL_LAYOUT, phi::AssignKernel, ALL_DTYPE) {} + +PD_REGISTER_GENERAL_KERNEL(assign_raw, + CPU, + ALL_LAYOUT, + phi::AssignRawKernel, + ALL_DTYPE) { kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); } PD_REGISTER_GENERAL_KERNEL(assign_array, @@ -124,7 +137,12 @@ PD_REGISTER_KERNEL(assign_value, #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) PD_REGISTER_GENERAL_KERNEL( - assign, GPU, ALL_LAYOUT, phi::AssignKernel, ALL_DTYPE) { + assign, GPU, ALL_LAYOUT, phi::AssignKernel, ALL_DTYPE) {} +PD_REGISTER_GENERAL_KERNEL(assign_raw, + GPU, + ALL_LAYOUT, + phi::AssignRawKernel, + ALL_DTYPE) { kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); } PD_REGISTER_GENERAL_KERNEL(assign_array, diff --git a/paddle/phi/kernels/assign_kernel.h b/paddle/phi/kernels/assign_kernel.h index f1f3f024205a1..6881ac9f0ee22 100644 --- a/paddle/phi/kernels/assign_kernel.h +++ b/paddle/phi/kernels/assign_kernel.h @@ -21,13 +21,18 @@ namespace phi { +template +void AssignKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out); + // In order to be compatible with the `AsDispensable` input in the original // assign op maker, the input parameter here needs to be dispensable, but // this looks weird template -void AssignKernel(const Context& dev_ctx, - paddle::optional x, - DenseTensor* out); +void AssignRawKernel(const Context& dev_ctx, + paddle::optional x, + DenseTensor* out); template void AssignArrayKernel(const Context& dev_ctx, diff --git a/paddle/phi/ops/compat/assign_sig.cc b/paddle/phi/ops/compat/assign_sig.cc index d149e8e6a9aa0..c8cd9e44ff9ae 100644 --- a/paddle/phi/ops/compat/assign_sig.cc +++ b/paddle/phi/ops/compat/assign_sig.cc @@ -23,10 +23,10 @@ KernelSignature AssignOpArgumentMapping(const ArgumentMappingContext& ctx) { } else if (ctx.IsSelectedRowsInput("X")) { return KernelSignature("assign_sr", {"X"}, {}, {"Out"}); } else { - return KernelSignature("assign", {"X"}, {}, {"Out"}); + return KernelSignature("assign_raw", {"X"}, {}, {"Out"}); } } else { - return KernelSignature("assign", {"X"}, {}, {"Out"}); + return KernelSignature("assign_raw", {"X"}, {}, {"Out"}); } } diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index 83738c1f13194..e671064a4e0eb 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -23,7 +23,7 @@ from ..framework import convert_np_dtype_to_dtype_, _in_legacy_dygraph from .. import core from .. import unique_name -from ..framework import Variable, Parameter, ParamBase, _getitem_impl_, _setitem_impl_, EagerParamBase +from ..framework import Variable, Parameter, ParamBase, _getitem_impl_, _setitem_impl_, EagerParamBase, in_dygraph_mode from .base import switch_to_static_graph from .math_op_patch import monkey_patch_math_varbase from .parallel import scale_loss @@ -798,6 +798,9 @@ def _set_grad_ivar(self, value): @framework.dygraph_only def clone(self): + if in_dygraph_mode(): + return _C_ops.final_state_assign(self) + if _in_legacy_dygraph(): output = core.VarBase() else: diff --git a/python/paddle/fluid/initializer.py b/python/paddle/fluid/initializer.py index 37eff6d132d03..ab38bbf56ee3c 100644 --- a/python/paddle/fluid/initializer.py +++ b/python/paddle/fluid/initializer.py @@ -331,22 +331,56 @@ def __call__(self, var, block=None): ["uint16", "float16", "float32", "float64"], "guassian_random") + # to be compatible of fp16 initalizers + if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]: + out_dtype = VarDesc.VarType.FP32 + out_var = block.create_var( + name=unique_name.generate(".".join( + ['normal_init', var.name, 'tmp'])), + shape=var.shape, + dtype=out_dtype, + type=VarDesc.VarType.LOD_TENSOR, + persistable=False) + else: + out_dtype = var.dtype + out_var = var + if self._seed == 0: self._seed = block.program.random_seed - if framework._non_static_mode(): + if in_dygraph_mode(): + place = _current_expected_place() + out_var = _C_ops.final_state_gaussian_random( + var.shape, self._mean, self._std_dev, self._seed, out_dtype, + place) + out_var._share_underline_tensor_to(var) + + if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]: + var_tmp = _C_ops.final_state_cast(out_var, var.dtype) + var_tmp._share_underline_tensor_to(var) + else: + out_var._share_underline_tensor_to(var) + return None + + if _in_legacy_dygraph(): out_var = _C_ops.gaussian_random( - 'shape', var.shape, 'dtype', var.dtype, 'mean', self._mean, + 'shape', var.shape, 'dtype', out_dtype, 'mean', self._mean, 'std', self._std_dev, 'seed', self._seed, 'use_mkldnn', False) - out_var._share_underline_tensor_to(var) + + if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]: + var_tmp = _C_ops.cast(out_var, 'in_dtype', out_var.dtype, + 'out_dtype', var.dtype) + var_tmp._share_underline_tensor_to(var) + else: + out_var._share_underline_tensor_to(var) return None else: op = block.append_op( type="gaussian_random", - outputs={"Out": var}, + outputs={"Out": out_var}, attrs={ "shape": var.shape, - "dtype": var.dtype, + "dtype": out_dtype, "mean": self._mean, "std": self._std_dev, "seed": self._seed, @@ -354,6 +388,13 @@ def __call__(self, var, block=None): }, stop_gradient=True) + if var.dtype in [VarDesc.VarType.FP16, VarDesc.VarType.BF16]: + block.append_op( + type="cast", + inputs={"X": out_var}, + outputs={"Out": var}, + attrs={"in_dtype": out_var.dtype, + "out_dtype": var.dtype}) var.op = op return op @@ -566,10 +607,16 @@ def __call__(self, var, block=None): -limit, 'max', limit, 'seed', self._seed, 'dtype', out_dtype) else: - std = np.sqrt(2.0 / float(fan_in + fan_out)) - out_var = _C_ops.gaussian_random( - 'shape', out_var.shape, 'dtype', out_dtype, 'mean', 0.0, - 'std', std, 'seed', self._seed) + std = math.sqrt(2.0 / float(fan_in + fan_out)) + + if in_dygraph_mode(): + place = _current_expected_place() + out_var = _C_ops.final_state_gaussian_random( + out_var.shape, 0.0, std, self._seed, out_dtype, place) + else: + out_var = _C_ops.gaussian_random( + 'shape', out_var.shape, 'dtype', out_dtype, 'mean', 0.0, + 'std', std, 'seed', self._seed) if var.dtype == VarDesc.VarType.FP16 or ( var.dtype == VarDesc.VarType.BF16 and not self._uniform): @@ -719,10 +766,16 @@ def __call__(self, var, block=None): self._seed, 'dtype', int(out_dtype)) else: - std = np.sqrt(2.0 / float(fan_in)) - out_var = _C_ops.gaussian_random( - 'shape', out_var.shape, 'dtype', - int(out_dtype), 'mean', 0.0, 'std', std, 'seed', self._seed) + std = math.sqrt(2.0 / float(fan_in)) + if in_dygraph_mode(): + place = _current_expected_place() + out_var = _C_ops.final_state_gaussian_random( + out_var.shape, 0.0, std, self._seed, out_dtype, place) + else: + out_var = _C_ops.gaussian_random( + 'shape', out_var.shape, 'dtype', + int(out_dtype), 'mean', 0.0, 'std', std, 'seed', + self._seed) if var.dtype == VarDesc.VarType.FP16 or ( var.dtype == VarDesc.VarType.BF16 and not self._uniform): diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 311a6278a89f8..fdc348f3a8377 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -28,6 +28,7 @@ from paddle.fluid.framework import _in_legacy_dygraph from ..initializer import Normal, Constant, NumpyArrayInitializer from ..framework import Variable, OpProtoHolder, _non_static_mode, dygraph_only, _dygraph_tracer, default_main_program, _varbase_creator, static_only, _global_flags, _in_legacy_dygraph, in_dygraph_mode +from ..framework import _current_expected_place from .. import dygraph_utils from ..param_attr import ParamAttr from .layer_function_generator import autodoc, templatedoc, _generate_doc_string_ @@ -10964,7 +10965,15 @@ def gaussian_random(shape, if not isinstance(dtype, core.VarDesc.VarType): dtype = convert_np_dtype_to_dtype_(dtype) - if _non_static_mode(): + if in_dygraph_mode(): + shape = utils.convert_shape_to_list(shape) + place = _current_expected_place() + return _C_ops.final_state_gaussian_random(shape, + float(mean), + float(std), seed, dtype, + place) + + if _in_legacy_dygraph(): shape = utils.convert_shape_to_list(shape) return _C_ops.gaussian_random('shape', shape, 'mean', float(mean), 'std', diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 28e0d4eff377f..3a8dfdc858079 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -622,12 +622,15 @@ def assign(input, output=None): # after this api. if isinstance(input, (Variable, core.VarBase)): if _non_static_mode(): - if output is None: - if _in_legacy_dygraph(): - output = core.VarBase() - else: - output = core.eager.Tensor() - _C_ops.assign(input, output) + if in_dygraph_mode() and output is None: + output = _C_ops.final_state_assign(input) + else: + if output is None: + if _in_legacy_dygraph(): + output = core.VarBase() + else: + output = core.eager.Tensor() + _C_ops.assign(input, output) else: check_dtype(input.dtype, 'input', [ 'float16', 'uint16', 'float32', 'float64', 'int32', 'int64', diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index a79d1b0073869..b1c1d1b9f2b93 100755 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -51,7 +51,8 @@ def setUp(self): self.op_type = "exp" self.init_dtype() self.init_kernel_type() - self.check_eager = False + self.check_eager = True + self.python_api = paddle.exp np.random.seed(2049) x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype) diff --git a/python/paddle/fluid/tests/unittests/test_assign_op.py b/python/paddle/fluid/tests/unittests/test_assign_op.py index 3dbd9311a71ed..bfe23c621270d 100644 --- a/python/paddle/fluid/tests/unittests/test_assign_op.py +++ b/python/paddle/fluid/tests/unittests/test_assign_op.py @@ -27,30 +27,32 @@ class TestAssignOp(op_test.OpTest): def setUp(self): + self.python_api = paddle.assign self.op_type = "assign" x = np.random.random(size=(100, 10)).astype('float64') self.inputs = {'X': x} self.outputs = {'Out': x} def test_forward(self): - self.check_output() + self.check_output(check_eager=True) def test_backward(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) class TestAssignFP16Op(op_test.OpTest): def setUp(self): + self.python_api = paddle.assign self.op_type = "assign" x = np.random.random(size=(100, 10)).astype('float16') self.inputs = {'X': x} self.outputs = {'Out': x} def test_forward(self): - self.check_output() + self.check_output(check_eager=True) def test_backward(self): - self.check_grad(['X'], 'Out') + self.check_grad(['X'], 'Out', check_eager=True) class TestAssignOpWithLoDTensorArray(unittest.TestCase): @@ -171,6 +173,8 @@ def test_assign_BasicTypes(self): def test_clone(self): paddle.disable_static() + self.python_api = paddle.clone + x = paddle.ones([2]) x.stop_gradient = False clone_x = paddle.clone(x) diff --git a/python/paddle/fluid/tests/unittests/test_determinant_op.py b/python/paddle/fluid/tests/unittests/test_determinant_op.py index f8110bffa2f71..d447d213f3c81 100644 --- a/python/paddle/fluid/tests/unittests/test_determinant_op.py +++ b/python/paddle/fluid/tests/unittests/test_determinant_op.py @@ -22,21 +22,23 @@ import paddle.fluid as fluid import paddle.fluid.core as core import paddle.tensor as tensor +from paddle.fluid.framework import _test_eager_guard paddle.enable_static() class TestDeterminantOp(OpTest): def setUp(self): + self.python_api = paddle.linalg.det self.init_data() self.op_type = "determinant" self.outputs = {'Out': self.target} def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['Input'], ['Out']) + self.check_grad(['Input'], ['Out'], check_eager=True) def init_data(self): np.random.seed(0) @@ -89,6 +91,10 @@ def test_api_dygraph(self): self.assertEqual(np.allclose(out.numpy(), out_ref, rtol=1e-03), True) paddle.enable_static() + def test_eager(self): + with _test_eager_guard(): + self.test_api_dygraph() + class TestSlogDeterminantOp(OpTest): def setUp(self): diff --git a/python/paddle/fluid/tests/unittests/test_egr_python_api.py b/python/paddle/fluid/tests/unittests/test_egr_python_api.py index e7abed0964679..ae29c6c262a84 100644 --- a/python/paddle/fluid/tests/unittests/test_egr_python_api.py +++ b/python/paddle/fluid/tests/unittests/test_egr_python_api.py @@ -251,9 +251,6 @@ def constructor(self, place): self.assertTrue(egr_tensor12.place._equals(paddle.fluid.CPUPlace())) self.assertTrue(np.array_equal(egr_tensor12.numpy(), x)) - egr_tensor13 = paddle.randn([2, 2]) - self.assertTrue("eager_tmp" in egr_tensor13.name) - with self.assertRaisesRegexp( ValueError, "The shape of Parameter should not be None"): eager_param = EagerParamBase(shape=None, dtype="float32") diff --git a/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py b/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py index 738441a46d377..0de09c98314c8 100644 --- a/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py +++ b/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py @@ -23,12 +23,14 @@ from paddle.fluid.op import Operator from paddle.fluid.executor import Executor from paddle.fluid.tests.unittests.op_test import OpTest, convert_uint16_to_float +from paddle.fluid.framework import _test_eager_guard import paddle class TestGaussianRandomOp(OpTest): def setUp(self): self.op_type = "gaussian_random" + self.python_api = paddle.normal self.set_attrs() self.inputs = {} self.use_mkldnn = False @@ -50,6 +52,10 @@ def set_attrs(self): def test_check_output(self): self.check_output_customized(self.verify_output) + def test_eager(self): + with _test_eager_guard(): + self.test_check_output() + def verify_output(self, outs): self.assertEqual(outs[0].shape, (123, 92)) hist, _ = np.histogram(outs[0], range=(-3, 5)) @@ -70,6 +76,7 @@ def verify_output(self, outs): class TestGaussianRandomBF16Op(OpTest): def setUp(self): self.op_type = "gaussian_random" + self.python_api = paddle.normal self.set_attrs() self.inputs = {} self.use_mkldnn = False @@ -93,6 +100,10 @@ def test_check_output(self): self.check_output_with_place_customized( self.verify_output, place=core.CUDAPlace(0)) + def test_eager(self): + with _test_eager_guard(): + self.test_check_output() + def verify_output(self, outs): outs = convert_uint16_to_float(outs) self.assertEqual(outs[0].shape, (123, 92)) diff --git a/python/paddle/fluid/tests/unittests/test_initializer.py b/python/paddle/fluid/tests/unittests/test_initializer.py index 91c2800836c9d..3a9387082e680 100644 --- a/python/paddle/fluid/tests/unittests/test_initializer.py +++ b/python/paddle/fluid/tests/unittests/test_initializer.py @@ -244,7 +244,7 @@ def test_normal_initializer(self, dtype="float32"): lod_level=0, name="param", initializer=initializer.NormalInitializer(2.3, 1.9, 123)) - num_ops = 1 + num_ops = 2 if (dtype == "float16" or dtype == "uint16") else 1 self.assertEqual(len(block.ops), num_ops) init_op = block.ops[0] self.assertEqual(init_op.type, 'gaussian_random') @@ -685,6 +685,68 @@ def test_uniform_initializer(self, dtype="float32"): self.func_uniform_initializer() +class TestXavierInitializerDygraph(unittest.TestCase): + def func_xvarier_initializer(self, dtype="float32"): + """ + In dygraph mode, we can use initializer directly to initialize a tensor. + """ + paddle.disable_static() + + tensor = paddle.zeros([1024, 1024, 16]) + tensor.stop_gradient = False + + xavier_ = paddle.fluid.initializer.XavierInitializer( + uniform=False, fan_in=3, fan_out=5) + xavier_(tensor) + + hist, _ = output_hist(tensor.numpy()) + + hist2, _ = output_hist( + np.random.normal(0, np.sqrt(2.0 / (3 + 5)), [1024, 1024, 16])) + + self.assertTrue( + np.allclose( + hist, hist2, rtol=0, atol=0.01), + "hist: " + str(hist) + " hist2: " + str(hist2)) + paddle.enable_static() + + def test_xavier_initializer(self, dtype="float32"): + with framework._test_eager_guard(): + self.func_xvarier_initializer() + self.func_xvarier_initializer() + + +class TestMSRAInitializerDygraph(unittest.TestCase): + def func_msra_initializer(self, dtype="float32"): + """ + In dygraph mode, we can use initializer directly to initialize a tensor. + """ + paddle.disable_static() + + tensor = paddle.zeros([1024, 1024, 16]) + tensor.stop_gradient = False + + msra_ = paddle.fluid.initializer.MSRAInitializer( + uniform=False, fan_in=4) + msra_(tensor) + + hist, _ = output_hist(tensor.numpy()) + + hist2, _ = output_hist( + np.random.normal(0, np.sqrt(2.0 / (4)), [1024, 1024, 16])) + + self.assertTrue( + np.allclose( + hist, hist2, rtol=0, atol=0.01), + "hist: " + str(hist) + " hist2: " + str(hist2)) + paddle.enable_static() + + def test_msra_initializer(self, dtype="float32"): + with framework._test_eager_guard(): + self.func_msra_initializer() + self.func_msra_initializer() + + class TesetconsistencyOfDynamicAndStaticGraph(unittest.TestCase): def func_order(self): paddle.set_device('cpu') diff --git a/python/paddle/fluid/tests/unittests/test_initializer_nn.py b/python/paddle/fluid/tests/unittests/test_initializer_nn.py index 74686652044ec..9953681e0f5bd 100644 --- a/python/paddle/fluid/tests/unittests/test_initializer_nn.py +++ b/python/paddle/fluid/tests/unittests/test_initializer_nn.py @@ -400,7 +400,7 @@ def test_normal_initializer(self, dtype="float32"): lod_level=0, name="param", initializer=initializer.Normal(2.3, 1.9)) - num_ops = 1 + num_ops = 2 if dtype in ["float16", "uint16"] else 1 self.assertEqual(len(block.ops), num_ops) init_op = block.ops[0] self.assertEqual(init_op.type, 'gaussian_random') diff --git a/python/paddle/fluid/tests/unittests/test_split_op.py b/python/paddle/fluid/tests/unittests/test_split_op.py index aac904dc2e15d..c826a0e1030f4 100644 --- a/python/paddle/fluid/tests/unittests/test_split_op.py +++ b/python/paddle/fluid/tests/unittests/test_split_op.py @@ -459,5 +459,24 @@ def test_axis_tensor_input(self): self.assertTrue(np.allclose(ex_x2, x2_out)) +class API_TestEmptySplit(unittest.TestCase): + def test_axis_input_empty_section(self): + with fluid.dygraph.guard(): + input_1 = np.random.random([8, 6, 6]).astype("float32") + # input is a variable which shape is [8, 6, 6] + input = paddle.to_tensor(input_1) + x0, x1, x2 = paddle.split(input, num_or_sections=[5, 0, 3]) + x0_out = x0.numpy() + x1_out = x1.numpy() + x2_out = x2.numpy() + ex_x0, ex_x1, ex_x2 = np.split(input_1, [ + 5, + 5, + ]) + self.assertTrue(np.allclose(ex_x0, x0_out)) + self.assertTrue(np.allclose(ex_x1, x1_out)) + self.assertTrue(np.allclose(ex_x2, x2_out)) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 33ff27202031f..51df977c00644 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -1599,7 +1599,10 @@ def det(x, name=None): """ - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_det(x) + + if _in_legacy_dygraph(): return _C_ops.determinant(x) check_dtype(x.dtype, 'Input', ['float32', 'float64'], 'det') diff --git a/python/paddle/tensor/random.py b/python/paddle/tensor/random.py index 82818d50510c9..3d0617e40d6b6 100644 --- a/python/paddle/tensor/random.py +++ b/python/paddle/tensor/random.py @@ -239,7 +239,15 @@ def gaussian(shape, mean=0.0, std=1.0, dtype=None, name=None): if not isinstance(dtype, core.VarDesc.VarType): dtype = convert_np_dtype_to_dtype_(dtype) - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + shape = utils.convert_shape_to_list(shape) + place = _current_expected_place() + return _C_ops.final_state_gaussian_random(shape, + float(mean), + float(std), seed, dtype, + place) + + if _in_legacy_dygraph(): shape = utils.convert_shape_to_list(shape) return _C_ops.gaussian_random('shape', shape, 'mean', float(mean), 'std', diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 718c35683cb0b..6df8c6efcc03d 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -167,6 +167,16 @@ func : asinh backward : asinh_grad +# assign +- api : assign + args : (Tensor x) + output : Tensor + infer_meta : + func : UnchangedInferMeta + kernel : + func : assign + backward : assign_grad + # atan - api : atan args : (Tensor x) @@ -454,6 +464,15 @@ func : depthwise_conv2d_transpose backward : depthwise_conv2d_transpose_grad +- api : det + args : (Tensor x) + output : Tensor + infer_meta : + func : UnchangedInferMeta + kernel : + func : determinant + backward : det_grad + - api : diag args : (Tensor x, int offset, float padding_value) output : Tensor @@ -598,6 +617,16 @@ func : erfinv backward : erfinv_grad +# exp +- api : exp + args : (Tensor x) + output : Tensor + infer_meta : + func : UnchangedInferMeta + kernel : + func : exp + backward : exp_grad + # expand_as - api : expand_as args : (Tensor x, Tensor y, int[] target_shape) @@ -763,6 +792,18 @@ kernel : func : gather_tree +- api : gaussian_random + args : (IntArray shape, float mean, float std, int seed, DataType dtype, Place place={}) + output: Tensor + infer_meta : + func : GaussianRandomInferMeta + param : [shape, mean, std, seed, dtype] + kernel : + func : gaussian_random + param : [shape, mean, std, seed, dtype] + data_type : dtype + backend : place + - api : gelu args : (Tensor x, bool approximate) output : Tensor(out) diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index f60563d5d018e..038097d72e3f1 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -89,6 +89,16 @@ kernel : func : asinh_grad +- backward_api : assign_grad + forward : assign (Tensor x) -> Tensor(out) + args : (Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [out_grad] + kernel : + func : assign + - backward_api : atan2_grad forward : atan2 (Tensor x, Tensor y) -> Tensor(out) args : (Tensor x, Tensor y, Tensor out_grad) @@ -321,6 +331,16 @@ kernel : func : depthwise_conv2d_transpose_grad +- backward_api : det_grad + forward : det (Tensor x) -> Tensor(out) + args : (Tensor x, Tensor out, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : determinant_grad + - backward_api : diagonal_grad forward : diagonal (Tensor x, int offset, int axis1, int axis2) -> Tensor(out) args : (Tensor x, Tensor out_grad, int offset = 0, int axis1 = 0, int axis2 = 1) @@ -424,6 +444,16 @@ kernel : func : erfinv_grad +- backward_api : exp_grad + forward : exp (Tensor x) -> Tensor(out) + args : (Tensor out, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [out] + kernel : + func : exp_grad + - backward_api : expand_as_grad forward : expand_as (Tensor x, Tensor y, int[] target_shape) -> Tensor(out) args : (Tensor x, Tensor out_grad, int[] target_shape)