diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index d725bdffa010b..8c49a5f6de770 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -268,6 +268,7 @@ cc_test(op_compatible_info_test SRCS op_compatible_info_test.cc DEPS op_compatib cc_library(save_load_util SRCS save_load_util DEPS tensor scope layer) cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tensor scope layer) +cc_library(generator SRCS generator.cc) # Get the current working branch execute_process( diff --git a/paddle/fluid/framework/fleet/CMakeLists.txt b/paddle/fluid/framework/fleet/CMakeLists.txt index 0d62488bfe67a..3eee0a1abbaf0 100644 --- a/paddle/fluid/framework/fleet/CMakeLists.txt +++ b/paddle/fluid/framework/fleet/CMakeLists.txt @@ -19,6 +19,6 @@ else() cc_library(gloo_wrapper SRCS gloo_wrapper.cc DEPS framework_proto variable_helper scope) endif(WITH_GLOO) -cc_library(heter_wrapper SRCS heter_wrapper.cc DEPS framework_proto device_context) +cc_library(heter_wrapper SRCS heter_wrapper.cc DEPS framework_proto device_context heter_service_proto) cc_test(test_fleet SRCS test_fleet.cc DEPS fleet_wrapper gloo_wrapper fs shell) diff --git a/paddle/fluid/framework/generator.cc b/paddle/fluid/framework/generator.cc new file mode 100644 index 0000000000000..d00e38784c2c0 --- /dev/null +++ b/paddle/fluid/framework/generator.cc @@ -0,0 +1,78 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/generator.h" + +namespace paddle { +namespace framework { + +std::shared_ptr Generator::gen_instance_ = NULL; + +GeneratorState* Generator::GetState() { + std::lock_guard lock(this->mutex); + return this->state_.get(); +} + +void Generator::SetState(GeneratorState* state_in) { + std::lock_guard lock(this->mutex); + *this->state_ = *state_in; +} + +uint64_t Generator::GetCurrentSeed() { + std::lock_guard lock(this->mutex); + return this->state_->current_seed; +} + +uint64_t Generator::Seed() { + std::lock_guard lock(this->mutex); + uint64_t seed; + std::random_device de; + seed = ((((uint64_t)de()) << 32) + de()) & 0x1FFFFFFFFFFFFF; + this->state_->current_seed = seed; + std::seed_seq seq({seed}); + this->state_->cpu_engine.seed(seq); + + return this->state_->current_seed; +} + +void Generator::SetCurrentSeed(uint64_t seed) { + std::lock_guard lock(this->mutex); + this->state_->current_seed = uint64_t(seed); + std::seed_seq seq({seed}); + this->state_->cpu_engine.seed(seq); +} + +std::mt19937_64& Generator::GetCPUEngine() { + std::lock_guard lock(this->mutex); + return this->state_->cpu_engine; +} + +void Generator::SetCPUEngine(std::mt19937_64 engine) { + std::lock_guard lock(this->mutex); + this->state_->cpu_engine = std::mt19937_64(engine); +} + +uint64_t Generator::Random64() { + std::lock_guard lock(this->mutex); + return this->state_->cpu_engine(); +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/generator.h b/paddle/fluid/framework/generator.h new file mode 100644 index 0000000000000..17870782ba72a --- /dev/null +++ b/paddle/fluid/framework/generator.h @@ -0,0 +1,96 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include // temp for debug +#include +#include // NOLINT +#include +#include +#include + +namespace paddle { +namespace framework { + +struct GeneratorState { + int64_t device = -1; + uint64_t current_seed = 34342423252; + std::mt19937_64 cpu_engine; +}; + +struct Generator { + Generator() { + GeneratorState default_gen_state_cpu; + default_gen_state_cpu.device = -1; + default_gen_state_cpu.current_seed = 34342423252; + std::seed_seq seq({34342423252}); + default_gen_state_cpu.cpu_engine = std::mt19937_64(seq); + this->state_ = std::make_shared(default_gen_state_cpu); + } + explicit Generator(GeneratorState state_in) + : state_{std::make_shared(state_in)} {} + Generator(const Generator& other) + : Generator(other, std::lock_guard(other.mutex)) {} + + // get random state + GeneratorState* GetState(); + // set random state + void SetState(GeneratorState* state_in); + // get current seed + uint64_t GetCurrentSeed(); + // random a seed and get + uint64_t Seed(); + + // set seed + void SetCurrentSeed(uint64_t seed); + // get cpu engine + std::mt19937_64& GetCPUEngine(); + // set cpu engine + void SetCPUEngine(std::mt19937_64 engine); + + uint64_t Random64(); + + bool is_init_py = false; + + // CPU Generator singleton + static std::shared_ptr GetInstance() { + if (NULL == gen_instance_) { + gen_instance_.reset(new paddle::framework::Generator()); + } + return gen_instance_; + } + + static std::shared_ptr GetInstanceX() { + if (NULL == gen_instance_) { + gen_instance_.reset(new paddle::framework::Generator()); + } + gen_instance_->is_init_py = true; + return gen_instance_; + } + + private: + static std::shared_ptr gen_instance_; + std::shared_ptr state_; + mutable std::mutex mutex; + + Generator(const Generator& other, const std::lock_guard&) + : state_(std::make_shared(*(other.state_))) {} +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc index 60e4ac8cbcfd8..9d3e0806ac79d 100644 --- a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc @@ -368,3 +368,7 @@ REGISTER_PASS(conv_transpose_bn_fuse_pass, paddle::framework::ir::ConvTransposeBNFusePass); REGISTER_PASS(conv_transpose_eltwiseadd_bn_fuse_pass, paddle::framework::ir::ConvTransposeEltwiseAddBNFusePass); +REGISTER_PASS(depthwise_conv_bn_fuse_pass, + paddle::framework::ir::DepthwiseConvBNFusePass); +REGISTER_PASS(depthwise_conv_eltwiseadd_bn_fuse_pass, + paddle::framework::ir::DepthwiseConvEltwiseAddBNFusePass); diff --git a/paddle/fluid/framework/ir/conv_bn_fuse_pass.h b/paddle/fluid/framework/ir/conv_bn_fuse_pass.h index fcdbcf299c504..57a9f69ca15af 100644 --- a/paddle/fluid/framework/ir/conv_bn_fuse_pass.h +++ b/paddle/fluid/framework/ir/conv_bn_fuse_pass.h @@ -56,6 +56,16 @@ class ConvTransposeEltwiseAddBNFusePass : public ConvEltwiseAddBNFusePass { std::string conv_type() const { return "conv2d_transpose"; } }; +class DepthwiseConvBNFusePass : public ConvBNFusePass { + public: + std::string conv_type() const { return "depthwise_conv2d"; } +}; + +class DepthwiseConvEltwiseAddBNFusePass : public ConvEltwiseAddBNFusePass { + public: + std::string conv_type() const { return "depthwise_conv2d"; } +}; + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/subgraph_detector.cc b/paddle/fluid/framework/ir/subgraph_detector.cc index 62c91af15da60..7979953d7be82 100644 --- a/paddle/fluid/framework/ir/subgraph_detector.cc +++ b/paddle/fluid/framework/ir/subgraph_detector.cc @@ -309,7 +309,8 @@ std::vector> SubgraphDetector::ExtractSubGraphs() { BriefNode *brief_node = itr.second; if (!Agent(brief_node->node).marked()) { - VLOG(4) << brief_node->node->id() << " node not a trt candidate."; + VLOG(4) << brief_node->node->id() << " node named " + << brief_node->node->Name() << " is not a trt candidate."; continue; } diff --git a/paddle/fluid/framework/prune.cc b/paddle/fluid/framework/prune.cc index 919378c929185..274b0ca0d903d 100644 --- a/paddle/fluid/framework/prune.cc +++ b/paddle/fluid/framework/prune.cc @@ -210,6 +210,23 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, should_run.push_back(true); } else { should_run.push_back(false); + // If the output of an op modifies feed vars, the op should not clip. + // For example, in the transformer structure, the third parameter returned + // by beam_search op is generally assigned to a feed var. Cutting the + // assign op will cause an error. + if (parent_block_id != -1) { + bool flag = false; + for (auto& var : op_desc.outputs()) { + for (auto& argu : var.arguments()) { + if (feed_var_names.count(argu)) { + flag = true; + } + } + } + if (flag) { + should_run.back() = true; + } + } } } diff --git a/paddle/fluid/framework/prune_test.cc b/paddle/fluid/framework/prune_test.cc index eb5c241a8372a..12fa0c61f8121 100644 --- a/paddle/fluid/framework/prune_test.cc +++ b/paddle/fluid/framework/prune_test.cc @@ -185,3 +185,34 @@ TEST(Prune, recurrrent_op) { EXPECT_EQ(pruned.blocks(0).ops_size(), 2); EXPECT_EQ(pruned.blocks(1).ops_size(), 1); } + +// If the output of an op modifies feed vars, the op should not clip. +TEST(Prune, recurrrent_op_2) { + f::ProgramDesc program; + f::BlockDesc *block = program.MutableBlock(0); + f::BlockDesc *sub_block = program.AppendBlock(*block); + AddOp("one_two", {{"input", {"a"}}}, {{"output", {"b", "c"}}}, + f::AttributeMap{}, block); + + std::vector state_var_name(1, "y"); + AddOp("recurrent", {{"input", {"b", "c"}}}, {{"output", {"b1, c1"}}}, + {{"ex_states", state_var_name}, + {"states", state_var_name}, + {"sub_block", sub_block}}, + block); + + EXPECT_TRUE(sub_block != nullptr); + AddOp("rnn_memory_helper", {{"input", {"x"}}}, {{"output", {"a"}}}, + f::AttributeMap{}, sub_block); + + f::proto::ProgramDesc *pdesc = program.Proto(); + pdesc->mutable_blocks(0)->mutable_ops(1)->set_is_target(true); + + f::proto::ProgramDesc pruned; + std::set feed_var_names = {"x", "a"}; + + f::Prune(*pdesc, feed_var_names, &pruned); + EXPECT_EQ(pruned.blocks_size(), 2); + EXPECT_EQ(pruned.blocks(0).ops_size(), 2); + EXPECT_EQ(pruned.blocks(1).ops_size(), 1); +} diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index 82b91d2e77292..0336325bef6bd 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -42,23 +42,17 @@ static void PrepareData(const platform::Place& place, for (const auto& var_base : name_pair.second) { const auto* tensor = GetTensorFromVar(var_base->Var()); if (tensor && tensor->IsInitialized()) { - auto tmp_place = tensor->place(); - - // TODO(jiabin): Support transform data layout when we Verify it on more - // tests - if (!(tmp_place == place)) { - auto kernel_type_for_var = op.GetKernelTypeForVar( - name_pair.first, *tensor, expected_kernel_key); - if (!NeedTransform(kernel_type_for_var, expected_kernel_key)) { - continue; - } else { - VLOG(3) << "Transform Variable " << var_base->Name() << " from " - << kernel_type_for_var << " to " << expected_kernel_key; - framework::Tensor out; - TransformData(expected_kernel_key, kernel_type_for_var, *tensor, - &out); - SetTensorToVariable(var_base->Var(), out, var_base->MutableVar()); - } + auto kernel_type_for_var = op.GetKernelTypeForVar( + name_pair.first, *tensor, expected_kernel_key); + if (!NeedTransform(kernel_type_for_var, expected_kernel_key)) { + continue; + } else { + VLOG(3) << "Transform Variable " << var_base->Name() << " from " + << kernel_type_for_var << " to " << expected_kernel_key; + framework::Tensor out; + TransformData(expected_kernel_key, kernel_type_for_var, *tensor, + &out); + SetTensorToVariable(var_base->Var(), out, var_base->MutableVar()); } } } @@ -93,6 +87,13 @@ PreparedOp PrepareOpImpl(const NameVarMap& ins, auto& kernels = kernels_iter->second; framework::RuntimeContext ctx({}, {}); +#ifdef PADDLE_WITH_MKLDNN + // MKLDNN variant of code reads attributes in some of GetKernelTypeForVar and + // GetKernelType functions, so we need to copy the attributes there. + // Const qualifier of Attrs had to be discarded to overwrite it. + auto& mutable_op_attrs = const_cast(op.Attrs()); + mutable_op_attrs = attrs; +#endif auto expected_kernel_key = op.GetExpectedKernelType(DygraphExecutionContext( op, framework::Scope(), *dev_ctx, ctx, ins, outs, attrs)); diff --git a/paddle/fluid/imperative/tests/test_prepare_op.cc b/paddle/fluid/imperative/tests/test_prepare_op.cc index c2e30b45a7f6c..f226c63f0c432 100644 --- a/paddle/fluid/imperative/tests/test_prepare_op.cc +++ b/paddle/fluid/imperative/tests/test_prepare_op.cc @@ -176,7 +176,7 @@ TEST(test_prepare_op, test_prepare_data) { } #endif -TEST(test_prepare_op, test_prepare_data_same_place) { +void TestPrepareDataSamePlace(framework::AttributeMap attr_map) { std::shared_ptr vin( new imperative::VarBase(false, "vin")); std::shared_ptr vout( @@ -198,7 +198,6 @@ TEST(test_prepare_op, test_prepare_data_same_place) { var_pair out_pair = var_pair("Out", vb_vector(1, vout)); imperative::NameVarBaseMap ins = {x_pair}; imperative::NameVarBaseMap outs = {out_pair}; - framework::AttributeMap attr_map; const std::string op_type = "relu"; const auto& info = framework::OpInfoMap::Instance().Get(op_type); if (info.Checker()) info.Checker()->Check(&attr_map); @@ -222,8 +221,21 @@ TEST(test_prepare_op, test_prepare_data_same_place) { } } } + +TEST(test_prepare_op, test_prepare_data_same_place) { + TestPrepareDataSamePlace({}); +} + +#ifdef PADDLE_WITH_MKLDNN +TEST(test_prepare_op, test_prepare_data_cpu_mkldnn) { + TestPrepareDataSamePlace({{"use_mkldnn", true}}); +} +#endif } // namespace imperative } // namespace paddle USE_OP(split); USE_OP(relu); +#ifdef PADDLE_WITH_MKLDNN +USE_OP_DEVICE_KERNEL(relu, MKLDNN); +#endif diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index fdd71b0d88400..1a3413657ce6f 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -83,7 +83,12 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector& shape, std::string input, } else if (shape.size() == 3UL) { return nvinfer1::Dims3(shape[0], shape[1], shape[2]); } - return nvinfer1::Dims4(shape[0], shape[1], 1, 1); + nvinfer1::Dims dims; + dims.nbDims = shape.size(); + for (size_t i = 0; i < shape.size(); i++) { + dims.d[i] = shape[i]; + } + return dims; } } } // NOLINT diff --git a/paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.cu index e7f9381e97137..5e43be90de3db 100644 --- a/paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.cu @@ -76,6 +76,16 @@ nvinfer1::DimsExprs EmbEltwiseLayernormPluginDynamic::getOutputDimensions( return ret; } +template +void EmbEltwiseLayernormPluginDynamic::terminate() { + for (auto ptr : embs_gpu_) { + if (ptr) cudaFree(ptr); + } + + if (bias_gpu_) cudaFree(bias_gpu_); + if (scale_gpu_) cudaFree(scale_gpu_); +} + template bool EmbEltwiseLayernormPluginDynamic::supportsFormatCombination( int pos, const nvinfer1::PluginTensorDesc *in_out, int nb_inputs, @@ -153,7 +163,7 @@ int EmbEltwiseLayernormPluginDynamic::enqueue( int64_t *emb_ptr_gpu_d = emb_ptr_tensor.mutable_data(platform::CUDAPlace(device_id)); - std::vector in_ptr, emb_ptr; + std::vector in_ptr, emb_ptr; for (int i = 0; i < input_num; i++) { in_ptr.push_back(reinterpret_cast(inputs[i])); emb_ptr.push_back(reinterpret_cast(embs_gpu_[i])); diff --git a/paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h b/paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h index 8ac611cd7c62f..5babd87db0602 100644 --- a/paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h @@ -81,9 +81,13 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT { } nvinfer1::IPluginV2DynamicExt* clone() const override { - return new EmbEltwiseLayernormPluginDynamic( + auto ptr = new EmbEltwiseLayernormPluginDynamic( embs_, bias_, scale_, emb_sizes_, bias_size_, scale_size_, hidden_size_, eps_); + ptr->embs_gpu_ = embs_gpu_; + ptr->bias_gpu_ = bias_gpu_; + ptr->scale_gpu_ = scale_gpu_; + return ptr; } const char* getPluginType() const override { @@ -111,6 +115,7 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT { return sum_num; } + void terminate() override; void serialize(void* buffer) const override { // SerializeValue(&buffer, with_fp16_); SerializeValue(&buffer, emb_sizes_); diff --git a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu index f1e11b6fba1f1..860f1039d5e10 100644 --- a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu @@ -80,6 +80,12 @@ int PReluPlugin::enqueue(int batch_size, const void *const *inputs, #if IS_TRT_VERSION_GE(6000) +void PReluPluginDynamic::terminate() { + if (p_gpu_weight_) { + cudaFree(p_gpu_weight_); + } +} + int PReluPluginDynamic::initialize() { cudaMalloc(&p_gpu_weight_, sizeof(float) * weight_.size()); cudaMemcpy(p_gpu_weight_, weight_.data(), weight_.size() * sizeof(float), diff --git a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h index 4756ca2e02257..3126366c5fdd8 100644 --- a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h @@ -102,12 +102,15 @@ class PReluPluginDynamic : public DynamicPluginTensorRT { } ~PReluPluginDynamic() { cudaFree(p_gpu_weight_); } nvinfer1::IPluginV2DynamicExt* clone() const override { - return new PReluPluginDynamic(weight_.data(), weight_.size(), mode_); + auto ptr = new PReluPluginDynamic(weight_.data(), weight_.size(), mode_); + ptr->p_gpu_weight_ = p_gpu_weight_; + return ptr; } const char* getPluginType() const override { return "prelu_plugin"; } int getNbOutputs() const override { return 1; } int initialize() override; + void terminate() override; size_t getSerializationSize() const override; void serialize(void* buffer) const override; diff --git a/paddle/fluid/inference/tensorrt/plugin/skip_layernorm_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/skip_layernorm_op_plugin.h index 8fe1edc4bf032..24cd8e0368182 100644 --- a/paddle/fluid/inference/tensorrt/plugin/skip_layernorm_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/skip_layernorm_op_plugin.h @@ -51,8 +51,11 @@ class SkipLayerNormPluginDynamic : public DynamicPluginTensorRT { } nvinfer1::IPluginV2DynamicExt* clone() const override { - return new SkipLayerNormPluginDynamic( + auto ptr = new SkipLayerNormPluginDynamic( bias_.data(), scale_.data(), bias_size_, scale_size_, eps_, ban_fp16_); + ptr->bias_gpu_ = bias_gpu_; + ptr->scale_gpu_ = bias_gpu_; + return ptr; } const char* getPluginType() const override { return "skip_layernorm_plugin"; } diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index 959ba2288acc0..9a3a73f6c946d 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -471,19 +471,10 @@ if(WITH_GPU AND TENSORRT_FOUND) inference_download_and_uncompress(${TEST_TRT_ERNIE_MODEL} ${INFERENCE_URL}/tensorrt_test "ernie_model_4_unserialized.tgz") endif() - inference_analysis_test(test_trt_dynamic_shape_ernie_serialize SRCS trt_dynamic_shape_ernie_deserialize_test.cc + inference_analysis_test(test_trt_dynamic_shape_ernie_ser_deser SRCS trt_dynamic_shape_ernie_deserialize_test.cc EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} ARGS --infer_model=${TEST_TRT_ERNIE_MODEL}/ernie_model_4_unserialized) - set(TEST_TRT_ERNIE_SER_MODEL "${TRT_MODEL_INSTALL_DIR}/ernie_test/ernie_model_4_serialized/") - if (NOT EXISTS ${TEST_TRT_ERNIE_SER_MODEL}) - inference_download_and_uncompress(${TEST_TRT_ERNIE_MODEL} ${INFERENCE_URL}/tensorrt_test "ernie_model_4_serialized.tgz") - endif() - - inference_analysis_test(test_trt_dynamic_shape_ernie_deserialize SRCS trt_dynamic_shape_ernie_deserialize_test.cc - EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} - ARGS --infer_model=${TEST_TRT_ERNIE_MODEL}/ernie_model_4_serialized) - endif() set(LITE_MODEL_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/lite") diff --git a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_deserialize_test.cc b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_deserialize_test.cc index 6526b87436557..7e5dfa2424dbc 100644 --- a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_deserialize_test.cc +++ b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_deserialize_test.cc @@ -123,8 +123,11 @@ void trt_ernie(bool with_fp16, std::vector result) { config.EnableTensorRtEngine(1 << 30, 1, 5, precision, true, false); config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape, opt_input_shape); + AnalysisConfig* config_deser = new AnalysisConfig(config); + std::vector out_data; - run(config, &out_data); + run(config, &out_data); // serialize + run(*config_deser, &out_data); // deserialize for (size_t i = 0; i < out_data.size(); i++) { EXPECT_NEAR(result[i], out_data[i], 1e-6); } diff --git a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc index babe9977cd571..c99ebcdcb5f31 100644 --- a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc +++ b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc @@ -126,7 +126,7 @@ void trt_ernie(bool with_fp16, std::vector result) { std::vector out_data; run(config, &out_data); for (size_t i = 0; i < out_data.size(); i++) { - EXPECT_NEAR(result[i], out_data[i], 1e-6); + EXPECT_NEAR(result[i], out_data[i], 1e-5); } } diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index e74f363d886e4..48d1ec9461a88 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -88,7 +88,9 @@ endif() cc_library(common_infer_shape_functions SRCS common_infer_shape_functions.cc DEPS operator) -set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor device_memory_aligment) +set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows +lod_tensor maxouting unpooling pooling lod_rank_table context_project +sequence_pooling executor device_memory_aligment generator) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc matrix_inverse) diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index a8c4107add1be..9ed169fe3502e 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -196,7 +196,7 @@ framework::OpKernelType ConvOp::GetKernelTypeForVar( auto ar = paddle::framework::AttrReader(attrs); const std::string data_format = ar.Get("data_format"); auto dl = framework::StringToDataLayout(data_format); - // Some models may have intentionally set "AnyLayout" for pool + // Some models may have intentionally set "AnyLayout" for conv // op. Treat this as NCHW (default data_format value) if (dl != framework::DataLayout::kAnyLayout) { return framework::OpKernelType(expected_kernel_type.data_type_, diff --git a/paddle/fluid/operators/cudnn_lstm_op.cc b/paddle/fluid/operators/cudnn_lstm_op.cc index 16e2ca464b5c4..7081490fd1bf0 100644 --- a/paddle/fluid/operators/cudnn_lstm_op.cc +++ b/paddle/fluid/operators/cudnn_lstm_op.cc @@ -24,34 +24,62 @@ class CudnnLSTMOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Input"), - "Input(Input) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasInput("W"), - "Input(Weight) of LSTM should not be null."); - - PADDLE_ENFORCE(ctx->HasInput("InitH"), - "Input(init_h) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasInput("InitC"), - "Input(init_c) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Cache"), - "Input(Cache) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("last_h"), - "Output(last_h) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("last_c"), - "Output(last_c) of LSTM should not be null."); + OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "CudnnLSTM"); + OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "CudnnLSTM"); + OP_INOUT_CHECK(ctx->HasInput("InitH"), "Input", "InitH", "CudnnLSTM"); + OP_INOUT_CHECK(ctx->HasInput("InitC"), "Input", "InitC", "CudnnLSTM"); + + OP_INOUT_CHECK(ctx->HasOutput("Reserve"), "Output", "Reserve", "CudnnLSTM"); + OP_INOUT_CHECK(ctx->HasOutput("StateOut"), "Output", "StateOut", + "CudnnLSTM"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "CudnnLSTM"); + OP_INOUT_CHECK(ctx->HasOutput("LastH"), "Output", "LastH", "CudnnLSTM"); + OP_INOUT_CHECK(ctx->HasOutput("LastC"), "Output", "LastC", "CudnnLSTM"); auto in_dims = ctx->GetInputDim("Input"); - PADDLE_ENFORCE_EQ(in_dims.size(), 3, "Input(X)'s rank must be 3."); + auto init_dims = ctx->GetInputDim("InitH"); + PADDLE_ENFORCE_EQ(in_dims.size(), 3, + platform::errors::InvalidArgument( + "The rank of Input in CudnnLSTM must be 3. But " + "received Input's rank is %d.", + in_dims.size())); + PADDLE_ENFORCE_EQ(init_dims.size(), 3, + platform::errors::InvalidArgument( + "The rank of InitH in CudnnLSTM must be 3. But " + "received InitH's rank is %d.", + init_dims.size())); + + PADDLE_ENFORCE_EQ(in_dims[1], init_dims[1], + platform::errors::InvalidArgument( + "The in_dims[1] (Input dims) and init_dims[1] (InitH " + "dims) should be equal. But " + "received in_dims[1] is %d and init_dims[1] is %d.", + in_dims[1], init_dims[1])); + PADDLE_ENFORCE_EQ(in_dims[2], init_dims[2], + platform::errors::InvalidArgument( + "The in_dims[2] (Input dims) and init_dims[2] (InitH " + "dims) should be equal. But " + "received in_dims[2] is %d and init_dims[2] is %d.", + in_dims[2], init_dims[2])); auto out_dims = in_dims; auto hidden_size = ctx->Attrs().Get("hidden_size"); - out_dims[2] = hidden_size; + bool is_bidirec = ctx->Attrs().Get("is_bidirec"); + out_dims[2] = is_bidirec ? hidden_size * 2 : hidden_size; + auto last_dims = init_dims; + last_dims[0] = is_bidirec ? last_dims[0] * 2 : last_dims[0]; ctx->SetOutputDim("Out", out_dims); - ctx->SetOutputDim("last_h", ctx->GetInputDim("InitH")); - ctx->SetOutputDim("last_c", ctx->GetInputDim("InitC")); + ctx->SetOutputDim("LastH", last_dims); + ctx->SetOutputDim("LastC", last_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), + ctx.device_context()); } }; @@ -84,33 +112,31 @@ class CudnnLSTMOpMaker : public framework::OpProtoAndCheckerMaker { "(Tensor) the learnable hidden-hidden weights." " The shape is (N), where N is total weight size of the LSTM. " " cudnn concatenate all the weight to one Tensor"); - AddInput("Cache", - "The cache of dropout op, a RAW type variable including random " - "number generator states and some descriptors, which is used in " - "cudnn kernel.") - .AsDispensable(); + AddOutput("Reserve", + "(Tensor, a temporary output Tensor to store the reserve_data " + "of cudnn kernel.") + .AsIntermediate(); + AddOutput("StateOut", + "Share memory with State. " + "Store the global drop state when training"); AddOutput("Out", "(Tensor) the hidden state of LSTM operator. " "The shape is ( seq_len x batch_size x hidden_size) if " "is_bidirec is False" "and When is_bidirec is True, the shape will be ( seq_len x " "batch_size x hidden_size * 2) "); - AddOutput("last_h", + AddOutput("LastH", "(Tensor) the hidden state of the last step. " "The shape is ( num_layers x batch_size x hidden_size) if " "is_bidirec is False" "and When is_bidirec is True, the shape will be (num_layers*2 x " "batch_size x hidden_size)"); - AddOutput("last_c", + AddOutput("LastC", "(Tensor) the cell state of the last step" "The shape is ( num_layers x batch_size x hidden_size) if " "is_bidirec is False" "and When is_bidirect is True, the shape will be (num_layers*2 x " "batch_size x hidden_size*2)"); - AddAttr("max_len", - "max length of the LSTM op" - "the first dim of the Input can NOT be greater than max_len") - .SetDefault(20); AddAttr( "dropout_prob", "dropout prob of the dropout op" @@ -120,14 +146,14 @@ class CudnnLSTMOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("is_bidirec", "is_bidirec" "if it is bidirectional rnn" - "The will affect the shape of the Out, last_h, and last_c") + "The will affect the shape of the Out, LastH, and LastC") .SetDefault(false); AddAttr("input_size", "input size ot the Input Tensor").SetDefault(10); AddAttr("hidden_size", "hidden size of the LSTM").SetDefault(100); AddAttr("num_layers", "the total layer number of the LSTM") .SetDefault(1); AddAttr("is_test", "True if in test phase.").SetDefault(false); - AddAttr("seed", "seed to used if fix_seed is True").SetDefault(-1); + AddAttr("seed", "seed to used if fix_seed is True").SetDefault(0); AddComment(R"DOC( CUDNN LSTM implementation @@ -172,16 +198,10 @@ class CudnnLSTMGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Input"), - "Input(Input) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Cache"), - "Input(last_c) of LSTM should not be null."); - PADDLE_ENFORCE(ctx->HasInput("InitH"), - "Input(init_h) of LSTM should not be null."); - - PADDLE_ENFORCE(ctx->HasInput("InitC"), - "Input(init_c) of LSTM should not be null."); + OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "CudnnLSTMGrad"); + OP_INOUT_CHECK(ctx->HasInput("W"), "Input", "W", "CudnnLSTMGrad"); + OP_INOUT_CHECK(ctx->HasInput("InitH"), "Input", "InitH", "CudnnLSTMGrad"); + OP_INOUT_CHECK(ctx->HasInput("InitC"), "Input", "InitC", "CudnnLSTMGrad"); auto SetOutGradDim = [&ctx](const std::string& name) { auto g_name = framework::GradVarName(name); @@ -195,6 +215,12 @@ class CudnnLSTMGradOp : public framework::OperatorWithKernel { SetOutGradDim("InitH"); SetOutGradDim("InitC"); } + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); + } }; template @@ -209,13 +235,12 @@ class CudnnLSTMGradOpMaker : public framework::SingleGradOpMaker { op->SetInput("InitH", this->Input("InitH")); op->SetInput("InitC", this->Input("InitC")); op->SetInput("W", this->Input("W")); - if (this->HasInput("Cache")) { - op->SetInput("Cache", this->Input("Cache")); - } + op->SetInput("Reserve", this->Output("Reserve")); + op->SetInput("StateOut", this->Output("StateOut")); op->SetInput("Out", this->Output("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - op->SetInput(framework::GradVarName("last_c"), this->OutputGrad("last_c")); - op->SetInput(framework::GradVarName("last_h"), this->OutputGrad("last_h")); + op->SetInput(framework::GradVarName("LastC"), this->OutputGrad("LastC")); + op->SetInput(framework::GradVarName("LastH"), this->OutputGrad("LastH")); op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input")); op->SetOutput(framework::GradVarName("W"), this->InputGrad("W")); diff --git a/paddle/fluid/operators/cudnn_lstm_op.cu.cc b/paddle/fluid/operators/cudnn_lstm_op.cu.cc index 579dddee8e821..37e5e518ea2af 100644 --- a/paddle/fluid/operators/cudnn_lstm_op.cu.cc +++ b/paddle/fluid/operators/cudnn_lstm_op.cu.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/cudnn_rnn_cache.h" #include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/cudnn_desc.h" namespace paddle { namespace operators { @@ -33,8 +34,10 @@ class CudnnLSTMGPUKernel : public framework::OpKernel { auto w = ctx.Input("W"); Tensor *out = ctx.Output("Out"); - Tensor *last_h = ctx.Output("last_h"); - Tensor *last_c = ctx.Output("last_c"); + Tensor *last_h = ctx.Output("LastH"); + Tensor *last_c = ctx.Output("LastC"); + Tensor *reserve = ctx.Output("Reserve"); + Tensor *state_out = ctx.Output("StateOut"); const T *x_data = x->data(); const T *init_h_data = init_h->data(); @@ -46,72 +49,56 @@ class CudnnLSTMGPUKernel : public framework::OpKernel { T *last_h_data = last_h->mutable_data(ctx.GetPlace()); T *last_c_data = last_c->mutable_data(ctx.GetPlace()); - size_t max_len = ctx.Attr("max_len"); float dropout_prob = ctx.Attr("dropout_prob"); bool is_bidirec = ctx.Attr("is_bidirec"); - int input_size = ctx.Attr("input_size"); int hidden_size = ctx.Attr("hidden_size"); int num_layers = ctx.Attr("num_layers"); bool is_test = ctx.Attr("is_test"); + int seed = ctx.Attr("seed"); auto &dev_ctx = ctx.template device_context(); auto handle = dev_ctx.cudnn_handle(); - auto *cache_var = ctx.InputVar("Cache"); - if (!cache_var) { - // The RAW type cache variable wouldn't be created and broadcasted on - // multi-devices before the first running. - // use parent scope to make cache persistable - auto *scope = const_cast(ctx.scope().parent()); - auto cache_var_name = ctx.InputNames("Cache")[0]; - cache_var = scope->Var(cache_var_name); - } - CudnnRNNCache *cudnn_rnn_cache = nullptr; - if (cache_var->IsInitialized()) { - // const_cast is usually bad. - cudnn_rnn_cache = const_cast(cache_var) - ->GetMutable(); - } else { - // const_cast is usually bad. - cudnn_rnn_cache = const_cast(cache_var) - ->GetMutable(); - std::random_device rnd; - int seed = ctx.Attr("seed"); - if (seed == -1) { - seed = rnd(); - } - - auto input_w_numel = w->numel(); - auto batch_size = x->dims()[1]; - cudnn_rnn_cache->init(handle, ctx.GetPlace(), max_len, batch_size, - input_size, hidden_size, num_layers, dropout_prob, - is_bidirec, seed, input_w_numel); - } - auto run_seq_len = x->dims()[0]; + CudnnRNNCache *cudnn_rnn_cache = new CudnnRNNCache(); + + auto input_w_numel = w->numel(); + auto seq_len = x->dims()[0]; + auto batch_size = x->dims()[1]; + auto input_dim = x->dims()[2]; + size_t reserve_size; + bool state_initialized = state_out->IsInitialized() ? true : false; + cudnnDataType_t cudnn_type = platform::ToCudnnDataType( + framework::ToDataType(std::type_index(typeid(T)))); + cudnn_rnn_cache->init(handle, ctx.GetPlace(), seq_len, batch_size, + input_dim, hidden_size, num_layers, dropout_prob, + is_bidirec, seed, input_w_numel, &reserve_size, + state_out, state_initialized, cudnn_type); + + auto *reserve_data = reserve->mutable_data( + {static_cast(reserve_size)}, ctx.GetPlace()); if (is_test) { // for inference PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardInference( - handle, cudnn_rnn_cache->rnn_desc_, run_seq_len, - cudnn_rnn_cache->x_desc_, x_data, cudnn_rnn_cache->hx_desc_, - init_h_data, cudnn_rnn_cache->cx_desc_, init_c_data, - cudnn_rnn_cache->w_desc_, w_data, cudnn_rnn_cache->y_desc_, out_data, - cudnn_rnn_cache->hy_desc_, last_h_data, cudnn_rnn_cache->cy_desc_, - last_c_data, cudnn_rnn_cache->workspace_data_.data(), + handle, cudnn_rnn_cache->rnn_desc_, seq_len, cudnn_rnn_cache->x_desc_, + x_data, cudnn_rnn_cache->hx_desc_, init_h_data, + cudnn_rnn_cache->cx_desc_, init_c_data, cudnn_rnn_cache->w_desc_, + w_data, cudnn_rnn_cache->y_desc_, out_data, cudnn_rnn_cache->hy_desc_, + last_h_data, cudnn_rnn_cache->cy_desc_, last_c_data, + cudnn_rnn_cache->workspace_data_.data(), cudnn_rnn_cache->workspace_size_)); } else { // for train PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNForwardTraining( - handle, cudnn_rnn_cache->rnn_desc_, run_seq_len, - cudnn_rnn_cache->x_desc_, x_data, cudnn_rnn_cache->hx_desc_, - init_h_data, cudnn_rnn_cache->cx_desc_, init_c_data, - cudnn_rnn_cache->w_desc_, w_data, cudnn_rnn_cache->y_desc_, out_data, - cudnn_rnn_cache->hy_desc_, last_h_data, cudnn_rnn_cache->cy_desc_, - last_c_data, cudnn_rnn_cache->workspace_data_.data(), - cudnn_rnn_cache->workspace_size_, - cudnn_rnn_cache->reserve_data_.data(), - cudnn_rnn_cache->reserve_size_)); + handle, cudnn_rnn_cache->rnn_desc_, seq_len, cudnn_rnn_cache->x_desc_, + x_data, cudnn_rnn_cache->hx_desc_, init_h_data, + cudnn_rnn_cache->cx_desc_, init_c_data, cudnn_rnn_cache->w_desc_, + w_data, cudnn_rnn_cache->y_desc_, out_data, cudnn_rnn_cache->hy_desc_, + last_h_data, cudnn_rnn_cache->cy_desc_, last_c_data, + cudnn_rnn_cache->workspace_data_.data(), + cudnn_rnn_cache->workspace_size_, reserve_data, reserve_size)); } + delete cudnn_rnn_cache; } }; @@ -123,15 +110,13 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel { auto *weight = ctx.Input("W"); auto *init_h = ctx.Input("InitH"); auto *init_c = ctx.Input("InitC"); - // auto * last_h = ctx.Input("last_h"); - // auto * last_c = ctx.Input("last_c"); + auto *reserve = ctx.Input("Reserve"); + auto *state_out = ctx.Input("StateOut"); + auto *out = ctx.Input("Out"); auto *out_grad = ctx.Input(framework::GradVarName("Out")); - auto *last_h_grad = ctx.Input(framework::GradVarName("last_h")); - auto *last_c_grad = ctx.Input(framework::GradVarName("last_c")); - - // auto* init_h = ctx.Input("init_h"); - // auto* init_c = ctx.Input("init_c"); + auto *last_h_grad = ctx.Input(framework::GradVarName("LastH")); + auto *last_c_grad = ctx.Input(framework::GradVarName("LastC")); auto *in_grad = ctx.Output(framework::GradVarName("Input")); auto *weight_grad = ctx.Output(framework::GradVarName("W")); @@ -140,116 +125,75 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel { auto &dev_ctx = ctx.template device_context(); auto handle = dev_ctx.cudnn_handle(); - auto *cache_var = ctx.InputVar("Cache"); - PADDLE_ENFORCE(cache_var->IsInitialized()); - CudnnRNNCache *cudnn_rnn_cache = - const_cast(cache_var) - ->GetMutable(); auto input_dims = input->dims(); auto init_h_dims = init_h->dims(); auto init_c_dims = init_c->dims(); - in_grad->mutable_data(ctx.GetPlace()); - weight_grad->mutable_data(ctx.GetPlace()); - math::SetConstant zero; - zero(dev_ctx, in_grad, static_cast(0.0)); - zero(dev_ctx, weight_grad, static_cast(0.0)); - - T *init_h_grad_data = NULL; - if (init_h_grad == nullptr) { - Tensor init_h_grad_temp; - init_h_grad_temp.mutable_data(init_h_dims, ctx.GetPlace()); - zero(dev_ctx, &init_h_grad_temp, static_cast(0.0)); - - init_h_grad_data = init_h_grad_temp.data(); - } else { - init_h_grad->mutable_data(init_h_dims, ctx.GetPlace()); - zero(dev_ctx, init_h_grad, static_cast(0.0)); - init_h_grad_data = init_h_grad->data(); - } - - T *init_c_grad_data = NULL; - if (init_c_grad == nullptr) { - Tensor init_c_grad_temp; - init_c_grad_temp.mutable_data(init_c_dims, ctx.GetPlace()); - zero(dev_ctx, &init_c_grad_temp, static_cast(0.0)); - init_c_grad_data = init_c_grad_temp.data(); - } else { - init_c_grad->mutable_data(init_c_dims, ctx.GetPlace()); - zero(dev_ctx, init_c_grad, static_cast(0.0)); - init_c_grad_data = init_c_grad->data(); - } + auto *weight_data = weight->data(); + auto *init_h_data = init_h->data(); + auto *init_c_data = init_c->data(); + auto *out_data = out->data(); + auto *out_grad_data = out_grad->data(); + auto *last_h_grad_data = last_h_grad->data(); + auto *last_c_grad_data = last_c_grad->data(); - const T *last_h_grad_data = NULL; - if (last_h_grad == nullptr) { - Tensor last_h_grad_temp; - last_h_grad_temp.mutable_data(init_h_dims, ctx.GetPlace()); - zero(dev_ctx, &last_h_grad_temp, static_cast(0.0)); - - last_h_grad_data = (const T *)last_h_grad_temp.data(); - } else { - last_h_grad_data = last_h_grad->data(); - } - - const T *last_c_grad_data = NULL; - if (last_c_grad == nullptr) { - Tensor last_c_grad_temp; - last_c_grad_temp.mutable_data(init_c_dims, ctx.GetPlace()); - zero(dev_ctx, &last_c_grad_temp, static_cast(0.0)); - - last_c_grad_data = (const T *)last_c_grad_temp.data(); - } else { - last_c_grad_data = last_c_grad->data(); - } + math::SetConstant zero; + weight_grad->mutable_data(ctx.GetPlace()); + zero(dev_ctx, weight_grad, static_cast(0.0)); - const T *out_grad_data = NULL; - if (out_grad == nullptr) { - Tensor out_grad_temp; - out_grad_temp.mutable_data(out->dims(), ctx.GetPlace()); - zero(dev_ctx, &out_grad_temp, static_cast(0.0)); + in_grad->mutable_data(input_dims, ctx.GetPlace()); + auto *in_grad_data = in_grad->data(); - out_grad_data = (const T *)out_grad_temp.data(); - } else { - out_grad_data = out_grad->data(); - } + init_h_grad->mutable_data(init_h_dims, ctx.GetPlace()); + auto *init_h_grad_data = init_h_grad->data(); - // zero( dev_ctx, last_h_grad, static_cast(0.0)); - // zero( dev_ctx, last_c_grad, static_cast(0.0)); + init_c_grad->mutable_data(init_c_dims, ctx.GetPlace()); + auto *init_c_grad_data = init_c_grad->data(); - auto out_data = out->data(); - // auto out_grad_data = out_grad->data(); - auto weight_data = weight->data(); - auto init_h_data = init_h->data(); - auto init_c_data = init_c->data(); - auto in_grad_data = in_grad->data(); + float dropout_prob = ctx.Attr("dropout_prob"); + bool is_bidirec = ctx.Attr("is_bidirec"); + int hidden_size = ctx.Attr("hidden_size"); + int num_layers = ctx.Attr("num_layers"); + int seed = ctx.Attr("seed"); + + CudnnRNNCache *cudnn_rnn_cache = new CudnnRNNCache(); + + auto input_w_numel = weight->numel(); + auto seq_len = input_dims[0]; + auto batch_size = input->dims()[1]; + auto input_dim = input->dims()[2]; + size_t reserve_size; + cudnnDataType_t cudnn_type = platform::ToCudnnDataType( + framework::ToDataType(std::type_index(typeid(T)))); + cudnn_rnn_cache->init(handle, ctx.GetPlace(), seq_len, batch_size, + input_dim, hidden_size, num_layers, dropout_prob, + is_bidirec, seed, input_w_numel, &reserve_size, + const_cast(state_out), true, cudnn_type); auto work_data = cudnn_rnn_cache->workspace_data_.data(); - auto reserve_data = cudnn_rnn_cache->reserve_data_.data(); + const uint8_t *reserve_data = reserve->data(); - auto run_seq_len = input_dims[0]; - PADDLE_ENFORCE_LE((size_t)run_seq_len, cudnn_rnn_cache->max_length_, - "cudnn running seq_len CAN not greater max_lengh"); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardData( - handle, cudnn_rnn_cache->rnn_desc_, run_seq_len, - cudnn_rnn_cache->y_desc_, out_data, cudnn_rnn_cache->dy_desc_, - out_grad_data, cudnn_rnn_cache->dhy_desc_, last_h_grad_data, - cudnn_rnn_cache->dcy_desc_, last_c_grad_data, cudnn_rnn_cache->w_desc_, - weight_data, cudnn_rnn_cache->hx_desc_, init_h_data, - cudnn_rnn_cache->cx_desc_, init_c_data, cudnn_rnn_cache->dx_desc_, - in_grad_data, cudnn_rnn_cache->dhx_desc_, init_h_grad_data, - cudnn_rnn_cache->dcx_desc_, init_c_grad_data, work_data, - cudnn_rnn_cache->workspace_size_, reserve_data, - cudnn_rnn_cache->reserve_size_)); + handle, cudnn_rnn_cache->rnn_desc_, seq_len, cudnn_rnn_cache->y_desc_, + out_data, cudnn_rnn_cache->y_desc_, out_grad_data, + cudnn_rnn_cache->hy_desc_, last_h_grad_data, cudnn_rnn_cache->cy_desc_, + last_c_grad_data, cudnn_rnn_cache->w_desc_, weight_data, + cudnn_rnn_cache->hx_desc_, init_h_data, cudnn_rnn_cache->cx_desc_, + init_c_data, cudnn_rnn_cache->x_desc_, in_grad_data, + cudnn_rnn_cache->hx_desc_, init_h_grad_data, cudnn_rnn_cache->cx_desc_, + init_c_grad_data, work_data, cudnn_rnn_cache->workspace_size_, + const_cast(reserve_data), reserve_size)); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnRNNBackwardWeights( - handle, cudnn_rnn_cache->rnn_desc_, run_seq_len, - cudnn_rnn_cache->x_desc_, input->data(), cudnn_rnn_cache->hx_desc_, - init_h->data(), cudnn_rnn_cache->y_desc_, out->data(), + handle, cudnn_rnn_cache->rnn_desc_, seq_len, cudnn_rnn_cache->x_desc_, + input->data(), cudnn_rnn_cache->hx_desc_, init_h->data(), + cudnn_rnn_cache->y_desc_, out->data(), cudnn_rnn_cache->workspace_data_.data(), - cudnn_rnn_cache->workspace_size_, cudnn_rnn_cache->dw_desc_, - weight_grad->data(), cudnn_rnn_cache->reserve_data_.data(), - cudnn_rnn_cache->reserve_size_)); + cudnn_rnn_cache->workspace_size_, cudnn_rnn_cache->w_desc_, + weight_grad->data(), const_cast(reserve_data), + reserve_size)); + delete cudnn_rnn_cache; } }; @@ -257,5 +201,7 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(cudnn_lstm, ops::CudnnLSTMGPUKernel); -REGISTER_OP_CUDA_KERNEL(cudnn_lstm_grad, ops::CudnnLSTMGPUGradKernel); +REGISTER_OP_CUDA_KERNEL(cudnn_lstm, ops::CudnnLSTMGPUKernel, + ops::CudnnLSTMGPUKernel); +REGISTER_OP_CUDA_KERNEL(cudnn_lstm_grad, ops::CudnnLSTMGPUGradKernel, + ops::CudnnLSTMGPUGradKernel); diff --git a/paddle/fluid/operators/cudnn_rnn_cache.h b/paddle/fluid/operators/cudnn_rnn_cache.h index cd33338abc622..13a3e7d09b9f6 100644 --- a/paddle/fluid/operators/cudnn_rnn_cache.h +++ b/paddle/fluid/operators/cudnn_rnn_cache.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/cudnn_helper.h" @@ -24,16 +25,12 @@ struct CudnnRNNCache { CudnnRNNCache() { x_desc_ = NULL; y_desc_ = NULL; - dx_desc_ = NULL; - dy_desc_ = NULL; } ~CudnnRNNCache() { release(); } cudnnRNNDescriptor_t rnn_desc_; cudnnTensorDescriptor_t *x_desc_; cudnnTensorDescriptor_t *y_desc_; - cudnnTensorDescriptor_t *dx_desc_; - cudnnTensorDescriptor_t *dy_desc_; cudnnTensorDescriptor_t hx_desc_; cudnnTensorDescriptor_t cx_desc_; @@ -55,13 +52,9 @@ struct CudnnRNNCache { cudnnFilterDescriptor_t dw_desc_; size_t workspace_size_; - size_t reserve_size_; - framework::Tensor reserve_data_; framework::Tensor workspace_data_; - framework::Tensor dropout_state_; - - size_t max_length_; + size_t seq_length_; float dropout_prob_; bool is_bidirec_; @@ -72,10 +65,12 @@ struct CudnnRNNCache { int num_layers_; int seed_; - void init(cudnnHandle_t handle, const platform::Place &place, size_t max_len, + void init(cudnnHandle_t handle, const platform::Place &place, size_t seq_len, int batch_size, int input_size, int hidden_size, int num_layers, - float dropout_prob, bool is_bidirec, int seed, int weight_numel) { - max_length_ = max_len; + float dropout_prob, bool is_bidirec, int seed, int weight_numel, + size_t *reserve_size_, framework::Tensor *dropout_state_, + bool initialized, cudnnDataType_t cudnn_type) { + seq_length_ = seq_len; batch_size_ = batch_size; input_size_ = input_size; hidden_size_ = hidden_size; @@ -84,55 +79,34 @@ struct CudnnRNNCache { is_bidirec_ = is_bidirec; seed_ = seed; - x_desc_ = new cudnnTensorDescriptor_t[max_length_]; - y_desc_ = new cudnnTensorDescriptor_t[max_length_]; - dx_desc_ = new cudnnTensorDescriptor_t[max_length_]; - dy_desc_ = new cudnnTensorDescriptor_t[max_length_]; - int dim_a[3]; - int stride_a[3]; + const auto numDirections = is_bidirec_ ? 2 : 1; + auto cudnn_size = + cudnn_type == CUDNN_DATA_FLOAT ? sizeof(float) : sizeof(double); + + x_desc_ = new cudnnTensorDescriptor_t[seq_length_]; + y_desc_ = new cudnnTensorDescriptor_t[seq_length_]; + std::vector dims = {batch_size_, input_size_, 1}; + std::vector strides = {input_size_, 1, 1}; + + std::vector dims_y = {batch_size_, hidden_size_ * numDirections, 1}; + std::vector strides_y = {hidden_size_ * numDirections, 1, 1}; - for (size_t i = 0; i < max_length_; ++i) { + for (size_t i = 0; i < seq_length_; ++i) { PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnCreateTensorDescriptor(&x_desc_[i])); PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnCreateTensorDescriptor(&y_desc_[i])); - PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::cudnnCreateTensorDescriptor(&dx_desc_[i])); - PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::cudnnCreateTensorDescriptor(&dy_desc_[i])); - dim_a[0] = batch_size_; - dim_a[1] = input_size_; - dim_a[2] = 1; - - stride_a[0] = dim_a[2] * dim_a[1]; - stride_a[1] = dim_a[2]; - stride_a[2] = 1; - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( - x_desc_[i], CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( - dx_desc_[i], CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); - - dim_a[0] = batch_size_; - dim_a[1] = is_bidirec_ ? hidden_size_ * 2 : hidden_size_; - dim_a[2] = 1; - - stride_a[0] = dim_a[2] * dim_a[1]; - stride_a[1] = dim_a[2]; - stride_a[2] = 1; PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( - y_desc_[i], CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + x_desc_[i], cudnn_type, 3, dims.data(), strides.data())); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( - dy_desc_[i], CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + y_desc_[i], cudnn_type, 3, dims_y.data(), strides_y.data())); } - dim_a[0] = num_layers_ * (is_bidirec_ ? 2 : 1); - dim_a[1] = batch_size_; - dim_a[2] = hidden_size_; - - stride_a[0] = dim_a[2] * dim_a[1]; - stride_a[1] = dim_a[2]; - stride_a[2] = 1; + std::vector dims_hx = {num_layers_ * numDirections, batch_size_, + hidden_size_}; + std::vector strides_hx = {hidden_size_ * batch_size_, hidden_size_, 1}; PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnCreateTensorDescriptor(&hx_desc_)); @@ -152,33 +126,44 @@ struct CudnnRNNCache { platform::dynload::cudnnCreateTensorDescriptor(&dcy_desc_)); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( - hx_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + hx_desc_, cudnn_type, 3, dims_hx.data(), strides_hx.data())); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( - cx_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + cx_desc_, cudnn_type, 3, dims_hx.data(), strides_hx.data())); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( - hy_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + hy_desc_, cudnn_type, 3, dims_hx.data(), strides_hx.data())); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( - cy_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + cy_desc_, cudnn_type, 3, dims_hx.data(), strides_hx.data())); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( - dhx_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + dhx_desc_, cudnn_type, 3, dims_hx.data(), strides_hx.data())); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( - dcx_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + dcx_desc_, cudnn_type, 3, dims_hx.data(), strides_hx.data())); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( - dhy_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + dhy_desc_, cudnn_type, 3, dims_hx.data(), strides_hx.data())); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( - dcy_desc_, CUDNN_DATA_FLOAT, 3, dim_a, stride_a)); + dcy_desc_, cudnn_type, 3, dims_hx.data(), strides_hx.data())); PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnCreateDropoutDescriptor(&dropout_desc_)); size_t state_size; - PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::cudnnDropoutGetStatesSize(handle, &state_size)); - dropout_state_.Resize({static_cast(state_size)}); - auto *dropout_state_data = dropout_state_.mutable_data(place); - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetDropoutDescriptor( - dropout_desc_, handle, dropout_prob_, dropout_state_data, state_size, - seed_)); + if (!initialized) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cudnnDropoutGetStatesSize(handle, &state_size)); + dropout_state_->Resize({static_cast(state_size)}); + uint8_t *dropout_state_data = + dropout_state_->mutable_data(place); + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetDropoutDescriptor( + dropout_desc_, handle, dropout_prob_, dropout_state_data, state_size, + seed_)); + } else { + uint8_t *dropout_state_data = dropout_state_->data(); + auto dropout_state_dims = dropout_state_->dims(); + state_size = dropout_state_dims[0]; + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cudnnRestoreDropoutDescriptor( + dropout_desc_, handle, dropout_prob_, dropout_state_data, + state_size, 0)); + } PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnCreateRNNDescriptor(&rnn_desc_)); @@ -188,12 +173,12 @@ struct CudnnRNNCache { handle, rnn_desc_, hidden_size_, num_layers_, dropout_desc_, CUDNN_LINEAR_INPUT, is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM, - CUDNN_RNN_ALGO_STANDARD, CUDNN_DATA_FLOAT)); + CUDNN_RNN_ALGO_STANDARD, cudnn_type)); #else PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetRNNDescriptor( rnn_desc_, hidden_size_, num_layers_, dropout_desc_, CUDNN_LINEAR_INPUT, is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, CUDNN_LSTM, - CUDNN_DATA_FLOAT)); + cudnn_type)); #endif PADDLE_ENFORCE_CUDA_SUCCESS( @@ -202,48 +187,42 @@ struct CudnnRNNCache { platform::dynload::cudnnCreateFilterDescriptor(&dw_desc_)); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnGetRNNParamsSize( - handle, rnn_desc_, x_desc_[0], &weights_size_, CUDNN_DATA_FLOAT)); + handle, rnn_desc_, x_desc_[0], &weights_size_, cudnn_type)); + + PADDLE_ENFORCE_EQ( + weights_size_, cudnn_size * weight_numel, + platform::errors::InvalidArgument( + "The cudnn lstm and setting weight size should be same.")); - PADDLE_ENFORCE_EQ(weights_size_, sizeof(float) * weight_numel, - "cudnn lstm weight size should be SAME"); int dim_w[3]; - dim_w[0] = weights_size_ / sizeof(float); + dim_w[0] = weights_size_ / cudnn_size; dim_w[1] = 1; dim_w[2] = 1; PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetFilterNdDescriptor( - w_desc_, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, 3, dim_w)); + w_desc_, cudnn_type, CUDNN_TENSOR_NCHW, 3, dim_w)); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetFilterNdDescriptor( - dw_desc_, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, 3, dim_w)); + dw_desc_, cudnn_type, CUDNN_TENSOR_NCHW, 3, dim_w)); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnGetRNNWorkspaceSize( - handle, rnn_desc_, max_length_, x_desc_, &workspace_size_)); + handle, rnn_desc_, seq_length_, x_desc_, &workspace_size_)); PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnGetRNNTrainingReserveSize( - handle, rnn_desc_, max_length_, x_desc_, &reserve_size_)); - - reserve_data_.Resize({static_cast(reserve_size_)}); - reserve_data_.mutable_data(place); + handle, rnn_desc_, seq_length_, x_desc_, reserve_size_)); workspace_data_.Resize({static_cast(workspace_size_)}); workspace_data_.mutable_data(place); } void release() { - for (size_t i = 0; i < max_length_; ++i) { + for (size_t i = 0; i < seq_length_; ++i) { PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnDestroyTensorDescriptor(x_desc_[i])); PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnDestroyTensorDescriptor(y_desc_[i])); - PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::cudnnDestroyTensorDescriptor(dx_desc_[i])); - PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::cudnnDestroyTensorDescriptor(dy_desc_[i])); } delete[] x_desc_; delete[] y_desc_; - delete[] dx_desc_; - delete[] dy_desc_; PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnDestroyTensorDescriptor(hx_desc_)); diff --git a/paddle/fluid/operators/expand_as_v2_op.cc b/paddle/fluid/operators/expand_as_v2_op.cc new file mode 100644 index 0000000000000..495b640bb4399 --- /dev/null +++ b/paddle/fluid/operators/expand_as_v2_op.cc @@ -0,0 +1,150 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/expand_as_v2_op.h" +#include +#include + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class ExpandAsV2Op : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ExpandAsV2"); + OP_INOUT_CHECK(ctx->HasInput("target_tensor"), "Input", "target_tensor", + "ExpandAsV2"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ExpandAsV2"); + auto x_dims = ctx->GetInputDim("X"); + auto target_tensor_dims = ctx->GetInputDim("target_tensor"); + PADDLE_ENFORCE_GE( + target_tensor_dims.size(), static_cast(x_dims.size()), + platform::errors::InvalidArgument( + "The rank of Input(target_tensor) must be greater than or equal " + "to the rank of Input(X). But received Input(X): input " + "rank %u, input shape [%s]; received Input(target_tensor): " + "input rank %u, input shape [%s].", + x_dims.size(), x_dims, target_tensor_dims.size(), + target_tensor_dims)); + PADDLE_ENFORCE_LE( + target_tensor_dims.size(), MAX_RANK_SUPPORTED, + platform::errors::InvalidArgument( + "The rank of Input(target_tensor) must not be less than or equal " + "to %d. But received: input rank %u, input shape [%s].", + MAX_RANK_SUPPORTED, x_dims.size(), x_dims)); + std::vector out_shape(target_tensor_dims.size()); + ctx->SetOutputDim("Out", framework::make_ddim(out_shape)); + } +}; + +class ExpandAsV2OpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor, default Tensor). A tensor with rank in [1, 6]." + "X is the input to be expanded."); + AddOutput("Out", + "(Tensor, default Tensor). A tensor with rank in [1, 6]." + "The rank of Output(Out) have the same with Input(X). " + "After expanding, size of each dimension of Output(Out) is equal " + "to size of the corresponding dimension of Input(X) multiplying " + "the corresponding value given by Attr(expand_times)."); + AddInput("target_tensor", "Expand tensor's shape for each dimension."); + AddComment(R"DOC( +Expand the input by given times number. You should set times +number for each dimension by providing tensor 'expend_tensor'. The rank of X +should be in [1, 6]. Please note that size of 'expend_tensor' must be the same +with X's rank. Following is a using case: +Input(X) is a 3-D tensor with shape [2, 3, 1]: + [ + [[1], [2], [3]], + [[4], [5], [6]] + ] +target_tensors'shape: [2, 6, 2] +Output(Out) is a 3-D tensor with shape [2, 6, 2]: + [ + [[1, 1], [2, 2], [3, 3], [1, 1], [2, 2], [3, 3]], + [[4, 4], [5, 5], [6, 6], [4, 4], [5, 5], [6, 6]] + ] +)DOC"); + } +}; + +class ExpandAsV2GradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ExpandAsV2Grad"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + framework::GradVarName("Out"), "ExpandAsV2Grad"); + + auto x_dims = ctx->GetInputDim("X"); + auto x_grad_name = framework::GradVarName("X"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); + } +}; + +template +class ExpandAsV2GradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("expand_as_v2_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput("target_tensor", this->Input("target_tensor")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetAttrMap(this->Attrs()); + } +}; + +DECLARE_NO_NEED_BUFFER_VARS_INFERER(ExpandAsV2GradNoNeedBufVarsInferer, "X"); + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(expand_as_v2, ops::ExpandAsV2Op, ops::ExpandAsV2OpMaker, + ops::ExpandAsV2GradOpMaker, + ops::ExpandAsV2GradOpMaker); +REGISTER_OPERATOR(expand_as_v2_grad, ops::ExpandAsV2GradOp, + ops::ExpandAsV2GradNoNeedBufVarsInferer); +REGISTER_OP_CPU_KERNEL( + expand_as_v2, + ops::ExpandAsV2Kernel, + ops::ExpandAsV2Kernel, + ops::ExpandAsV2Kernel, + ops::ExpandAsV2Kernel, + ops::ExpandAsV2Kernel); +REGISTER_OP_CPU_KERNEL( + expand_as_v2_grad, + ops::ExpandAsV2GradKernel, + ops::ExpandAsV2GradKernel, + ops::ExpandAsV2GradKernel, + ops::ExpandAsV2GradKernel); diff --git a/paddle/fluid/operators/expand_as_v2_op.cu b/paddle/fluid/operators/expand_as_v2_op.cu new file mode 100644 index 0000000000000..e315144472dd9 --- /dev/null +++ b/paddle/fluid/operators/expand_as_v2_op.cu @@ -0,0 +1,26 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include "paddle/fluid/operators/expand_as_v2_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + expand_as_v2, + ops::ExpandAsV2Kernel, + ops::ExpandAsV2Kernel, + ops::ExpandAsV2Kernel, + ops::ExpandAsV2Kernel, + ops::ExpandAsV2Kernel); +REGISTER_OP_CUDA_KERNEL( + expand_as_v2_grad, + ops::ExpandAsV2GradKernel, + ops::ExpandAsV2GradKernel, + ops::ExpandAsV2GradKernel, + ops::ExpandAsV2GradKernel); diff --git a/paddle/fluid/operators/expand_as_v2_op.h b/paddle/fluid/operators/expand_as_v2_op.h new file mode 100644 index 0000000000000..a4c30dfe1298d --- /dev/null +++ b/paddle/fluid/operators/expand_as_v2_op.h @@ -0,0 +1,214 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +#define MAX_RANK_SUPPORTED 6 + +#define EXPAND_AS_TEMPLATE(z, n, data) \ + case n + 1: { \ + ExpandAs(context); \ + break; \ + } +#define REP_EXPAND_AS_TEMPLATE(n) BOOST_PP_REPEAT(n, EXPAND_AS_TEMPLATE, ~) +#define COND(n) BOOST_PP_GREATER_EQUAL(n, BOOST_PP_MOD(n, MAX_RANK_SUPPORTED)) +#define EXPAND_AS_GRAD_CASE(n) \ + case n: { \ + ExpandAsBackward(context, reshape_dims_vec, reduce_dims_vec); \ + break; \ + } +#define EXPAND_AS_GRAD_TEMPLATE(z, n, data) \ + BOOST_PP_IF(COND(n), EXPAND_AS_GRAD_CASE(n), ) +#define REP_EXPAND_AS_GRAD_TEMPLATE(n) \ + BOOST_PP_REPEAT(n, EXPAND_AS_GRAD_TEMPLATE, ~) + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenVector = framework::EigenVector; +template +using EigenTensor = framework::EigenTensor; + +template +class ExpandAsV2Kernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto rank = context.Input("X")->dims().size(); + auto* target_tensor = context.Input("target_tensor"); + auto target_rank = target_tensor->dims().size(); + PADDLE_ENFORCE_GE(target_rank, rank, + platform::errors::InvalidArgument( + "The rank (%d) of the input 'target_tensor' for " + "expand_as_v2 op must be greater than or equal to " + "the rank (%d) of the input 'x'.", + target_rank, rank)); + PADDLE_ENFORCE_GE(rank, 1, platform::errors::InvalidArgument( + "The rank (%d) of the input 'x' for " + "expand_as_v2 op must be positive.", + rank)); + PADDLE_ENFORCE_LE(target_rank, MAX_RANK_SUPPORTED, + platform::errors::InvalidArgument( + "The rank (%d) of the input 'target_tensor' for " + "expand_as_v2 op must be less than or equal to %d.", + target_rank, MAX_RANK_SUPPORTED)); + + switch (target_rank) { REP_EXPAND_AS_TEMPLATE(MAX_RANK_SUPPORTED) } + } + + protected: + template + void ExpandAs(const framework::ExecutionContext& context) const { + auto* in0 = context.Input("X"); + auto in_dims = in0->dims(); + auto* target_tensor = context.Input("target_tensor"); + auto vec_in_dims = framework::vectorize(in_dims); + auto target_shape = framework::vectorize(target_tensor->dims()); + auto diff = target_shape.size() - vec_in_dims.size(); + vec_in_dims.insert(vec_in_dims.begin(), diff, 1); + std::vector repeat_times(vec_in_dims.size()); + for (size_t i = 0; i < vec_in_dims.size(); ++i) { + PADDLE_ENFORCE_NE(target_shape[i], 0, + platform::errors::InvalidArgument( + "The value of target shape cannot be zero.")); + if (vec_in_dims[i] != 1) { + PADDLE_ENFORCE_EQ( + vec_in_dims[i], target_shape[i], + platform::errors::InvalidArgument( + "The value (%d) of the non-singleton dimension does not match" + " the corresponding value (%d) in " + "target tensor for expand_as_v2 op.", + vec_in_dims[i], target_shape[i])); + repeat_times[i] = 1; + } else { + repeat_times[i] = target_shape[i]; + } + } + auto* out0 = context.Output("Out"); + Eigen::DSizes bcast_dims; + for (size_t i = 0; i < repeat_times.size(); ++i) { + bcast_dims[i] = repeat_times[i]; + } + + framework::DDim new_in_dims = framework::make_ddim(vec_in_dims); + framework::DDim out_dims = framework::make_ddim(target_shape); + + out0->Resize(out_dims); + auto x = EigenTensor::From(*in0, new_in_dims); + out0->mutable_data(context.GetPlace()); + auto y = EigenTensor::From(*out0, out_dims); + auto& place = + *context.template device_context().eigen_device(); + y.device(place) = x.broadcast(bcast_dims); + } +}; + +template +class ExpandAsV2GradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in0 = context.Input("X"); + auto* target_tensor = context.Input("target_tensor"); + auto x_dims = in0->dims(); + auto target_shape = target_tensor->dims(); + auto vec_in_dims = framework::vectorize(x_dims); + auto diff = target_shape.size() - vec_in_dims.size(); + vec_in_dims.insert(vec_in_dims.begin(), diff, 1); + std::vector repeat_times(vec_in_dims.size()); + for (size_t i = 0; i < vec_in_dims.size(); ++i) { + repeat_times[i] = target_shape[i] / vec_in_dims[i]; + } + std::vector reshape_dims_vec; + std::vector reduce_dims_vec; + for (size_t i = 0; i < repeat_times.size(); ++i) { + reduce_dims_vec.push_back(reshape_dims_vec.size()); + reshape_dims_vec.push_back(repeat_times[i]); + reshape_dims_vec.push_back(vec_in_dims[i]); + } + + int dims = reduce_dims_vec.size(); + bool just_copy = true; + for (size_t i = 0; i < repeat_times.size(); i++) { + if (repeat_times[i] != 1) { + just_copy = false; + break; + } + } + // no need reduce, just copy + if (just_copy) { + auto* in0 = context.Input(framework::GradVarName("Out")); + auto* out0 = context.Output(framework::GradVarName("X")); + out0->mutable_data(context.GetPlace()); + framework::TensorCopy(*in0, context.GetPlace(), context.device_context(), + out0); + } else { + PADDLE_ENFORCE_GE(dims, 1, + platform::errors::InvalidArgument( + "The rank of the input 'Out@GRAD' for " + "expand_as_v2_grad op must be greater than or " + "equal to 1, but the value received is %d.", + dims)); + PADDLE_ENFORCE_LE(dims, MAX_RANK_SUPPORTED, + platform::errors::InvalidArgument( + "The rank of the input 'Out@GRAD' for " + "expand_as_v2_grad op must be less than or equal " + "to %d, but the value received is %d.", + MAX_RANK_SUPPORTED, dims)); + switch (dims) { REP_EXPAND_AS_GRAD_TEMPLATE(MAX_RANK_SUPPORTED) } + } + } + + protected: + template + void ExpandAsBackward(const framework::ExecutionContext& context, + const std::vector& reshape_dims_vec, + const std::vector& reduce_dims_vec) const { + size_t reshape_size = reshape_dims_vec.size(); + size_t reduce_size = reduce_dims_vec.size(); + auto* in0 = context.Input(framework::GradVarName("Out")); + auto* out0 = context.Output(framework::GradVarName("X")); + out0->mutable_data(context.GetPlace()); + auto x_grad = EigenVector::Flatten(*out0); + Eigen::DSizes reshape_dims; + for (size_t i = 0; i < reshape_size; ++i) { + reshape_dims[i] = reshape_dims_vec[i]; + } + Eigen::DSizes reduce_dims; + for (size_t i = 0; i < reduce_size; ++i) { + reduce_dims[i] = reduce_dims_vec[i]; + } + auto out_grad = EigenVector::Flatten(*in0); + x_grad.device( + *context.template device_context().eigen_device()) = + out_grad.reshape(reshape_dims) + .sum(reduce_dims) + .reshape(x_grad.dimensions()); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fake_dequantize_op.cc b/paddle/fluid/operators/fake_dequantize_op.cc index 0d2b951ee1c54..9b0328b0945ba 100644 --- a/paddle/fluid/operators/fake_dequantize_op.cc +++ b/paddle/fluid/operators/fake_dequantize_op.cc @@ -37,20 +37,49 @@ template struct ChannelDequantizeFunctor { void operator()(const platform::CPUDeviceContext& dev_ctx, const framework::Tensor* in, const framework::Tensor** scales, - const int scale_num, T max_range, framework::Tensor* out) { + const int scale_num, T max_range, const int quant_axis, + framework::Tensor* out) { if (scale_num == 1) { - const int channel = in->dims()[0]; + // Dequant op is before quantized op + // Dequantize the weight of quantized op + auto in_dims = in->dims(); + const int64_t channel = in_dims[quant_axis]; const T* scale_factor = scales[0]->data(); - for (int i = 0; i < channel; i++) { - T s = scale_factor[i]; - framework::Tensor one_channel_in = in->Slice(i, i + 1); - framework::Tensor one_channel_out = out->Slice(i, i + 1); - auto in_e = framework::EigenVector::Flatten(one_channel_in); - auto out_e = framework::EigenVector::Flatten(one_channel_out); - auto& dev = *dev_ctx.eigen_device(); - out_e.device(dev) = in_e * s / max_range; + if (quant_axis == 0) { + for (int64_t i = 0; i < channel; i++) { + T s = scale_factor[i]; + framework::Tensor one_channel_in = in->Slice(i, i + 1); + framework::Tensor one_channel_out = out->Slice(i, i + 1); + auto in_e = framework::EigenVector::Flatten(one_channel_in); + auto out_e = framework::EigenVector::Flatten(one_channel_out); + auto& dev = *dev_ctx.eigen_device(); + out_e.device(dev) = in_e * s / max_range; + } + } else if (quant_axis == 1) { + int64_t out_iter = 1; + for (int i = 0; i < quant_axis; i++) { + out_iter *= in_dims[i]; + } + int64_t step_i = in->numel() / out_iter; + int64_t step_j = in->numel() / (out_iter * channel); + auto* in_data = in->data(); + auto* out_data = out->mutable_data(dev_ctx.GetPlace()); + for (int64_t i = 0; i < out_iter; i++) { + for (int64_t j = 0; j < channel; j++) { + auto* cur_in = in_data + i * step_i + j * step_j; + auto* cur_out = out_data + i * step_i + j * step_j; + T s = scale_factor[j]; + for (int64_t k = 0; k < step_j; k++) { + *cur_out = (*cur_in) * s / max_range; + ++cur_in; + ++cur_out; + } + } + } } } else if (scale_num == 2) { + // Dequant op is after quantized op + // Dequantize the output tensor of quantized op int batch_size = in->dims()[0]; int channel = in->dims()[1]; const T* scale_one = scales[0]->data(); @@ -157,6 +186,18 @@ class FakeChannelWiseDequantizeMaxAbsOpMaker "Quantization bit numbers in quantization stage. " "The size of `quant_bits` should be equal to the size of `Scales`.") .SetDefault({8}); + AddAttr("quant_axis", + "(int, default 0) The axis for quantization. " + "For conv2d, depthwise_conv2d, conv2d_transpose " + "and mul, the quant_axis is equal to the cout axis.") + .SetDefault(0) + .AddCustomChecker([](const int& quant_axis) { + PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1, true, + platform::errors::InvalidArgument( + "'quant_axis' should be 0 or 1, but " + "the received is %d", + quant_axis)); + }); AddComment(R"DOC( FakeChannelWiseDequantizeMaxAbsOp operator. diff --git a/paddle/fluid/operators/fake_dequantize_op.cu b/paddle/fluid/operators/fake_dequantize_op.cu index 02f9dc827d68c..54a92b055a39d 100644 --- a/paddle/fluid/operators/fake_dequantize_op.cu +++ b/paddle/fluid/operators/fake_dequantize_op.cu @@ -45,8 +45,9 @@ struct DequantizeFunctor { }; template -__global__ void DequantizeOneScale(const T* in, const T* scale, T max_range, - int num, int channel, T* out) { +__global__ void DequantizeOneScaleQuantAxis0(const T* in, const T* scale, + T max_range, int num, int channel, + T* out) { int tid = threadIdx.x; int channel_size = num / channel; const T* in_c = in + blockIdx.x * channel_size; @@ -56,6 +57,23 @@ __global__ void DequantizeOneScale(const T* in, const T* scale, T max_range, } } +template +__global__ void DequantizeOneScaleQuantAxis1(const T* in, const T* scale, + T max_range, const int num, + const int cin, const int cout, + T* out) { + int cout_wh_size = num / cin; + int wh_size = cout_wh_size / cout; + + T s = scale[blockIdx.x]; + const T* in_current = in + threadIdx.x * cout_wh_size + blockIdx.x * wh_size; + T* out_current = out + threadIdx.x * cout_wh_size + blockIdx.x * wh_size; + + for (int i = 0; i < wh_size; i++) { + out_current[i] = in_current[i] * s / max_range; + } +} + template __global__ void DequantizeTwoScale(const T* in, const T* scale_one, const T* scale_two, T max_range, int num, @@ -74,18 +92,29 @@ template struct ChannelDequantizeFunctor { void operator()(const platform::CUDADeviceContext& dev_ctx, const framework::Tensor* in, const framework::Tensor** scales, - const int scale_num, T max_range, framework::Tensor* out) { + const int scale_num, T max_range, const int quant_axis, + framework::Tensor* out) { + auto in_dims = in->dims(); const T* in_data = in->data(); T* out_data = out->mutable_data(dev_ctx.GetPlace()); if (scale_num == 1) { int num = in->numel(); - int channel = in->dims()[0]; const T* scale_factor = scales[0]->data(); - int block = 1024; - int grid = channel; - DequantizeOneScale<<>>( - in_data, scale_factor, max_range, num, channel, out_data); + if (quant_axis == 0) { + int grid = in_dims[0]; + int block = 1024; + DequantizeOneScaleQuantAxis0<<>>( + in_data, scale_factor, max_range, num, in_dims[0], out_data); + } else if (quant_axis == 1) { + // Dequantize weight of Cin * Cout * W * H + int grid = in_dims[1]; + int block = in_dims[0]; + DequantizeOneScaleQuantAxis1<<>>( + in_data, scale_factor, max_range, num, in_dims[0], in_dims[1], + out_data); + } } else if (scale_num == 2) { + // Not need to consider quant_axis int num = in->numel(); int batch_size = in->dims()[0]; int channel = in->dims()[1]; diff --git a/paddle/fluid/operators/fake_dequantize_op.h b/paddle/fluid/operators/fake_dequantize_op.h index 500960098f5ce..6ddb12771fd51 100644 --- a/paddle/fluid/operators/fake_dequantize_op.h +++ b/paddle/fluid/operators/fake_dequantize_op.h @@ -33,7 +33,7 @@ template struct ChannelDequantizeFunctor { void operator()(const DeviceContext& dev_ctx, const framework::Tensor* in, const framework::Tensor** scales, const int scale_num, - T max_range, framework::Tensor* out); + T max_range, const int quant_axis, framework::Tensor* out); }; template @@ -63,6 +63,7 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel { auto* out = ctx.Output("Out"); auto quant_bits = ctx.Attr>("quant_bits"); + auto quant_axis = ctx.Attr("quant_axis"); int max_range = 1; auto& dev_ctx = ctx.template device_context(); @@ -70,12 +71,12 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel { int scale_num = scales.size(); if (scale_num == 1) { PADDLE_ENFORCE_EQ( - scales[0]->numel(), in->dims()[0], + scales[0]->numel(), in->dims()[quant_axis], platform::errors::PreconditionNotMet( "The number of first scale values must be the same with " - "first dimension value of Input(X) when the `Scales` has only " - "one element, but %ld != %ld here.", - scales[0]->numel(), in->dims()[0])); + "quant_axis dimension value of Input(X) when the `Scales` has " + "only one element, but %ld != %ld here.", + scales[0]->numel(), in->dims()[quant_axis])); max_range *= (std::pow(2, quant_bits[0] - 1) - 1); } else if (scale_num == 2) { PADDLE_ENFORCE_EQ( @@ -94,7 +95,8 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel { (std::pow(2, quant_bits[1] - 1) - 1); } ChannelDequantizeFunctor()( - dev_ctx, in, scales.data(), scale_num, static_cast(max_range), out); + dev_ctx, in, scales.data(), scale_num, static_cast(max_range), + quant_axis, out); } }; diff --git a/paddle/fluid/operators/fake_quantize_op.cc b/paddle/fluid/operators/fake_quantize_op.cc index 358f122c8359f..04ac4a35208a5 100644 --- a/paddle/fluid/operators/fake_quantize_op.cc +++ b/paddle/fluid/operators/fake_quantize_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/fake_quantize_op.h" +#include #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/operators/clip_op.h" @@ -39,13 +40,41 @@ template struct FindAbsMaxFunctor; template struct FindChannelAbsMaxFunctor { - void operator()(const platform::CPUDeviceContext& ctx, const T* in, - const int num, const int channel, T* out) { - const int channel_size = num / channel; - for (int i = 0; i < channel; i++) { - auto* start = in + i * channel_size; - auto* end = in + (i + 1) * channel_size; - out[i] = std::abs(*(std::max_element(start, end, Compare()))); + void operator()(const platform::CPUDeviceContext& ctx, + const framework::Tensor& in_tensor, const int quant_axis, + T* out_abs_max) { + // At present, channelwise quantization supports conv2d, depthwise_conv2d + // conv2d_transpose and mul + PADDLE_ENFORCE_EQ( + quant_axis == 0 || quant_axis == 1, true, + platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " + "the received is %d", + quant_axis)); + auto* in_data = in_tensor.data(); + auto in_dims = in_tensor.dims(); + const int64_t channel = in_dims[quant_axis]; + if (quant_axis == 0) { + const int64_t channel_size = in_tensor.numel() / channel; + for (int64_t i = 0; i < channel; i++) { + auto* start = in_data + i * channel_size; + auto* end = in_data + (i + 1) * channel_size; + out_abs_max[i] = + std::abs(*(std::max_element(start, end, Compare()))); + } + } else if (quant_axis == 1) { + for (int64_t i = 0; i < channel; i++) { + out_abs_max[i] = 0; + } + const int64_t step_i = in_tensor.numel() / in_dims[0]; + const int64_t step_j = in_tensor.numel() / (in_dims[0] * in_dims[1]); + for (int64_t i = 0; i < in_dims[0]; i++) { + for (int64_t j = 0; j < in_dims[1]; j++) { + auto* start = in_data + i * step_i + j * step_j; + auto* end = in_data + i * step_i + (j + 1) * step_j; + T abs_max = std::abs(*(std::max_element(start, end, Compare()))); + out_abs_max[j] = std::max(out_abs_max[j], abs_max); + } + } } } }; @@ -92,26 +121,53 @@ template struct ChannelClipAndFakeQuantFunctor { void operator()(const platform::CPUDeviceContext& ctx, const framework::Tensor& in, const framework::Tensor& scale, - const int bin_cnt, const int channel, + const int bin_cnt, const int quant_axis, framework::Tensor* out) { + // At present, channelwise quantization supports conv2d, depthwise_conv2d + // conv2d_transpose and mul + PADDLE_ENFORCE_EQ( + quant_axis == 0 || quant_axis == 1, true, + platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " + "the received is %d", + quant_axis)); auto* scale_data = scale.data(); auto* in_data = in.data(); auto* out_data = out->mutable_data(ctx.GetPlace()); - const int channel_size = in.numel() / channel; + auto in_dims = in.dims(); + const int64_t channel = in_dims[quant_axis]; platform::Transform trans; - for (int i = 0; i < channel; i++) { - T s = scale_data[i]; - auto* start = in_data + i * channel_size; - auto* end = in_data + (i + 1) * channel_size; - trans(ctx, start, end, out_data + i * channel_size, - ClipFunctor(-s, s)); - } - for (int i = 0; i < channel; i++) { - T s = scale_data[i]; - T inv_s = inverse(s); - framework::Tensor one_channel_out = out->Slice(i, i + 1); - auto out_e = framework::EigenVector::Flatten(one_channel_out); - out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round(); + if (quant_axis == 0) { + const int64_t channel_size = in.numel() / channel; + for (int64_t i = 0; i < channel; i++) { + T s = scale_data[i]; + auto* start = in_data + i * channel_size; + auto* end = in_data + (i + 1) * channel_size; + trans(ctx, start, end, out_data + i * channel_size, + ClipFunctor(-s, s)); + } + for (int64_t i = 0; i < channel; i++) { + T s = scale_data[i]; + T inv_s = inverse(s); + framework::Tensor one_channel_out = out->Slice(i, i + 1); + auto out_e = framework::EigenVector::Flatten(one_channel_out); + out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round(); + } + } else if (quant_axis == 1) { + const int64_t step_i = in.numel() / in_dims[0]; + const int64_t step_j = in.numel() / (in_dims[0] * in_dims[1]); + for (int i = 0; i < in_dims[0]; i++) { + for (int j = 0; j < in_dims[1]; j++) { + T s = scale_data[j]; + T inv_s = inverse(s); + auto* start = in_data + i * step_i + j * step_j; + auto* end = in_data + i * step_i + (j + 1) * step_j; + auto* cur_out_data = out_data + i * step_i + j * step_j; + trans(ctx, start, end, cur_out_data, ClipFunctor(-s, s)); + for (int k = 0; k < step_j; k++) { + cur_out_data[k] = std::round(bin_cnt * inv_s * cur_out_data[k]); + } + } + } } } }; @@ -247,8 +303,9 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel { "FakeChannelWiseQuantizeAbsMax"); OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale", "FakeChannelWiseQuantizeAbsMax"); + int quant_axis = ctx->Attrs().Get("quant_axis"); ctx->SetOutputDim("Out", ctx->GetInputDim("X")); - ctx->SetOutputDim("OutScale", {ctx->GetInputDim("X")[0]}); + ctx->SetOutputDim("OutScale", {ctx->GetInputDim("X")[quant_axis]}); ctx->ShareLoD("X", /*->*/ "Out"); } @@ -269,6 +326,18 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker "(Tensor) Output of quantized low level tensor, " "but also saved as float data type."); AddOutput("OutScale", "(Tensor) Current channel wise scale"); + AddAttr("quant_axis", + "(int, default 0) The axis for quantization. " + "For conv2d, depthwise_conv2d, conv2d_transpose " + "and mul, the quant_axis is equal to the cout axis.") + .SetDefault(0) + .AddCustomChecker([](const int& quant_axis) { + PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1, true, + platform::errors::InvalidArgument( + "'quant_axis' should be 0 or 1, but " + "the received is %d", + quant_axis)); + }); AddAttr("bit_length", "(int, default 8)") .SetDefault(8) .AddCustomChecker([](const int& bit_length) { diff --git a/paddle/fluid/operators/fake_quantize_op.cu b/paddle/fluid/operators/fake_quantize_op.cu index 75a55fa821f0a..6ff3c7ec632f2 100644 --- a/paddle/fluid/operators/fake_quantize_op.cu +++ b/paddle/fluid/operators/fake_quantize_op.cu @@ -75,8 +75,8 @@ struct FindAbsMaxFunctor { template struct FindAbsMaxFunctor; template -__global__ void FindChannelAbsMaxKernel(const T* in, const int n, const int c, - T* out) { +__global__ void FindChannelAbsMaxKernelQuantAxis0(const T* in, const int n, + const int c, T* out) { int tid = threadIdx.x; int channel_size = n / c; const T* in_c = in + blockIdx.x * channel_size; @@ -100,14 +100,69 @@ __global__ void FindChannelAbsMaxKernel(const T* in, const int n, const int c, } } +template +__global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n, + const int cin, const int cout, + T* out) { + extern __shared__ T shared_max_data[]; + int cout_wh_size = n / cin; + int wh_size = n / (cin * cout); + + int tid = threadIdx.x; + int bid = blockIdx.x; + const T* in_current = in + tid * cout_wh_size + bid * wh_size; + shared_max_data[tid] = T(0); + for (int i = 0; i < wh_size; i++) { + T tmp = fabs(in_current[i]); + if (tmp > shared_max_data[tid]) { + shared_max_data[tid] = tmp; + } + } + __syncthreads(); + + int len = blockDim.x; + for (int i = (len + 1) / 2; i > 0; len = i, i = (i + 1) / 2) { + if (tid < i && tid + i < len && + shared_max_data[tid] < shared_max_data[tid + i]) { + shared_max_data[tid] = shared_max_data[tid + i]; + } + if (i == 1) { + i = 0; // break the loop + } + __syncthreads(); + } + if (tid == 0) { + out[bid] = shared_max_data[0]; + } +} + template struct FindChannelAbsMaxFunctor { - void operator()(const platform::CUDADeviceContext& ctx, const T* in, - const int num, const int channel, T* out) { - int block = 1024; - int grid = channel; - FindChannelAbsMaxKernel<<>>( - in, num, channel, out); + void operator()(const platform::CUDADeviceContext& ctx, + const framework::Tensor& in_tensor, const int quant_axis, + T* out_abs_max) { + PADDLE_ENFORCE_EQ( + quant_axis == 0 || quant_axis == 1, true, + platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " + "the received is %d", + quant_axis)); + const int num = in_tensor.numel(); + auto in_dims = in_tensor.dims(); + int channel = in_dims[quant_axis]; + const T* in_data = in_tensor.data(); + if (quant_axis == 0) { + int grid = channel; + int block = 1024; + FindChannelAbsMaxKernelQuantAxis0< + T><<>>( + in_data, num, channel, out_abs_max); + } else if (quant_axis == 1) { + int grid = in_dims[1]; + int block = in_dims[0]; + FindChannelAbsMaxKernelQuantAxis1< + T><<>>( + in_data, num, in_dims[0], in_dims[1], out_abs_max); + } } }; @@ -189,10 +244,12 @@ struct ClipAndFakeQuantDequantFunctor { template struct ClipAndFakeQuantDequantFunctor; +// ChannelClipAndQuantKernel for quant_axis is 0 template -__global__ void ChannelClipAndQuantKernel(const T* in, const T* scale, - const int bin_cnt, const int n, - const int c, T* out) { +__global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale, + const int bin_cnt, + const int n, const int c, + T* out) { int tid = threadIdx.x; int channel_size = n / c; @@ -211,22 +268,57 @@ __global__ void ChannelClipAndQuantKernel(const T* in, const T* scale, } } +// ChannelClipAndQuantKernel for quant_axis is 1 +template +__global__ void ChannelClipAndQuantKernelQuantAxis1(const T* in, const T* scale, + const int bin_cnt, + const int n, const int cin, + const int cout, T* out) { + T s = scale[blockIdx.x % cout]; + T inv_s = inverse(s); + + int wh_size = n / (cin * cout); + const T* in_c = in + blockIdx.x * wh_size; + T* out_c = out + blockIdx.x * wh_size; + + for (int i = threadIdx.x; i < wh_size; i += blockDim.x) { + T x = in_c[i]; + T v = x > s ? s : x; + v = v < -s ? -s : v; + v = bin_cnt * inv_s * v; + out_c[i] = round(v); + } +} + template struct ChannelClipAndFakeQuantFunctor { void operator()(const platform::CUDADeviceContext& ctx, const framework::Tensor& in, const framework::Tensor& scale, - const int bin_cnt, const int channel, + const int bin_cnt, const int quant_axis, framework::Tensor* out) { - int num = in.numel(); - int block = 1024; - int grid = channel; + PADDLE_ENFORCE_EQ( + quant_axis == 0 || quant_axis == 1, true, + platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " + "the received is %d", + quant_axis)); + int num = in.numel(); + auto in_dims = in.dims(); const T* in_data = in.data(); const T* scale_data = scale.data(); T* out_data = out->mutable_data(ctx.GetPlace()); - ChannelClipAndQuantKernel<<>>( - in_data, scale_data, bin_cnt, num, channel, out_data); + if (quant_axis == 0) { + int grid = in_dims[0]; + int block = 1024; + ChannelClipAndQuantKernelQuantAxis0<<>>( + in_data, scale_data, bin_cnt, num, in_dims[0], out_data); + } else if (quant_axis == 1) { + int grid = in_dims[0] * in_dims[1]; + int block = 1024; + ChannelClipAndQuantKernelQuantAxis1<<>>( + in_data, scale_data, bin_cnt, num, in_dims[0], in_dims[1], out_data); + } } }; diff --git a/paddle/fluid/operators/fake_quantize_op.h b/paddle/fluid/operators/fake_quantize_op.h index 4136217fb0c5f..5c6e0b1f6e26d 100644 --- a/paddle/fluid/operators/fake_quantize_op.h +++ b/paddle/fluid/operators/fake_quantize_op.h @@ -61,15 +61,15 @@ struct FindRangeAbsMaxFunctor { template struct FindChannelAbsMaxFunctor { - void operator()(const DeviceContext& ctx, const T* in, const int num, - const int channel, T* out); + void operator()(const DeviceContext& ctx, const framework::Tensor& in_tensor, + const int quant_axis, T* out_abs_max); }; template struct ChannelClipAndFakeQuantFunctor { void operator()(const DeviceContext& ctx, const framework::Tensor& in, const framework::Tensor& scale, const int bin_cnt, - const int channel, framework::Tensor* out); + const int quant_axis, framework::Tensor* out); }; template @@ -144,12 +144,13 @@ class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel { int bit_length = context.Attr("bit_length"); int bin_cnt = std::pow(2, bit_length - 1) - 1; + int quant_axis = context.Attr("quant_axis"); auto& dev_ctx = context.template device_context(); - FindChannelAbsMaxFunctor()( - dev_ctx, in->data(), in->numel(), in->dims()[0], out_scale_data); + FindChannelAbsMaxFunctor()(dev_ctx, *in, quant_axis, + out_scale_data); ChannelClipAndFakeQuantFunctor()( - dev_ctx, *in, *out_scale, bin_cnt, in->dims()[0], out); + dev_ctx, *in, *out_scale, bin_cnt, quant_axis, out); } }; diff --git a/paddle/fluid/operators/gaussian_random_op.cc b/paddle/fluid/operators/gaussian_random_op.cc index 253078751ce66..898c063afdd43 100644 --- a/paddle/fluid/operators/gaussian_random_op.cc +++ b/paddle/fluid/operators/gaussian_random_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include + #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/fill_constant_op.h" #ifdef PADDLE_WITH_MKLDNN @@ -30,13 +31,13 @@ class CPUGaussianRandomKernel : public framework::OpKernel { float mean = context.Attr("mean"); float std = context.Attr("std"); auto* tensor = context.Output("Out"); - unsigned int seed = static_cast(context.Attr("seed")); std::minstd_rand engine; if (seed == 0) { seed = std::random_device()(); } engine.seed(seed); + std::normal_distribution dist(mean, std); const std::string op_type = "gaussian_random"; diff --git a/paddle/fluid/operators/log_softmax_op.cc b/paddle/fluid/operators/log_softmax_op.cc new file mode 100644 index 0000000000000..d6e2b3ecff8c8 --- /dev/null +++ b/paddle/fluid/operators/log_softmax_op.cc @@ -0,0 +1,128 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/log_softmax_op.h" +#include +#include +#include "paddle/fluid/operators/common_infer_shape_functions.h" + +namespace paddle { +namespace operators { + +class LogSoftmaxOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + return UnaryOpUnchangedInferShapeCheckAxis(ctx); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); + } +}; + +class LogSoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "The input tensor of softmax, " + "whose dimension :attr:`axis` is the input_feature_dimensions."); + AddOutput("Out", "The normalized values with the same shape as X."); + AddAttr("axis", + "The dimension index of Input(x) to perform log_softmax," + "default -1 for last dimension") + .SetDefault(-1); + AddComment(R"DOC( +LogSoftmax Operator. + +)DOC"); + } +}; + +class LogSoftmaxOpInferVarType + : public framework::PassInDtypeAndVarTypeToOutput { + protected: + std::unordered_map& GetInputOutputWithSameType() + const override { + static std::unordered_map m{{"X", /*->*/ "Out"}}; + return m; + } +}; + +class LogSoftmaxGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "log_softmax_grad"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + "Out@grad", "log_softmax_grad"); + PADDLE_ENFORCE_EQ( + ctx->GetInputDim("Out"), + ctx->GetInputDim(framework::GradVarName("Out")), + platform::errors::InvalidArgument("Input(Out) and its gradients " + "should have the same shape.")); + + ctx->SetOutputDim(framework::GradVarName("X"), + ctx->GetInputDim(framework::GradVarName("Out"))); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); + } +}; + +template +class LogSoftmaxGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("log_softmax_grad"); + op->SetInput("Out", this->Output("Out")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(log_softmax, ops::LogSoftmaxOp, ops::LogSoftmaxOpMaker, + ops::LogSoftmaxOpInferVarType, + ops::LogSoftmaxGradOpMaker, + ops::LogSoftmaxGradOpMaker); +REGISTER_OPERATOR(log_softmax_grad, ops::LogSoftmaxGradOp); + +REGISTER_OP_CPU_KERNEL( + log_softmax, + ops::LogSoftmaxKernel, + ops::LogSoftmaxKernel); +REGISTER_OP_CPU_KERNEL( + log_softmax_grad, + ops::LogSoftmaxGradKernel, + ops::LogSoftmaxGradKernel); diff --git a/paddle/fluid/operators/log_softmax_op.cu b/paddle/fluid/operators/log_softmax_op.cu new file mode 100644 index 0000000000000..02fca246d241d --- /dev/null +++ b/paddle/fluid/operators/log_softmax_op.cu @@ -0,0 +1,26 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/log_softmax_op.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_CUDA_KERNEL( + log_softmax, ops::LogSoftmaxKernel, + ops::LogSoftmaxKernel, + ops::LogSoftmaxKernel); +REGISTER_OP_CUDA_KERNEL( + log_softmax_grad, ops::LogSoftmaxGradKernel, + ops::LogSoftmaxGradKernel, + ops::LogSoftmaxGradKernel); diff --git a/paddle/fluid/operators/log_softmax_op.h b/paddle/fluid/operators/log_softmax_op.h new file mode 100644 index 0000000000000..b983ac54157d9 --- /dev/null +++ b/paddle/fluid/operators/log_softmax_op.h @@ -0,0 +1,192 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +using EigenMatrix = framework::EigenMatrix; + +static inline int CanonicalAxis(const int axis, const int rank) { + if (axis < 0) { + return axis + rank; + } + return axis; +} + +static inline int SizeToAxis(const int axis, const framework::DDim dims) { + int size = 1; + for (int i = 0; i < axis; i++) { + size *= dims[i]; + } + return size; +} + +static inline int SizeFromAxis(const int axis, const framework::DDim dims) { + int size = 1; + for (int i = axis; i < dims.size(); i++) { + size *= dims[i]; + } + return size; +} + +template +struct ValueClip { + HOSTDEVICE T operator()(const T& x) const { + const T kThreshold = static_cast(-64.); + return x < kThreshold ? kThreshold : x; + } +}; + +template +struct LogSoftmaxFunctor { + void operator()(const DeviceContext& context, const framework::Tensor* X, + framework::Tensor* Y, const int axis) { + constexpr int kBatchDim = 0; + constexpr int kClassDim = 1; + constexpr int kAxisDim = 1; + + int axis_dim = X->dims()[axis]; + const int n = SizeToAxis(axis, X->dims()); + const int d = SizeFromAxis(axis, X->dims()); + framework::DDim dim_2d{n, d}; + + auto logits = EigenMatrix::From(*X, dim_2d); + auto log_softmax = EigenMatrix::From(*Y, dim_2d); + + const int batch_size = logits.dimension(kBatchDim); + const int num_classes = logits.dimension(kClassDim); + const int num_remain = num_classes / axis_dim; + + Eigen::DSizes along_axis(kAxisDim); + Eigen::DSizes batch_classes(batch_size, num_classes); + Eigen::DSizes batch_by_one(batch_size, 1); + Eigen::DSizes one_by_class(1, num_classes); + Eigen::DSizes batch_one_remain(batch_size, 1, num_remain); + Eigen::DSizes one_axis_one(1, axis_dim, 1); + Eigen::DSizes one_axis(1, axis_dim); + Eigen::DSizes batch_axis_remain(batch_size, axis_dim, num_remain); + + // For numerical stability, logits should be shifted by maximum number along + // axis, calculate shifted_logits into log_softmax tensor for memory reuse. + if (num_remain == 1) { + // axis == -1, axis and class in same dimension, calculate along + // class dimension directly for higher performance + log_softmax.device(*context.eigen_device()) = + (logits - + logits.maximum(along_axis) + .eval() + .reshape(batch_by_one) + .broadcast(one_by_class)) + .unaryExpr(ValueClip()); + } else { + // axis != -1, class dimension split into (axis, remain), max and sum + // should be calculated along axis dimension + log_softmax.device(*context.eigen_device()) = + (logits.reshape(batch_axis_remain) - + logits.reshape(batch_axis_remain) + .maximum(along_axis) + .eval() + .reshape(batch_one_remain) + .broadcast(one_axis_one) + .reshape(batch_classes)) + .unaryExpr(ValueClip()); + } + + log_softmax.device(*context.eigen_device()) = + log_softmax - + log_softmax.exp() + .eval() + .reshape(batch_axis_remain) + .sum(along_axis) + .log() + .broadcast(one_axis); + } +}; + +template +class LogSoftmaxKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* X = context.Input("X"); + auto* Out = context.Output("Out"); + const int rank = X->dims().size(); + const int axis = CanonicalAxis(context.Attr("axis"), rank); + + // allocate memory on device. + Out->mutable_data(context.GetPlace()); + + LogSoftmaxFunctor()( + context.template device_context(), X, Out, axis); + } +}; + +template +struct LogSoftmaxGradFunctor { + void operator()(const DeviceContext& context, const framework::Tensor* Y, + const framework::Tensor* dY, framework::Tensor* dX, + const int axis) { + constexpr int kBatchDim = 0; + constexpr int kClassDim = 1; + + const int n = SizeToAxis(axis, Y->dims()); + const int d = SizeFromAxis(axis, Y->dims()); + framework::DDim dim_2d{n, d}; + + auto y = EigenMatrix::From(*Y, dim_2d); + auto dy = EigenMatrix::From(*dY, dim_2d); + auto dx = EigenMatrix::From(*dX, dim_2d); + + const int axis_dim = Y->dims()[axis]; + const int batch_size = y.dimension(kBatchDim); + const int num_classes = y.dimension(kClassDim); + const int num_remain = num_classes / axis_dim; + + Eigen::DSizes along_class(kClassDim); + Eigen::DSizes batch_axis_remain(batch_size, axis_dim, num_remain); + Eigen::DSizes one_axis(1, axis_dim); + + dx.device(*context.eigen_device()) = + dy - + (y.exp()) * (dy.reshape(batch_axis_remain) + .sum(along_class) + .broadcast(one_axis)); + } +}; + +template +class LogSoftmaxGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* Out = context.Input("Out"); + auto* dOut = + context.Input(framework::GradVarName("Out")); + auto* dX = context.Output(framework::GradVarName("X")); + const int rank = Out->dims().size(); + const int axis = CanonicalAxis(context.Attr("axis"), rank); + + // allocate memory on device. + dX->mutable_data(context.GetPlace()); + + LogSoftmaxGradFunctor()( + context.template device_context(), Out, dOut, dX, axis); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/randint_op.cc b/paddle/fluid/operators/randint_op.cc index 9f6df3f32b746..11ce738e00151 100644 --- a/paddle/fluid/operators/randint_op.cc +++ b/paddle/fluid/operators/randint_op.cc @@ -14,6 +14,7 @@ #include #include + #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/uniform_random_op.h" @@ -37,11 +38,11 @@ class CPURandintKernel : public framework::OpKernel { new_shape = GetNewDataFromShapeTensorList(list_new_shape_tensor); } } - auto* out = ctx.Output("Out"); if (!new_shape.empty()) out->Resize(framework::make_ddim(new_shape)); T* data = out->mutable_data(ctx.GetPlace()); int64_t size = out->numel(); + unsigned int seed = static_cast(ctx.Attr("seed")); std::minstd_rand engine; if (seed == 0) { diff --git a/paddle/fluid/operators/uniform_random_op.cc b/paddle/fluid/operators/uniform_random_op.cc index e0c56307639af..a4487cde27799 100644 --- a/paddle/fluid/operators/uniform_random_op.cc +++ b/paddle/fluid/operators/uniform_random_op.cc @@ -13,8 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/uniform_random_op.h" #include +#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" + namespace paddle { namespace operators { @@ -55,19 +57,40 @@ class CPUUniformRandomKernel : public framework::OpKernel { "supports SelectedRows and LoDTensor"); } T *data = tensor->mutable_data(ctx.GetPlace()); - unsigned int seed = static_cast(ctx.Attr("seed")); - std::minstd_rand engine; - if (seed == 0) { - seed = std::random_device()(); - } - engine.seed(seed); + + int64_t size = tensor->numel(); std::uniform_real_distribution dist( static_cast(ctx.Attr("min")), static_cast(ctx.Attr("max"))); - int64_t size = tensor->numel(); - for (int64_t i = 0; i < size; ++i) { - data[i] = dist(engine); + auto gen_ptr = framework::Generator::GetInstance(); + if (gen_ptr->is_init_py) { + std::mt19937_64 &gen_engine = gen_ptr->GetCPUEngine(); + // auto gen_engine = gen_ptr_->GetCPUEngine(); + // std::uniform_real_distribution dist( + // static_cast(ctx.Attr("min")), + // static_cast(ctx.Attr("max"))); + + for (int64_t i = 0; i < size; ++i) { + data[i] = dist(gen_engine); + } + } else { + unsigned int seed = static_cast(ctx.Attr("seed")); + std::minstd_rand engine; + if (seed == 0) { + seed = std::random_device()(); + } + engine.seed(seed); + // std::uniform_real_distribution dist( + // static_cast(ctx.Attr("min")), + // static_cast(ctx.Attr("max"))); + // int64_t size = tensor->numel(); + for (int64_t i = 0; i < size; ++i) { + data[i] = dist(engine); + } } + // std::mt19937_64 &engine = gen_ptr->GetCPUEngine(); + // auto engine = gen_ptr_->GetCPUEngine(); + unsigned int diag_num = static_cast(ctx.Attr("diag_num")); unsigned int diag_step = diff --git a/paddle/fluid/operators/uniform_random_op.cu b/paddle/fluid/operators/uniform_random_op.cu index 53c79cf672e7d..c024bb87b09c0 100644 --- a/paddle/fluid/operators/uniform_random_op.cu +++ b/paddle/fluid/operators/uniform_random_op.cu @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include #include +#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/uniform_random_op.h" @@ -87,9 +88,14 @@ class GPUUniformRandomKernel : public framework::OpKernel { } T* data = tensor->mutable_data(context.GetPlace()); unsigned int seed = static_cast(context.Attr("seed")); - if (seed == 0) { - std::random_device rd; - seed = rd(); + if (framework::Generator::GetInstance()->is_init_py) { + seed = static_cast( + framework::Generator::GetInstance()->GetCurrentSeed()); + } else { + if (seed == 0) { + std::random_device rd; + seed = rd(); + } } T min = static_cast(context.Attr("min")); T max = static_cast(context.Attr("max")); diff --git a/paddle/fluid/operators/uniform_random_op.h b/paddle/fluid/operators/uniform_random_op.h index 867b10441640c..d263dd03dd0de 100644 --- a/paddle/fluid/operators/uniform_random_op.h +++ b/paddle/fluid/operators/uniform_random_op.h @@ -17,6 +17,7 @@ #include #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/platform/dynload/cudnn.h b/paddle/fluid/platform/dynload/cudnn.h index 0eb28f0c0c356..ebeb14e940e5f 100644 --- a/paddle/fluid/platform/dynload/cudnn.h +++ b/paddle/fluid/platform/dynload/cudnn.h @@ -100,6 +100,7 @@ extern void EnforceCUDNNLoaded(const char* fn_name); __macro(cudnnCreateDropoutDescriptor); \ __macro(cudnnDropoutGetStatesSize); \ __macro(cudnnSetDropoutDescriptor); \ + __macro(cudnnRestoreDropoutDescriptor); \ __macro(cudnnCreateRNNDescriptor); \ __macro(cudnnGetRNNParamsSize); \ __macro(cudnnGetRNNWorkspaceSize); \ diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 318a45919af72..ef19fcc547555 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -1,7 +1,7 @@ set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapper prune feed_fetch_method pass_builder parallel_executor profiler layer tracer engine scope_pool analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context - gloo_wrapper infer_io_utils heter_wrapper) + gloo_wrapper infer_io_utils heter_wrapper generator) if (WITH_NCCL) set(PYBIND_DEPS ${PYBIND_DEPS} nccl_wrapper) @@ -37,7 +37,8 @@ set(PYBIND_SRCS data_set_py.cc imperative.cc ir.cc - inference_api.cc) + inference_api.cc + generator_py.cc) if (WITH_CRYPTO) set(PYBIND_DEPS ${PYBIND_DEPS} paddle_crypto) diff --git a/paddle/fluid/pybind/generator_py.cc b/paddle/fluid/pybind/generator_py.cc new file mode 100644 index 0000000000000..3bccd5fb2dd92 --- /dev/null +++ b/paddle/fluid/pybind/generator_py.cc @@ -0,0 +1,51 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include + +#ifdef _POSIX_C_SOURCE +#undef _POSIX_C_SOURCE +#endif + +#ifdef _XOPEN_SOURCE +#undef _XOPEN_SOURCE +#endif + +#include +#include +#include + +#include "paddle/fluid/framework/generator.h" +#include "paddle/fluid/pybind/generator_py.h" + +namespace py = pybind11; + +namespace paddle { +namespace pybind { +void BindGenerator(py::module* m) { + py::class_(*m, "GeneratorState", ""); + py::class_(*m, "mt19937_64", ""); + py::class_>( + *m, "Generator") + .def(py::init([]() { return framework::Generator::GetInstanceX(); }), + py::return_value_policy::reference) + .def("get_state", &framework::Generator::GetState, + py::return_value_policy::move) + .def("set_state", &framework::Generator::SetState) + .def("manual_seed", &framework::Generator::SetCurrentSeed) + .def("seed", &framework::Generator::Seed) + .def("initial_seed", &framework::Generator::GetCurrentSeed) + .def("random", &framework::Generator::Random64) + .def("get_cpu_engine", &framework::Generator::GetCPUEngine, + py::return_value_policy::move) + .def("set_cpu_engine", &framework::Generator::SetCPUEngine); +} // end Generator +} // end namespace pybind +} // end namespace paddle diff --git a/paddle/fluid/pybind/generator_py.h b/paddle/fluid/pybind/generator_py.h new file mode 100644 index 0000000000000..d37654c1ba24e --- /dev/null +++ b/paddle/fluid/pybind/generator_py.h @@ -0,0 +1,28 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +namespace py = pybind11; + +namespace paddle { +namespace pybind { + +void BindGenerator(py::module* m); + +} // namespace pybind +} // namespace paddle diff --git a/paddle/fluid/pybind/op_function.h b/paddle/fluid/pybind/op_function.h index b33555759f8ff..70b321f658cd2 100644 --- a/paddle/fluid/pybind/op_function.h +++ b/paddle/fluid/pybind/op_function.h @@ -90,15 +90,36 @@ CastPyHandleToVarBaseList(const std::string& op_type, return result; } // namespace pybind -static inline void ConstructAttrMapFromPyArgs(framework::AttributeMap* attrs, +static inline void ConstructAttrMapFromPyArgs(const std::string& op_type, + int start_idx, + framework::AttributeMap* attrs, const py::args& args) { PADDLE_ENFORCE_EQ( args.size() % 2, 0, platform::errors::InvalidArgument( "The number of arguments for arributes should be even.")); for (size_t i = 0; i < args.size(); i += 2) { - auto name = args[i].cast(); - auto value = args[i + 1].cast(); + std::string name; + framework::Attribute value; + try { + name = args[i].cast(); + } catch (std::exception& e) { + PyObject* py_obj = args[i].ptr(); // get underlying PyObject + PADDLE_THROW(platform::errors::InvalidArgument( + "%s(): argument (position %d) must be str, but got " + "%s", + op_type, start_idx + i, Py_TYPE(py_obj)->tp_name)); + } + try { + value = args[i + 1].cast(); + } catch (std::exception& e) { + PyObject* py_obj = args[i + 1].ptr(); // get underlying PyObject + PADDLE_THROW(platform::errors::InvalidArgument( + "%s(): argument (position %d) must be " + "Attribute type (one of str, bool, int, int64, float, or list of " + "them), but got %s", + op_type, start_idx + i + 1, Py_TYPE(py_obj)->tp_name)); + } (*attrs)[name] = value; } } diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index 93ba9feedf95b..a770275bf08c0 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -57,6 +57,9 @@ std::map> op_outs_map = { {"batch_norm", {"Y", "MeanOut", "VarianceOut", "SavedMean", "SavedVariance", "ReserveSpace"}}, + {"sync_batch_norm", + {"Y", "MeanOut", "VarianceOut", "SavedMean", "SavedVariance", + "ReserveSpace"}}, }; // NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are @@ -76,6 +79,7 @@ std::map> op_passing_outs_map = { {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}}, {"momentum", {"ParamOut", "VelocityOut"}}, {"batch_norm", {"MeanOut", "VarianceOut"}}, + {"sync_batch_norm", {"MeanOut", "VarianceOut"}}, {"accuracy", {"Correct", "Total"}}, {"fill_constant", {"Out"}}, {"matmul", {"Out"}}, @@ -146,7 +150,7 @@ R"( { %s framework::AttributeMap attrs; - ConstructAttrMapFromPyArgs(&attrs, args); + ConstructAttrMapFromPyArgs("%s", %d, &attrs, args); { py::gil_scoped_release release; auto tracer = imperative::GetCurrentTracer(); @@ -204,6 +208,7 @@ GenerateOpFunctions(const std::string& module_name) { std::string ins_initializer_with_null = ""; std::string py_arg = ""; int arg_idx = 0; + int input_args_num = 0; std::string ins_cast_str = ""; for (auto& input : op_proto->inputs()) { auto& in_name = input.name(); @@ -216,6 +221,7 @@ GenerateOpFunctions(const std::string& module_name) { paddle::string::Sprintf(ARG_TEMPLATE, in_type, TempName(in_name)); input_args += input_arg; input_args += ","; + input_args_num++; const auto in_cast_type = input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE; ins_cast_str += @@ -269,6 +275,7 @@ GenerateOpFunctions(const std::string& module_name) { } input_args += out_type; input_args += out_name; + input_args_num++; if (output.dispensable()) { const auto out_template = @@ -295,6 +302,7 @@ GenerateOpFunctions(const std::string& module_name) { auto out_num_str = paddle::string::Sprintf(ARG_OUT_NUM, out_name); input_args += ARG_OUT_NUM_TYPE; input_args += out_num_str; + input_args_num++; outs_initializer += paddle::string::Sprintf( OUT_DUPLICABLE_INITIALIZER_TEMPLATE, out_name, out_num_str); } else { @@ -334,9 +342,9 @@ GenerateOpFunctions(const std::string& module_name) { // generate op funtcion body auto op_function_str = paddle::string::Sprintf( OP_FUNCTION_TEMPLATE, return_type, func_name, function_args, - ins_cast_str, outs_initializer, ins_initializer, - ins_initializer_with_null + outs_initializer_with_null, op_type, - return_str); + ins_cast_str, op_type, input_args_num, outs_initializer, + ins_initializer, ins_initializer_with_null + outs_initializer_with_null, + op_type, return_str); // generate pybind item auto bind_function_str = paddle::string::Sprintf( diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index f426ca82966d3..635a81dff0d37 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -64,6 +64,7 @@ limitations under the License. */ #include "paddle/fluid/pybind/data_set_py.h" #include "paddle/fluid/pybind/exception.h" #include "paddle/fluid/pybind/fleet_wrapper_py.h" +#include "paddle/fluid/pybind/generator_py.h" #include "paddle/fluid/pybind/global_value_getter_setter.h" #include "paddle/fluid/pybind/gloo_wrapper_py.h" #include "paddle/fluid/pybind/heter_wrapper_py.h" @@ -2503,6 +2504,7 @@ All parameter, weight, gradient are variables in Paddle. BindNode(&m); BindInferenceApi(&m); BindDataset(&m); + BindGenerator(&m); #ifdef PADDLE_WITH_CRYPTO BindCrypto(&m); #endif diff --git a/paddle/scripts/paddle_build.bat b/paddle/scripts/paddle_build.bat index ad0583c08cd38..d3f3fc79c4d97 100644 --- a/paddle/scripts/paddle_build.bat +++ b/paddle/scripts/paddle_build.bat @@ -25,12 +25,13 @@ mkdir build cd /d build rem ------initialize the virtual environment------ -if not defined PYTHON_ROOT set PYTHON_ROOT=c:\Python37 +if not defined PYTHON_ROOT set PYTHON_ROOT=C:\Python37 +set PATH=%PYTHON_ROOT%;%PYTHON_ROOT%\Scripts;%PATH% + set PYTHON_EXECUTABLE=%PYTHON_ROOT%\python.exe %PYTHON_EXECUTABLE% -m pip install virtualenv -if not exist paddle_winci (%PYTHON_EXECUTABLE% -m virtualenv paddle_winci) +%PYTHON_EXECUTABLE% -m virtualenv paddle_winci call paddle_winci\Scripts\activate.bat -if %ERRORLEVEL% NEQ 0 exit /b %ERRORLEVEL% rem ------pre install requirement---------- where python @@ -38,8 +39,11 @@ where pip pip install --upgrade pip pip install wheel pip install gym -pip install -r %work_dir%\python\requirements.txt -if %ERRORLEVEL% NEQ 0 exit /b %ERRORLEVEL% +pip install -U -r %work_dir%\python\requirements.txt +if %ERRORLEVEL% NEQ 0 ( + call paddle_winci\Scripts\deactivate.bat + exit /b %ERRORLEVEL% +) rem ------initialize common variable------ if not defined CUDA_TOOLKIT_ROOT_DIR set CUDA_TOOLKIT_ROOT_DIR="C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.0" @@ -57,7 +61,7 @@ dir %cache_dir% set INFERENCE_DEMO_INSTALL_DIR=%cache_dir:\=/%/inference_demo if not exist %cache_dir%\tools ( - git clone https://github.com/zhouwei25/tools %cache_dir%\tools + git clone https://github.com/zhouwei25/tools.git %cache_dir%\tools if %ERRORLEVEL% NEQ 0 exit /b %ERRORLEVEL% ) @@ -124,7 +128,7 @@ goto:eof :cmake_error call paddle_winci\Scripts\deactivate.bat -echo cmake failed! +echo Cmake failed, will exit! exit /b 1 rem --------------------------------------------------------------------------------------------- @@ -133,39 +137,41 @@ echo ======================================== echo Step 2. Buile Paddle ... echo ======================================== call "C:\Program Files (x86)\Microsoft Visual Studio 14.0\VC\vcvarsall.bat" amd64 -set build_times=1 +set build_times=1 :build_tp -echo BUILD THIRD_PARTY %build_times% +echo Build third_party the %build_times% time: msbuild /m /p:Configuration=Release /verbosity:quiet third_party.vcxproj -echo BUILD THIRD_PARTY RESULT %ERRORLEVEL% if %ERRORLEVEL% NEQ 0 ( set /a build_times=%build_times%+1 if %build_times% GTR 3 ( exit /b 1 ) else ( + echo Build third_party failed, will retry! goto :build_tp ) ) +echo Build third_party successfully! set build_times=1 :build_paddle -echo BUILD PADDLE %build_times% -msbuild /m /p:Configuration=Release /verbosity:quiet paddle.sln -echo BUILD PADDLE RESULT %ERRORLEVEL% +echo Build Paddle the %build_times% time: +msbuild /m /p:Configuration=Release /verbosity:minimal paddle.sln if %ERRORLEVEL% NEQ 0 ( set /a build_times=%build_times%+1 if %build_times% GTR 2 ( exit /b 1 ) else ( + echo Build Paddle failed, will retry! goto :build_paddle ) ) +echo Build Paddle successfully! goto:eof :build_error call paddle_winci\Scripts\deactivate.bat -echo build paddle failed! +echo Build Paddle failed, will exit! exit /b 7 rem --------------------------------------------------------------------------------------------- @@ -185,6 +191,7 @@ goto:eof :test_whl_pacakage_error call paddle_winci\Scripts\deactivate.bat +echo Pip install whl package failed, will exit! exit /b 3 rem --------------------------------------------------------------------------------------------- @@ -206,6 +213,7 @@ goto:eof :unit_test_error call paddle_winci\Scripts\deactivate.bat +echo Running unit tests failed, will exit! exit /b 8 rem --------------------------------------------------------------------------------------------- diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 6235f1c12ba7e..518e2c0c4d90d 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -101,6 +101,7 @@ from .tensor.manipulation import cast #DEFINE_ALIAS from .tensor.manipulation import concat #DEFINE_ALIAS from .tensor.manipulation import expand #DEFINE_ALIAS +from .tensor.manipulation import broadcast_to #DEFINE_ALIAS from .tensor.manipulation import expand_as #DEFINE_ALIAS from .tensor.manipulation import tile #DEFINE_ALIAS from .tensor.manipulation import flatten #DEFINE_ALIAS @@ -125,6 +126,7 @@ from .tensor.manipulation import flip #DEFINE_ALIAS from .tensor.manipulation import unbind #DEFINE_ALIAS from .tensor.manipulation import roll #DEFINE_ALIAS +from .tensor.manipulation import chunk #DEFINE_ALIAS from .tensor.math import abs #DEFINE_ALIAS from .tensor.math import acos #DEFINE_ALIAS from .tensor.math import asin #DEFINE_ALIAS @@ -225,6 +227,8 @@ from .framework import InverseTimeDecay #DEFINE_ALIAS from .framework import PolynomialDecay #DEFINE_ALIAS from .framework import CosineDecay #DEFINE_ALIAS +from .framework import set_default_dtype #DEFINE_ALIAS +from .framework import get_default_dtype #DEFINE_ALIAS from .tensor.search import index_sample #DEFINE_ALIAS from .tensor.stat import mean #DEFINE_ALIAS diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index 31bfd482766cb..5531160d7c503 100755 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -81,6 +81,8 @@ def _set_distributed_strategy(self, dist_strategy): class DistributedStrategy(object): + __lock_attr = False + def __init__(self): """ DistributedStrategy is the main configuration entry for distributed training of Paddle. @@ -95,6 +97,13 @@ def __init__(self): """ self.strategy = distributed_strategy_pb2.DistributedStrategy() + self.__lock_attr = True + + def __setattr__(self, key, value): + if self.__lock_attr and not hasattr(self, key): + raise TypeError("%s is not a attribute of %s" % + (key, self.__class__.__name__)) + object.__setattr__(self, key, value) def save_to_prototxt(self, output): """ diff --git a/python/paddle/distributed/fleet/base/role_maker.py b/python/paddle/distributed/fleet/base/role_maker.py index 0cf909c98c057..856290a05046d 100644 --- a/python/paddle/distributed/fleet/base/role_maker.py +++ b/python/paddle/distributed/fleet/base/role_maker.py @@ -110,6 +110,14 @@ def role_id(self): """ raise NotImplementedError("Please implement this method in child class") + def node_num(self): + """ + Get the training node number + Returns: + int: node num + """ + raise NotImplementedError("Please implement this method in child class") + def get_trainer_endpoints(self): """ return trainer endpoints @@ -286,6 +294,14 @@ def server_num(self): self.generate_role() return self._trainers_num + def node_num(self): + """ + return the training node number + """ + if not self._role_is_generated: + self.generate_role() + return self._node_num + def get_trainer_endpoints(self): """ get endpoint of all trainers @@ -353,6 +369,8 @@ def _ps_env(self): self._trainers_num = trainers_num self._role = role self._current_id = current_id + self._node_num = len( + set([x.split(':')[0] for x in self._worker_endpoints])) def _collective_env(self): self._current_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) @@ -363,6 +381,8 @@ def _collective_env(self): assert self._worker_endpoints is not None, "can't find PADDLE_TRAINER_ENDPOINTS" self._worker_endpoints = self._worker_endpoints.split(",") self._trainers_num = len(self._worker_endpoints) + self._node_num = len( + set([x.split(':')[0] for x in self._worker_endpoints])) def _init_gloo_env(self): def init_gloo_instance(role="trainer"): @@ -513,12 +533,16 @@ def _user_defined_ps_env(self): self._cur_endpoint = self._worker_endpoints[self._current_id] elif self._role == Role.SERVER: self._cur_endpoint = self._server_endpoints[self._current_id] + self._node_num = len( + set([x.split(':')[0] for x in self._worker_endpoints])) def _user_defined_collective_env(self): self._worker_endpoints = self._kwargs.get("worker_endpoints") self._current_id = self._kwargs.get("current_id") self._trainers_num = len(self._worker_endpoints) self._training_role = Role.Worker + self._node_num = len( + set([x.split(':')[0] for x in self._worker_endpoints])) def generate_role(self): """ diff --git a/python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py index 78478b9691b21..b9ff31a068e7f 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py @@ -119,18 +119,26 @@ def _try_to_compile(self, startup_program, main_program, loss): local_build_strategy.nccl_comm_num = \ dist_strategy.nccl_comm_num + if self.user_defined_strategy.recompute == True: + logging.warn( + "set enable_sequential_execution=True since you have enable the recompute strategy" + ) + local_build_strategy.enable_sequential_execution = True + exe_strategy = self.user_defined_strategy.execution_strategy - node_num = self.role_maker.worker_num() + worker_num = self.role_maker.worker_num() + node_num = self.role_maker.node_num() if self.role_maker._is_collective: - assert node_num >= 1, "nccl2 node_num must >= 1, now:{}" % node_num + assert worker_num >= 1, "nccl2 worker_num must >= 1, now:{}" % worker_num - if node_num <= 1: + if worker_num <= 1: # local mode if local_build_strategy.nccl_comm_num > 1: logging.warn("set nccl_comm_num=1 since you only have 1 node.") local_build_strategy.nccl_comm_num = 1 + if node_num <= 1: if local_build_strategy.use_hierarchical_allreduce: logging.warn( "set hierachical_allreduce=False since you only have 1 node." diff --git a/python/paddle/distributed/fleet/meta_optimizers/recompute_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/recompute_optimizer.py index 96247474927b9..07b69f19e7ebd 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/recompute_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/recompute_optimizer.py @@ -30,7 +30,8 @@ def _set_basic_info(self, loss, role_maker, user_defined_optimizer, user_defined_strategy): super(RecomputeOptimizer, self)._set_basic_info( loss, role_maker, user_defined_optimizer, user_defined_strategy) - self.wrapped_opt._set_checkpoints([]) + self.wrapped_opt._set_checkpoints( + list(user_defined_strategy.recompute_configs["checkpoints"])) def _can_apply(self): if self.user_defined_strategy.recompute == True: diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 88dd815d937a4..7e0d8c0de5b78 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -89,6 +89,7 @@ from .io import save, load, load_program_state, set_program_state from .dygraph.checkpoint import save_dygraph, load_dygraph from .dygraph.varbase_patch_methods import monkey_patch_varbase +from . import generator Tensor = LoDTensor enable_imperative = enable_dygraph disable_imperative = disable_dygraph @@ -96,7 +97,7 @@ __all__ = framework.__all__ + executor.__all__ + \ trainer_desc.__all__ + transpiler.__all__ + \ parallel_executor.__all__ + lod_tensor.__all__ + \ - data_feed_desc.__all__ + compiler.__all__ + backward.__all__ + [ + data_feed_desc.__all__ + compiler.__all__ + backward.__all__ + generator.__all__ + [ 'io', 'initializer', 'embedding', diff --git a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py index 3097e1d82a9cb..244a621611060 100644 --- a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py +++ b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py @@ -29,6 +29,7 @@ from .quantization_pass import _get_op_input_var_names from .quantization_pass import _get_op_output_var_names from .quantization_pass import _get_output_name_index +from .quantization_pass import _channelwise_quant_axis1_ops __all__ = ['PostTrainingQuantization', 'WeightQuantization'] @@ -316,6 +317,7 @@ def __init__(self, self._out_scale_op_list = _out_scale_op_list self._quantized_weight_var_name = set() self._quantized_act_var_name = set() + self.weight_op_pairs = {} self._sampling_data = {} self._quantized_var_kl_threshold = {} self._quantized_var_min = {} @@ -436,6 +438,8 @@ def _optimize_fp32_model(self): graph = IrGraph(core.Graph(self._program.desc), for_test=True) graph = _remove_ctrl_vars(graph) graph = _apply_pass(self._scope, graph, 'conv_bn_fuse_pass') + graph = _apply_pass(self._scope, graph, 'depthwise_conv_bn_fuse_pass') + graph = _apply_pass(self._scope, graph, 'conv_transpose_bn_fuse_pass') self._program = graph.to_program() def _collect_target_varnames(self): @@ -446,10 +450,11 @@ def _collect_target_varnames(self): # TODO(juncaipeng), consider the name_scope of skip_quant _logger.info("Collect quantized variable names ...") - def collect_var_name(var_name_list, persistable_var_names): + def collect_var_name(var_name_list, persistable_var_names, op_type): for var_name in var_name_list: if var_name in persistable_var_names: self._quantized_weight_var_name.add(var_name) + self.weight_op_pairs[var_name] = op_type else: self._quantized_act_var_name.add(var_name) @@ -462,13 +467,15 @@ def collect_var_name(var_name_list, persistable_var_names): # For quantized ops, sample inputs and outputs if op_type in self._quantizable_op_type: collect_var_name( - _get_op_input_var_names(op), persistable_var_names) + _get_op_input_var_names(op), persistable_var_names, op_type) collect_var_name( - _get_op_output_var_names(op), persistable_var_names) + _get_op_output_var_names(op), persistable_var_names, + op_type) # For other op, only sample output scale elif op_type in self._out_scale_op_list: collect_var_name( - _get_op_output_var_names(op), persistable_var_names) + _get_op_output_var_names(op), persistable_var_names, + op_type) def _set_activation_persistable(self): ''' @@ -492,45 +499,75 @@ def _sample_threshold(self): Sample the input threshold(min, max, or abs_max) in every iterations. ''' assert self._algo in ["abs_max", "min_max"], \ - "The algo should be abs_max or min_max to sample min max value." - + "The algo should be abs_max or min_max for _sample_threshold." if self._algo == "abs_max": - # Only calculate abs_max value for weight for once - if self._quantized_var_abs_max == {}: - for var_name in self._quantized_weight_var_name: - var_tensor = _load_variable_data(self._scope, var_name) - abs_max_per_channel = [] - for i in range(var_tensor.shape[0]): - abs_max_per_channel.append( - float(np.max(np.abs(var_tensor[i])))) - self._quantized_var_abs_max[var_name] = abs_max_per_channel - for var_name in self._quantized_act_var_name: - var_tensor = _load_variable_data(self._scope, var_name) - abs_max_value = float(np.max(np.abs(var_tensor))) - if (var_name not in self._quantized_var_abs_max) or \ - (abs_max_value > self._quantized_var_abs_max[var_name]): - self._quantized_var_abs_max[var_name] = abs_max_value + self._sample_threshold_abs_max() elif self._algo == "min_max": - if self._quantized_var_min == {} and self._quantized_var_max == {}: - for var_name in self._quantized_weight_var_name: - var_tensor = _load_variable_data(self._scope, var_name) - min_per_channel = [] - max_per_channle = [] - for i in range(var_tensor.shape[0]): - min_per_channel.append(float(np.min(var_tensor[i]))) - max_per_channle.append(float(np.max(var_tensor[i]))) - self._quantized_var_min[var_name] = min_per_channel - self._quantized_var_max[var_name] = max_per_channle - for var_name in self._quantized_act_var_name: + self._sample_threshold_min_max() + + def _sample_threshold_abs_max(self): + assert self._algo == "abs_max", \ + "The algo should be abs_max for _sample_threshold_abs_max." + # Only calculate abs_max value for weight for once + if self._quantized_var_abs_max == {}: + for var_name in self._quantized_weight_var_name: + var_tensor = _load_variable_data(self._scope, var_name) + if self._weight_quantize_type == "abs_max": + abs_max_value = float(np.max(np.abs(var_tensor))) + elif self._weight_quantize_type == "channel_wise_abs_max": + abs_max_value = [] + if self.weight_op_pairs[ + var_name] in _channelwise_quant_axis1_ops: + for i in range(var_tensor.shape[1]): + abs_max_value.append( + float(np.max(np.abs(var_tensor[:, i])))) + else: + for i in range(var_tensor.shape[0]): + abs_max_value.append( + float(np.max(np.abs(var_tensor[i])))) + self._quantized_var_abs_max[var_name] = abs_max_value + + for var_name in self._quantized_act_var_name: + var_tensor = _load_variable_data(self._scope, var_name) + abs_max_value = float(np.max(np.abs(var_tensor))) + if (var_name not in self._quantized_var_abs_max) or \ + (abs_max_value > self._quantized_var_abs_max[var_name]): + self._quantized_var_abs_max[var_name] = abs_max_value + + def _sample_threshold_min_max(self): + assert self._algo == "min_max", \ + "The algo should be min_max for _sample_threshold_min_max." + if self._quantized_var_min == {} and self._quantized_var_max == {}: + for var_name in self._quantized_weight_var_name: var_tensor = _load_variable_data(self._scope, var_name) - min_value = float(np.min(var_tensor)) - max_value = float(np.max(var_tensor)) - if (var_name not in self._quantized_var_min) or \ - (min_value < self._quantized_var_min[var_name]): - self._quantized_var_min[var_name] = min_value - if (var_name not in self._quantized_var_max) or \ - (max_value > self._quantized_var_max[var_name]): - self._quantized_var_max[var_name] = max_value + if self._weight_quantize_type == "abs_max": + min_value = float(np.min(var_tensor)) + max_value = float(np.max(var_tensor)) + elif self._weight_quantize_type == "channel_wise_abs_max": + min_value = [] + max_value = [] + if self.weight_op_pairs[ + var_name] in _channelwise_quant_axis1_ops: + for i in range(var_tensor.shape[1]): + min_value.append(float(np.min(var_tensor[:, i]))) + max_value.append(float(np.max(var_tensor[:, i]))) + else: + for i in range(var_tensor.shape[0]): + min_value.append(float(np.min(var_tensor[i]))) + max_value.append(float(np.max(var_tensor[i]))) + self._quantized_var_min[var_name] = min_value + self._quantized_var_max[var_name] = max_value + + for var_name in self._quantized_act_var_name: + var_tensor = _load_variable_data(self._scope, var_name) + min_value = float(np.min(var_tensor)) + max_value = float(np.max(var_tensor)) + if (var_name not in self._quantized_var_min) or \ + (min_value < self._quantized_var_min[var_name]): + self._quantized_var_min[var_name] = min_value + if (var_name not in self._quantized_var_max) or \ + (max_value > self._quantized_var_max[var_name]): + self._quantized_var_max[var_name] = max_value def _save_input_threhold(self): ''' @@ -554,11 +591,6 @@ def _sample_data(self, iter): applied in every iteration. ''' assert self._algo == "KL", "The algo should be KL to sample data." - for var_name in self._quantized_weight_var_name: - if var_name not in self._sampling_data: - var_tensor = _load_variable_data(self._scope, var_name) - self._sampling_data[var_name] = var_tensor - if self._is_use_cache_file: for var_name in self._quantized_act_var_name: var_tensor = _load_variable_data(self._scope, var_name) @@ -584,15 +616,20 @@ def _calculate_kl_threshold(self): # Abs_max threshold for weights for var_name in self._quantized_weight_var_name: - weight_data = self._sampling_data[var_name] - weight_threshold = None + weight_data = _load_variable_data(self._scope, var_name) if self._weight_quantize_type == "abs_max": - weight_threshold = np.max(np.abs(weight_data)) + weight_threshold = float(np.max(np.abs(weight_data))) elif self._weight_quantize_type == "channel_wise_abs_max": weight_threshold = [] - for i in range(weight_data.shape[0]): - abs_max_value = np.max(np.abs(weight_data[i])) - weight_threshold.append(abs_max_value) + if self.weight_op_pairs[ + var_name] in _channelwise_quant_axis1_ops: + for i in range(weight_data.shape[1]): + weight_threshold.append( + float(np.max(np.abs(weight_data[:, i])))) + else: + for i in range(weight_data.shape[0]): + weight_threshold.append( + float(np.max(np.abs(weight_data[i])))) self._quantized_var_kl_threshold[var_name] = weight_threshold # KL threshold for activations diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index 8851bcc6440d4..0eef94896287a 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -111,6 +111,10 @@ "scale": [["X"], ["Out"]], } +_conv_ops = ['conv2d', 'depthwise_conv2d', 'conv2d_transpose'] + +_channelwise_quant_axis1_ops = ['conv2d_transpose', 'mul'] + def _get_op_input_var_names(op): """ """ @@ -185,10 +189,24 @@ def _is_input_all_not_persistable(graph, op_node): return is_input_all_not_persistable +def _check_grandchild_op_node(op_node, grandchild_op_name): + ''' + Check whether the fake_quant node has a grandchild op node named + grandchild_op_name. + ''' + for out1_var_node in op_node.outputs: + for out1_op_node in out1_var_node.outputs: + for out2_var_node in out1_op_node.outputs: + for out2_op_node in out2_var_node.outputs: + if out2_op_node.name() == grandchild_op_name: + return True + return False + + class QuantizationTransformPass(object): """ - Quantize the ops that have weights. Add quant and dequant ops for the quantized - ops's inputs. + Quantize the ops that have weights. Add quant and dequant ops for + the quantized ops's inputs. """ _supported_quantizable_op_type = [ 'conv2d', 'depthwise_conv2d', 'conv2d_transpose', 'mul', 'matmul' @@ -311,8 +329,8 @@ def __init__(self, if weight_quantize_type not in quant_type: raise ValueError( "Unknown weight_quantize_type: '%s'. It can only be " - "'abs_max' or 'channel_wise_abs_max' or 'range_abs_max' or 'moving_average_abs_max'." - % (str(weight_quantize_type))) + "'abs_max' or 'channel_wise_abs_max' or 'range_abs_max' " + "or 'moving_average_abs_max'." % (str(weight_quantize_type))) self._activation_quantize_type = activation_quantize_type self._weight_quantize_type = weight_quantize_type @@ -323,7 +341,6 @@ def __init__(self, for op in self._quantizable_ops: assert op in QuantizationTransformPass._supported_quantizable_op_type, \ op + " is not supported for quantization." - self._conv_ops = ['conv2d', 'depthwise_conv2d'] self._quantizable_grad_ops = [ '%s_grad' % (op) for op in self._quantizable_ops ] @@ -356,10 +373,12 @@ def _quant_preprocess(op_node): user_skipped = False if isinstance(self._skip_pattern, list): user_skipped = op_node.op().has_attr("op_namescope") and \ - any(pattern in op_node.op().attr("op_namescope") for pattern in self._skip_pattern) + any(pattern in op_node.op().attr("op_namescope") \ + for pattern in self._skip_pattern) elif isinstance(self._skip_pattern, str): user_skipped = op_node.op().has_attr("op_namescope") and \ - op_node.op().attr("op_namescope").find(self._skip_pattern) != -1 + op_node.op().attr("op_namescope").find( + self._skip_pattern) != -1 if user_skipped: op_node.op()._set_attr("skip_quant", True) @@ -373,15 +392,11 @@ def _transform_forward(graph, op): if var_node.name() in dequantized_vars: dequant_var_node = dequantized_vars[var_node.name()] else: - name = var_node.name() if name in processed_vars: continue - - if var_node.name() in persistable_vars: - is_weight = True - else: - is_weight = False + is_weight = True if var_node.name() in persistable_vars \ + else False # if var node is weight and weight_preprocess_func is not None, # will insert weight preprocess func @@ -415,20 +430,14 @@ def _transform_forward(graph, op): else self._activation_bits quant_type = self._weight_quantize_type if is_weight \ else self._activation_quantize_type - if quant_type == 'channel_wise_abs_max': - assert is_weight, "'channel_wise_abs_max' can only be applied on weights." - if op.name() in self._conv_ops: - quant_var_node, scale_var_node = self._insert_channel_quant_op( - graph, var_node, name, quant_bits) - dequant_var_node = self._insert_channel_dequant_op( - graph, quant_var_node, [scale_var_node], - [quant_bits]) - else: - quant_var_node, scale_var_node = self._insert_quant_op( - graph, var_node, name, quant_bits, 'abs_max') - dequant_var_node = self._insert_dequant_op( - graph, quant_var_node, scale_var_node, - quant_bits) + if quant_type == 'channel_wise_abs_max': # Weight quantization + quant_axis = 1 if op.name() in \ + _channelwise_quant_axis1_ops else 0 + quant_var_node, scale_var_node = self._insert_channel_quant_op( + graph, var_node, name, quant_bits, quant_axis) + dequant_var_node = self._insert_channel_dequant_op( + graph, quant_var_node, [scale_var_node], + [quant_bits], quant_axis) else: quant_var_node, scale_var_node = self._insert_quant_op( graph, var_node, name, quant_bits, quant_type) @@ -529,11 +538,19 @@ def _insert_quant_abs_max_op(self, graph, var_node, name, quant_bits): var_type=var_node.type(), shape=var_node.shape(), var_dtype=var_node.dtype()) - scale_var_node = graph.create_var_node( + scale_var_node = graph.create_persistable_node( name=self._quantized_scale_name(name), var_type=var_node.type(), shape=[1], var_dtype=var_node.dtype()) + data_type = 'float64' if var_node.dtype( + ) == core.VarDesc.VarType.FP64 else 'float32' + _init_var_node( + scale_var_node, + np.zeros( + scale_var_node.shape(), dtype=data_type), + self._scope, + self._place) quant_op_node = graph.create_op_node( op_type='fake_quantize_abs_max', attrs={ @@ -706,7 +723,8 @@ def _insert_quant_moving_average_abs_max_op(self, graph, var_node, name, return quant_var_node, scale_out_node - def _insert_channel_quant_op(self, graph, var_node, name, quant_bits): + def _insert_channel_quant_op(self, graph, var_node, name, quant_bits, + quant_axis): """ Insert fake_channel_wise_quantize_abs_max op in the graph. """ @@ -717,15 +735,24 @@ def _insert_channel_quant_op(self, graph, var_node, name, quant_bits): var_type=var_node.type(), shape=var_node.shape(), var_dtype=var_node.dtype()) - scale_var_node = graph.create_var_node( + scale_var_node = graph.create_persistable_node( name=self._quantized_scale_name(name), var_type=var_node.type(), - shape=[var_node.shape()[0]], + shape=[var_node.shape()[quant_axis]], var_dtype=var_node.dtype()) + data_type = 'float64' if var_node.dtype( + ) == core.VarDesc.VarType.FP64 else 'float32' + _init_var_node( + scale_var_node, + np.zeros( + scale_var_node.shape(), dtype=data_type), + self._scope, + self._place) quant_op_node = graph.create_op_node( op_type='fake_channel_wise_quantize_abs_max', attrs={ 'bit_length': quant_bits, + 'quant_axis': quant_axis, 'op_role': core.op_proto_and_checker_maker.OpRole.Forward }, inputs={'X': var_node}, @@ -763,7 +790,7 @@ def _insert_dequant_op(self, graph, var_node, scale_var_node, quant_bits): return dequant_var_node def _insert_channel_dequant_op(self, graph, var_node, scale_var_nodes, - quant_bits): + quant_bits, quant_axis): """ Insert fake_channel_wise_dequantize_max_abs in the graph. """ @@ -778,6 +805,7 @@ def _insert_channel_dequant_op(self, graph, var_node, scale_var_nodes, op_type='fake_channel_wise_dequantize_max_abs', attrs={ 'quant_bits': quant_bits, + 'quant_axis': quant_axis, 'op_role': core.op_proto_and_checker_maker.OpRole.Forward }, inputs={'X': var_node, @@ -1036,7 +1064,6 @@ def __init__(self, self._weight_bits = weight_bits self._activation_bits = activation_bits self._weight_quantize_type = weight_quantize_type - self._conv_ops = ['conv2d', 'depthwise_conv2d', 'conv2d_transpose'] self._fake_quant_op_names = _fake_quant_op_list self._fake_dequant_op_names = _fake_dequant_op_list self._op_input_rename_map = collections.OrderedDict() @@ -1063,34 +1090,37 @@ def apply(self, graph): if input_arg_name in graph.out_node_mapping_table.keys(): input_arg_name = graph.out_node_mapping_table[ input_arg_name] - if input_arg_name in persistable_vars: - if self._weight_quantize_type == 'abs_max': - param = self._load_var(input_arg_name) - scale_v = np.max(np.abs(param)) - elif self._weight_quantize_type == 'channel_wise_abs_max': - param = self._load_var(input_arg_name) - if len(param.shape) == 4: # conv2d or depthwise_conv2d - scale_v = [] - for i in range(param.shape[0]): - scale_v.append(np.max(np.abs(param[i]))) - else: - scale_v = np.max(np.abs(param)) + if input_arg_name not in persistable_vars: + scale_v = graph._find_node_by_name( + op_node.outputs, op_node.output('OutScale')[0]) + self._quant_var_scale_map[input_arg_name] = scale_v + else: + # Obtain scale from OutScale var node + scale_v = self._load_var(op_node.output('OutScale')[0]) + assert scale_v.ndim in [ + 1, 2 + ], "the dim of scale_v should be 1 or 2" + if scale_v.ndim == 2: + scale_v = scale_v[0] + if scale_v.size == 1: + scale_v = scale_v[0] else: - scale_v = self._load_var( - op_node.output('OutScale')[0])[0] + scale_v = scale_v.tolist() self._quant_var_scale_map[input_arg_name] = scale_v - self._remove_fake_quant_and_dequant_op(graph, op_node) - # quantize weight and restore + # Quantize weight and restore param_v = self._load_var(input_arg_name) - quantized_param_v = self._quant(param_v, scale_v, - self._weight_bits) + if isinstance(scale_v, list) and \ + any(_check_grandchild_op_node(op_node, op) + for op in _channelwise_quant_axis1_ops): + quant_axis = 1 + else: + quant_axis = 0 + quantized_param_v = self._quant( + param_v, scale_v, self._weight_bits, quant_axis) self._restore_var(input_arg_name, quantized_param_v) - else: - scale_v = graph._find_node_by_name( - op_node.outputs, op_node.output('OutScale')[0]) - self._quant_var_scale_map[input_arg_name] = scale_v + self._remove_fake_quant_and_dequant_op(graph, op_node) - # Remove all fake dequant op +# Remove all fake dequant op ops = graph.all_op_nodes() for op_node in ops: op_name = op_node.name() @@ -1103,8 +1133,7 @@ def apply(self, graph): op_node_desc = op_node.op() if op_node_desc.has_attr("quantization_type") and \ op_node_desc.attr("quantization_type") == "qat_with_weight": - if self._weight_quantize_type == 'channel_wise_abs_max' \ - and op_node.name() in self._conv_ops: + if self._weight_quantize_type == 'channel_wise_abs_max': self._insert_post_channel_dequant_op(graph, op_node) else: self._insert_post_dequant_op(graph, op_node) @@ -1295,10 +1324,15 @@ def _is_float(self, v): return isinstance(v, float) or isinstance(v, np.float32) \ or isinstance(v, np.float64) - def _quant(self, x, scale, num_bits): + def _quant(self, x, scale, num_bits, quant_axis): + assert quant_axis in [0, 1], 'quant_axis should be 0 or 1 for now.' if isinstance(scale, list): for i, s in enumerate(scale): - x[i] = np.round(x[i] / s * ((1 << (num_bits - 1)) - 1)) + if quant_axis == 0: + x[i] = np.round(x[i] / s * ((1 << (num_bits - 1)) - 1)) + else: + x[:, i] = np.round(x[:, i] / s * ( + (1 << (num_bits - 1)) - 1)) return x else: return np.round(x / scale * ((1 << (num_bits - 1)) - 1)) @@ -1468,6 +1502,10 @@ def apply(self, graph): for op in target_ops: for output_var_name in _get_op_output_var_names(op): in_node = graph._find_node_by_name(op.outputs, output_var_name) + if in_node.dtype() not in \ + [core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32]: + continue + scale_node = graph.create_persistable_node( name=self._scale_name(in_node.name()), var_type=core.VarDesc.VarType.LOD_TENSOR, @@ -1570,17 +1608,26 @@ def apply(self, graph): if op_node.name() in self._teller_set: var_names = _get_op_output_var_names(op_node) for var_name in var_names: - # For compatibility, we save output threshold by two methods. + in_node = graph._find_node_by_name(op_node.outputs, + var_name) + if in_node.dtype() not in \ + [core.VarDesc.VarType.FP64, core.VarDesc.VarType.FP32]: + continue + scale_name = self._scale_name(var_name) - scale_v = np.array( - self._scope.find_var(scale_name).get_tensor())[0] - op_node.op()._set_attr("out_threshold", float(scale_v)) + scale_var = self._scope.find_var(scale_name) + assert scale_var is not None, \ + "Can not find {} variable in the scope".format(scale_name) + scale_value = np.array(scale_var.get_tensor())[0] + + # For compatibility, we save output threshold by two methods. + op_node.op()._set_attr("out_threshold", float(scale_value)) argname_index = _get_output_name_index(op_node, var_name) assert argname_index is not None, \ var_name + " is not the output of the op" op_node.op()._set_attr(argname_index[0] + str(argname_index[1]) \ - + "_threshold", float(scale_v)) + + "_threshold", float(scale_value)) graph.resolve_hazard() return graph diff --git a/python/paddle/fluid/contrib/slim/tests/test_user_defined_quantization.py b/python/paddle/fluid/contrib/slim/tests/test_user_defined_quantization.py index c9ea15bf6cde9..32292c8a47b50 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_user_defined_quantization.py +++ b/python/paddle/fluid/contrib/slim/tests/test_user_defined_quantization.py @@ -33,34 +33,29 @@ os.environ["CPU_NUM"] = "1" -def residual_block(img, label, num=1): - def conv_bn_layer(input, - ch_out, - filter_size, - stride, - padding, - act='relu', - bias_attr=False): - tmp = fluid.layers.conv2d( - input=input, - filter_size=filter_size, - num_filters=ch_out, - stride=stride, - padding=padding, - use_cudnn=False, - act=None, - bias_attr=bias_attr) - return fluid.layers.batch_norm(input=tmp, act=act) - - hidden = img - for _ in six.moves.xrange(num): - conv = conv_bn_layer(hidden, 20, 3, 1, 1, act=None, bias_attr=True) - short = conv_bn_layer(hidden, 20, 1, 1, 0, act=None) - hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu') - fc = fluid.layers.fc(input=hidden, size=10, act='softmax') - loss = fluid.layers.cross_entropy(input=fc, label=label) - loss = fluid.layers.mean(loss) - return loss +def conv_net(img, label): + conv_pool_1 = fluid.nets.simple_img_conv_pool( + input=img, + filter_size=5, + num_filters=20, + pool_size=2, + pool_stride=2, + pool_type='max', + act="relu") + conv_pool_1 = fluid.layers.batch_norm(conv_pool_1) + conv_pool_2 = fluid.nets.simple_img_conv_pool( + input=conv_pool_1, + filter_size=5, + num_filters=50, + pool_size=2, + pool_stride=2, + pool_type='avg', + act="relu") + hidden = fluid.layers.fc(input=conv_pool_2, size=100, act='relu') + prediction = fluid.layers.fc(input=hidden, size=10, act='softmax') + loss = fluid.layers.cross_entropy(input=prediction, label=label) + avg_loss = fluid.layers.mean(loss) + return avg_loss def pact(x, name=None): @@ -102,7 +97,7 @@ def build_program(main, startup, is_test): img.stop_gradient = False label = fluid.layers.data( name='label', shape=[1], dtype='int64') - loss = residual_block(img, label, 1) + loss = conv_net(img, label) if not is_test: opt = fluid.optimizer.SGD(learning_rate=0.0001) opt.minimize(loss) diff --git a/python/paddle/fluid/dataloader/__init__.py b/python/paddle/fluid/dataloader/__init__.py index 2f15811e4f360..597f1f217483c 100644 --- a/python/paddle/fluid/dataloader/__init__.py +++ b/python/paddle/fluid/dataloader/__init__.py @@ -23,6 +23,10 @@ from . import dataloader_iter from .dataloader_iter import * +from . import sampler +from .sampler import * + __all__ = dataset.__all__ \ + batch_sampler.__all__ \ - + dataloader_iter.__all__ + + dataloader_iter.__all__ \ + + sampler.__all__ diff --git a/python/paddle/fluid/dataloader/batch_sampler.py b/python/paddle/fluid/dataloader/batch_sampler.py index 811468c523b2f..8043237c0d97d 100644 --- a/python/paddle/fluid/dataloader/batch_sampler.py +++ b/python/paddle/fluid/dataloader/batch_sampler.py @@ -16,12 +16,13 @@ from __future__ import division import numpy as np +from .sampler import Sampler, SequenceSampler from .dataset import Dataset, IterableDataset __all__ = ["BatchSampler"] -class BatchSampler(object): +class BatchSampler(Sampler): """ A base implement of batch sampler used by `paddle.io.DataLoader` which yield mini-batch indices(a list/tuple with length as @@ -41,10 +42,11 @@ class BatchSampler(object): implement or other python object which implemented :code:`__len__` for BatchSampler to get indices as the range of :attr:`dataset` length. Default None. - indices (list|tuple): a substitution parameter for - :attr:`dataset` either :attr:`dataset` or - :attr:`indices` should be set, give the whole - indices to sampler from directly. Default None. + sampler (Sampler): this could be a :code:`paddle.io.Dataset` + instance which implemented :code:`__iter__` to yield + sample indices. :attr:`sampler` and :attr:`dataset` + can not be set in the same time. If :attr:`sampler` + is set, :attr:`shuffle` should not be set. Default None. shuffle(bool): whether to shuffle indices order before genrating batch indices. Default False. batch_size(int): sample indice number in a mini-batch indices. @@ -58,16 +60,7 @@ class BatchSampler(object): .. code-block:: python - from paddle.io import BatchSampler, Dataset - - # init with indices - bs = BatchSampler(indices=list(range(100)), - shuffle=True, - batch_size=8, - drop_last=True) - - for batch_indices in bs: - print(batch_indices) + from paddle.io import RandomSampler, BatchSampler, Dataset # init with dataset class RandomDataset(Dataset): @@ -90,34 +83,42 @@ def __len__(self): for batch_indices in bs: print(batch_indices) + # init with sampler + sampler = RandomSampler(RandomDataset(100)) + bs = BatchSampler(sampler=sampler, + shuffle=True, + batch_size=8, + drop_last=True) + + for batch_indices in bs: + print(batch_indices) + + see `paddle.io.DataLoader` """ def __init__(self, dataset=None, - indices=None, + sampler=None, shuffle=False, batch_size=1, drop_last=False): if dataset is None: - assert indices is not None, \ - "either dataset or indices should be set" - assert isinstance(indices, list) or isinstance(indices, tuple), \ - "indices should be a list or tuple, but got {}".format(type(indices)) - self.indices = indices - self.sampler_iter = None + assert sampler is not None, \ + "either dataset or sampler should be set" + assert isinstance(sampler, Sampler), \ + "sampler should be a paddle.io.Sampler, but got {}".format(type(sampler)) + assert not shuffle, "shuffle should be False when sampler is set" + self.sampler = sampler else: - if isinstance(dataset, IterableDataset): - self.sampler_iter = iter( - _InfiniteIterableSampler(dataset, batch_size)) - else: - self.sampler_iter = None - assert isinstance(dataset, Dataset), \ - "dataset should be an instance of paddle.io.Dataset" - assert indices is None, \ - "should not set both dataset and indices" - self.indices = list(range(len(dataset))) + assert isinstance(dataset, Dataset), \ + "dataset should be a paddle.io.Dataset" + assert not isinstance(dataset, IterableDataset), \ + "dataset should not be a paddle.io.IterableDataset" + assert sampler is None, \ + "should not set both dataset and sampler" + self.sampler = SequenceSampler(dataset) assert isinstance(batch_size, int) and batch_size > 0, \ "batch_size should be a positive integer, but got {}".format(batch_size) @@ -130,15 +131,8 @@ def __init__(self, self.drop_last = drop_last def __iter__(self): - if self.sampler_iter: - yield next(self.sampler_iter) - - if self.shuffle: - np.random.shuffle(self.indices) - _iter = iter(self.indices) - batch_indices = [] - for idx in _iter: + for idx in self.sampler: batch_indices.append(idx) if len(batch_indices) == self.batch_size: yield batch_indices @@ -147,10 +141,7 @@ def __iter__(self): yield batch_indices def __len__(self): - if self.sampler_iter: - raise RuntimeError("'{}' should not be called for IterableDataset". - format('__len__')) - num_samples = len(self.indices) + num_samples = len(self.sampler) num_samples += int(not self.drop_last) * (self.batch_size - 1) return num_samples // self.batch_size diff --git a/python/paddle/fluid/dataloader/sampler.py b/python/paddle/fluid/dataloader/sampler.py new file mode 100644 index 0000000000000..d2f3231cc6b12 --- /dev/null +++ b/python/paddle/fluid/dataloader/sampler.py @@ -0,0 +1,232 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +from __future__ import division + +import numpy as np + +__all__ = ["Sampler", "SequenceSampler", "RandomSampler"] + + +class Sampler(object): + """ + An abstract class to encapsulate methods and behaviors of samplers. + + All sampler used by :code:`paddle.io.BatchSampler` should be a subclass + of :code:`paddle.io.Sampler`, BatchSampler subclasses should + implement following methods: + + :code:`__iter__`: return sample index iterably, which iterate over indices + of dataset elements + + :code:`__len__`: the number of sample in :attr:`data_source` + + + Args: + data_source(Dataset, optional): this could be an instance of + :code:`paddle.io.Dataset` other Python object which + implemented :code:`__len__` for Sampler to get indices + as the range of :attr:`dataset` length. Default None. + + Returns: + Sampler: an iterable object for sample indices iterating + + Examples: + + .. code-block:: python + + from paddle.io import Dataset, Sampler + + class RandomDataset(Dataset): + def __init__(self, num_samples): + self.num_samples = num_samples + + def __getitem__(self, idx): + image = np.random.random([784]).astype('float32') + label = np.random.randint(0, 9, (1, )).astype('int64') + return image, label + + def __len__(self): + return self.num_samples + + class MySampler(Sampler): + def __init__(self, data_source): + self.data_source = data_source + + def __iter__(self): + return iter(range(len(self.data_source))) + + def __len__(self): + return len(self.data_source) + + sampler = MySampler(data_source=RandomDataset(100)) + + for index in sampler: + print(index) + + see `paddle.io.BatchSampler` + see `paddle.io.DataLoader` + + """ + + def __init__(self, data_source=None): + self.data_source = data_source + + def __iter__(self): + raise NotImplementedError + + # Not define __len__ method in this base class here for __len__ + # is not needed in same sence, e.g. paddle.io.IterableDataset + + +class SequenceSampler(Sampler): + """ + Iterate samples sequentially, yield :code:`0, 1, 2, ..., len(data_source) -1` + generally, + + Args: + data_source(Dataset): dataset to sample, this could be an + instance of :code:`paddle.io.Dataset` other Python + object which implemented :code:`__len__`. + + Returns: + Sampler: a Sampler yield sample index sequentially + + Examples: + + .. code-block:: python + + from paddle.io import Dataset, SequenceSampler + + class RandomDataset(Dataset): + def __init__(self, num_samples): + self.num_samples = num_samples + + def __getitem__(self, idx): + image = np.random.random([784]).astype('float32') + label = np.random.randint(0, 9, (1, )).astype('int64') + return image, label + + def __len__(self): + return self.num_samples + + sampler = SequenceSampler(data_source=RandomDataset(100)) + + for index in sampler: + print(index) + + see `paddle.io.Sampler` + """ + + def __init__(self, data_source): + self.data_source = data_source + + def __iter__(self): + return iter(range(len(self.data_source))) + + def __len__(self): + return len(self.data_source) + + +class RandomSampler(Sampler): + """ + Iterate samples randomly, yield shuffled indices, if :attr:`replacement=False`, + yield shuffled indices of the whole data souce, if :attr:`replacement=True`, + :attr:`num_samples` can set to specify the sample number to draw. + + Args: + data_source(Dataset): dataset to sample, this could be an + instance of :code:`paddle.io.Dataset` other Python + object which implemented :code:`__len__`. + replacement(bool): If False, sample the whole dataset, If False, + set :attr:`num_samples` for how many sample to draw. Default False. + num_samples(int): set sample number to draw if :attr:`replacement` + is True. Default None. + generator(Generator): specify a generator to sample the data source. Default None + + Returns: + Sampler: a Sampler yield sample index randomly + + Examples: + + .. code-block:: python + + from paddle.io import Dataset, RandomSampler + + class RandomDataset(Dataset): + def __init__(self, num_samples): + self.num_samples = num_samples + + def __getitem__(self, idx): + image = np.random.random([784]).astype('float32') + label = np.random.randint(0, 9, (1, )).astype('int64') + return image, label + + def __len__(self): + return self.num_samples + + sampler = RandomSampler(data_souce=RandomDataset(100)) + + for index in sampler: + print(index) + + see `paddle.io.Sampler` + """ + + def __init__(self, + data_source, + replacement=False, + num_samples=None, + generator=None): + self.data_source = data_source + self.replacement = replacement + self._num_samples = num_samples + self.generator = generator + + if not isinstance(self.replacement, bool): + raise TypeError("expect boolean value for replacement, but got " + "replacement={}".format(self.replacement)) + + if self._num_samples is not None and not replacement: + raise ValueError( + "num_samples should not be specified while replacement is False") + + if not isinstance(self.num_samples, int) or self.num_samples <= 0: + raise ValueError("num_samples should be a positive integer, " + "but got num_samples={}".format(self.num_samples)) + + @property + def num_samples(self): + if self._num_samples is None: + return len(self.data_source) + return self._num_samples + + def __iter__(self): + n = len(self.data_source) + if self.generator: + for index in self.generator: + yield index + else: + if self.replacement: + for index in np.random.choice( + np.arange(n), self.num_samples, replace=True).tolist(): + yield index + else: + for index in np.random.choice( + np.arange(n), n, replace=False).tolist(): + yield index + + def __len__(self): + return self.num_samples diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/origin_info.py b/python/paddle/fluid/dygraph/dygraph_to_static/origin_info.py index aeece9513b577..13f38b0726c27 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/origin_info.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/origin_info.py @@ -18,8 +18,8 @@ import inspect import gast - from paddle.fluid import core +from paddle.fluid.dygraph.dygraph_to_static.utils import unwrap from paddle.fluid.framework import Program # NOTE(liym27): Please use `getattr(ast_node, ORIGI_INFO)` instead of . operation to get the original information of ast node. @@ -197,18 +197,6 @@ def attach_origin_info(ast_node, func): return ast_node -# NOTE: inspect.unwrap() exits in PY3 but not in PY2. -def unwrap(func): - def _is_wrapped(f): - return hasattr(f, '__wrapped__') - - unwrapped_f = func - while (_is_wrapped(unwrapped_f)): - unwrapped_f = unwrapped_f.__wrapped__ - - return unwrapped_f - - def ast_walk(transformed_node, static_node): """ Recursively yield all descendant nodes in the trees starting at transformed_node and static_node (including itself) in parallel. diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index 88562dd40a63b..ceacba25375c6 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -13,32 +13,38 @@ # limitations under the License. from __future__ import print_function -import gast + +import collections import inspect -import warnings import textwrap import threading -import collections +import warnings + +import gast import numpy as np -from paddle.fluid import core, scope_guard -from paddle.fluid import framework +from paddle.fluid import core from paddle.fluid import executor +from paddle.fluid import framework +from paddle.fluid import scope_guard from paddle.fluid import unique_name +from paddle.fluid.data_feeder import check_type from paddle.fluid.dygraph import layers -from paddle.fluid.layers.utils import flatten -from paddle.fluid.layers.utils import pack_sequence_as +from paddle.fluid.dygraph.base import param_guard from paddle.fluid.dygraph.base import switch_to_static_graph from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst +from paddle.fluid.dygraph.dygraph_to_static.error import ERROR_DATA +from paddle.fluid.dygraph.dygraph_to_static.error import attach_error_data +from paddle.fluid.dygraph.dygraph_to_static.origin_info import attach_origin_info +from paddle.fluid.dygraph.dygraph_to_static.origin_info import create_and_update_origin_info_map +from paddle.fluid.dygraph.dygraph_to_static.origin_info import update_op_callstack_with_origin_info +from paddle.fluid.dygraph.dygraph_to_static.partial_program import partial_program_from +from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code -from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func +from paddle.fluid.dygraph.dygraph_to_static.utils import unwrap +from paddle.fluid.layers.utils import flatten +from paddle.fluid.layers.utils import pack_sequence_as from paddle.fluid.wrapped_decorator import signature_safe_contextmanager -from paddle.fluid.dygraph.base import param_guard -from paddle.fluid.data_feeder import check_type -from paddle.fluid.dygraph.dygraph_to_static.partial_program import partial_program_from -from paddle.fluid.dygraph.dygraph_to_static.origin_info import attach_origin_info, create_and_update_origin_info_map -from paddle.fluid.dygraph.dygraph_to_static.origin_info import update_op_callstack_with_origin_info -from paddle.fluid.dygraph.dygraph_to_static.error import attach_error_data, ERROR_DATA __all__ = ['ProgramTranslator', 'convert_to_static'] @@ -89,7 +95,7 @@ def foo(x, y): """ # Note: In Python2, it will raise OSError when inspect function # with decorator directly and function.__wrapped__ holds the actual function. - func = getattr(func, '__wrapped__', func) + func = unwrap(func) source_code = func_to_source_code(func) # TODO(liym27): @@ -669,7 +675,9 @@ def func(x): dygraph_func ), "Input dygraph_func is not a callable in ProgramTranslator.get_code" # Gets AST from dygraph function - raw_code = inspect.getsource(dygraph_func) + + unwrap_func = unwrap(dygraph_func) + raw_code = inspect.getsource(unwrap_func) code = textwrap.dedent(raw_code) root = gast.parse(code) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index def201cedc242..6636bf7c4b405 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -369,13 +369,14 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True): function, the other inner functions are invisible for the decorated function. """ - def remove_file(filepath): + def remove_if_exit(filepath): if os.path.exists(filepath): os.remove(filepath) source = ast_to_source_code(ast_root) import_fluid = "import paddle.fluid as fluid\n" source = import_fluid + source + if six.PY2: source = source.encode('utf-8') f = tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) @@ -387,8 +388,8 @@ def remove_file(filepath): f.write(source) if delete_on_exit: - atexit.register(lambda: remove_file(f.name)) - atexit.register(lambda: remove_file(f.name[:-3] + ".pyc")) + atexit.register(lambda: remove_if_exit(f.name)) + atexit.register(lambda: remove_if_exit(f.name[:-3] + ".pyc")) module = imp.load_source(module_name, f.name) func_name = dyfunc.__name__ @@ -1052,3 +1053,19 @@ def _parse_multi_target_assign(self, node): value_node = target return new_nodes + + +# NOTE: inspect.unwrap() exits in PY3 but not in PY2. +def unwrap(func): + """ + Returns the object wrapped by decorators. + """ + + def _is_wrapped(f): + return hasattr(f, '__wrapped__') + + unwrapped_f = func + while (_is_wrapped(unwrapped_f)): + unwrapped_f = unwrapped_f.__wrapped__ + + return unwrapped_f diff --git a/python/paddle/fluid/dygraph/layer_object_helper.py b/python/paddle/fluid/dygraph/layer_object_helper.py index f2e914a2137d0..a904f80639752 100644 --- a/python/paddle/fluid/dygraph/layer_object_helper.py +++ b/python/paddle/fluid/dygraph/layer_object_helper.py @@ -136,18 +136,13 @@ def get_parameter(self, name): return param # TODO: this should not be called anymore after all activation func move to Layers - def append_activation(self, - input_var, - act=None, - use_cudnn=None, - use_mkl_dnn=None): + def append_activation(self, input_var, act=None, use_cudnn=None): """Append activation Args: input_var: the input variable. The len(input_var.shape) is larger or equal than 2. act: activation type - use_mkl_dnn: if use mkldnn use_cudnn: if use cudnn Return the Variable of after append activation @@ -163,8 +158,9 @@ def append_activation(self, if (use_cudnn is not None) and use_cudnn: act['use_cudnn'] = use_cudnn - if (use_mkl_dnn is not None) and use_mkl_dnn: - act['use_mkldnn'] = use_mkl_dnn + use_mkldnn = core.globals()["FLAGS_use_mkldnn"] + if (use_mkldnn is not None) and use_mkldnn: + act['use_mkldnn'] = use_mkldnn act_type = act.pop('type') tmp = self.create_variable_for_type_inference(dtype=input_var.dtype) diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index 1bed04479fb22..7b2434e1a213a 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -283,7 +283,7 @@ def forward_pre_hook(layer, input): def create_parameter(self, shape, attr=None, - dtype='float32', + dtype=None, is_bias=False, default_initializer=None): """Create parameters for this layer. diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index e56f26f1b1b94..45744841fc5be 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -35,7 +35,7 @@ 'Conv2D', 'Conv3D', 'Pool2D', 'Linear', 'BatchNorm', 'Dropout', 'Embedding', 'GRUUnit', 'InstanceNorm', 'LayerNorm', 'NCE', 'PRelu', 'BilinearTensorProduct', 'Conv2DTranspose', 'Conv3DTranspose', 'GroupNorm', - 'SpectralNorm', 'TreeConv', 'Flatten' + 'SpectralNorm', 'TreeConv', 'Flatten', 'SyncBatchNorm' ] @@ -180,6 +180,7 @@ def __init__(self, if not isinstance(use_cudnn, bool): raise ValueError("use_cudnn should be True or False") self._use_cudnn = use_cudnn + self._use_mkldnn = core.globals()["FLAGS_use_mkldnn"] self._filter_size = filter_size self._num_filters = num_filters self._param_attr = param_attr @@ -187,7 +188,8 @@ def __init__(self, self._dtype = dtype if (self._num_channels == self._groups and - num_filters % self._num_channels == 0 and not self._use_cudnn): + num_filters % self._num_channels == 0 and + not self._use_cudnn and not self._use_mkldnn): self._l_type = 'depthwise_conv2d' else: self._l_type = 'conv2d' @@ -224,14 +226,15 @@ def forward(self, input): if in_dygraph_mode() and self._l_type == 'conv2d': attrs = ('strides', self._stride, 'paddings', self._padding, 'dilations', self._dilation, 'groups', self._groups - if self._groups else 1, 'use_cudnn', self._use_cudnn) + if self._groups else 1, 'use_cudnn', self._use_cudnn, + 'use_mkldnn', self._use_mkldnn) out = core.ops.conv2d(input, self.weight, *attrs) pre_bias = out - pre_act = dygraph_utils._append_bias_in_dygraph(pre_bias, self.bias, - 1) - return dygraph_utils._append_activation_in_dygraph(pre_act, - self._act) + pre_act = dygraph_utils._append_bias_in_dygraph( + pre_bias, self.bias, 1, use_mkldnn=self._use_mkldnn) + return dygraph_utils._append_activation_in_dygraph( + pre_act, self._act, use_mkldnn=self._use_mkldnn) inputs = { 'Input': [input], 'Filter': [self.weight], @@ -242,7 +245,7 @@ def forward(self, input): 'dilations': self._dilation, 'groups': self._groups if self._groups else 1, 'use_cudnn': self._use_cudnn, - 'use_mkldnn': False, + 'use_mkldnn': self._use_mkldnn, } check_variable_and_dtype(input, 'input', @@ -267,7 +270,8 @@ def forward(self, input): inputs={'X': [pre_bias], 'Y': [self.bias]}, outputs={'Out': [pre_act]}, - attrs={'axis': 1}) + attrs={'axis': 1, + 'use_mkldnn': self._use_mkldnn}) else: pre_act = pre_bias @@ -828,6 +832,8 @@ def __init__(self, if not isinstance(use_cudnn, bool): raise ValueError("use_cudnn should be True or False") + self._use_mkldnn = core.globals()["FLAGS_use_mkldnn"] + if data_format not in ["NCHW", "NHWC"]: raise ValueError( "Attr(data_format) should be 'NCHW' or 'NHWC'. Received " @@ -853,8 +859,8 @@ def forward(self, input): 'global_pooling', self._global_pooling, 'strides', self._pool_stride, 'paddings', self._pool_padding, 'use_cudnn', self._use_cudnn, 'ceil_mode', self._ceil_mode, - 'use_mkldnn', False, 'exclusive', self._exclusive, - 'data_format', self._data_format) + 'use_mkldnn', self._use_mkldnn, 'exclusive', + self._exclusive, 'data_format', self._data_format) return core.ops.pool2d(input, *attrs) check_variable_and_dtype( @@ -869,7 +875,7 @@ def forward(self, input): "paddings": self._pool_padding, "use_cudnn": self._use_cudnn, "ceil_mode": self._ceil_mode, - "use_mkldnn": False, + "use_mkldnn": self._use_mkldnn, "exclusive": self._exclusive, "data_format": self._data_format, } @@ -958,16 +964,22 @@ def __init__(self, self.bias = self.create_parameter( shape=[output_dim], attr=bias_attr, dtype=dtype, is_bias=True) + self._use_mkldnn = core.globals()["FLAGS_use_mkldnn"] + def forward(self, input): if in_dygraph_mode(): pre_bias = _varbase_creator(dtype=input.dtype) core.ops.matmul(input, self.weight, pre_bias, 'transpose_X', False, - 'transpose_Y', False, "alpha", 1) + 'transpose_Y', False, "alpha", 1, "use_mkldnn", + self._use_mkldnn) pre_act = dygraph_utils._append_bias_in_dygraph( - pre_bias, self.bias, axis=len(input.shape) - 1) + pre_bias, + self.bias, + axis=len(input.shape) - 1, + use_mkldnn=self._use_mkldnn) - return dygraph_utils._append_activation_in_dygraph(pre_act, - self._act) + return dygraph_utils._append_activation_in_dygraph( + pre_act, self._act, use_mkldnn=self._use_mkldnn) check_variable_and_dtype(input, 'input', ['float16', 'float32', 'float64'], "Linear") @@ -976,6 +988,7 @@ def forward(self, input): "transpose_X": False, "transpose_Y": False, "alpha": 1, + "use_mkldnn": self._use_mkldnn, } inputs = {"X": [input], "Y": [self.weight]} @@ -990,7 +1003,10 @@ def forward(self, input): inputs={'X': [tmp], 'Y': [self.bias]}, outputs={'Out': [pre_activation]}, - attrs={'axis': len(input.shape) - 1}) + attrs={ + 'axis': len(input.shape) - 1, + 'use_mkldnn': self._use_mkldnn + }) else: pre_activation = tmp return self._helper.append_activation(pre_activation, act=self._act) @@ -1250,6 +1266,7 @@ def __init__(self, self._param_attr = param_attr self._bias_attr = bias_attr self._act = act + self._use_mkldnn = core.globals()["FLAGS_use_mkldnn"] assert bias_attr is not False, "bias_attr should not be False in batch_norm." @@ -1314,8 +1331,8 @@ def forward(self, input): if in_dygraph_mode(): attrs = ("momentum", self._momentum, "epsilon", self._epsilon, "is_test", not self.training, "data_layout", - self._data_layout, "use_mkldnn", False, "fuse_with_relu", - self._fuse_with_relu, "use_global_stats", + self._data_layout, "use_mkldnn", self._use_mkldnn, + "fuse_with_relu", self._fuse_with_relu, "use_global_stats", self._use_global_stats, 'trainable_statistics', self._trainable_statistics) batch_norm_out, _, _, _, _, _ = core.ops.batch_norm( @@ -1323,7 +1340,7 @@ def forward(self, input): mean_out, variance_out, *attrs) return dygraph_utils._append_activation_in_dygraph( - batch_norm_out, act=self._act) + batch_norm_out, act=self._act, use_mkldnn=self._use_mkldnn) check_variable_and_dtype(input, 'input', ['float16', 'float32', 'float64'], 'BatchNorm') @@ -3185,6 +3202,220 @@ def forward(self, nodes_vector, edge_set): return self._helper.append_activation(pre_activation, act=self._act) +class SyncBatchNorm(layers.Layer): + """ + This interface is used to construct a callable object of the ``SyncBatchNorm`` class. + It implements the function of the Cross-GPU Synchronized Batch Normalization Layer, and can + be used as a normalizer function for other operations, such as conv2d and fully connected + operations. + The data is normalized by the mean and variance of the channel based on whole mini-batch + , which including data in all gpus. + Refer to `Batch Normalization: Accelerating Deep Network Training by Reducing + Internal Covariate Shift `_ + for more details. + + When model in training mode, the :math:`\\mu_{\\beta}` + and :math:`\\sigma_{\\beta}^{2}` are the statistics of whole mini-batch data in all gpus. + Calculated as follows: + + .. math:: + + \\mu_{\\beta} &\\gets \\frac{1}{m} \\sum_{i=1}^{m} x_i \\qquad &//\\ + \ mini-batch\ mean \\\\ + \\sigma_{\\beta}^{2} &\\gets \\frac{1}{m} \\sum_{i=1}^{m}(x_i - \\ + \\mu_{\\beta})^2 \\qquad &//\ mini-batch\ variance \\\\ + + - :math:`x` : whole mini-batch data in all gpus + - :math:`m` : the size of the whole mini-batch data + + When model in evaluation mode, the :math:`\\mu_{\\beta}` + and :math:`\\sigma_{\\beta}^{2}` are global statistics (moving_mean and moving_variance, + which usually got from the pre-trained model). Global statistics calculated as follows: + + .. math:: + moving\_mean = moving\_mean * momentum + \mu_{\beta} * (1. - momentum) \quad &// global mean \\ + moving\_variance = moving\_variance * momentum + \sigma_{\beta}^{2} * (1. - momentum) \quad &// global variance \\ + + The formula of normalization is as follows: + + .. math:: + + \\hat{x_i} &\\gets \\frac{x_i - \\mu_\\beta} {\\sqrt{\\ + \\sigma_{\\beta}^{2} + \\eps}} \\qquad &//\ normalize \\\\ + y_i &\\gets \\gamma \\hat{x_i} + \\beta \\qquad &//\ scale\ and\ shift + + - :math:`\\eps` : add a smaller value to the variance to prevent division by zero + - :math:`\\gamma` : trainable scale parameter vector + - :math:`\\beta` : trainable shift parameter vector + + Parameters: + num_features(int): Indicate the number of channels of the input ``Tensor``. + epsilon(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-5. + momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9. + weight_attr(ParamAttr|bool, optional): The parameter attribute for Parameter `scale` + of this layer. If it is set to None or one attribute of ParamAttr, this layerr + will create ParamAttr as param_attr. If the Initializer of the param_attr + is not set, the parameter is initialized with Xavier. If it is set to False, + this layer will not have trainable scale parameter. Default: None. + bias_attr(ParamAttr|bool, optional): The parameter attribute for the bias of this layer. + If it is set to None or one attribute of ParamAttr, this layer + will create ParamAttr as bias_attr. If the Initializer of the bias_attr + is not set, the bias is initialized zero. If it is set to False, this layer will not + have trainable bias parameter. Default: None. + track_running_stats(bool, optional): Whether to compute global stats, which including running mean and + running variance. Default: True. + + Returns: + None + + Examples: + .. code-block:: python + + import paddle + import paddle.nn as nn + import numpy as np + + x = np.array([[[[0.3, 0.4], [0.3, 0.07]], [[0.83, 0.37], [0.18, 0.93]]]]).astype('float32') + paddle.disable_static() + x = paddle.to_tensor(x) + if paddle.fluid.is_compiled_with_cuda(): + sync_batch_norm = nn.SyncBatchNorm(2) + hidden1 = sync_batch_norm(x) + print(hidden1.numpy()) + # [[[[0.26824948, 1.0936325],[0.26824948, -1.6301316]],[[ 0.8095662, -0.665287],[-1.2744656, 1.1301866 ]]]] + """ + + def __init__(self, + num_features, + epsilon=1e-05, + momentum=0.9, + track_running_stats=True, + weight_attr=None, + bias_attr=None, + data_format='NCHW', + name=None): + super(SyncBatchNorm, self).__init__() + self._weight_attr = weight_attr + self._bias_attr = bias_attr + self._num_features = num_features + self._data_layout = data_format + self._momentum = momentum + self._epsilon = epsilon + self._track_running_stats = track_running_stats + + if self._track_running_stats == False: + logging.warn( + "moving mean and moving variance will be calculated whether `track_running_stats` is set to `True` or `False`, we will fix it in the next version." + ) + + param_shape = [self._num_features] + + # create parameter + if weight_attr == False: + self.weight = self.create_parameter( + attr=None, shape=param_shape, default_initializer=Constant(1.0)) + self.weight.stop_gradient = True + else: + self.weight = self.create_parameter( + attr=self._weight_attr, + shape=param_shape, + default_initializer=Constant(1.0)) + self.weight.stop_gradient = self._weight_attr != None and self._weight_attr.learning_rate == 0. + + if bias_attr == False: + self.bias = self.create_parameter( + attr=None, + shape=param_shape, + default_initializer=Constant(0.0), + is_bias=True) + self.bias.stop_gradient = True + else: + self.bias = self.create_parameter( + attr=self._bias_attr, shape=param_shape, is_bias=True) + self.bias.stop_gradient = self._weight_attr != None and self._weight_attr.learning_rate == 0. + + self._mean = self.create_parameter( + attr=ParamAttr( + name=None, + initializer=Constant(0.0), + trainable=False, + do_model_average=True), + shape=param_shape, + dtype=self._dtype) + self._mean.stop_gradient = True + + self._variance = self.create_parameter( + attr=ParamAttr( + name=None, + initializer=Constant(1.0), + trainable=False, + do_model_average=True), + shape=param_shape, + dtype=self._dtype) + self._variance.stop_gradient = True + + def forward(self, x): + # create output + # mean and mean_out share the same memory + mean_out = self._mean + # variance and variance out share the same memory + variance_out = self._variance + + ### train mode: use mini-batch stats, eval mode: use global stats + if in_dygraph_mode(): + attrs = ("momentum", self._momentum, "epsilon", self._epsilon, + "is_test", not self.training, "data_layout", + self._data_layout, "use_mkldnn", False, "fuse_with_relu", + False, "use_global_stats", not self.training, + 'trainable_statistics', False) + sync_batch_norm_out, _, _, _, _, _ = core.ops.sync_batch_norm( + x, self.weight, self.bias, self._mean, self._variance, mean_out, + variance_out, *attrs) + + return sync_batch_norm_out + + check_variable_and_dtype(x, 'input', ['float16', 'float32', 'float64'], + 'BatchNorm') + + attrs = { + "momentum": self._momentum, + "epsilon": self._epsilon, + "is_test": not self.training, + "data_layout": self._data_layout, + "use_mkldnn": False, + "fuse_with_relu": False, + "use_global_stats": not self.training, + "trainable_statistics": False, + } + + inputs = { + "X": [x], + "Scale": [self.weight], + "Bias": [self.bias], + "Mean": [self._mean], + "Variance": [self._variance] + } + + saved_mean = self._helper.create_variable_for_type_inference( + dtype=self._dtype, stop_gradient=True) + saved_variance = self._helper.create_variable_for_type_inference( + dtype=self._dtype, stop_gradient=True) + sync_batch_norm_out = self._helper.create_variable_for_type_inference( + self._dtype) + + outputs = { + "Y": [sync_batch_norm_out], + "MeanOut": [mean_out], + "VarianceOut": [variance_out], + "SavedMean": [saved_mean], + "SavedVariance": [saved_variance] + } + + self._helper.append_op( + type="sync_batch_norm", inputs=inputs, outputs=outputs, attrs=attrs) + return sync_batch_norm_out + + class Flatten(layers.Layer): """ :alias_main: paddle.nn.Flatten diff --git a/python/paddle/fluid/dygraph_utils.py b/python/paddle/fluid/dygraph_utils.py index 7b559494e6c3b..a2338b874f51a 100644 --- a/python/paddle/fluid/dygraph_utils.py +++ b/python/paddle/fluid/dygraph_utils.py @@ -45,17 +45,19 @@ def _append_activation_in_dygraph(input, @dygraph_only -def _append_bias_in_dygraph(input, bias=None, axis=1): +def _append_bias_in_dygraph(input, bias=None, axis=1, use_mkldnn=False): """Append bias operation in dygraph mode. Args: input: the input variable. bias: the bias to be appended axis: the axis to perform operation + use_mkldnn: whether to use mkldnn Return the Variable after bias operation """ if bias is None: return input - return core.ops.elementwise_add(input, bias, 'axis', axis) + return core.ops.elementwise_add(input, bias, 'axis', axis, 'use_mkldnn', + use_mkldnn) diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index f16da029e29a6..5759b94276351 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -1156,6 +1156,26 @@ def _run_impl(self, program, feed, fetch_list, feed_var_name, compiled = isinstance(program, compiler.CompiledProgram) + # Check if fluid.data() variable no feed data + if use_prune: + if compiled: + global_block = program._program.global_block() + else: + global_block = program.global_block() + for varname in global_block.vars: + vardesc = global_block.desc.find_var(cpt.to_bytes(varname)) + varobj = global_block.vars[varname] + + # Can not check var build by fluid.layers.data(), bucause fluid.layers.data() had not set need_check_feed + if vardesc.persistable() == False and \ + vardesc.type() == core.VarDesc.VarType.LOD_TENSOR and \ + vardesc.need_check_feed() == True and \ + varobj._stop_gradient == True and \ + varobj.is_data == True and \ + varobj.belong_to_optimizer == False and \ + varname not in feed: + raise ValueError('Need feed data for variable %s' % varname) + acp._auto_checkpoint(self, program) # For backward compatibility, run directly. diff --git a/python/paddle/fluid/generator.py b/python/paddle/fluid/generator.py new file mode 100644 index 0000000000000..24262e3f5666a --- /dev/null +++ b/python/paddle/fluid/generator.py @@ -0,0 +1,60 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This is definition of generator class, which is for managing the state of the algorithm that produces pseudo random numbers.""" + +from . import core + +__all__ = ['Generator'] + +default_rng_seed_val = 34342423252 + + +class Generator(object): + """Generator class""" + + def __init__(self, device="CPU"): + """init""" + self.device = device + seed_in = default_rng_seed_val + if self.device == "CPU": + self.generator = core.Generator() + self.generator.manual_seed(seed_in) + else: + raise ValueError( + "generator class with device %s does not exist, currently only support generator with device 'CPU' " + % device) + + def get_state(self): + return self.generator.get_state() + + def set_state(self, state): + self.generator.set_state(state) + + def manual_seed(self, seed): + self.generator.manual_seed(seed) + + def seed(self): + return self.generator.seed() + + def initial_seed(self): + return self.generator.initial_seed() + + def random(self): + return self.generator.random() + + def get_cpu_engine(self): + return self.generator.get_cpu_engine() + + def set_cpu_engine(self, cpu_engine): + self.generator.set_cpu_engine(cpu_engine) diff --git a/python/paddle/fluid/layer_helper_base.py b/python/paddle/fluid/layer_helper_base.py index 0b57b3fefd414..6e38c85556280 100644 --- a/python/paddle/fluid/layer_helper_base.py +++ b/python/paddle/fluid/layer_helper_base.py @@ -23,8 +23,13 @@ from . import core from .initializer import _global_weight_initializer, _global_bias_initializer +__all__ = ['LayerHelperBase'] + class LayerHelperBase(object): + # global dtype + __dtype = "float32" + def __init__(self, name, layer_type): self._layer_type = layer_type self._name = name @@ -45,6 +50,14 @@ def main_program(self): def startup_program(self): return default_startup_program() + @classmethod + def set_default_dtype(cls, dtype): + cls.__dtype = dtype + + @classmethod + def get_default_dtype(cls): + return cls.__dtype + def to_variable(self, value, name=None): """ The API will create a ``Variable`` object from numpy\.ndarray or Variable object. @@ -277,7 +290,7 @@ def __weight_normalize(g, v, dim): def create_parameter(self, attr, shape, - dtype, + dtype=None, is_bias=False, default_initializer=None, stop_gradient=False, @@ -299,6 +312,9 @@ def create_parameter(self, if not attr: return None assert isinstance(attr, ParamAttr) + # set global dtype + if not dtype: + dtype = self.__dtype if is_bias: suffix = 'b' default_initializer = _global_bias_initializer( @@ -372,6 +388,9 @@ def create_variable_for_type_inference(self, dtype, stop_gradient=False): based on operator's `VarTypeInference` implementation in infer_var_type. """ + # set global dtype + if not dtype: + dtype = self.__dtype return self.main_program.current_block().create_var( name=unique_name.generate_with_ignorable_key(".".join( [self.name, 'tmp'])), diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index fb1913285e2ea..7f726c6a599d5 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -11399,7 +11399,12 @@ def gen_data(): """ if in_dygraph_mode(): return _elementwise_op_in_dygraph( - x, y, axis=axis, act=act, op_name='elementwise_add') + x, + y, + axis=axis, + act=act, + op_name='elementwise_add', + use_mkldnn=core.globals()["FLAGS_use_mkldnn"]) return _elementwise_op(LayerHelper('elementwise_add', **locals())) diff --git a/python/paddle/fluid/layers/ops.py b/python/paddle/fluid/layers/ops.py index ac518ac83b7a0..1a61072ace6ff 100644 --- a/python/paddle/fluid/layers/ops.py +++ b/python/paddle/fluid/layers/ops.py @@ -645,6 +645,7 @@ def thresholded_relu(x, threshold=None): _gelu_ = generate_layer_fn('gelu') +@deprecated(since="2.0.0", update_to="paddle.nn.functional.gelu") def gelu(x, approximate=False): locals_var = locals().copy() kwargs = dict() @@ -655,10 +656,6 @@ def gelu(x, approximate=False): gelu.__doc__ = """ - :alias_main: paddle.nn.functional.gelu - :alias: paddle.nn.functional.gelu,paddle.nn.functional.activation.gelu - :old_api: paddle.fluid.layers.gelu - :strong:`GeLU Activation Operator` For more details, see [Gaussian Error Linear Units](https://arxiv.org/abs/1606.08415). diff --git a/python/paddle/fluid/layers/rnn.py b/python/paddle/fluid/layers/rnn.py index 39d25a9c7ffd4..bc1368b562d7b 100644 --- a/python/paddle/fluid/layers/rnn.py +++ b/python/paddle/fluid/layers/rnn.py @@ -2213,9 +2213,9 @@ def lstm(input, input ( :ref:`api_guide_Variable_en` ): LSTM input tensor, 3-D Tensor of shape :math:`[batch\_size, seq\_len, input\_dim]` . Data type is float32 or float64 init_h( :ref:`api_guide_Variable_en` ): The initial hidden state of the LSTM, 3-D Tensor of shape :math:`[num\_layers, batch\_size, hidden\_size]` . If is_bidirec = True, shape should be :math:`[num\_layers*2, batch\_size, hidden\_size]` . Data type is float32 or float64. + max_len (int): This parameter has no effect and will be discarded. init_c( :ref:`api_guide_Variable_en` ): The initial cell state of the LSTM, 3-D Tensor of shape :math:`[num\_layers, batch\_size, hidden\_size]` . If is_bidirec = True, shape should be :math:`[num\_layers*2, batch\_size, hidden\_size]` . Data type is float32 or float64. - max_len (int): max length of LSTM. the first dim of input tensor CAN NOT greater than max_len. hidden_size (int): hidden size of the LSTM. num_layers (int): total layers number of the LSTM. dropout_prob(float, optional): dropout prob, dropout ONLY work between rnn layers, NOT between time steps @@ -2256,7 +2256,6 @@ def lstm(input, data = fluid.data(name='x', shape=[None, 100], dtype='int64') emb = fluid.embedding(input=data, size=[vocab_size, emb_dim], is_sparse=True) batch_size = 20 - max_len = 100 dropout_prob = 0.2 input_size = 100 hidden_size = 150 @@ -2309,9 +2308,11 @@ def lstm(input, out = helper.create_variable_for_type_inference(dtype) last_h = helper.create_variable_for_type_inference(dtype) last_c = helper.create_variable_for_type_inference(dtype) - - cache = helper.create_variable( - persistable=True, type=core.VarDesc.VarType.RAW, stop_gradient=True) + reserve = helper.create_variable_for_type_inference( + dtype=core.VarDesc.VarType.UINT8, stop_gradient=True) + state_out = helper.create_variable_for_type_inference( + dtype=core.VarDesc.VarType.UINT8, stop_gradient=True) + state_out.persistable = True helper.append_op( type='cudnn_lstm', @@ -2320,15 +2321,15 @@ def lstm(input, 'InitH': init_h, 'InitC': init_c, 'W': weight, - 'Cache': cache, }, outputs={ 'Out': out, - 'last_h': last_h, - 'last_c': last_c, + 'LastH': last_h, + 'LastC': last_c, + 'Reserve': reserve, + 'StateOut': state_out, }, attrs={ - 'max_len': max_len, 'is_bidirec': is_bidirec, 'input_size': input_size, 'hidden_size': hidden_size, @@ -3102,7 +3103,8 @@ def beam_search_decode(ids, scores, beam_size, end_id, name=None): 'beam_search_encode') helper = LayerHelper('beam_search_decode', **locals()) sentence_ids = helper.create_variable_for_type_inference(dtype=ids.dtype) - sentence_scores = helper.create_variable_for_type_inference(dtype=ids.dtype) + sentence_scores = helper.create_variable_for_type_inference( + dtype=scores.dtype) helper.append_op( type="beam_search_decode", diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 05f5d174f8774..126b4465eae48 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -106,6 +106,7 @@ if (NOT ${WITH_GPU}) list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_se_resnext) LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sparse_embedding) LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_transformer) + LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sync_batch_norm) LIST(REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision) elseif(${CUDNN_VERSION} VERSION_LESS 7100) LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op) diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_activation_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_activation_mkldnn_op.py index 55c6bad9af689..d904bdbfa96ae 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_activation_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_activation_mkldnn_op.py @@ -112,13 +112,10 @@ class TestMKLDNNSwishDim2(TestSwish): def setUp(self): super(TestMKLDNNSwishDim2, self).setUp() - x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype) - beta = 2.3 - out = x * expit(beta * x) + self.attrs["use_mkldnn"] = True - self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} - self.outputs = {'Out': out} - self.attrs = {"use_mkldnn": True, "beta": beta} + def init_dtype(self): + self.dtype = np.float32 def init_dtype(self): self.dtype = np.float32 diff --git a/python/paddle/fluid/tests/unittests/parallel_dygraph_sync_batch_norm.py b/python/paddle/fluid/tests/unittests/parallel_dygraph_sync_batch_norm.py new file mode 100644 index 0000000000000..5e2059592b517 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/parallel_dygraph_sync_batch_norm.py @@ -0,0 +1,108 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import os +import contextlib +import unittest +import numpy as np +import six +import pickle + +import paddle +import paddle.fluid as fluid +import paddle.fluid.dygraph as dygraph +from paddle.fluid import core +from paddle.fluid.optimizer import SGDOptimizer +from paddle.nn import Conv2D, Pool2D, Linear, SyncBatchNorm +from paddle.fluid.dygraph.base import to_variable + +from test_dist_base import runtime_main, TestParallelDyGraphRunnerBase + + +class TestLayer(fluid.dygraph.Layer): + def __init__(self, + num_channels, + num_filters, + filter_size, + stride=1, + groups=1, + act=None): + super(TestLayer, self).__init__() + + self._conv = Conv2D( + num_channels=num_channels, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + bias_attr=False) + + self._sync_batch_norm = SyncBatchNorm(num_filters) + + self._conv2 = Conv2D( + num_channels=num_filters, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + bias_attr=False) + + self._sync_batch_norm2 = SyncBatchNorm( + num_filters, + weight_attr=False, + bias_attr=False, + track_running_stats=False) + + def forward(self, inputs): + y = self._conv(inputs) + y = self._sync_batch_norm(y) + y = self._conv2(y) + y = self._sync_batch_norm2(y) + + return y + + +class TestSyncBatchNorm(TestParallelDyGraphRunnerBase): + def get_model(self): + model = TestLayer(3, 64, 7) + train_reader = paddle.batch( + paddle.dataset.flowers.test(use_xmap=False), + batch_size=32, + drop_last=True) + opt = fluid.optimizer.Adam( + learning_rate=1e-3, parameter_list=model.parameters()) + return model, train_reader, opt + + def run_one_loop(self, model, opt, data): + batch_size = len(data) + dy_x_data = np.array([x[0].reshape(3, 224, 224) + for x in data]).astype('float32') + img = to_variable(dy_x_data) + img.stop_gradient = False + + out = model(img) + + out = fluid.layers.mean(out) + + return out + + +if __name__ == "__main__": + runtime_main(TestSyncBatchNorm) diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index a5e8ae82e6fb4..8534de268a936 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -118,7 +118,7 @@ def setUp(self): x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) out = np.log(1 / (1 + np.exp(-x))) - self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.inputs = {'X': x} self.outputs = {'Out': out} def test_check_grad(self): @@ -127,6 +127,48 @@ def test_check_grad(self): self.check_grad(['X'], 'Out', max_relative_error=0.008) +class TestLogSigmoidAPI(unittest.TestCase): + # test paddle.nn.LogSigmoid, paddle.nn.functional.logsigmoid + def setUp(self): + self.x_np = np.random.uniform(-1, 1, [11, 17]).astype('float32') + self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + def test_static_api(self): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.data('X', [11, 17]) + out1 = F.logsigmoid(x) + m = paddle.nn.LogSigmoid() + out2 = m(x) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2]) + out_ref = np.log(1 / (1 + np.exp(-self.x_np))) + for r in res: + self.assertEqual(np.allclose(out_ref, r), True) + + def test_dygraph_api(self): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.x_np) + out1 = F.logsigmoid(x) + m = paddle.nn.LogSigmoid() + out2 = m(x) + out_ref = np.log(1 / (1 + np.exp(-self.x_np))) + for r in [out1, out2]: + self.assertEqual(np.allclose(out_ref, r.numpy()), True) + paddle.enable_static() + + def test_errors(self): + with paddle.static.program_guard(paddle.static.Program()): + # The input type must be Variable. + self.assertRaises(TypeError, F.logsigmoid, 1) + # The input dtype must be float16, float32, float64. + x_int32 = paddle.data(name='x_int32', shape=[11, 17], dtype='int32') + self.assertRaises(TypeError, F.logsigmoid, x_int32) + # support the input dtype is float16 + x_fp16 = paddle.data(name='x_fp16', shape=[11, 17], dtype='float16') + F.logsigmoid(x_fp16) + + class TestTanh(TestActivation, TestParameter): def setUp(self): self.op_type = "tanh" @@ -644,7 +686,7 @@ def setUp(self): x[np.abs(x) < 0.005] = 0.02 out = np.maximum(x, 0) - self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.inputs = {'X': x} self.outputs = {'Out': out} def test_check_grad(self): @@ -653,18 +695,46 @@ def test_check_grad(self): self.check_grad(['X'], 'Out') -class TestReluOpError(unittest.TestCase): +class TestReluAPI(unittest.TestCase): + # test paddle.nn.ReLU, paddle.nn.functional.relu + def setUp(self): + self.x_np = np.random.uniform(-1, 1, [10, 12]).astype('float32') + self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + def test_static_api(self): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.data('X', [10, 12]) + out1 = F.relu(x) + m = paddle.nn.ReLU() + out2 = m(x) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2]) + out_ref = np.maximum(self.x_np, 0) + for r in res: + self.assertEqual(np.allclose(out_ref, r), True) + + def test_dygraph_api(self): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.x_np) + out1 = F.relu(x) + m = paddle.nn.ReLU() + out2 = m(x) + out_ref = np.maximum(self.x_np, 0) + for r in [out1, out2]: + self.assertEqual(np.allclose(out_ref, r.numpy()), True) + paddle.enable_static() + def test_errors(self): - with program_guard(Program()): + with paddle.static.program_guard(paddle.static.Program()): # The input type must be Variable. - self.assertRaises(TypeError, fluid.layers.relu, 1) + self.assertRaises(TypeError, F.relu, 1) # The input dtype must be float16, float32, float64. - x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32') - self.assertRaises(TypeError, fluid.layers.relu, x_int32) + x_int32 = paddle.data(name='x_int32', shape=[10, 12], dtype='int32') + self.assertRaises(TypeError, F.relu, x_int32) # support the input dtype is float16 - x_fp16 = fluid.layers.data( - name='x_fp16', shape=[12, 10], dtype='float16') - fluid.layers.relu(x_fp16) + x_fp16 = paddle.data(name='x_fp16', shape=[10, 12], dtype='float16') + F.relu(x_fp16) def ref_leaky_relu(x, alpha=0.01): @@ -789,7 +859,7 @@ def setUp(self): x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) out = gelu(x, approximate) - self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.inputs = {'X': x} self.outputs = {'Out': out} self.attrs = {"approximate": approximate} @@ -807,7 +877,7 @@ def setUp(self): x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) out = gelu(x, approximate) - self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.inputs = {'X': x} self.outputs = {'Out': out} self.attrs = {"approximate": approximate} @@ -817,6 +887,55 @@ def test_check_grad(self): self.check_grad(['X'], 'Out') +class TestGELUAPI(unittest.TestCase): + # test paddle.nn.GELU, paddle.nn.functional.gelu + def setUp(self): + self.x_np = np.random.uniform(-1, 1, [11, 17]).astype('float32') + self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + def test_static_api(self): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.data('X', [11, 17]) + out1 = F.gelu(x) + m = paddle.nn.GELU() + out2 = m(x) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2]) + out_ref = gelu(self.x_np, False) + for r in res: + self.assertEqual(np.allclose(out_ref, r), True) + + def test_dygraph_api(self): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.x_np) + out1 = F.gelu(x) + m = paddle.nn.GELU() + out2 = m(x) + out_ref = gelu(self.x_np, False) + for r in [out1, out2]: + self.assertEqual(np.allclose(out_ref, r.numpy()), True) + + out1 = F.gelu(x, True) + m = paddle.nn.GELU(True) + out2 = m(x) + out_ref = gelu(self.x_np, True) + for r in [out1, out2]: + self.assertEqual(np.allclose(out_ref, r.numpy()), True) + paddle.enable_static() + + def test_errors(self): + with paddle.static.program_guard(paddle.static.Program()): + # The input type must be Variable. + self.assertRaises(TypeError, F.gelu, 1) + # The input dtype must be float16, float32, float64. + x_int32 = paddle.data(name='x_int32', shape=[11, 17], dtype='int32') + self.assertRaises(TypeError, F.gelu, x_int32) + # support the input dtype is float16 + x_fp16 = paddle.data(name='x_fp16', shape=[11, 17], dtype='float16') + F.gelu(x_fp16) + + class TestBRelu(TestActivation): def setUp(self): self.op_type = "brelu" @@ -966,6 +1085,11 @@ def test_errors(self): fluid.layers.soft_relu(x_fp16) +def elu(x, alpha): + out_ref = np.maximum(0, x) + np.minimum(0, alpha * (np.exp(x) - 1)) + return out_ref.astype(x.dtype) + + class TestELU(TestActivation): def setUp(self): self.op_type = "elu" @@ -973,7 +1097,7 @@ def setUp(self): x = np.random.uniform(-3, 3, [10, 12]).astype(self.dtype) alpha = 1. - out = np.maximum(0, x) + np.minimum(0, alpha * (np.exp(x) - 1)) + out = elu(x, alpha) # Note: unlike other Relu extensions, point 0 on standard ELU function (i.e. alpha = 1) # is differentiable, so we can skip modifications like x[np.abs(x) < 0.005] = 0.02 here self.inputs = {'X': x} @@ -986,16 +1110,53 @@ def test_check_grad(self): self.check_grad(['X'], 'Out') -class TestELUOpError(unittest.TestCase): +class TestELUAPI(unittest.TestCase): + # test paddle.nn.ELU, paddle.nn.functional.elu + def setUp(self): + self.x_np = np.random.uniform(-3, 3, [10, 12]).astype('float32') + self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + def test_static_api(self): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.data('X', [10, 12]) + out1 = F.elu(x) + m = paddle.nn.ELU() + out2 = m(x) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2]) + out_ref = elu(self.x_np, 1.0) + for r in res: + self.assertEqual(np.allclose(out_ref, r), True) + + def test_dygraph_api(self): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.x_np) + out1 = F.elu(x) + m = paddle.nn.ELU() + out2 = m(x) + out_ref = elu(self.x_np, 1.0) + for r in [out1, out2]: + self.assertEqual(np.allclose(out_ref, r.numpy()), True) + + out1 = F.elu(x, 0.2) + m = paddle.nn.ELU(0.2) + out2 = m(x) + out_ref = elu(self.x_np, 0.2) + for r in [out1, out2]: + self.assertEqual(np.allclose(out_ref, r.numpy()), True) + paddle.enable_static() + def test_errors(self): - with program_guard(Program(), Program()): - # The input type of elu_op must be Variable. - x1 = fluid.create_lod_tensor( - np.array([[-1]]), [[1]], fluid.CPUPlace()) - self.assertRaises(TypeError, fluid.layers.elu, x1) - # The input dtype of elu_op must be float16 float32 or float64. - x2 = fluid.layers.data(name='x2', shape=[4], dtype="int32") - self.assertRaises(TypeError, fluid.layers.elu, x2) + with paddle.static.program_guard(paddle.static.Program()): + # The input type must be Variable. + self.assertRaises(TypeError, F.elu, 1) + # The input dtype must be float16, float32, float64. + x_int32 = paddle.data(name='x_int32', shape=[10, 12], dtype='int32') + self.assertRaises(TypeError, F.elu, x_int32) + # support the input dtype is float16 + x_fp16 = paddle.data(name='x_fp16', shape=[10, 12], dtype='float16') + F.elu(x_fp16) class TestReciprocal(TestActivation): @@ -1494,73 +1655,5 @@ def test_check_grad(self): create_test_act_fp16_class(TestSwish) create_test_act_fp16_class(TestHardSwish) - -class TestNNReluAPI(unittest.TestCase): - def setUp(self): - self.init_data() - - def init_data(self): - self.x_shape = [10, 12] - self.x = np.random.uniform(-1, 1, self.x_shape).astype(np.float32) - self.y = self.ref_forward(self.x) - - def ref_forward(self, x): - return np.maximum(x, 0) - - def ref_backward(self, y, dy): - y_t = y.copy() - y_t[y_t > 0] = 1 - return y_t * dy - - def check_api(self, place=fluid.CPUPlace()): - main_program = Program() - myrelu = nn.ReLU() - with fluid.program_guard(main_program): - x = fluid.data(name='x', shape=self.x_shape) - x.stop_gradient = False - y = myrelu(x) - fluid.backward.append_backward(fluid.layers.mean(y)) - exe = fluid.Executor(place) - out = exe.run(main_program, - feed={'x': self.x}, - fetch_list=[y, y.grad_name, x.grad_name]) - self.assertTrue(np.allclose(out[0], self.y)) - self.assertTrue(np.allclose(out[2], self.ref_backward(self.y, out[1]))) - - with fluid.dygraph.guard(place): - x = fluid.dygraph.to_variable(self.x) - y = myrelu(x) - self.assertTrue(np.allclose(y.numpy(), self.y)) - - def test_check_api(self): - places = [fluid.CPUPlace()] - if core.is_compiled_with_cuda(): - places.append(fluid.CUDAPlace(0)) - for place in places: - self.check_api(place) - - -class TestNNFunctionalReluAPI(unittest.TestCase): - def setUp(self): - self.init_data() - - def init_data(self): - self.x_shape = [10, 12] - self.x = np.random.uniform(-1, 1, self.x_shape).astype(np.float32) - self.y = self.ref_forward(self.x) - - def ref_forward(self, x): - return np.maximum(x, 0) - - def test_check_api(self): - main_program = Program() - with fluid.program_guard(main_program): - x = fluid.data(name='x', shape=self.x_shape) - y = F.relu(x) - exe = fluid.Executor(fluid.CPUPlace()) - out = exe.run(main_program, feed={'x': self.x}, fetch_list=[y]) - self.assertTrue(np.allclose(out[0], self.y)) - - if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_batch_sampler.py b/python/paddle/fluid/tests/unittests/test_batch_sampler.py index 7d90bbd0357bc..2e2a6144fd011 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_sampler.py +++ b/python/paddle/fluid/tests/unittests/test_batch_sampler.py @@ -17,7 +17,7 @@ import unittest import paddle.fluid as fluid -from paddle.io import BatchSampler, Dataset +from paddle.io import BatchSampler, Dataset, Sampler, SequenceSampler, RandomSampler class RandomDataset(Dataset): @@ -35,6 +35,60 @@ def __len__(self): return self.sample_num +class TestSampler(unittest.TestCase): + def test_main(self): + dataset = RandomDataset(100, 10) + sampler = Sampler(dataset) + try: + iter(sampler) + self.assertTrue(False) + except NotImplementedError: + pass + + +class TestSequenceSampler(unittest.TestCase): + def test_main(self): + dataset = RandomDataset(100, 10) + sampler = SequenceSampler(dataset) + assert len(sampler) == 100 + + for i, index in enumerate(iter(sampler)): + assert i == index + + +class TestRandomSampler(unittest.TestCase): + def test_main(self): + dataset = RandomDataset(100, 10) + sampler = RandomSampler(dataset) + assert len(sampler) == 100 + + rets = [] + for i in iter(sampler): + rets.append(i) + assert tuple(sorted(rets)) == tuple(range(0, 100)) + + def test_with_num_samples(self): + dataset = RandomDataset(100, 10) + sampler = RandomSampler(dataset, num_samples=50, replacement=True) + assert len(sampler) == 50 + + rets = [] + for i in iter(sampler): + rets.append(i) + assert i >= 0 and i < 100 + + def test_with_generator(self): + dataset = RandomDataset(100, 10) + generator = iter(range(0, 60)) + sampler = RandomSampler(dataset, generator=generator) + assert len(sampler) == 100 + + rets = [] + for i in iter(sampler): + rets.append(i) + assert tuple(sorted(rets)) == tuple(range(0, 60)) + + class TestBatchSampler(unittest.TestCase): def setUp(self): self.num_samples = 1000 @@ -86,16 +140,18 @@ def setUp(self): self.drop_last = True -class TestBatchSamplerWithIndices(TestBatchSampler): +class TestBatchSamplerWithSampler(TestBatchSampler): def init_batch_sampler(self): + dataset = RandomDataset(1000, 10) + sampler = SequenceSampler(dataset) bs = BatchSampler( - indices=list(range(self.num_samples)), + sampler=sampler, batch_size=self.batch_size, drop_last=self.drop_last) return bs -class TestBatchSamplerWithIndicesAndDataSource(unittest.TestCase): +class TestBatchSamplerWithSamplerDropLast(unittest.TestCase): def setUp(self): self.num_samples = 1000 self.num_classes = 10 @@ -103,12 +159,22 @@ def setUp(self): self.shuffle = False self.drop_last = True + +class TestBatchSamplerWithSamplerShuffle(unittest.TestCase): + def setUp(self): + self.num_samples = 1000 + self.num_classes = 10 + self.batch_size = 32 + self.shuffle = True + self.drop_last = True + def test_main(self): try: dataset = RandomDataset(self.num_samples, self.num_classes) + sampler = RandomSampler(dataset) bs = BatchSampler( - dataset=dataset, - indices=list(range(self.num_samples)), + sampler=sampler, + shuffle=self.shuffle, batch_size=self.batch_size, drop_last=self.drop_last) self.assertTrue(False) diff --git a/python/paddle/fluid/tests/unittests/test_chunk_op.py b/python/paddle/fluid/tests/unittests/test_chunk_op.py new file mode 100644 index 0000000000000..043b326fbd987 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_chunk_op.py @@ -0,0 +1,138 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import numpy as np +from paddle.fluid import Program, program_guard +from paddle import fluid +import paddle + + +class TestChunkOpError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + # The type of axis in chunk_op should be int or Variable. + def test_axis_type(): + x1 = paddle.data(shape=[4], dtype='float16', name='x3') + paddle.chunk(x=x1, chunks=2, axis=3.2) + + self.assertRaises(TypeError, test_axis_type) + + # The type of axis in chunk op should be int or Variable. + def test_axis_variable_type(): + x2 = paddle.data(shape=[4], dtype='float16', name='x9') + x3 = paddle.data(shape=[1], dtype='float16', name='x10') + paddle.chunk(input=x2, chunks=2, axis=x3) + + self.assertRaises(TypeError, test_axis_variable_type) + + # The type of num_or_sections in chunk_op should be int, tuple or list. + def test_chunks_type(): + x4 = paddle.data(shape=[4], dtype='float16', name='x4') + paddle.chunk(input=x4, chunks=2.1, axis=3) + + self.assertRaises(TypeError, test_chunks_type) + + def test_axis_type_tensor(): + x5 = paddle.data(shape=[4], dtype='float16', name='x6') + paddle.chunk(input=x5, chunks=2, axis=3.2) + + self.assertRaises(TypeError, test_axis_type_tensor) + + +class API_TestChunk(unittest.TestCase): + def test_out(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + data1 = paddle.data('data1', shape=[4, 6, 6], dtype='float64') + data2 = paddle.data('data2', shape=[1], dtype='int32') + x0, x1, x2 = paddle.chunk(data1, chunks=3, axis=data2) + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + input1 = np.random.random([4, 6, 6]).astype('float64') + input2 = np.array([2]).astype('int32') + r0, r1, r2, = exe.run(feed={"data1": input1, + "data2": input2}, + fetch_list=[x0, x1, x2]) + ex_x0, ex_x1, ex_x2 = np.array_split(input1, 3, axis=2) + self.assertTrue(np.allclose(ex_x0, r0)) + self.assertTrue(np.allclose(ex_x1, r1)) + self.assertTrue(np.allclose(ex_x2, r2)) + + +class API_TestChunk1(unittest.TestCase): + def test_out(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + data1 = paddle.data('data1', shape=[4, 6, 6], dtype='float64') + x0, x1, x2 = paddle.chunk(data1, chunks=3, axis=2) + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + input1 = np.random.random([4, 6, 6]).astype('float64') + r0, r1, r2, = exe.run(feed={"data1": input1}, + fetch_list=[x0, x1, x2]) + ex_x0, ex_x1, ex_x2 = np.array_split(input1, 3, axis=2) + self.assertTrue(np.allclose(ex_x0, r0)) + self.assertTrue(np.allclose(ex_x1, r1)) + self.assertTrue(np.allclose(ex_x2, r2)) + + +class API_TestDygraphChunk(unittest.TestCase): + def test_out1(self): + with fluid.dygraph.guard(): + input_1 = np.random.random([4, 6, 6]).astype("int32") + # input is a variable which shape is [4, 6, 6] + input = fluid.dygraph.to_variable(input_1) + x0, x1, x2 = paddle.chunk(input, chunks=3, axis=1) + x0_out = x0.numpy() + x1_out = x1.numpy() + x2_out = x2.numpy() + ex_x0, ex_x1, ex_x2 = np.array_split(input_1, 3, axis=1) + self.assertTrue(np.allclose(ex_x0, x0_out)) + self.assertTrue(np.allclose(ex_x1, x1_out)) + self.assertTrue(np.allclose(ex_x2, x2_out)) + + def test_out2(self): + with fluid.dygraph.guard(): + input_1 = np.random.random([4, 6, 6]).astype("bool") + # input is a variable which shape is [4, 6, 6] + input = fluid.dygraph.to_variable(input_1) + x0, x1, x2 = paddle.chunk(input, chunks=3, axis=1) + x0_out = x0.numpy() + x1_out = x1.numpy() + x2_out = x2.numpy() + ex_x0, ex_x1, ex_x2 = np.array_split(input_1, 3, axis=1) + self.assertTrue(np.allclose(ex_x0, x0_out)) + self.assertTrue(np.allclose(ex_x1, x1_out)) + self.assertTrue(np.allclose(ex_x2, x2_out)) + + def test_axis_tensor_input(self): + with fluid.dygraph.guard(): + input_1 = np.random.random([4, 6, 6]).astype("int32") + # input is a variable which shape is [4, 6, 6] + input = fluid.dygraph.to_variable(input_1) + num1 = paddle.full(shape=[1], fill_value=1, dtype='int32') + x0, x1, x2 = paddle.chunk(input, chunks=3, axis=num1) + x0_out = x0.numpy() + x1_out = x1.numpy() + x2_out = x2.numpy() + ex_x0, ex_x1, ex_x2 = np.array_split(input_1, 3, axis=1) + self.assertTrue(np.allclose(ex_x0, x0_out)) + self.assertTrue(np.allclose(ex_x1, x1_out)) + self.assertTrue(np.allclose(ex_x2, x2_out)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_default_dtype.py b/python/paddle/fluid/tests/unittests/test_default_dtype.py new file mode 100644 index 0000000000000..eba4ec3420f2d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_default_dtype.py @@ -0,0 +1,41 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np +from paddle.framework import set_default_dtype, get_default_dtype +import paddle +import paddle.fluid as fluid +from paddle.fluid.dygraph import Linear +import paddle.fluid.core as core +from paddle import to_variable + + +class TestDefaultType(unittest.TestCase): + def check_default(self): + self.assertEqual("float32", get_default_dtype()) + + def test_api(self): + self.check_default() + + set_default_dtype("float64") + self.assertEqual("float64", get_default_dtype()) + + set_default_dtype(np.int32) + self.assertEqual("int32", get_default_dtype()) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_executor_check_feed.py b/python/paddle/fluid/tests/unittests/test_executor_check_feed.py new file mode 100644 index 0000000000000..6b1e3c5a28a54 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_executor_check_feed.py @@ -0,0 +1,84 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +import numpy +import paddle.fluid.core as core +import paddle.fluid as fluid + + +class TestExecutor(unittest.TestCase): + def net(self): + lr = fluid.data(name="lr", shape=[1], dtype='float32') + x = fluid.data(name="x", shape=[None, 1], dtype='float32') + y = fluid.data(name="y", shape=[None, 1], dtype='float32') + y_predict = fluid.layers.fc(input=x, size=1, act=None) + + cost = fluid.layers.square_error_cost(input=y_predict, label=y) + avg_cost = fluid.layers.mean(cost) + + opt = fluid.optimizer.Adam(learning_rate=lr) + opt.minimize(avg_cost) + + return lr, avg_cost + + def test_program_check_feed(self): + main_program = fluid.Program() + startup_program = fluid.Program() + scope = fluid.Scope() + with fluid.program_guard(main_program, startup_program): + with fluid.scope_guard(scope): + cpu = fluid.CPUPlace() + exe = fluid.Executor(cpu) + lr, cost = self.net() + exe.run(startup_program) + train_data = [[1.0], [2.0], [3.0], [4.0]] + y_true = [[2.0], [4.0], [6.0], [8.0]] + a = 0 + with self.assertRaises(ValueError): + exe.run(feed={'x': train_data, + 'lr': a}, + fetch_list=[lr, cost], + return_numpy=False, + use_prune=True) + + def test_compiled_program_check_feed(self): + main_program = fluid.Program() + startup_program = fluid.Program() + scope = fluid.Scope() + with fluid.program_guard(main_program, startup_program): + with fluid.scope_guard(scope): + cpu = fluid.CPUPlace() + exe = fluid.Executor(cpu) + lr, cost = self.net() + exe.run(startup_program) + compiled_prog = fluid.CompiledProgram( + main_program).with_data_parallel(loss_name=cost.name) + train_data = [[1.0], [2.0], [3.0], [4.0]] + y_true = [[2.0], [4.0], [6.0], [8.0]] + a = 0 + with self.assertRaises(ValueError): + exe.run(compiled_prog, + feed={'x': train_data, + 'lr': a}, + fetch_list=[lr, cost], + return_numpy=False, + use_prune=True) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_expand_as_v2_op.py b/python/paddle/fluid/tests/unittests/test_expand_as_v2_op.py new file mode 100755 index 0000000000000..0ccb725870cde --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_expand_as_v2_op.py @@ -0,0 +1,121 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle +import paddle.fluid as fluid + + +class TestExpandAsOpRank1(OpTest): + def setUp(self): + self.op_type = "expand_as_v2" + x = np.random.rand(100).astype("float64") + target_tensor = np.random.rand(2, 100).astype("float64") + self.inputs = {'X': x, 'target_tensor': target_tensor} + self.attrs = {} + bcast_dims = [2, 1] + output = np.tile(self.inputs['X'], bcast_dims) + self.outputs = {'Out': output} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestExpandAsOpRank2(OpTest): + def setUp(self): + self.op_type = "expand_as_v2" + x = np.random.rand(10, 12).astype("float64") + target_tensor = np.random.rand(10, 12).astype("float64") + self.inputs = {'X': x, 'target_tensor': target_tensor} + self.attrs = {} + bcast_dims = [1, 1] + output = np.tile(self.inputs['X'], bcast_dims) + self.outputs = {'Out': output} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestExpandAsOpRank3(OpTest): + def setUp(self): + self.op_type = "expand_as_v2" + x = np.random.rand(2, 3, 20).astype("float64") + target_tensor = np.random.rand(2, 3, 20).astype("float64") + self.inputs = {'X': x, 'target_tensor': target_tensor} + self.attrs = {} + bcast_dims = [1, 1, 1] + output = np.tile(self.inputs['X'], bcast_dims) + self.outputs = {'Out': output} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestExpandAsOpRank4(OpTest): + def setUp(self): + self.op_type = "expand_as_v2" + x = np.random.rand(1, 1, 7, 16).astype("float64") + target_tensor = np.random.rand(4, 6, 7, 16).astype("float64") + self.inputs = {'X': x, 'target_tensor': target_tensor} + self.attrs = {} + bcast_dims = [4, 6, 1, 1] + output = np.tile(self.inputs['X'], bcast_dims) + self.outputs = {'Out': output} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +# Test python API +class TestExpandAPI(unittest.TestCase): + def test_api(self): + input1 = np.random.random([12, 14]).astype("float32") + input2 = np.random.random([2, 12, 14]).astype("float32") + x = fluid.layers.data( + name='x', shape=[12, 14], append_batch_size=False, dtype="float32") + + y = fluid.layers.data( + name='target_tensor', + shape=[2, 12, 14], + append_batch_size=False, + dtype="float32") + + out_1 = paddle.expand_as(x, y=y) + + exe = fluid.Executor(place=fluid.CPUPlace()) + res_1 = exe.run(fluid.default_main_program(), + feed={"x": input1, + "target_tensor": input2}, + fetch_list=[out_1]) + assert np.array_equal(res_1[0], np.tile(input1, (2, 1, 1))) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_expand_v2_op.py b/python/paddle/fluid/tests/unittests/test_expand_v2_op.py index 94669bc28f955..aee6ca249f535 100644 --- a/python/paddle/fluid/tests/unittests/test_expand_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_expand_v2_op.py @@ -193,7 +193,7 @@ def test_errors(self): x2 = fluid.layers.data(name='x2', shape=[4], dtype="uint8") self.assertRaises(TypeError, paddle.tensor.expand, x2, shape) x3 = fluid.layers.data(name='x3', shape=[4], dtype="bool") - x3.stop_gradient = True + x3.stop_gradient = False self.assertRaises(ValueError, paddle.tensor.expand, x3, shape) diff --git a/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py b/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py index 0812b02b47db7..b30e0a6775ea9 100644 --- a/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py +++ b/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py @@ -31,45 +31,45 @@ def dequantize_max_abs(x, scale, max_range): return y -def channel_wise_quantize_max_abs(x, quant_bit=8, use_second_dim=False): +def channel_wise_quantize_max_abs(x, quant_bit=8, quant_axis=0): + assert quant_axis in [0, 1], "The quant_axis should be 0 or 1." scales = [] - if not use_second_dim: + y = x.copy() + max_range = math.pow(2, quant_bit - 1) - 1 + if quant_axis == 0: for i in range(x.shape[0]): - scales.append(np.max(np.abs(x[i])).astype("float32")) - y = x.copy() - max_range = math.pow(2, quant_bit - 1) - 1 - for i, scale in enumerate(scales): - y[i] = np.round(x[i] / scale * max_range) - else: - for i in range(x.shape[0]): - s = [] - for j in range(x.shape[1]): - s.append(np.max(np.abs(x[i][j])).astype("float32")) - scales.append(s) - scales = np.amax(np.array(scales), axis=0) - y = x.copy() - max_range = math.pow(2, quant_bit - 1) - 1 - for i in range(x.shape[0]): - for j, scale in enumerate(scales): - y[i][j] = np.round(x[i][j] / scale * max_range) + scale = np.max(np.abs(x[i])).astype("float32") + scales.append(scale) + y[i] = np.round(x[i] * max_range / scale) + elif quant_axis == 1: + for i in range(x.shape[1]): + scale = np.max(np.abs(x[:, i])).astype("float32") + scales.append(scale) + y[:, i] = np.round(x[:, i] * max_range / scale) return y, scales def channel_wise_dequantize_max_abs(x, scales, quant_bits, + quant_axis, activation_scale=None): - if activation_scale is None: - y = x.copy() - for i in range(x.shape[0]): - y[i] = (scales[i] / (math.pow(2, quant_bits[0] - 1) - 1)) * x[i] + assert quant_axis in [0, 1], "The quant_axis should be 0 or 1." + + if isinstance(quant_bits, list): + max_range = math.pow(2, quant_bits[0] - 1) - 1 else: - y = x.copy() + max_range = math.pow(2, quant_bits - 1) - 1 + y = x.copy() + if quant_axis == 0: for i in range(x.shape[0]): - for j in range(x.shape[1]): - y[i][j] = (scales[j] / - (math.pow(2, quant_bits[0] - 1) - 1)) * x[i][j] - y *= activation_scale / (math.pow(2, quant_bits[1] - 1) - 1) + y[i] = x[i] * scales[i] / max_range + elif quant_axis == 1: + for i in range(x.shape[1]): + y[:, i] = x[:, i] * scales[i] / max_range + + if activation_scale is not None: + y = y * activation_scale / (math.pow(2, quant_bits[1] - 1) - 1) return y @@ -83,9 +83,8 @@ def setUp(self): self.set_args() self.op_type = "fake_channel_wise_dequantize_max_abs" x = np.random.randn(4, 3, 64, 64).astype(self.data_type) - yq, scales = channel_wise_quantize_max_abs( - x, self.quant_bits[0], use_second_dim=True) - ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits, + yq, scales = channel_wise_quantize_max_abs(x, self.quant_bits[0], 1) + ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits, 1, self.activation_scale) self.inputs = { @@ -105,25 +104,39 @@ class TestFakeChannelWiseDequantizeMaxAbsOpOneScale(OpTest): def set_args(self): self.quant_bits = [8] self.data_type = "float32" + self.quant_axis = 0 def setUp(self): self.set_args() self.op_type = "fake_channel_wise_dequantize_max_abs" x = np.random.randn(4, 3, 64, 64).astype(self.data_type) - yq, scales = channel_wise_quantize_max_abs(x, self.quant_bits[0]) - ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits) + yq, scales = channel_wise_quantize_max_abs(x, self.quant_bits[0], + self.quant_axis) + ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits, + self.quant_axis) self.inputs = { 'X': yq, 'Scales': [("scales0", np.array(scales).astype(self.data_type))] } - self.attrs = {'quant_bits': self.quant_bits} + self.attrs = { + 'quant_bits': self.quant_bits, + 'quant_axis': self.quant_axis + } self.outputs = {'Out': ydq} def test_check_output(self): self.check_output() +class TestFakeChannelWiseDequantizeMaxAbsOpOneScale1( + TestFakeChannelWiseDequantizeMaxAbsOpOneScale): + def set_args(self): + self.quant_bits = [8] + self.data_type = "float32" + self.quant_axis = 1 + + class TestFakeDequantizeMaxAbsOp(OpTest): def set_args(self): self.num_bits = 8 diff --git a/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py index 1c8335e3bceab..7835fd3f53ddb 100644 --- a/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py +++ b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py @@ -72,28 +72,62 @@ def test_check_output(self): class TestFakeChannelWiseQuantizeOp(OpTest): def setUp(self): + self.set_arg() + assert self.quant_axis in [0, 1], "quant_axis should be 0 or 1." + self.op_type = "fake_channel_wise_quantize_abs_max" - self.attrs = {'bit_length': 8} - self.inputs = { - 'X': np.random.random((4, 3, 64, 64)).astype("float32"), - } + self.attrs = {'bit_length': 8, 'quant_axis': self.quant_axis} + scales = [] - for i in range(self.inputs['X'].shape[0]): - scales.append(np.max(np.abs(self.inputs['X'][i])).astype("float32")) outputs = self.inputs['X'].copy() - for i, scale in enumerate(scales): - outputs[i] = np.round(outputs[i] / scale * ( - (1 << (self.attrs['bit_length'] - 1)) - 1)) + bnt = (1 << (self.attrs['bit_length'] - 1)) - 1 + if self.quant_axis == 0: + for i in range(self.inputs['X'].shape[0]): + scale_v = np.max(np.abs(self.inputs['X'][i])).astype("float32") + scales.append(scale_v) + outputs[i] = np.round(outputs[i] / scale_v * bnt) + elif self.quant_axis == 1: + for i in range(self.inputs['X'].shape[1]): + scale_v = np.max(np.abs(self.inputs['X'][:, i])).astype( + "float32") + scales.append(scale_v) + outputs[:, i] = np.round(outputs[:, i] / scale_v * bnt) self.outputs = { 'Out': outputs, 'OutScale': np.array(scales).astype("float32"), } + def set_arg(self): + self.quant_axis = 0 + self.inputs = { + 'X': np.random.random((20, 15, 6, 6)).astype("float32"), + } + def test_check_output(self): self.check_output() +class TestFakeChannelWiseQuantizeOp1(TestFakeChannelWiseQuantizeOp): + def set_quant_axis(self): + self.quant_axis = 1 + self.inputs = { + 'X': np.random.random((15, 20, 5, 5)).astype("float32"), + } + + +class TestFakeChannelWiseQuantizeOp2(TestFakeChannelWiseQuantizeOp): + def set_quant_axis(self): + self.quant_axis = 0 + self.inputs = {'X': np.random.random((30, 15)).astype("float32"), } + + +class TestFakeChannelWiseQuantizeOp3(TestFakeChannelWiseQuantizeOp): + def set_quant_axis(self): + self.quant_axis = 1 + self.inputs = {'X': np.random.random((30, 15)).astype("float32"), } + + class TestFakeQuantizeRangeAbsMaxOp(OpTest): def setUp(self): self.op_type = "fake_quantize_range_abs_max" diff --git a/python/paddle/fluid/tests/unittests/test_fleet_amp_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_amp_meta_optimizer.py index 0e19069d5c04e..22a1434ae251a 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_amp_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_amp_meta_optimizer.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import paddle.distributed.fleet as fleet +import paddle.distributed.fleet.base.role_maker as role_maker import unittest import paddle import os @@ -23,8 +25,6 @@ def setUp(self): os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001" def test_amp_optimizer(self): - import paddle.distributed.fleet as fleet - import paddle.fluid.incubate.fleet.base.role_maker as role_maker role = role_maker.PaddleCloudRoleMaker(is_collective=True) fleet.init(role) input_x = paddle.fluid.layers.data( diff --git a/python/paddle/fluid/tests/unittests/test_fleet_base.py b/python/paddle/fluid/tests/unittests/test_fleet_base.py index 3a79b694cad5b..ca657a5a619b6 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_base.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_base.py @@ -14,6 +14,8 @@ import unittest import paddle +import paddle.distributed.fleet as fleet +import paddle.distributed.fleet.base.role_maker as role_maker import os @@ -26,67 +28,49 @@ def setUp(self): "127.0.0.1:36001,127.0.0.2:36001" def test_init(self): - import paddle.distributed.fleet as fleet - import paddle.fluid.incubate.fleet.base.role_maker as role_maker role = role_maker.PaddleCloudRoleMaker(is_collective=True) fleet.init(role) def test_is_first_worker(self): - import paddle.distributed.fleet as fleet - import paddle.fluid.incubate.fleet.base.role_maker as role_maker role = role_maker.PaddleCloudRoleMaker(is_collective=True) fleet.init(role) if fleet.is_first_worker(): print("test fleet first worker done.") def test_worker_index(self): - import paddle.distributed.fleet as fleet - import paddle.fluid.incubate.fleet.base.role_maker as role_maker role = role_maker.PaddleCloudRoleMaker(is_collective=True) fleet.init(role) print(fleet.worker_index()) def test_worker_num(self): - import paddle.distributed.fleet as fleet - import paddle.fluid.incubate.fleet.base.role_maker as role_maker role = role_maker.PaddleCloudRoleMaker(is_collective=True) fleet.init(role) print(fleet.worker_num()) def test_is_worker(self): - import paddle.distributed.fleet as fleet - import paddle.fluid.incubate.fleet.base.role_maker as role_maker role = role_maker.PaddleCloudRoleMaker(is_collective=True) fleet.init(role) if fleet.is_worker(): print("test fleet is worker") def test_worker_endpoints(self): - import paddle.distributed.fleet as fleet - import paddle.fluid.incubate.fleet.base.role_maker as role_maker role = role_maker.PaddleCloudRoleMaker(is_collective=True) fleet.init(role) print(fleet.worker_endpoints(to_string=True)) def test_server_num(self): - import paddle.distributed.fleet as fleet - import paddle.fluid.incubate.fleet.base.role_maker as role_maker role = role_maker.PaddleCloudRoleMaker(is_collective=True) fleet.init(role) if fleet.is_server(): print("fleet server num: {}".format(fleet.server_num())) def test_server_index(self): - import paddle.distributed.fleet as fleet - import paddle.fluid.incubate.fleet.base.role_maker as role_maker role = role_maker.PaddleCloudRoleMaker(is_collective=True) fleet.init(role) if fleet.is_server(): print("fleet server index: {}".format(fleet.server_index())) def test_server_endpoints(self): - import paddle.distributed.fleet as fleet - import paddle.fluid.incubate.fleet.base.role_maker as role_maker role = role_maker.PaddleCloudRoleMaker(is_collective=True) fleet.init(role) if fleet.is_server(): @@ -94,55 +78,41 @@ def test_server_endpoints(self): fleet.server_endpoints(to_string=True))) def test_is_server(self): - import paddle.distributed.fleet as fleet - import paddle.fluid.incubate.fleet.base.role_maker as role_maker role = role_maker.PaddleCloudRoleMaker(is_collective=True) fleet.init(role) if fleet.is_server(): print("test fleet is server") def test_util(self): - import paddle.distributed.fleet as fleet - import paddle.fluid.incubate.fleet.base.role_maker as role_maker role = role_maker.PaddleCloudRoleMaker(is_collective=True) fleet.init(role) self.assertEqual(fleet.util, None) def test_barrier_worker(self): - import paddle.distributed.fleet as fleet - import paddle.fluid.incubate.fleet.base.role_maker as role_maker role = role_maker.PaddleCloudRoleMaker(is_collective=True) fleet.init(role) if fleet.is_worker(): fleet.barrier_worker() def test_init_worker(self): - import paddle.distributed.fleet as fleet - import paddle.fluid.incubate.fleet.base.role_maker as role_maker role = role_maker.PaddleCloudRoleMaker(is_collective=True) fleet.init(role) if fleet.is_worker(): fleet.init_worker() def test_run_server(self): - import paddle.distributed.fleet as fleet - import paddle.fluid.incubate.fleet.base.role_maker as role_maker role = role_maker.PaddleCloudRoleMaker(is_collective=True) fleet.init(role) if fleet.is_worker(): fleet.run_worker() def test_stop_worker(self): - import paddle.distributed.fleet as fleet - import paddle.fluid.incubate.fleet.base.role_maker as role_maker role = role_maker.PaddleCloudRoleMaker(is_collective=True) fleet.init(role) if fleet.is_worker(): fleet.stop_worker() def test_distributed_optimizer(self): - import paddle.distributed.fleet as fleet - import paddle.fluid.incubate.fleet.base.role_maker as role_maker role = role_maker.PaddleCloudRoleMaker(is_collective=True) fleet.init(role) @@ -150,10 +120,6 @@ def test_distributed_optimizer(self): optimizer = fleet.distributed_optimizer(optimizer) def test_minimize(self): - import paddle - import paddle.distributed.fleet as fleet - import paddle.fluid.incubate.fleet.base.role_maker as role_maker - input_x = paddle.fluid.layers.data( name="x", shape=[32], dtype='float32') input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64') diff --git a/python/paddle/fluid/tests/unittests/test_fleet_dgc_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_dgc_meta_optimizer.py index 1d211a77008b4..b43687ce1cdab 100755 --- a/python/paddle/fluid/tests/unittests/test_fleet_dgc_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_dgc_meta_optimizer.py @@ -17,7 +17,7 @@ from paddle import fluid import os import paddle.distributed.fleet as fleet -import paddle.fluid.incubate.fleet.base.role_maker as role_maker +import paddle.distributed.fleet.base.role_maker as role_maker class TestFleetDGCOptimizer(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py b/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py index 45dd461237ba5..40e0168e1ac93 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py @@ -289,6 +289,11 @@ def test_execution_strategy(self): strategy = paddle.distributed.fleet.DistributedStrategy() strategy.execution_strategy = exe_strategy + def test_unknown_strategy(self): + strategy = paddle.distributed.fleet.DistributedStrategy() + with self.assertRaises(TypeError): + strategy.unknown_key = 'UNK' + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fleet_gradient_merge_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_gradient_merge_meta_optimizer.py index 49ce09877f0a0..581f8becbbff1 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_gradient_merge_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_gradient_merge_meta_optimizer.py @@ -16,7 +16,7 @@ import paddle import os import paddle.distributed.fleet as fleet -import paddle.fluid.incubate.fleet.base.role_maker as role_maker +import paddle.distributed.fleet.base.role_maker as role_maker class TestFleetGradientMergeMetaOptimizer(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/test_fleet_graph_executor.py b/python/paddle/fluid/tests/unittests/test_fleet_graph_executor.py index d2e0112ba298c..4d92c6f70541d 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_graph_executor.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_graph_executor.py @@ -14,6 +14,8 @@ import unittest import paddle +import paddle.distributed.fleet as fleet +import paddle.distributed.fleet.base.role_maker as role_maker import os from launch_function_helper import launch_func @@ -39,8 +41,6 @@ def test_graph_execution_optimizer(self): } def node_func(): - import paddle.distributed.fleet as fleet - import paddle.fluid.incubate.fleet.base.role_maker as role_maker role = role_maker.PaddleCloudRoleMaker(is_collective=True) fleet.init(role) input_x = paddle.fluid.layers.data( diff --git a/python/paddle/fluid/tests/unittests/test_fleet_lamb_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_lamb_meta_optimizer.py index 8ad051924f274..134aea363b55e 100755 --- a/python/paddle/fluid/tests/unittests/test_fleet_lamb_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_lamb_meta_optimizer.py @@ -17,7 +17,7 @@ from paddle import fluid import os import paddle.distributed.fleet as fleet -import paddle.fluid.incubate.fleet.base.role_maker as role_maker +import paddle.distributed.fleet.base.role_maker as role_maker class TestFleetLambMetaOptimizer(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/test_fleet_lars_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_lars_meta_optimizer.py index 87c4823693e2e..b15db0b12d001 100755 --- a/python/paddle/fluid/tests/unittests/test_fleet_lars_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_lars_meta_optimizer.py @@ -17,7 +17,7 @@ from paddle import fluid import os import paddle.distributed.fleet as fleet -import paddle.fluid.incubate.fleet.base.role_maker as role_maker +import paddle.distributed.fleet.base.role_maker as role_maker class TestFleetLarsMetaOptimizer(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/test_fleet_localsgd_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_localsgd_meta_optimizer.py index f4bb870484949..86098d42b823b 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_localsgd_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_localsgd_meta_optimizer.py @@ -17,7 +17,7 @@ import os import paddle.distributed.fleet as fleet -import paddle.fluid.incubate.fleet.base.role_maker as role_maker +import paddle.distributed.fleet.base.role_maker as role_maker class TestFleetLocalSGDMetaOptimizer(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/test_fleet_pipeline_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_pipeline_meta_optimizer.py index d35f2fe5e6288..ca969bc4032b1 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_pipeline_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_pipeline_meta_optimizer.py @@ -25,7 +25,7 @@ def setUp(self): def test_pipeline_optimizer(self): import paddle.distributed.fleet as fleet - import paddle.fluid.incubate.fleet.base.role_maker as role_maker + import paddle.distributed.fleet.base.role_maker as role_maker role = role_maker.PaddleCloudRoleMaker(is_collective=True) fleet.init(role) with paddle.fluid.device_guard("cpu"): diff --git a/python/paddle/fluid/tests/unittests/test_fleet_recompute_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_recompute_meta_optimizer.py index f07c6421192a0..95e1c3a360257 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_recompute_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_recompute_meta_optimizer.py @@ -27,7 +27,7 @@ def setUp(self): def test_recompute_optimizer(self): import paddle.distributed.fleet as fleet - import paddle.fluid.incubate.fleet.base.role_maker as role_maker + import paddle.distributed.fleet.base.role_maker as role_maker role = role_maker.PaddleCloudRoleMaker(is_collective=True) fleet.init(role) input_x = paddle.fluid.layers.data( @@ -43,7 +43,7 @@ def test_recompute_optimizer(self): strategy = paddle.distributed.fleet.DistributedStrategy() strategy.recompute = True - strategy.recompute_configs = {"checkpoints": ["fc2"]} + strategy.recompute_configs = {"checkpoints": ["fc_1.tmp_0"]} optimizer = paddle.optimizer.SGD(learning_rate=0.01) optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) diff --git a/python/paddle/fluid/tests/unittests/test_fleet_rolemaker_new.py b/python/paddle/fluid/tests/unittests/test_fleet_rolemaker_new.py index f80d45ed5e09d..cf9b3e1e9a160 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_rolemaker_new.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_rolemaker_new.py @@ -34,6 +34,7 @@ def test_rolemaker_base(self): self.assertRaises(Exception, role.worker_index) self.assertRaises(Exception, role.server_index) self.assertRaises(Exception, role.role_id) + self.assertRaises(Exception, role.node_num) trainer_endpoints = role.get_trainer_endpoints() self.assertTrue(len(trainer_endpoints) == 0) @@ -80,10 +81,12 @@ def test_tr_rolemaker(self): worker_endpoints = ro.get_trainer_endpoints() self.assertEqual(worker_endpoints[0], '127.0.0.1:36001') self.assertEqual(ro.role_id(), 0) + self.assertEqual(ro.node_num(), 2) def test_tr_rolemaker_collective(self): ro = role_maker.PaddleCloudRoleMaker(is_collective=True) self.assertEqual(ro.worker_num(), 2) + self.assertEqual(ro.node_num(), 2) def test_ps_rolemaker(self): """Test ps rolemaker.""" diff --git a/python/paddle/fluid/tests/unittests/test_generator.py b/python/paddle/fluid/tests/unittests/test_generator.py new file mode 100644 index 0000000000000..6cc43d3d54982 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_generator.py @@ -0,0 +1,44 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test cloud role maker.""" + +from __future__ import print_function +import os +import unittest +import paddle.fluid.generator as generator +import time # temp for debug + + +class TestGenerator(unittest.TestCase): + """ + Test cases for cpu generator. + """ + + def test_basic_generator(self): + """Test basic generator.""" + gen = generator.Generator() + gen.manual_seed(123123143) + s = gen.initial_seed() + s = gen.seed() + st = gen.get_state() + gen.set_state(st) + gen.random() + gen.set_cpu_engine(gen.get_cpu_engine()) + + def test_basic_generator_error(self): + self.assertRaises(ValueError, generator.Generator, device="CUDA") + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_basic.py b/python/paddle/fluid/tests/unittests/test_imperative_basic.py index 8a88c2d673c4d..f83f8ef35215e 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_basic.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_basic.py @@ -21,6 +21,7 @@ from paddle.fluid import Linear from test_imperative_base import new_program_scope import paddle.fluid.dygraph_utils as dygraph_utils +from paddle.fluid.dygraph.layer_object_helper import LayerObjectHelper import paddle @@ -629,6 +630,16 @@ def test_append_activation_in_dygraph2(self): res2 = fluid.layers.sigmoid(a) self.assertTrue(np.allclose(res1.numpy(), res2.numpy())) + def test_append_activation_in_dygraph3(self): + a_np = np.random.random(size=(10, 20, 30)).astype(np.float32) + helper = LayerObjectHelper(fluid.unique_name.generate("test")) + func = helper.append_activation + with fluid.dygraph.guard(): + a = fluid.dygraph.to_variable(a_np) + res1 = func(a, act="sigmoid", use_cudnn=True) + res2 = fluid.layers.sigmoid(a) + self.assertTrue(np.array_equal(res1.numpy(), res2.numpy())) + def test_append_bias_in_dygraph_exception(self): with new_program_scope(): np_inp = np.random.random(size=(10, 20, 30)).astype(np.float32) diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 017992ecc84e4..4ce7bd693f3de 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -283,6 +283,24 @@ def test_layer_norm(self): with self.assertRaises(ValueError): lm(base.to_variable(inp)) + def test_SyncBatchNorm(self): + if core.is_compiled_with_cuda(): + with self.static_graph(): + t = layers.data(name='t', shape=[-1, 3, 5, 5], dtype='float32') + my_sync_bn = nn.SyncBatchNorm(3) + ret = my_sync_bn(t) + static_ret = self.get_static_graph_result( + feed={'t': np.ones( + [3, 3, 5, 5], dtype='float32')}, + fetch_list=[ret])[0] + + with self.dynamic_graph(): + t = np.ones([3, 3, 5, 5], dtype='float32') + my_syncbn = paddle.nn.SyncBatchNorm(3) + dy_ret = my_syncbn(base.to_variable(t)) + dy_ret_value = dy_ret.numpy() + self.assertTrue(np.array_equal(static_ret, static_ret)) + def test_relu(self): with self.static_graph(): t = layers.data(name='t', shape=[3, 3], dtype='float32') diff --git a/python/paddle/fluid/tests/unittests/test_log_softmax.py b/python/paddle/fluid/tests/unittests/test_log_softmax.py index 2b77624734d33..e3d7003ecedb6 100644 --- a/python/paddle/fluid/tests/unittests/test_log_softmax.py +++ b/python/paddle/fluid/tests/unittests/test_log_softmax.py @@ -14,93 +14,136 @@ import unittest import numpy as np -import paddle.fluid as fluid -import paddle.fluid.core as core -import paddle.nn as nn -import paddle.nn.functional as functional +from op_test import OpTest +import paddle +import paddle.nn.functional as F +np.random.seed(10) -def stable_softmax(x): + +def ref_log_softmax(x): shiftx = (x - np.max(x)) - exps = np.exp(shiftx) - return exps / np.sum(exps) + out = shiftx - np.log(np.exp(shiftx).sum()) + return out -def ref_log_softmax(x, axis=None, dtype=None): - x_t = x.copy() - if dtype is not None: - x_t = x_t.astype(dtype) - if axis is None: - axis = -1 - out = np.apply_along_axis(stable_softmax, axis, x_t) - return np.log(out) +def ref_log_softmax_grad(x, axis): + if axis < 0: + axis += len(x.shape) + out = np.apply_along_axis(ref_log_softmax, axis, x) + axis_dim = x.shape[axis] + dout = np.full_like(x, fill_value=1. / x.size) + dx = dout - np.exp(out) * dout.copy().sum(axis=axis, keepdims=True).repeat( + axis_dim, axis=axis) + return dx -class TestNNLogSoftmaxAPI(unittest.TestCase): +class TestLogSoftmaxOp(OpTest): def setUp(self): - self.init_data() + self.op_type = 'log_softmax' + self.dtype = 'float64' + self.shape = [2, 3, 4, 5] + self.axis = -1 + self.set_attrs() - def init_data(self): - self.x_shape = [2, 3, 4, 5] - self.x = np.random.uniform(-1, 1, self.x_shape).astype(np.float32) + x = np.random.uniform(0.1, 1., self.shape).astype(self.dtype) + out = np.apply_along_axis(ref_log_softmax, self.axis, x) + self.x_grad = ref_log_softmax_grad(x, self.axis) + + self.inputs = {'X': x} + self.outputs = {'Out': out} + self.attrs = {'axis': self.axis} + + def set_attrs(self): + pass + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], ['Out'], user_defined_grads=[self.x_grad]) + + +class TestLogSoftmaxShape(TestLogSoftmaxOp): + def set_attrs(self): + self.shape = [12, 10] - def check_api(self, place=fluid.CPUPlace(), axis=None): - ref_out = ref_log_softmax(self.x, axis) - main_program = fluid.Program() - mylogsoftmax = nn.LogSoftmax(axis) - with fluid.program_guard(main_program): - x = fluid.data(name='x', shape=self.x_shape) - y = mylogsoftmax(x) - exe = fluid.Executor(place) - out = exe.run(main_program, feed={'x': self.x}, fetch_list=[y]) +class TestLogSoftmaxAxis(TestLogSoftmaxOp): + def set_attrs(self): + self.axis = 1 + + +class TestNNLogSoftmaxAPI(unittest.TestCase): + def setUp(self): + self.x_shape = [2, 3, 4, 5] + self.x = np.random.uniform(-1., 1., self.x_shape).astype(np.float32) + self.place = paddle.CUDAPlace(0) \ + if paddle.fluid.core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + def check_api(self, axis=-1): + ref_out = np.apply_along_axis(ref_log_softmax, axis, self.x) + + logsoftmax = paddle.nn.LogSoftmax(axis) + # test static api + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.data(name='x', shape=self.x_shape) + y = logsoftmax(x) + exe = paddle.static.Executor(self.place) + out = exe.run(feed={'x': self.x}, fetch_list=[y]) self.assertTrue(np.allclose(out[0], ref_out)) - with fluid.dygraph.guard(place): - x = fluid.dygraph.to_variable(self.x) - y = mylogsoftmax(x) + # test dygrapg api + paddle.disable_static() + x = paddle.to_variable(self.x) + y = logsoftmax(x) self.assertTrue(np.allclose(y.numpy(), ref_out)) + paddle.enable_static() def test_check_api(self): - places = [fluid.CPUPlace()] - if core.is_compiled_with_cuda(): - places.append(fluid.CUDAPlace(0)) - for place in places: - for axis in [None, 2]: - self.check_api(place, axis) + for axis in [-1, 1]: + self.check_api(axis) class TestNNFunctionalLogSoftmaxAPI(unittest.TestCase): def setUp(self): - self.init_data() - - def init_data(self): self.x_shape = [2, 3, 4, 5] self.x = np.random.uniform(-1, 1, self.x_shape).astype(np.float32) - - def check_api(self, place=fluid.CPUPlace(), axis=None, dtype=None): - ref_out = ref_log_softmax(self.x, axis, dtype) - main_program = fluid.Program() - mylogsoftmax = nn.LogSoftmax(axis) - with fluid.program_guard(main_program): - x = fluid.data(name='x', shape=self.x_shape) - y = functional.log_softmax(x, axis, dtype) - exe = fluid.Executor(place) - out = exe.run(main_program, feed={'x': self.x}, fetch_list=[y]) + self.place = paddle.CUDAPlace(0) \ + if paddle.fluid.core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + def check_api(self, axis=-1, dtype=None): + x = self.x.copy() + if dtype is not None: + x = x.astype(dtype) + ref_out = np.apply_along_axis(ref_log_softmax, axis, x) + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.data(name='x', shape=self.x_shape) + y = F.log_softmax(x, axis, dtype) + exe = paddle.static.Executor(self.place) + out = exe.run(feed={'x': self.x}, fetch_list=[y]) self.assertTrue(np.allclose(out[0], ref_out)) - with fluid.dygraph.guard(place): - x = fluid.dygraph.to_variable(self.x) - y = functional.log_softmax(x, axis, dtype) - self.assertTrue(np.allclose(y.numpy(), ref_out)) + paddle.disable_static() + x = paddle.to_variable(self.x) + y = F.log_softmax(x, axis, dtype) + self.assertTrue(np.allclose(y.numpy(), ref_out), True) + paddle.enable_static() def test_check_api(self): - places = [fluid.CPUPlace()] - if core.is_compiled_with_cuda(): - places.append(fluid.CUDAPlace(0)) - for place in places: - self.check_api(place, None, None) - self.check_api(place, None, np.float64) + for axis in [-1, 1]: + self.check_api(axis) + self.check_api(-1, 'float64') + + def test_errors(self): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.data(name='X1', shape=[100], dtype='int32') + self.assertRaises(TypeError, F.log_softmax, x) + + x = paddle.data(name='X2', shape=[100], dtype='float32') + self.assertRaises(TypeError, F.log_softmax, x, dtype='int32') if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_lstm_cudnn_op.py b/python/paddle/fluid/tests/unittests/test_lstm_cudnn_op.py index d4189eca03697..90430bbce4d18 100644 --- a/python/paddle/fluid/tests/unittests/test_lstm_cudnn_op.py +++ b/python/paddle/fluid/tests/unittests/test_lstm_cudnn_op.py @@ -20,15 +20,14 @@ import paddle.fluid.core as core from op_test import OpTest import paddle.fluid as fluid +import paddle.fluid.layers as layers SIGMOID_THRESHOLD_MIN = -40.0 SIGMOID_THRESHOLD_MAX = 13.0 EXP_MAX_INPUT = 40.0 -def lstm_naive( - input, - w, ): +def lstm_naive(input, w): seq_len, batch_size, hidden_size = input.shape offset = 0 @@ -86,8 +85,8 @@ def tanh(x): return (2. / (1. + np.exp(y))) - 1. output = [] - pre_h = np.zeros((batch_size, hidden_size), dtype=input.dtype) - pre_c = np.zeros((batch_size, hidden_size), dtype=input.dtype) + pre_h = np.zeros((1, batch_size, hidden_size), dtype=input.dtype) + pre_c = np.zeros((1, batch_size, hidden_size), dtype=input.dtype) for i in range(seq_len): emb_1 = input[i] @@ -110,7 +109,6 @@ def tanh(x): output = np.concatenate(output, -1) output = output.reshape((batch_size, -1, hidden_size)) - output = output.transpose((1, 0, 2)) return output, pre_h, pre_c @@ -119,11 +117,12 @@ def tanh(x): @unittest.skipIf(not core.is_compiled_with_cuda(), "core is not compiled with CUDA") class TestCUDNNLstmOp(OpTest): + # TODO(GaoWei8):when input dtype is fp64, precision threshold should be removed. def setUp(self): self.op_type = "cudnn_lstm" - self.dtype = np.float32 + self.dtype = np.float64 - num_steps = 20 + seq_length = 20 batch_size = 5 hidden_size = 20 @@ -133,33 +132,24 @@ def setUp(self): weight_size += hidden_size * 8 input = np.random.uniform( - low=-0.1, high=0.1, size=(num_steps, batch_size, + low=-0.1, high=0.1, size=(seq_length, batch_size, hidden_size)).astype(self.dtype) flat_w = np.random.uniform( low=-0.1, high=0.1, size=(weight_size)).astype(self.dtype) output, last_hidden, last_cell = lstm_naive(input, flat_w) - init_h = np.zeros((batch_size, hidden_size), dtype=np.float32) - init_c = np.zeros((batch_size, hidden_size), dtype=np.float32) - scope = core.Scope() - program = fluid.Program() - block = program.global_block() - - cache_temp = block.create_var( - name="Cache", - persistable=True, - type=core.VarDesc.VarType.RAW, - stop_gradient=True) + init_h = np.zeros((1, batch_size, hidden_size), dtype=np.float64) + init_c = np.zeros((1, batch_size, hidden_size), dtype=np.float64) + state_out = np.ndarray((300)).astype("uint8") + self.inputs = { - 'Input': OpTest.np_dtype_to_fluid_dtype(input), - 'W': OpTest.np_dtype_to_fluid_dtype(flat_w), - 'InitH': OpTest.np_dtype_to_fluid_dtype(init_h), - 'InitC': OpTest.np_dtype_to_fluid_dtype(init_c), + 'Input': input, + 'W': flat_w, + 'InitH': init_h, + 'InitC': init_c } - self.cache_name_list = ['Cache'] self.attrs = { - 'max_len': num_steps, 'dropout_prob': 0.0, 'is_bidirec': False, 'input_size': hidden_size, @@ -168,22 +158,61 @@ def setUp(self): } self.outputs = { 'Out': output, - "last_h": last_hidden, - 'last_c': last_cell + "LastH": last_hidden, + 'LastC': last_cell, + 'Reserve': np.ndarray((400)).astype("uint8"), + 'StateOut': state_out } def test_output_with_place(self): # depend on the scope structure place = core.CUDAPlace(0) - self.check_output_with_place(place, atol=1e-5, check_dygraph=False) + self.check_output_with_place( + place, no_check_set=['Reserve', 'StateOut']) def test_grad_with_place(self): # depend on the scope structure place = core.CUDAPlace(0) self.check_grad_with_place( place, - set(['Input', 'W', 'InitH', 'InitC']), ['Out', 'last_h', 'last_c'], - check_dygraph=False) + set(['Input', 'W', 'InitH', 'InitC']), ['Out', 'LastH', 'LastC'], + max_relative_error=1e-4) + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestCUDNNlstmAPI(unittest.TestCase): + def test_lstm(self): + seq_len = 20 + batch_size = 5 + hidden_size = 20 + dropout_prob = 0.0 + num_layers = 1 + input = fluid.data( + name='input', + shape=[seq_len, batch_size, hidden_size], + dtype='float64') + init_h = layers.fill_constant([num_layers, batch_size, hidden_size], + 'float64', 0.0) + init_c = layers.fill_constant([num_layers, batch_size, hidden_size], + 'float64', 0.0) + rnn_out, last_h, last_c = layers.lstm(input, init_h, init_c, seq_len, + hidden_size, num_layers, + dropout_prob) + exe = fluid.Executor(fluid.CUDAPlace(0)) + exe.run(fluid.default_startup_program()) + input_i = np.random.uniform( + low=-0.1, high=0.1, size=(seq_len, batch_size, + hidden_size)).astype("float64") + out = exe.run(fluid.default_main_program(), + feed={'input': input_i}, + fetch_list=[rnn_out, last_h, last_c, 'cudnn_lstm_0.w_0']) + + output, last_hidden, last_cell = lstm_naive(input_i, out[3]) + + self.assertTrue(np.allclose(output, out[0], atol=1e-5)) + self.assertTrue(np.allclose(last_hidden, out[1], atol=1e-5)) + self.assertTrue(np.allclose(last_cell, out[2], atol=1e-5)) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_mean_op.py b/python/paddle/fluid/tests/unittests/test_mean_op.py index 3799640b98800..a2befb4a29a0f 100644 --- a/python/paddle/fluid/tests/unittests/test_mean_op.py +++ b/python/paddle/fluid/tests/unittests/test_mean_op.py @@ -86,6 +86,7 @@ def setUp(self): else paddle.CPUPlace() def test_api_static(self): + paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): x = paddle.data('X', self.x_shape) out1 = paddle.mean(x) @@ -102,7 +103,9 @@ def test_api_static(self): for out in res: self.assertEqual(np.allclose(out, out_ref), True) - def test_api_imperative(self): + def test_api_dygraph(self): + paddle.disable_static(self.place) + def test_case(x, axis=None, keepdim=False): x_tensor = paddle.to_variable(x) out = paddle.mean(x_tensor, axis, keepdim) @@ -113,7 +116,6 @@ def test_case(x, axis=None, keepdim=False): out_ref = np.mean(x, axis, keepdims=keepdim) self.assertEqual(np.allclose(out.numpy(), out_ref), True) - paddle.disable_static(self.place) test_case(self.x) test_case(self.x, []) test_case(self.x, -1) @@ -125,6 +127,7 @@ def test_case(x, axis=None, keepdim=False): paddle.enable_static() def test_errors(self): + paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): x = paddle.data('X', [10, 12], 'int8') self.assertRaises(TypeError, paddle.mean, x) diff --git a/python/paddle/fluid/tests/unittests/test_normalize.py b/python/paddle/fluid/tests/unittests/test_normalize.py new file mode 100644 index 0000000000000..6595a29b24ae2 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_normalize.py @@ -0,0 +1,102 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle +import paddle.nn.functional as F +import paddle.fluid as fluid +import paddle.fluid.core as core +import numpy as np + + +def p_normalize(x, axis=1, p=2, epsilon=1e-12, keepdims=True): + if len(x.shape) == 1: + axis = 0 + xp = np.power(np.abs(x), p) + s = np.sum(xp, axis=axis, keepdims=keepdims) + r = np.maximum(np.power(s, 1.0 / p), epsilon) + return x / r + + +class TestNNFunctionalNormalize(unittest.TestCase): + def setUp(self): + self.input_np = np.random.random(size=(10, 10)).astype(np.float32) + self.input_np2 = np.array([0.0, 0.0]).astype(np.float32) + self.expected0 = p_normalize(self.input_np) + self.expected1 = p_normalize(self.input_np, p=1.5) + self.expected2 = p_normalize(self.input_np, axis=0) + self.expected3 = p_normalize(self.input_np2) + + def run_imperative(self): + x = paddle.to_variable(self.input_np) + y = F.normalize(x) + self.assertTrue(np.allclose(y.numpy(), self.expected0)) + + y = F.normalize(x, p=1.5) + self.assertTrue(np.allclose(y.numpy(), self.expected1)) + + y = F.normalize(x, axis=0) + self.assertTrue(np.allclose(y.numpy(), self.expected2)) + + x = paddle.to_variable(self.input_np2) + y = F.normalize(x) + self.assertTrue(np.allclose(y.numpy(), self.expected3)) + + def run_static(self, use_gpu=False): + x = paddle.data(name='input', shape=[10, 10], dtype='float32') + x2 = paddle.data(name='input2', shape=[2], dtype='float32') + result0 = F.normalize(x) + result1 = F.normalize(x, p=1.5) + result2 = F.normalize(x, axis=0) + result3 = F.normalize(x, name='aaa') + result4 = F.normalize(x2) + + place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + static_result = exe.run( + feed={"input": self.input_np, + "input2": self.input_np2}, + fetch_list=[result0, result1, result2, result4]) + + self.assertTrue(np.allclose(static_result[0], self.expected0)) + self.assertTrue(np.allclose(static_result[1], self.expected1)) + self.assertTrue(np.allclose(static_result[2], self.expected2)) + self.assertTrue('aaa' in result3.name) + self.assertTrue(np.allclose(static_result[3], self.expected3)) + + def test_cpu(self): + paddle.disable_static(place=paddle.fluid.CPUPlace()) + self.run_imperative() + paddle.enable_static() + + with fluid.program_guard(fluid.Program()): + self.run_static() + + def test_gpu(self): + if not fluid.core.is_compiled_with_cuda(): + return + + paddle.disable_static(place=paddle.fluid.CUDAPlace(0)) + self.run_imperative() + paddle.enable_static() + + with fluid.program_guard(fluid.Program()): + self.run_static(use_gpu=True) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sync_batch_norm.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sync_batch_norm.py new file mode 100644 index 0000000000000..5c34b35fc83a3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_sync_batch_norm.py @@ -0,0 +1,40 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +from test_dist_base import TestDistBase +import paddle.fluid as fluid + +import os +flag_name = os.path.splitext(__file__)[0] + + +class TestParallelDygraphMnist(TestDistBase): + def _setup_config(self): + self._sync_mode = False + self._nccl2_mode = True + self._dygraph = False #True + + def test_mnist(self): + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "parallel_dygraph_sync_batch_norm.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_random_seed.py b/python/paddle/fluid/tests/unittests/test_random_seed.py new file mode 100644 index 0000000000000..31120a73042c9 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_random_seed.py @@ -0,0 +1,119 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test cloud role maker.""" + +from __future__ import print_function +import os +import unittest +import paddle.fluid.generator as generator + +import time # temp for debug +import paddle.fluid as fluid +import numpy as np +import paddle +import paddle.fluid.core as core + + +class TestGeneratorSeed(unittest.TestCase): + """ + Test cases for cpu generator seed. + """ + + def test_generator_uniform_random_dygraph(self): + """Test Generator seed.""" + gen = generator.Generator() + + fluid.enable_dygraph() + + gen.manual_seed(12312321111) + x = fluid.layers.uniform_random([10], dtype="float32", min=0.0, max=1.0) + st1 = gen.get_state() + x1 = fluid.layers.uniform_random( + [10], dtype="float32", min=0.0, max=1.0) + gen.set_state(st1) + x2 = fluid.layers.uniform_random( + [10], dtype="float32", min=0.0, max=1.0) + gen.manual_seed(12312321111) + x3 = fluid.layers.uniform_random( + [10], dtype="float32", min=0.0, max=1.0) + x_np = x.numpy() + x1_np = x1.numpy() + x2_np = x2.numpy() + x3_np = x3.numpy() + + if not core.is_compiled_with_cuda(): + self.assertTrue(np.allclose(x1_np, x2_np)) + self.assertTrue(np.allclose(x_np, x3_np)) + + def test_generator_uniform_random_static(self): + + fluid.disable_dygraph() + + gen = generator.Generator() + gen.manual_seed(123123143) + + startup_program = fluid.Program() + train_program = fluid.Program() + with fluid.program_guard(train_program, startup_program): + # example 1: + # attr shape is a list which doesn't contain tensor Variable. + result_1 = fluid.layers.uniform_random(shape=[3, 4]) + result_2 = fluid.layers.uniform_random(shape=[3, 4]) + + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(startup_program) + out1 = exe.run(train_program, + feed={}, + fetch_list=[result_1, result_2]) + #gen.set_state(cur_state) + gen.manual_seed(123123143) + out2 = exe.run(train_program, + feed={}, + fetch_list=[result_1, result_2]) + + out1_res1 = np.array(out1[0]) + out1_res2 = np.array(out1[1]) + out2_res1 = np.array(out2[0]) + out2_res2 = np.array(out2[1]) + + if not core.is_compiled_with_cuda(): + self.assertTrue(np.allclose(out1_res1, out2_res1)) + self.assertTrue(np.allclose(out1_res2, out2_res2)) + self.assertTrue(not np.allclose(out1_res2, out1_res1)) + + def test_generator_randint_dygraph(self): + """Test Generator seed.""" + gen = generator.Generator() + + fluid.enable_dygraph() + + gen.manual_seed(12312321111) + x = paddle.randint(low=1) + st1 = gen.get_state() + x1 = paddle.randint(low=1) + gen.set_state(st1) + x2 = paddle.randint(low=1) + gen.manual_seed(12312321111) + x3 = paddle.randint(low=1) + x_np = x.numpy() + x1_np = x1.numpy() + x2_np = x2.numpy() + x3_np = x3.numpy() + if not core.is_compiled_with_cuda(): + self.assertTrue(np.allclose(x1_np, x2_np)) + self.assertTrue(np.allclose(x_np, x3_np)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py index 8fd118c019303..806b6b90e7e2d 100644 --- a/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_sync_batch_norm_op.py @@ -25,6 +25,7 @@ import paddle.fluid.core as core import paddle.fluid as fluid from paddle.fluid import compiler +from paddle.fluid import Program, program_guard from op_test import OpTest, _set_use_system_allocator @@ -202,5 +203,22 @@ def setUp(self): self.atol = 1e-2 +class TestDygraphSyncBatchNormAPIError(unittest.TestCase): + def test_errors(self): + if not core.is_compiled_with_cuda(): + return + + with program_guard(Program(), Program()): + my_sync_batch_norm = fluid.dygraph.SyncBatchNorm(10) + x1 = fluid.create_lod_tensor( + np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CUDAPlace(0)) + self.assertRaises(TypeError, my_sync_batch_norm, x1) + + # the input dtype of SyncBatchNorm must be float16 or float32 or float64 + # float16 only can be set on GPU place + x2 = fluid.layers.data(name='x2', shape=[3, 4, 5, 6], dtype="int32") + self.assertRaises(TypeError, my_sync_batch_norm, x2) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_tile_op.py b/python/paddle/fluid/tests/unittests/test_tile_op.py index 0fc8efbd49e56..73e6ff5dbd69c 100644 --- a/python/paddle/fluid/tests/unittests/test_tile_op.py +++ b/python/paddle/fluid/tests/unittests/test_tile_op.py @@ -22,7 +22,7 @@ from paddle.fluid import compiler, Program, program_guard -# Situation 1: repeat_times is a list(without tensor) +# Situation 1: repeat_times is a list (without tensor) class TestTileOpRank1(OpTest): def setUp(self): self.op_type = "tile" @@ -81,7 +81,7 @@ def init_data(self): self.repeat_times = (3, 2, 1, 2) -# Situation 2: repeat_times is a list(with tensor) +# Situation 2: repeat_times is a list (with tensor) class TestTileOpRank1_tensor_attr(OpTest): def setUp(self): self.op_type = "tile" @@ -162,7 +162,7 @@ def setUp(self): self.op_type = "tile" self.inputs = { 'X': np.random.randint( - 10, size=(2, 4, 5)).astype("int32") + 10, size=(4, 4, 5)).astype("int32") } self.attrs = {'repeat_times': [2, 1, 4]} output = np.tile(self.inputs['X'], (2, 1, 4)) @@ -211,38 +211,30 @@ def test_errors(self): x2 = fluid.layers.data(name='x2', shape=[4], dtype="uint8") self.assertRaises(TypeError, paddle.tile, x2, repeat_times) x3 = fluid.layers.data(name='x3', shape=[4], dtype="bool") - x3.stop_gradient = True + x3.stop_gradient = False self.assertRaises(ValueError, paddle.tile, x3, repeat_times) # Test python API class TestTileAPI(unittest.TestCase): def test_api(self): - input = np.random.random([12, 14]).astype("float32") - x = fluid.layers.data( - name='x', shape=[12, 14], append_batch_size=False, dtype="float32") - - positive_2 = fluid.layers.fill_constant([1], "int32", 2) - repeat_times = fluid.layers.data( - name="repeat_times", shape=[2], append_batch_size=False) - - out_1 = paddle.tile(x, repeat_times=[2, 3]) - out_2 = paddle.tile(x, repeat_times=[positive_2, 3]) - out_3 = paddle.tile(x, repeat_times=repeat_times) - - g0 = fluid.backward.calc_gradient(out_2, x) - - exe = fluid.Executor(place=fluid.CPUPlace()) - res_1, res_2, res_3 = exe.run(fluid.default_main_program(), - feed={ - "x": input, - "repeat_times": - np.array([1, 3]).astype("int32") - }, - fetch_list=[out_1, out_2, out_3]) - assert np.array_equal(res_1, np.tile(input, (2, 3))) - assert np.array_equal(res_2, np.tile(input, (2, 3))) - assert np.array_equal(res_3, np.tile(input, (1, 3))) + with fluid.dygraph.guard(): + np_x = np.random.random([12, 14]).astype("float32") + x = paddle.to_variable(np_x) + + positive_2 = np.array([2]).astype("int32") + positive_2 = paddle.to_variable(positive_2) + + repeat_times = np.array([2, 3]).astype("int32") + repeat_times = paddle.to_variable(repeat_times) + + out_1 = paddle.tile(x, repeat_times=[2, 3]) + out_2 = paddle.tile(x, repeat_times=[positive_2, 3]) + out_3 = paddle.tile(x, repeat_times=repeat_times) + + assert np.array_equal(out_1.numpy(), np.tile(np_x, (2, 3))) + assert np.array_equal(out_2.numpy(), np.tile(np_x, (2, 3))) + assert np.array_equal(out_3.numpy(), np.tile(np_x, (2, 3))) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py b/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py index b8258f3153a80..0de0eeb464ad7 100644 --- a/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/no_check_set_white_list.py @@ -26,4 +26,5 @@ 'cross_entropy2', 'seed', 'amp_check_finite_and_scale', + 'cudnn_lstm', ] diff --git a/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py b/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py index ce6868b5c70ae..5300ab935a340 100644 --- a/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py @@ -41,7 +41,8 @@ 'unpool', \ 'yolov3_loss', \ 'inverse', \ - 'bilateral_slice' + 'bilateral_slice',\ + 'cudnn_lstm' ] NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST = ['bilinear_interp'] diff --git a/python/paddle/framework/__init__.py b/python/paddle/framework/__init__.py index 215546293a406..f01dc01973a60 100644 --- a/python/paddle/framework/__init__.py +++ b/python/paddle/framework/__init__.py @@ -15,7 +15,8 @@ # TODO: import framework api under this directory __all__ = [ 'create_global_var', 'create_parameter', 'ParamAttr', 'Variable', - 'CPUPlace', 'CUDAPlace', 'CUDAPinnedPlace' + 'CPUPlace', 'CUDAPlace', 'CUDAPinnedPlace', 'get_default_dtype', + 'set_default_dtype' ] __all__ += [ @@ -30,6 +31,8 @@ from . import random from .random import manual_seed +from .framework import get_default_dtype +from .framework import set_default_dtype from ..fluid.framework import Variable #DEFINE_ALIAS from ..fluid.framework import ComplexVariable #DEFINE_ALIAS diff --git a/python/paddle/framework/framework.py b/python/paddle/framework/framework.py index 65654b59c0830..4d5b2c8e6fcb1 100644 --- a/python/paddle/framework/framework.py +++ b/python/paddle/framework/framework.py @@ -13,5 +13,46 @@ # limitations under the License. # TODO: define framework api -# __all__ = ['set_default_dtype', -# 'get_default_dtype'] +from paddle.fluid.layer_helper_base import LayerHelperBase +from paddle.fluid.data_feeder import convert_dtype + +__all__ = ['set_default_dtype', 'get_default_dtype'] + + +def set_default_dtype(d): + """ + Set default dtype. The default dtype is initially float32 + + Args: + d(string|np.dtype): the dtype to make the default + + Returns: + None. + + Examples: + .. code-block:: python + + import paddle + paddle.set_default_dtype("float32") + + """ + d = convert_dtype(d) + LayerHelperBase.set_default_dtype(d) + + +def get_default_dtype(): + """ + Get the current default dtype. The default dtype is initially float32 + + Args: + None. + Returns: + The default dtype. + + Examples: + .. code-block:: python + + import paddle + paddle.get_default_dtype() + """ + return LayerHelperBase.get_default_dtype() diff --git a/python/paddle/incubate/hapi/tests/test_metrics.py b/python/paddle/incubate/hapi/tests/test_metrics.py index 3d25a275d5f1c..19c94b73f61a2 100644 --- a/python/paddle/incubate/hapi/tests/test_metrics.py +++ b/python/paddle/incubate/hapi/tests/test_metrics.py @@ -40,7 +40,8 @@ def accuracy(pred, label, topk=(1, )): def convert_to_one_hot(y, C): - oh = np.random.random((y.shape[0], C)).astype('float32') * .5 + oh = np.random.choice(np.arange(C), C, replace=False).astype('float32') / C + oh = np.tile(oh[np.newaxis, :], (y.shape[0], 1)) for i in range(y.shape[0]): oh[i, int(y[i])] = 1. return oh diff --git a/python/paddle/io/__init__.py b/python/paddle/io/__init__.py index 875f3ff2e9155..89bbd5916578b 100644 --- a/python/paddle/io/__init__.py +++ b/python/paddle/io/__init__.py @@ -20,6 +20,9 @@ # 'Transform', 'DataLoader', 'get_worker_info', + 'Sampler', + 'SequenceSampler', + 'RandomSampler', 'load', 'save', 'load_program_state', @@ -38,7 +41,8 @@ ] from ..fluid.io import DataLoader -from ..fluid.dataloader import Dataset, IterableDataset, BatchSampler, get_worker_info +from ..fluid.dataloader import Dataset, IterableDataset, BatchSampler, get_worker_info, \ + Sampler, SequenceSampler, RandomSampler from ..fluid.io import load, save, load_program_state, set_program_state, \ load_inference_model, save_inference_model, batch from ..reader import shuffle, buffered, cache, chain, firstn, compose, map_readers, xmap_readers diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index a52d45521fd1b..7b6dcdf7f67de 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -51,11 +51,14 @@ from .decode import gather_tree #DEFINE_ALIAS from .input import data #DEFINE_ALIAS # from .input import Input #DEFINE_ALIAS +from .layer.activation import ELU +from .layer.activation import GELU from .layer.activation import Hardshrink # from .layer.activation import PReLU #DEFINE_ALIAS -from .layer.activation import ReLU #DEFINE_ALIAS +from .layer.activation import ReLU from .layer.activation import LeakyReLU #DEFINE_ALIAS from .layer.activation import Sigmoid #DEFINE_ALIAS +from .layer.activation import LogSigmoid # from .layer.activation import Softmax #DEFINE_ALIAS from .layer.activation import LogSoftmax #DEFINE_ALIAS from .layer.activation import HSigmoid #DEFINE_ALIAS @@ -89,6 +92,7 @@ from .layer.loss import KLDivLoss #DEFINE_ALIAS from .layer.loss import MarginRankingLoss #DEFINE_ALIAS from .layer.norm import BatchNorm #DEFINE_ALIAS +from .layer.norm import SyncBatchNorm #DEFINE_ALIAS from .layer.norm import GroupNorm #DEFINE_ALIAS from .layer.norm import LayerNorm #DEFINE_ALIAS from .layer.norm import SpectralNorm #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index fa85b19426cd2..bc71b8bdf06d2 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -150,6 +150,7 @@ from .norm import l2_normalize #DEFINE_ALIAS # from .norm import layer_norm #DEFINE_ALIAS from .norm import lrn #DEFINE_ALIAS +from .norm import normalize #DEFINE_ALIAS # from .norm import spectral_norm #DEFINE_ALIAS from .pooling import pool2d #DEFINE_ALIAS from .pooling import pool3d #DEFINE_ALIAS diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index b02aa769dffc9..2d4f121b1d6bb 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -14,12 +14,9 @@ # TODO: define activation functions of neural network from ...fluid.layers import brelu #DEFINE_ALIAS -from ...fluid.layers import elu #DEFINE_ALIAS from ...fluid.layers import erf #DEFINE_ALIAS -from ...fluid.layers import gelu #DEFINE_ALIAS from ...fluid.layers import hard_sigmoid #DEFINE_ALIAS from ...fluid.layers import hard_swish #DEFINE_ALIAS -from ...fluid.layers import logsigmoid #DEFINE_ALIAS from ...fluid.layers import maxout #DEFINE_ALIAS from ...fluid.layers import relu6 #DEFINE_ALIAS from ...fluid.layers import selu #DEFINE_ALIAS @@ -64,10 +61,112 @@ from ...fluid.layer_helper import LayerHelper from ...fluid.framework import in_dygraph_mode, convert_np_dtype_to_dtype_ from ...fluid import core -from ...fluid.data_feeder import check_variable_and_dtype +from ...fluid.data_feeder import check_variable_and_dtype, check_dtype import paddle +def elu(x, alpha=1.0, name=None): + """ + elu activation. + + .. math:: + + elu(x) = max(0, x) + min(0, \\alpha * (e^{x}-1)) + + Parameters: + x (Tensor): The input Tensor with data type float32, float64. + alpha (float, optional): The 'alpha' value of the ELU formulation. Default is 1.0. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A Tensor with the same data type and shape as ``x`` . + + Examples: + .. code-block:: python + + import paddle + import paddle.nn.functional as F + import numpy as np + + paddle.disable_static() + + x = paddle.to_tensor(np.array([[-1,6],[1,15.6]])) + out = F.elu(x, alpha=0.2) + # [[-0.12642411 6. ] + # [ 1. 15.6 ]] + """ + + if in_dygraph_mode(): + return core.ops.elu(x, 'alpha', alpha) + + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'elu') + helper = LayerHelper("elu", **locals()) + out = helper.create_variable_for_type_inference(x.dtype) + helper.append_op( + type='elu', + inputs={'X': x}, + outputs={'Out': out}, + attrs={'alpha': alpha}) + return out + + +def gelu(x, approximate=False, name=None): + """ + gelu activation. + + if approximate is True + .. math:: + gelu(x) = 0.5 * x * (1 + tanh(\\sqrt{\\frac{2}{\\pi}} * (x + 0.044715x^{3}))) + else + .. math:: + gelu(x) = 0.5 * x * (1 + erf(\\frac{x}{\\sqrt{2}})) + + Parameters: + x (Tensor): The input Tensor with data type float32, float64. + approximate (bool, optional): Wether to enable approximation. Default is False. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A Tensor with the same data type and shape as ``x`` . + + Examples: + .. code-block:: python + + import paddle + import paddle.nn.functional as F + import numpy as np + + paddle.disable_static() + + data = np.random.randn(2, 3).astype("float32") + x = paddle.to_tensor(data) + + out = F.gelu(x) + + data + # array([[ 0.87165993, -1.0541513 , -0.37214822], + # [ 0.15647964, 0.32496083, 0.33045998]], dtype=float32) + out + # array([[ 0.70456535, -0.15380788, -0.13207214], + # [ 0.08796856, 0.20387867, 0.2080159 ]], dtype=float32) + """ + + if in_dygraph_mode(): + return core.ops.gelu(x, 'approximate', approximate) + + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'gelu') + helper = LayerHelper("gelu", **locals()) + out = helper.create_variable_for_type_inference(x.dtype) + helper.append_op( + type='gelu', + inputs={'X': x}, + outputs={'Out': out}, + attrs={'approximate': approximate}) + return out + + def hardshrink(x, threshold=0.5, name=None): """ hard shrinkage activation @@ -296,11 +395,8 @@ def leaky_relu(x, negative_slope=0.01, name=None): return out -def relu(input, inplace=False, name=None): +def relu(x, name=None): """ - :alias_main: paddle.nn.functional.relu - :alias: paddle.nn.functional.relu,paddle.nn.functional.activation.relu - ReLU Activation. .. math: @@ -308,44 +404,74 @@ def relu(input, inplace=False, name=None): out = max(x, 0) Parameters: - input (Variable): The input variable. A multi-dimension Tensor with type float16, float32, or float64. - inplace (bool, optional): If inplace is True, the input and output of ``ReLU`` are the same variable. - Otherwise, the input and output of ``ReLU`` are different variables. Default: False. Note that if x is - more than one OPs' input, inplace must be False. - name (str, optional): The default value is None. Normally there is no need for user to set this property. - For more information, please refer to :ref:`api_guide_Name` . + x (Tensor): The input Tensor with data type float32, float64. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. Returns: - Output of relu operator, a Tensor with shape same as input + A Tensor with the same data type and shape as ``x`` . Examples: .. code-block:: python - import paddle.fluid as fluid - import paddle.nn.functional as functional - import numpy as np + import paddle + import paddle.nn.functional as F + import numpy as np + + paddle.disable_static() - data = np.array([-2, 0, 1]).astype('float32') - with fluid.dygraph.guard(): - data = fluid.dygraph.to_variable(data) - res = functional.relu(data) # [0, 0, 1] + x = paddle.to_tensor(np.array([-2, 0, 1]).astype('float32')) + out = F.relu(x) # [0., 0., 1.] """ if in_dygraph_mode(): - if inplace: - warnings.warn( - "Inplace on ReLU is not allowed and will be discarded in dygraph mode currently." - ) - return core.ops.relu(input) - - check_variable_and_dtype(input, 'input', ['float16', 'float32', 'float64'], - 'relu') + return core.ops.relu(x) + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'relu') helper = LayerHelper('relu', **locals()) - outs = input if inplace else helper.create_variable_for_type_inference( - input.dtype) - helper.append_op(type='relu', inputs={'X': [input]}, outputs={'Out': outs}) - return outs + out = helper.create_variable_for_type_inference(x.dtype) + helper.append_op(type='relu', inputs={'X': x}, outputs={'Out': out}) + return out + + +def logsigmoid(x, name=None): + """ + logsigmoid activation. + + .. math: + + logsigmoid(x) = \log \frac{1}{1 + e^{-x}} + + Parameters: + x (Tensor): The input Tensor with data type float32, float64. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A Tensor with the same data type and shape as ``x`` . + + Examples: + .. code-block:: python + + import paddle + import paddle.nn.functional as F + import numpy as np + + paddle.disable_static() + + x = paddle.to_tensor(np.array([1.0, 2.0, 3.0, 4.0])) + out = F.logsigmoid(x) # [0.7310586, 0.880797, 0.95257413, 0.98201376] + """ + + if in_dygraph_mode(): + return core.ops.logsigmoid(x) + + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], + 'logsigmoid') + helper = LayerHelper("logsigmoid", **locals()) + out = helper.create_variable_for_type_inference(x.dtype) + helper.append_op(type='logsigmoid', inputs={'X': x}, outputs={'Out': out}) + return out def softmax(x, axis=-1, name=None): @@ -464,12 +590,10 @@ def softmax(x, axis=-1, name=None): return paddle.fluid.layers.softmax(input=x, axis=axis, name=name) -def log_softmax(input, axis=None, dtype=None, name=None): +def log_softmax(x, axis=-1, dtype=None, name=None): """ - :alias_main: paddle.nn.functional.log_softmax - :alias: paddle.nn.functional.log_softmax,paddle.nn.functional.activation.log_softmax - - This operator implements the log_softmax layer. The calculation process is as follows: + This operator implements the log_softmax layer. The calculation process is + as follows: .. math:: @@ -477,78 +601,85 @@ def log_softmax(input, axis=None, dtype=None, name=None): = log(\\frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])}) Parameters: - input (Variable): The input variable. A multi-dimension Tensor with type float32, or float64. - axis (int, optional): The index of dimension to perform softmax calculations, it should be in - range :math:`[-1, rank-1]`, while :math:`rank` is the rank of input variable. Default: None. - None and -1 means the last dimension. - dtype (np.dtype|core.VarDesc.VarType|str): The desired data type of returned tensor. If specified, - the input tensor is casted to dtype before the operation is performed. This is useful for - preventing data type overflows. Default: None. Supported dtype: float32 or float64 - name (str, optional): The default value is None. Normally there is no need for user to set this property. - For more information, please refer to :ref:`api_guide_Name` . + x (Tensor): The input Tensor with data type float32, float64. + axis (int, optional): The axis along which to perform log_softmax + calculations. It should be in range [-D, D), where D is the + dimensions of ``x`` . If ``axis`` < 0, it works the same way as + :math:`axis + D` . Default is -1. + dtype (str|np.dtype|core.VarDesc.VarType, optional): The desired data + type of the output tensor. If dtype is specified, ``x`` is casted + to ``dtype`` before the operation is performed. This is useful for + preventing data type overflows. Supported dtype: float32, float64. + If ``dtype`` is None, the output Tensor has the same dtype as x. + Default is None. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. Returns: - Variable: ``Tensor`` indicates the output of softmax. The data type and shape are the same as ``input``. + A Tensor with the same shape and data type (use ``dtype`` if it is + specified) as x. Examples: .. code-block:: python - import paddle.fluid as fluid - import paddle.nn.functional as F - import numpy as np - - data = np.array([[[-2.0, 3.0, -4.0, 5.0], - [3.0, -4.0, 5.0, -6.0], - [-7.0, -8.0, 8.0, 9.0]], - [[1.0, -2.0, -3.0, 4.0], - [-5.0, 6.0, 7.0, -8.0], - [6.0, 7.0, 8.0, 9.0]]]).astype('float32') - with fluid.dygraph.guard(): - data = fluid.dygraph.to_variable(data) - res = F.log_softmax(data, -1) - # [[[ -7.1278396 -2.1278396 -9.127839 -0.12783948] - # [ -2.1270514 -9.127051 -0.12705144 -11.127051 ] - # [-16.313261 -17.313261 -1.3132617 -0.31326184]] - # [[ -3.0518122 -6.051812 -7.051812 -0.051812 ] - # [-12.313267 -1.3132664 -0.3132665 -15.313267 ] - # [ -3.4401896 -2.4401896 -1.4401896 -0.44018966]]] + import paddle + import paddle.nn.functional as F + import numpy as np + + paddle.disable_static() + + x = np.array([[[-2.0, 3.0, -4.0, 5.0], + [3.0, -4.0, 5.0, -6.0], + [-7.0, -8.0, 8.0, 9.0]], + [[1.0, -2.0, -3.0, 4.0], + [-5.0, 6.0, 7.0, -8.0], + [6.0, 7.0, 8.0, 9.0]]], 'float32') + x = paddle.to_tensor(x) + out1 = F.log_softmax(x) + out2 = F.log_softmax(x, dtype='float64') + # out1's data type is float32; out2's data type is float64 + # out1 and out2's value is as follows: + # [[[ -7.1278396 -2.1278396 -9.127839 -0.12783948] + # [ -2.1270514 -9.127051 -0.12705144 -11.127051 ] + # [-16.313261 -17.313261 -1.3132617 -0.31326184]] + # [[ -3.0518122 -6.051812 -7.051812 -0.051812 ] + # [-12.313267 -1.3132664 -0.3132665 -15.313267 ] + # [ -3.4401896 -2.4401896 -1.4401896 -0.44018966]]] """ - axis = -1 if axis is None else axis - dtype = convert_np_dtype_to_dtype_(dtype) if dtype is not None else dtype + if axis is None: + axis = -1 + if (dtype is not None) and (not isinstance(dtype, core.VarDesc.VarType)): + dtype = convert_np_dtype_to_dtype_(dtype) if in_dygraph_mode(): - outs_cast = input if dtype is None \ - else core.ops.cast(input, 'in_dtype', input.dtype, 'out_dtype', dtype) - outs_softmax = core.ops.softmax(outs_cast, 'axis', axis, 'use_cudnn', - False) - return core.ops.log(outs_softmax) + if dtype is not None: + x = core.ops.cast(x, 'in_dtype', x.dtype, 'out_dtype', dtype) + return core.ops.log_softmax(x, 'axis', axis) if dtype is None: - check_variable_and_dtype( - input, 'input', ['float16', 'float32', 'float64'], 'log_softmax') + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], + 'log_softmax') + else: + check_dtype(dtype, 'dtype', ['float32', 'float64'], 'log_softmax', + 'If dtype is not None, it only support float32 or float64.') helper = LayerHelper("log_softmax", **locals()) - outs_cast = input + out_cast = x if dtype is not None: - outs_cast = helper.create_variable_for_type_inference(dtype) + out_cast = helper.create_variable_for_type_inference(dtype) helper.append_op( type='cast', - inputs={'X': input}, - outputs={'Out': outs_cast}, - attrs={'in_dtype': input.dtype, + inputs={'X': x}, + outputs={'Out': out_cast}, + attrs={'in_dtype': x.dtype, 'out_dtype': dtype}) - outs_softmax = helper.create_variable_for_type_inference(outs_cast.dtype) + out = helper.create_variable_for_type_inference(out_cast.dtype) helper.append_op( - type='softmax', - inputs={'X': outs_cast}, - outputs={'Out': outs_softmax}, - attrs={'axis': axis, - 'use_cudnn': False}) - - outs_log = helper.create_variable_for_type_inference(outs_softmax.dtype) - helper.append_op( - type='log', inputs={'X': outs_softmax}, outputs={'Out': outs_log}) + type='log_softmax', + inputs={'X': out_cast}, + outputs={'Out': out}, + attrs={'axis': axis}) - return outs_log + return out diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index e08c707b8daa6..f8bc0b1b54e96 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -176,61 +176,61 @@ def margin_ranking_loss(input, return result_out -def l1_loss(x, label, reduction='mean', name=None): +def l1_loss(input, label, reduction='mean', name=None): """ - This operator computes the L1 Loss of Tensor ``x`` and ``label`` as follows. + This operator computes the L1 Loss of Tensor ``input`` and ``label`` as follows. - If :attr:`reduction` set to ``'none'``, the loss is: + If `reduction` set to ``'none'``, the loss is: .. math:: - Out = \lvert x - label\rvert + Out = \lvert input - label\rvert - If :attr:`reduction` set to ``'mean'``, the loss is: + If `reduction` set to ``'mean'``, the loss is: .. math:: - Out = MEAN(\lvert x - label\rvert) + Out = MEAN(\lvert input - label\rvert) - If :attr:`reduction` set to ``'sum'``, the loss is: + If `reduction` set to ``'sum'``, the loss is: .. math:: - Out = SUM(\lvert x - label\rvert) + Out = SUM(\lvert input - label\rvert) Parameters: - x (Tensor): The input tensor. The shapes is [N, *], where N is batch size and `*` means any number of additional dimensions. It's data type should be float32, float64, int32, int64. - label (Tensor): label. The shapes is [N, *], same shape as ``x`` . It's data type should be float32, float64, int32, int64. + input (Tensor): The input tensor. The shapes is [N, *], where N is batch size and `*` means any number of additional dimensions. It's data type should be float32, float64, int32, int64. + label (Tensor): label. The shapes is [N, *], same shape as ``input`` . It's data type should be float32, float64, int32, int64. reduction (str, optional): Indicate the reduction to apply to the loss, the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. - If :attr:`reduction` is ``'none'``, the unreduced loss is returned; - If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned. - If :attr:`reduction` is ``'sum'``, the reduced sum loss is returned. + If `reduction` is ``'none'``, the unreduced loss is returned; + If `reduction` is ``'mean'``, the reduced mean loss is returned. + If `reduction` is ``'sum'``, the reduced sum loss is returned. Default is ``'mean'``. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - Tensor, the L1 Loss of Tensor ``x`` and ``label``. - If :attr:`reduction` is ``'none'``, the shape of output loss is [N, *], the same as ``x`` . - If :attr:`reduction` is ``'mean'`` or ``'sum'``, the shape of output loss is [1], which means the output is a scalar. + Tensor, the L1 Loss of Tensor ``input`` and ``label``. + If `reduction` is ``'none'``, the shape of output loss is [N, *], the same as ``input`` . + If `reduction` is ``'mean'`` or ``'sum'``, the shape of output loss is [1]. Examples: .. code-block:: python import paddle import numpy as np paddle.disable_static() - x_data = np.array([[1.5, 0.8], [0.2, 1.3]]).astype("float32") + input_data = np.array([[1.5, 0.8], [0.2, 1.3]]).astype("float32") label_data = np.array([[1.7, 1], [0.4, 0.5]]).astype("float32") - x = paddle.to_variable(x_data) + input = paddle.to_variable(input_data) label = paddle.to_variable(label_data) - l1_loss = paddle.nn.functional.l1_loss(x, label) + l1_loss = paddle.nn.functional.l1_loss(input, label) print(l1_loss.numpy()) # [0.35] - l1_loss = paddle.nn.functional.l1_loss(x, label, reduction='none') + l1_loss = paddle.nn.functional.l1_loss(input, label, reduction='none') print(l1_loss.numpy()) # [[0.20000005 0.19999999] # [0.2 0.79999995]] - l1_loss = paddle.nn.functional.l1_loss(x, label, reduction='sum') + l1_loss = paddle.nn.functional.l1_loss(input, label, reduction='sum') print(l1_loss.numpy()) # [1.4] """ @@ -241,7 +241,7 @@ def l1_loss(x, label, reduction='mean', name=None): if in_dygraph_mode(): unreduced = _elementwise_op_in_dygraph( - x, label, axis=-1, act='abs', op_name='elementwise_sub') + input, label, axis=-1, act='abs', op_name='elementwise_sub') if reduction == 'mean': return core.ops.mean(unreduced) elif reduction == 'sum': @@ -251,18 +251,18 @@ def l1_loss(x, label, reduction='mean', name=None): return unreduced fluid.data_feeder.check_variable_and_dtype( - x, 'x', ['float32', 'float64', 'int32', 'int64'], 'l1_loss') + input, 'input', ['float32', 'float64', 'int32', 'int64'], 'l1_loss') fluid.data_feeder.check_variable_and_dtype( label, 'label', ['float32', 'float64', 'int32', 'int64'], 'l1_loss') if reduction == 'sum': - unreduced = paddle.elementwise_sub(x, label, act='abs') + unreduced = paddle.elementwise_sub(input, label, act='abs') return paddle.sum(unreduced, name=name) elif reduction == 'mean': - unreduced = paddle.elementwise_sub(x, label, act='abs') + unreduced = paddle.elementwise_sub(input, label, act='abs') return paddle.mean(unreduced, name=name) else: - return paddle.elementwise_sub(x, label, act='abs', name=name) + return paddle.elementwise_sub(input, label, act='abs', name=name) def nll_loss(input, diff --git a/python/paddle/nn/functional/norm.py b/python/paddle/nn/functional/norm.py index 04b031b91ce38..0b007041b4ab3 100644 --- a/python/paddle/nn/functional/norm.py +++ b/python/paddle/nn/functional/norm.py @@ -13,6 +13,11 @@ # limitations under the License. # TODO: define normalization api +import paddle +import paddle.fluid as fluid +from ...fluid.data_feeder import check_variable_and_dtype, check_type +from ...fluid.layer_helper import LayerHelper +from ...fluid.framework import in_dygraph_mode, core from ...fluid.layers import l2_normalize #DEFINE_ALIAS from ...fluid.layers import lrn #DEFINE_ALIAS @@ -24,5 +29,84 @@ 'l2_normalize', # 'layer_norm', 'lrn', + 'normalize', # 'spectral_norm' ] + + +def normalize(x, p=2, axis=1, epsilon=1e-12, name=None): + """ + This op normalizes ``x`` along dimension ``axis`` using :math:`L_p` norm. This layer computes + + .. math:: + + y = \frac{x}{ \max\left( \lvert \lvert x \rvert \rvert_p, epsilon\right) } + + .. math:: + \lvert \lvert x \rvert \rvert_p = \left(\sum_i {\lvert x_i\rvert^p} \right)^{1/p} + + where, :math:`\sum_i{\lvert x_i\rvert^p}` is calculated along the ``axis`` dimension. + + + Args: + x (Tensor): The input tensor could be N-D tensor, and the input data type could be float32 or float64. + p (float|int, optional): The exponent value in the norm formulation. Default: 2 + axis (int, optional): The axis on which to apply normalization. If ``x`` is 1-D tensor, ``axis`` is fixed to 0. If `axis < 0`, \ + the dimension to normalization is `x.ndim + axis`. -1 is the last dimension. + epsilon (float, optional): Small float added to denominator to avoid dividing by zero. Default is 1e-12. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor, the output has the same shape and data type with ``x``. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + import paddle.nn.functional as F + + paddle.disable_static() + x = np.arange(6, dtype=np.float32).reshape(2,3) + x = paddle.to_variable(x) + y = F.normalize(x) + print(y.numpy()) + # [[0. 0.4472136 0.8944272 ] + # [0.42426404 0.5656854 0.7071067 ]] + + y = F.normalize(x, p=1.5) + print(y.numpy()) + # [[0. 0.40862012 0.81724024] + # [0.35684016 0.4757869 0.5947336 ]] + + y = F.normalize(x, axis=0) + print(y.numpy()) + # [[0. 0.24253564 0.37139067] + # [1. 0.97014254 0.9284767 ]] + """ + if len(x.shape) == 1: + axis = 0 + if in_dygraph_mode(): + eps = fluid.dygraph.base.to_variable([epsilon], dtype=x.dtype) + out = core.ops.p_norm(x, 'axis', axis, 'porder', + float(p), 'keepdim', True, 'epsilon', epsilon) + return x / core.ops.elementwise_max(out, eps) + + check_type(p, 'p', (float, int), 'normalize') + check_type(axis, 'axis', (int), 'normalize') + check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'normalize') + + attrs = { + 'axis': axis, + 'porder': float(p), + 'keepdim': True, + 'epsilon': epsilon, + } + helper = LayerHelper('p_norm', **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type='p_norm', inputs={'X': x}, outputs={'Out': out}, attrs=attrs) + eps = out.block.create_var(dtype=out.dtype) + paddle.fill_constant([1], out.dtype, epsilon, out=eps) + return paddle.elementwise_div(x, paddle.maximum(out, eps), name=name) diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index 9fb8ea78a16ab..f64252da5428a 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -65,6 +65,7 @@ from .loss import KLDivLoss #DEFINE_ALIAS from .loss import MarginRankingLoss #DEFINE_ALIAS from .norm import BatchNorm #DEFINE_ALIAS +from .norm import SyncBatchNorm #DEFINE_ALIAS from .norm import GroupNorm #DEFINE_ALIAS from .norm import LayerNorm #DEFINE_ALIAS from .norm import SpectralNorm #DEFINE_ALIAS diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index d37cb97094a52..6716813221841 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -15,12 +15,15 @@ # TODO: define activation functions of neural network __all__ = [ + 'ELU', + 'GELU', 'Hardshrink', # 'PReLU', 'ReLU', 'LeakyReLU', 'Sigmoid', # 'Softmax', + 'LogSigmoid', 'LogSoftmax', 'HSigmoid' ] @@ -31,6 +34,103 @@ from .. import functional as F +class ELU(layers.Layer): + """ + ELU Activation. + + .. math:: + + ELU(x) = max(0, x) + min(0, \\alpha * (e^{x}-1)) + + Parameters: + alpha (float, optional): The 'alpha' value of the ELU formulation. Default is 1.0. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Shape: + - input: Tensor with any shape. + - output: Tensor with the same shape as input. + + Examples: + .. code-block:: python + + import paddle + import numpy as np + + paddle.disable_static() + + x = paddle.to_tensor(np.array([[-1,6],[1,15.6]])) + m = paddle.nn.ELU(0.2) + out = m(x) + # [[-0.12642411 6. ] + # [ 1. 15.6 ]] + """ + + def __init__(self, alpha=1.0, name=None): + super(ELU, self).__init__() + self._alpha = alpha + self._name = name + + def forward(self, x): + return F.elu(x, self._alpha, self._name) + + +class GELU(layers.Layer): + """ + GELU Activation. + + If approximate is True + + .. math:: + + GELU(x) = 0.5 * x * (1 + tanh(\\sqrt{\\frac{2}{\\pi}} * (x + 0.044715x^{3}))) + + else + + .. math:: + + GELU(x) = 0.5 * x * (1 + erf(\\frac{x}{\\sqrt{2}})) + + Parameters: + approximate (bool, optional): Wether to enable approximation. Default is False. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Shape: + - input: Tensor with any shape. + - output: Tensor with the same shape as input. + + Examples: + .. code-block:: python + + import paddle + import numpy as np + + paddle.disable_static() + + data = np.random.randn(2, 3).astype("float32") + x = paddle.to_tensor(data) + + m = paddle.nn.GELU() + out = m(x) + + data + # array([[ 0.87165993, -1.0541513 , -0.37214822], + # [ 0.15647964, 0.32496083, 0.33045998]], dtype=float32) + out + # array([[ 0.70456535, -0.15380788, -0.13207214], + # [ 0.08796856, 0.20387867, 0.2080159 ]], dtype=float32) + """ + + def __init__(self, approximate=False, name=None): + super(GELU, self).__init__() + self._approximate = approximate + self._name = name + + def forward(self, x): + return F.gelu(x, self._approximate, self._name) + + class Hardshrink(layers.Layer): """ Hardshrink Activation @@ -216,44 +316,39 @@ def forward(self, input, label, path_table=None, path_code=None): class ReLU(layers.Layer): """ - :alias_main: paddle.nn.ReLU - :alias: paddle.nn.ReLU,paddle.nn.layer.ReLU,paddle.nn.layer.activation.ReLU - ReLU Activation. .. math: - out = max(x, 0) + ReLU(x) = max(x, 0) Parameters: - inplace (bool, optional): If inplace is True, the input and output of - ``ReLU`` are the same variable. Otherwise, the input and output of - ``ReLU`` are different variables. Default False. Note that if x is - more than one OPs' input, inplace must be False. - - Returns: - None + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Shape: + - input: Tensor with any shape. + - output: Tensor with the same shape as input. Examples: .. code-block:: python - import paddle.fluid as fluid - import paddle.nn as nn - import numpy as np + import paddle + import numpy as np - data = np.array([-2, 0, 1]).astype('float32') - my_relu = nn.ReLU() - with fluid.dygraph.guard(): - data = fluid.dygraph.to_variable(data) - res = my_relu(data) # [0, 0, 1] + paddle.disable_static() + + x = paddle.to_tensor(np.array([-2, 0, 1]).astype('float32')) + m = paddle.nn.ReLU() + out = m(x) # [0., 0., 1.] """ - def __init__(self, inplace=False): + def __init__(self, name=None): super(ReLU, self).__init__() - self._inplace = inplace + self._name = name - def forward(self, input): - return F.relu(input, self._inplace) + def forward(self, x): + return F.relu(x, self._name) class LeakyReLU(layers.Layer): @@ -342,11 +437,46 @@ def forward(self, x): return F.sigmoid(x, self.name) -class LogSoftmax(layers.Layer): +class LogSigmoid(layers.Layer): """ - :alias_main: paddle.nn.LogSoftmax - :alias: paddle.nn.LogSoftmax,paddle.nn.layer.LogSoftmax,paddle.nn.layer.activation.LogSoftmax + LogSigmoid Activation. + + .. math: + LogSigmoid(x) = \log \frac{1}{1 + e^{-x}} + + Parameters: + x (Tensor): The input Tensor with data type float32, or float64. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Shape: + - input: Tensor with any shape. + - output: Tensor with the same shape as input. + + Examples: + .. code-block:: python + + import paddle + import numpy as np + + paddle.disable_static() + + x = paddle.to_tensor(np.array([1.0, 2.0, 3.0, 4.0])) + m = paddle.nn.LogSigmoid() + out = m(x) # [0.7310586, 0.880797, 0.95257413, 0.98201376] + """ + + def __init__(self, name=None): + super(LogSigmoid, self).__init__() + self._name = name + + def forward(self, x): + return F.logsigmoid(x, self._name) + + +class LogSoftmax(layers.Layer): + """ This operator implements the log_softmax layer. The calculation process is as follows: .. math:: @@ -355,44 +485,46 @@ class LogSoftmax(layers.Layer): = log(\\frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])}) Parameters: - axis (int, optional): The index of dimension to perform softmax calculations, it should be in - range :math:`[-1, rank-1]`, while :math:`rank` is the rank of input variable. Default: None. - None and -1 means the last dimension. - dtype (np.dtype|core.VarDesc.VarType|str): The desired data type of returned tensor. If specified, - the input tensor is casted to dtype before the operation is performed. This is useful for - preventing data type overflows. Default: None. Supported dtype: float32 or float64 + axis (int, optional): The axis along which to perform log_softmax + calculations. It should be in range [-D, D), where D is the + dimensions of the input Tensor . If ``axis`` < 0, it works the + same way as :math:`axis + D` . Default is -1. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. - Returns: - None + Shape: + - input: Tensor with any shape. + - output: Tensor with the same shape as input. Examples: .. code-block:: python - import paddle.fluid as fluid - import paddle.nn as nn - import numpy as np + import paddle + import numpy as np - data = np.array([[[-2.0, 3.0, -4.0, 5.0], - [3.0, -4.0, 5.0, -6.0], - [-7.0, -8.0, 8.0, 9.0]], - [[1.0, -2.0, -3.0, 4.0], - [-5.0, 6.0, 7.0, -8.0], - [6.0, 7.0, 8.0, 9.0]]]).astype('float32') - my_log_softnmax = nn.LogSoftmax() - with fluid.dygraph.guard(): - data = fluid.dygraph.to_variable(data) - res = my_log_softnmax(data) - # [[[ -7.1278396 -2.1278396 -9.127839 -0.12783948] - # [ -2.1270514 -9.127051 -0.12705144 -11.127051 ] - # [-16.313261 -17.313261 -1.3132617 -0.31326184]] - # [[ -3.0518122 -6.051812 -7.051812 -0.051812 ] - # [-12.313267 -1.3132664 -0.3132665 -15.313267 ] - # [ -3.4401896 -2.4401896 -1.4401896 -0.44018966]]] + paddle.disable_static() + + x = np.array([[[-2.0, 3.0, -4.0, 5.0], + [3.0, -4.0, 5.0, -6.0], + [-7.0, -8.0, 8.0, 9.0]], + [[1.0, -2.0, -3.0, 4.0], + [-5.0, 6.0, 7.0, -8.0], + [6.0, 7.0, 8.0, 9.0]]]) + m = paddle.nn.LogSoftmax() + x = paddle.to_tensor(x) + out = m(x) + # [[[ -7.1278396 -2.1278396 -9.127839 -0.12783948] + # [ -2.1270514 -9.127051 -0.12705144 -11.127051 ] + # [-16.313261 -17.313261 -1.3132617 -0.31326184]] + # [[ -3.0518122 -6.051812 -7.051812 -0.051812 ] + # [-12.313267 -1.3132664 -0.3132665 -15.313267 ] + # [ -3.4401896 -2.4401896 -1.4401896 -0.44018966]]] """ - def __init__(self, axis=None): + def __init__(self, axis=-1, name=None): super(LogSoftmax, self).__init__() self._axis = axis + self._name = name - def forward(self, input): - return F.log_softmax(input, self._axis) + def forward(self, x): + return F.log_softmax(x, self._axis) diff --git a/python/paddle/nn/layer/loss.py b/python/paddle/nn/layer/loss.py index bc4f32f9c3186..5067264ee792d 100644 --- a/python/paddle/nn/layer/loss.py +++ b/python/paddle/nn/layer/loss.py @@ -256,39 +256,39 @@ def forward(self, input, label): class L1Loss(fluid.dygraph.Layer): """ This interface is used to construct a callable object of the ``L1Loss`` class. - The L1Loss layer calculates the L1 Loss of ``x`` and ``label`` as follows. + The L1Loss layer calculates the L1 Loss of ``input`` and ``label`` as follows. - If :attr:`reduction` set to ``'none'``, the loss is: + If `reduction` set to ``'none'``, the loss is: .. math:: - Out = \lvert x - label\rvert + Out = \lvert input - label\rvert - If :attr:`reduction` set to ``'mean'``, the loss is: + If `reduction` set to ``'mean'``, the loss is: .. math:: - Out = MEAN(\lvert x - label\rvert) + Out = MEAN(\lvert input - label\rvert) - If :attr:`reduction` set to ``'sum'``, the loss is: + If `reduction` set to ``'sum'``, the loss is: .. math:: - Out = SUM(\lvert x - label\rvert) + Out = SUM(\lvert input - label\rvert) Parameters: reduction (str, optional): Indicate the reduction to apply to the loss, the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. - If :attr:`reduction` is ``'none'``, the unreduced loss is returned; - If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned. - If :attr:`reduction` is ``'sum'``, the reduced sum loss is returned. + If `reduction` is ``'none'``, the unreduced loss is returned; + If `reduction` is ``'mean'``, the reduced mean loss is returned. + If `reduction` is ``'sum'``, the reduced sum loss is returned. Default is ``'mean'``. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Shape: - x (Tensor): The input tensor. The shapes is [N, *], where N is batch size and `*` means any number of additional dimensions. It's data type should be float32, float64, int32, int64. - label (Tensor): label. The shapes is [N, *], same shape as ``x`` . It's data type should be float32, float64, int32, int64. - output (Tensor): The L1 Loss of ``x`` and ``label``. - If :attr:`reduction` is ``'none'``, the shape of output loss is [N, *], the same as ``x`` . - If :attr:`reduction` is ``'mean'`` or ``'sum'``, the shape of output loss is [1], which means the output is a scalar. + input (Tensor): The input tensor. The shapes is [N, *], where N is batch size and `*` means any number of additional dimensions. It's data type should be float32, float64, int32, int64. + label (Tensor): label. The shapes is [N, *], same shape as ``input`` . It's data type should be float32, float64, int32, int64. + output (Tensor): The L1 Loss of ``input`` and ``label``. + If `reduction` is ``'none'``, the shape of output loss is [N, *], the same as ``input`` . + If `reduction` is ``'mean'`` or ``'sum'``, the shape of output loss is [1]. Examples: .. code-block:: python @@ -296,23 +296,23 @@ class L1Loss(fluid.dygraph.Layer): import numpy as np paddle.disable_static() - x_data = np.array([[1.5, 0.8], [0.2, 1.3]]).astype("float32") + input_data = np.array([[1.5, 0.8], [0.2, 1.3]]).astype("float32") label_data = np.array([[1.7, 1], [0.4, 0.5]]).astype("float32") - x = paddle.to_variable(x_data) + input = paddle.to_variable(input_data) label = paddle.to_variable(label_data) l1_loss = paddle.nn.loss.L1Loss() - output = l1_loss(x, label) + output = l1_loss(input, label) print(output.numpy()) # [0.35] l1_loss = paddle.nn.loss.L1Loss(reduction='sum') - output = l1_loss(x, label) + output = l1_loss(input, label) print(output.numpy()) # [1.4] l1_loss = paddle.nn.loss.L1Loss(reduction='none') - output = l1_loss(x, label) + output = l1_loss(input, label) print(output.numpy()) # [[0.20000005 0.19999999] # [0.2 0.79999995]] @@ -327,9 +327,9 @@ def __init__(self, reduction='mean', name=None): self.reduction = reduction self.name = name - def forward(self, x, label): + def forward(self, input, label): return paddle.nn.functional.l1_loss( - x, label, self.reduction, name=self.name) + input, label, self.reduction, name=self.name) class BCELoss(fluid.dygraph.Layer): diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index 1beba62c1809f..1d00f9c7b8b02 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -20,7 +20,9 @@ from ...fluid.dygraph import GroupNorm #DEFINE_ALIAS from ...fluid.dygraph import LayerNorm #DEFINE_ALIAS from ...fluid.dygraph import SpectralNorm #DEFINE_ALIAS +from ...fluid.dygraph import SyncBatchNorm #DEFINE_ALIAS __all__ = [ - 'BatchNorm', 'GroupNorm', 'LayerNorm', 'SpectralNorm', 'InstanceNorm' + 'BatchNorm', 'GroupNorm', 'LayerNorm', 'SpectralNorm', 'InstanceNorm', + 'SyncBatchNorm' ] diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index aa0d8c408899a..77d821d56b8e1 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -74,6 +74,7 @@ from .manipulation import cast #DEFINE_ALIAS from .manipulation import concat #DEFINE_ALIAS from .manipulation import expand #DEFINE_ALIAS +from .manipulation import broadcast_to #DEFINE_ALIAS from .manipulation import expand_as #DEFINE_ALIAS from .manipulation import tile #DEFINE_ALIAS from .manipulation import flatten #DEFINE_ALIAS @@ -98,6 +99,7 @@ from .manipulation import flip #DEFINE_ALIAS from .manipulation import unbind #DEFINE_ALIAS from .manipulation import roll #DEFINE_ALIAS +from .manipulation import chunk #DEFINE_ALIAS from .math import abs #DEFINE_ALIAS from .math import acos #DEFINE_ALIAS from .math import asin #DEFINE_ALIAS diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 9e2b7286ba677..2c8157645de29 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -23,7 +23,6 @@ import numpy as np # TODO: define functions to manipulate a tensor from ..fluid.layers import cast #DEFINE_ALIAS -from ..fluid.layers import expand_as #DEFINE_ALIAS from ..fluid.layers import reshape #DEFINE_ALIAS from ..fluid.layers import scatter #DEFINE_ALIAS from ..fluid.layers import slice #DEFINE_ALIAS @@ -44,6 +43,7 @@ 'cast', 'concat', 'expand', + 'broadcast_to', 'expand_as', 'flatten', 'gather', @@ -56,6 +56,7 @@ 'shard_index', 'slice', 'split', + 'chunk' 'squeeze', 'stack', 'strided_slice', @@ -789,82 +790,107 @@ def unbind(input, axis=0): return outs -def tile(x, repeat_times, name=None): +def chunk(x, chunks, axis=0, name=None): """ - :alias_main: paddle.tile - :alias: paddle.tile,paddle.tensor.tile,paddle.tensor.manipulation.tile - Construct a new tensor by repeating ``x`` the number of times given by the parameter ``repeat_times``. - The rank of ``x`` should be less than or equal to 6, and the size of the shape of ``repeat_times`` should - be less than or equal to 6. - If the size of the parameter ``repeat_times`` is ``d``, and the rank for ``x`` is ``r``, then the number - of dimensions for the result is ``max(d, r)``. - If ``r < d``, ``x`` if first promoted to a d-dimensional tensor by inserting new axes at the beginning. - For example, a tensor ``x`` with the shape(3,) is promoted to a 2-D tensor with the shape (1, 3) if ``d`` is 2 - and a 3-D tensor with the shape(1, 1, 3) if ``d`` is 3. - If ``r > d``, ``repeat_times`` is first promoted by inserting 1's at the beginning. - For example, if the tensor ``x`` is with a shape(4, 3, 2, 2) and ``repeat_times`` is a tuple (3, 2), - ``repeat_times`` is first promoted to a tuple (1, 1, 3, 2). - The following gives an using case: - .. code-block:: text - Input(x) is a 3-D tensor of shape (2, 3, 1): - [ - [[1], [2], [3]], - [[4], [5], [6]] - ] - Attr(repeat_times): [1, 2, 2] - Output(out) is a 3-D tensor of shape (2, 6, 2): - [ - [[1, 1], [2, 2], [3, 3], [1, 1], [2, 2], [3, 3]], - [[4, 4], [5, 5], [6, 6], [4, 4], [5, 5], [6, 6]] - ] + Split the input tensor into multiple sub-Tensors. + Args: - x (Tensor): The input tensor, its data type should be bool, float32, float64, int32. The rank of ``x`` should be in [1, 6]. - repeat_times (Tensor|tuple|list): The number of repeating times for each dimension of the input ``x``. If repeat_times is a list or tuple, the elements of - it should be integers or Tensors with shape [1]. If repeat_times is Tensor, it should be an 1-D Tensor. The size of its shape should be in [1, 6]. - name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name` . + x (Tensor): A N-D Tensor. The data type is bool, float16, float32, float64, int32 or int64. + chunks(int): The number of tensor to be split along the certain axis. + axis (int|Tensor, optional): The axis along which to split, it can be a scalar with type + ``int`` or a ``Tensor`` with shape [1] and data type ``int32`` or ``int64``. + If :math::`axis < 0`, the axis to split along is :math:`rank(x) + axis`. Default is 0. + name (str, optional): The default value is None. Normally there is no need for user to set this property. + For more information, please refer to :ref:`api_guide_Name` . Returns: - N-D Tensor. The data type is the same as ``x``. After tiling, each dimension of the output is equal to the corresponding dimension of ``x`` multiplying the corresponding value given by ``repeat_times`` . + list(Tensor): The list of segmented Tensors. Raises: - TypeError: The type of ``repeat_times`` must be list, tuple or Tensor. - ValueError: The elements of ``repeat_times`` cannot be negative. + TypeError: The data type of ``x`` must be one of bool, float16, float32, float64, int32, int64. + TypeError: ``chunks`` is not int. + TypeError: ``axis`` is not int or Tensor. the data type of ``axis`` must be int32 or int64 when it's a Tensor. + Example: + .. code-block:: python + + import numpy as np + import paddle + + paddle.disable_static() + # x is a Tensor which shape is [3, 9, 5] + x_np = np.random.random([3, 9, 5]).astype("int32") + x = paddle.to_variable(x_np) + + out0, out1, out22 = paddle.chunk(x, chunks=3, axis=1) + # out0.shape [3, 3, 5] + # out1.shape [3, 3, 5] + # out2.shape [3, 3, 5] + + + # axis is negative, the real axis is (rank(x) + axis) which real + # value is 1. + out0, out1, out2 = paddle.chunk(x, chunks=3, axis=-2) + # out0.shape [3, 3, 5] + # out1.shape [3, 3, 5] + # out2.shape [3, 3, 5] + """ + check_type(chunks, 'chunks', (int), 'chunk') + return paddle.fluid.layers.split( + input=x, num_or_sections=chunks, dim=axis, name=name) + + +def tile(x, repeat_times, name=None): + """ + + Construct a new Tensor by repeating ``x`` the number of times given by ``repeat_times``. + After tiling, the number of elements of the i'th dimension of the output is equal to ``x.dims[i] * repeat_times[i]``. + + Both the number of dimensions of ``x`` and the number of elements in ``repeat_times`` should be less than or equal to 6. + + Args: + x (Tensor): The input tensor, its data type should be bool, float32, float64, int32 or int64. + repeat_times (Tensor|tuple|list): The number of repeating times. If repeat_times is a list or tuple, all its elements + should be integers or 1-D Tensors with the data type int32. If repeat_times is a Tensor, it should be an 1-D Tensor with the data type int32. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + N-D Tensor. The data type is the same as ``x``. + Examples: .. code-block:: python + import paddle import numpy as np + paddle.disable_static() - # example 1: - np_data_1 = np.array([1, 2, 3]).astype('int32') - data_1 = paddle..to_variable(np_data_1) - tiled_1 = paddle.tile(data_1, repeat_times=[2, 1]) + np_data = np.array([1, 2, 3]).astype('int32') + data = paddle.to_variable(np_data) + out = paddle.tile(data, repeat_times=[2, 1]) + np_out = out1.numpy() # [[1, 2, 3], [1, 2, 3]] - # example 2: + + out = paddle.tile(data, repeat_times=[2, 2]) + np_out = out.numpy() + # [[1, 2, 3, 1, 2, 3], [1, 2, 3, 1, 2, 3]] + np_repeat_times = np.array([2, 1]).astype("int32") repeat_times = paddle.to_variable(np_repeat_times) - tiled_2 = paddle.tile(data_1, repeat_times=repeat_times) + out = paddle.tile(data, repeat_times=repeat_times) + np_out = out.numpy() # [[1, 2, 3], [1, 2, 3]] """ - if in_dygraph_mode(): - if isinstance(repeat_times, (list, tuple)): - repeat_times = [ - item.numpy()[0] if isinstance(item, Variable) else item - for item in repeat_times - ] - - return core.ops.tile(x, 'repeat_times', repeat_times) - - inputs = {"X": [x]} - attrs = {} check_variable_and_dtype( x, 'x', ['bool', 'float32', 'float64', 'int32', 'int64'], 'tile') check_type(repeat_times, 'repeat_times', (list, tuple, Variable), 'tile') - if convert_dtype(x.dtype) == 'bool' and x.stop_gradient == True: + if convert_dtype(x.dtype) == 'bool' and x.stop_gradient == False: raise ValueError( "When the date type is bool for the input 'x' of tile op, you " - "must set its stop_gradient to be False by " - "some_var.stop_gradient == True supporting some_var is the input.") + "must set its stop_gradient to be True by " + "some_var.stop_gradient == True supporting some_var as the input.") helper = LayerHelper('tile', input=x, **locals()) + inputs = {"X": [x]} + attrs = {} + def get_attr_repeat_times(list_repeat_times): attrs_repeat_times = [] for idx, times in enumerate(list_repeat_times): @@ -873,13 +899,13 @@ def get_attr_repeat_times(list_repeat_times): else: attrs_repeat_times.append(times) assert times > 0, ( - "Every element given in repeat_times must be positive.") + "All elements in repeat_times must be positive for tile.") return attrs_repeat_times if isinstance(repeat_times, Variable): repeat_times.stop_gradient = True inputs['RepeatTimes'] = repeat_times - attrs['repeat_times'] = [-1] * len(repeat_times.shape) + attrs['repeat_times'] = [-1] elif isinstance(repeat_times, (list, tuple)): attrs['repeat_times'] = get_attr_repeat_times(repeat_times) if utils._contain_var(repeat_times): @@ -893,67 +919,103 @@ def get_attr_repeat_times(list_repeat_times): return out +def expand_as(x, y, name=None): + """ + + Expand the input tensor ``x`` to the same shape as the input tensor ``y``. + + Both the number of dimensions of ``x`` and ``y`` must be less than or equal to 6, and the number of dimensions of ``y`` must be greather than or equal to that of ``x``. The dimension to expand must have a value of 1. + + Args: + x (Tensor): The input tensor, its data type is bool, float32, float64, int32 or int64. + y (Tensor): The input tensor gives the shape that ``x`` to expand to. + name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + N-D Tensor: A Tensor with the same shape as ``y``. The data type is the same as ``x``. + + Examples: + .. code-block:: python + + import numpy as np + import paddle + + paddle.disable_static() + + np_data_x = np.array([1, 2, 3]).astype=('int32) + np_data_y = np.array([[1, 2, 3], [4, 5, 6]]).astype=('int32) + data_x = paddle.to_variable(np_data_x) + data_y = paddle.to_variable(np_data_y) + out = paddle.expand_as(data_x, data_y) + np_out = out.numpy() + # [[1, 2, 3], [1, 2, 3]] + """ + check_variable_and_dtype( + x, 'x', ['bool', 'float32', 'float64', 'int32', 'int64'], 'expand_as') + check_type(y, 'y', Variable, 'expand_as') + + if convert_dtype(x.dtype) == 'bool' and x.stop_gradient == False: + raise ValueError( + "When the data type of input 'x' for expand_as is bool, " + "you must set its stop_gradient to be False by " + "some_var.stop_gradient = True, supporting " + "some_var as the input 'x'.") + inputs = {"X": [x], "target_tensor": [y]} + + helper = LayerHelper('expand_as', input=x, **locals()) + dtype = helper.input_dtype(input_param_name='x') + out = helper.create_variable_for_type_inference(dtype) + helper.append_op(type='expand_as_v2', inputs=inputs, outputs={'Out': out}) + return out + + def expand(x, shape, name=None): """ - :alias_main: paddle.expand - :alias: paddle.expand,paddle.tensor.expand,paddle.tensor.manipulation.expand Expand the input tensor to a given shape. - The rank of ``x`` should be less than or equal to 6, and the number of elements in ``shape`` should also be less than or equal to 6. The size of the dimension to expand must be 1. + Both the number of dimensions of ``x`` and the number of elements in ``shape`` should be less than or equal to 6. The dimension to expand must have a value 1. Args: - x (Tensor): The input Tensor with rank in [1, 6]. The data type is bool, float32, float64 or int32. - shape (list|tuple|Tensor): The result shape after expanding. The data type is int32. If shape is a list or tuple, all elements of - it should be integers or Tensors with shape (1,). If shape is a Tensor, it should be an 1-D Tensor. - The value -1 in shape, means keeping the corresponding dimension unchanged. + x (Tensor): The input tensor, its data type is bool, float32, float64, int32 or int64. + shape (list|tuple|Tensor): The result shape after expanding. The data type is int32. If shape is a list or tuple, all its elements + should be integers or 1-D Tensors with the data type int32. If shape is a Tensor, it should be an 1-D Tensor with the data type int32. + The value -1 in shape means keeping the corresponding dimension unchanged. name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` . Returns: - Tensor: A Tensor with the given shape. The data type is the same as ``x``. - - Raises: - TypeError: The type of ``shape`` must be list, tuple or Variable. - ValueError: The elements of ``shape`` must be positive or -1. + N-D Tensor: A Tensor with the given shape. The data type is the same as ``x``. Examples: .. code-block:: python import numpy as np import paddle - paddle.disable_static() - # example 1: - np_data_1 = np.array([1, 2, 3]).astype=('int32) - data_1 = paddle.to_variable(np_data_1) - expanded_1 = paddle.expand(data_1, shape=[2, 3]) + paddle.disable_static() + np_data = np.array([1, 2, 3]).astype=('int32) + data = paddle.to_variable(np_data) + out = paddle.expand(data, shape=[2, 3]) + out = out.numpy() # [[1, 2, 3], [1, 2, 3]] - # example 2: np_shape = np.array([2, 3]).astype=('int32) shape = paddle.to_variable(np_shape) - expanded_2 = paddle.expand(data_1, shape=shape) + out = paddle.expand(data, shape=shape) + out = out.numpy # [[1, 2, 3], [1, 2, 3]] """ - if in_dygraph_mode(): - if isinstance(shape, (list, tuple)): - expand_shape = [ - item.numpy()[0] if isinstance(item, Variable) else item - for item in shape - ] - - return core.ops.expand_v2(x, 'shape', expand_shape) - - inputs = {"X": [x]} - attrs = {} check_variable_and_dtype( x, 'x', ['bool', 'float32', 'float64', 'int32', 'int64'], 'expand') check_type(shape, 'shape', (list, tuple, Variable), 'expand') - if convert_dtype(x.dtype) == 'bool' and x.stop_gradient == True: + + inputs = {"X": [x]} + attrs = {} + if convert_dtype(x.dtype) == 'bool' and x.stop_gradient == False: raise ValueError("When the data type of input 'x' for expand is bool, " "you must set its stop_gradient to be False by " - " some_var.stop_gradient = False, supporting " + "some_var.stop_gradient = True, supporting " "some_var as the input.") helper = LayerHelper('expand', input=x, **locals()) @@ -966,7 +1028,7 @@ def get_attr_expand_shape(list_expand_shape): else: attrs_expand_shape.append(shape) assert shape > 0 or shape == -1, ( - "Every element in shape must be positive or -1.") + "All elements in shape of expand must be positive or -1.") return attrs_expand_shape if isinstance(shape, Variable): @@ -983,3 +1045,6 @@ def get_attr_expand_shape(list_expand_shape): helper.append_op( type='expand_v2', inputs=inputs, outputs={'Out': out}, attrs=attrs) return out + + +broadcast_to = expand diff --git a/python/requirements.txt b/python/requirements.txt index a055ad92139ce..76e6b5dedbeaf 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -1,6 +1,6 @@ requests>=2.20.0 -numpy>=1.12, <=1.16.4 ; python_version<"3.5" -numpy>=1.12 ; python_version>="3.5" +numpy>=1.13, <=1.16.4 ; python_version<"3.5" +numpy>=1.13 ; python_version>="3.5" protobuf>=3.1.0 gast==0.3.3 matplotlib<=2.2.4 ; python_version<"3.6"