Skip to content

Commit

Permalink
Merge branch 'develop' into leaky-relu
Browse files Browse the repository at this point in the history
  • Loading branch information
zhupengyang committed Aug 19, 2020
2 parents 62d4aad + db272ca commit 2390b43
Show file tree
Hide file tree
Showing 128 changed files with 4,933 additions and 1,322 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ cc_test(op_compatible_info_test SRCS op_compatible_info_test.cc DEPS op_compatib

cc_library(save_load_util SRCS save_load_util DEPS tensor scope layer)
cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tensor scope layer)
cc_library(generator SRCS generator.cc)

# Get the current working branch
execute_process(
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/fleet/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ else()
cc_library(gloo_wrapper SRCS gloo_wrapper.cc DEPS framework_proto variable_helper scope)
endif(WITH_GLOO)

cc_library(heter_wrapper SRCS heter_wrapper.cc DEPS framework_proto device_context)
cc_library(heter_wrapper SRCS heter_wrapper.cc DEPS framework_proto device_context heter_service_proto)

cc_test(test_fleet SRCS test_fleet.cc DEPS fleet_wrapper gloo_wrapper fs shell)
78 changes: 78 additions & 0 deletions paddle/fluid/framework/generator.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <deque>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <utility>

#include "paddle/fluid/framework/generator.h"

namespace paddle {
namespace framework {

std::shared_ptr<Generator> Generator::gen_instance_ = NULL;

GeneratorState* Generator::GetState() {
std::lock_guard<std::mutex> lock(this->mutex);
return this->state_.get();
}

void Generator::SetState(GeneratorState* state_in) {
std::lock_guard<std::mutex> lock(this->mutex);
*this->state_ = *state_in;
}

uint64_t Generator::GetCurrentSeed() {
std::lock_guard<std::mutex> lock(this->mutex);
return this->state_->current_seed;
}

uint64_t Generator::Seed() {
std::lock_guard<std::mutex> lock(this->mutex);
uint64_t seed;
std::random_device de;
seed = ((((uint64_t)de()) << 32) + de()) & 0x1FFFFFFFFFFFFF;
this->state_->current_seed = seed;
std::seed_seq seq({seed});
this->state_->cpu_engine.seed(seq);

return this->state_->current_seed;
}

void Generator::SetCurrentSeed(uint64_t seed) {
std::lock_guard<std::mutex> lock(this->mutex);
this->state_->current_seed = uint64_t(seed);
std::seed_seq seq({seed});
this->state_->cpu_engine.seed(seq);
}

std::mt19937_64& Generator::GetCPUEngine() {
std::lock_guard<std::mutex> lock(this->mutex);
return this->state_->cpu_engine;
}

void Generator::SetCPUEngine(std::mt19937_64 engine) {
std::lock_guard<std::mutex> lock(this->mutex);
this->state_->cpu_engine = std::mt19937_64(engine);
}

uint64_t Generator::Random64() {
std::lock_guard<std::mutex> lock(this->mutex);
return this->state_->cpu_engine();
}

} // namespace framework
} // namespace paddle
96 changes: 96 additions & 0 deletions paddle/fluid/framework/generator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <stdint.h>
#include <atomic>
#include <deque>
#include <iostream> // temp for debug
#include <memory>
#include <mutex> // NOLINT
#include <random>
#include <typeinfo>
#include <utility>

namespace paddle {
namespace framework {

struct GeneratorState {
int64_t device = -1;
uint64_t current_seed = 34342423252;
std::mt19937_64 cpu_engine;
};

struct Generator {
Generator() {
GeneratorState default_gen_state_cpu;
default_gen_state_cpu.device = -1;
default_gen_state_cpu.current_seed = 34342423252;
std::seed_seq seq({34342423252});
default_gen_state_cpu.cpu_engine = std::mt19937_64(seq);
this->state_ = std::make_shared<GeneratorState>(default_gen_state_cpu);
}
explicit Generator(GeneratorState state_in)
: state_{std::make_shared<GeneratorState>(state_in)} {}
Generator(const Generator& other)
: Generator(other, std::lock_guard<std::mutex>(other.mutex)) {}

// get random state
GeneratorState* GetState();
// set random state
void SetState(GeneratorState* state_in);
// get current seed
uint64_t GetCurrentSeed();
// random a seed and get
uint64_t Seed();

// set seed
void SetCurrentSeed(uint64_t seed);
// get cpu engine
std::mt19937_64& GetCPUEngine();
// set cpu engine
void SetCPUEngine(std::mt19937_64 engine);

uint64_t Random64();

bool is_init_py = false;

// CPU Generator singleton
static std::shared_ptr<Generator> GetInstance() {
if (NULL == gen_instance_) {
gen_instance_.reset(new paddle::framework::Generator());
}
return gen_instance_;
}

static std::shared_ptr<Generator> GetInstanceX() {
if (NULL == gen_instance_) {
gen_instance_.reset(new paddle::framework::Generator());
}
gen_instance_->is_init_py = true;
return gen_instance_;
}

private:
static std::shared_ptr<Generator> gen_instance_;
std::shared_ptr<GeneratorState> state_;
mutable std::mutex mutex;

Generator(const Generator& other, const std::lock_guard<std::mutex>&)
: state_(std::make_shared<GeneratorState>(*(other.state_))) {}
};

} // namespace framework
} // namespace paddle
4 changes: 4 additions & 0 deletions paddle/fluid/framework/ir/conv_bn_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -368,3 +368,7 @@ REGISTER_PASS(conv_transpose_bn_fuse_pass,
paddle::framework::ir::ConvTransposeBNFusePass);
REGISTER_PASS(conv_transpose_eltwiseadd_bn_fuse_pass,
paddle::framework::ir::ConvTransposeEltwiseAddBNFusePass);
REGISTER_PASS(depthwise_conv_bn_fuse_pass,
paddle::framework::ir::DepthwiseConvBNFusePass);
REGISTER_PASS(depthwise_conv_eltwiseadd_bn_fuse_pass,
paddle::framework::ir::DepthwiseConvEltwiseAddBNFusePass);
10 changes: 10 additions & 0 deletions paddle/fluid/framework/ir/conv_bn_fuse_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,16 @@ class ConvTransposeEltwiseAddBNFusePass : public ConvEltwiseAddBNFusePass {
std::string conv_type() const { return "conv2d_transpose"; }
};

class DepthwiseConvBNFusePass : public ConvBNFusePass {
public:
std::string conv_type() const { return "depthwise_conv2d"; }
};

class DepthwiseConvEltwiseAddBNFusePass : public ConvEltwiseAddBNFusePass {
public:
std::string conv_type() const { return "depthwise_conv2d"; }
};

} // namespace ir
} // namespace framework
} // namespace paddle
3 changes: 2 additions & 1 deletion paddle/fluid/framework/ir/subgraph_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,8 @@ std::vector<std::vector<Node *>> SubgraphDetector::ExtractSubGraphs() {
BriefNode *brief_node = itr.second;

if (!Agent(brief_node->node).marked()) {
VLOG(4) << brief_node->node->id() << " node not a trt candidate.";
VLOG(4) << brief_node->node->id() << " node named "
<< brief_node->node->Name() << " is not a trt candidate.";
continue;
}

Expand Down
17 changes: 17 additions & 0 deletions paddle/fluid/framework/prune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,23 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
should_run.push_back(true);
} else {
should_run.push_back(false);
// If the output of an op modifies feed vars, the op should not clip.
// For example, in the transformer structure, the third parameter returned
// by beam_search op is generally assigned to a feed var. Cutting the
// assign op will cause an error.
if (parent_block_id != -1) {
bool flag = false;
for (auto& var : op_desc.outputs()) {
for (auto& argu : var.arguments()) {
if (feed_var_names.count(argu)) {
flag = true;
}
}
}
if (flag) {
should_run.back() = true;
}
}
}
}

Expand Down
31 changes: 31 additions & 0 deletions paddle/fluid/framework/prune_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,34 @@ TEST(Prune, recurrrent_op) {
EXPECT_EQ(pruned.blocks(0).ops_size(), 2);
EXPECT_EQ(pruned.blocks(1).ops_size(), 1);
}

// If the output of an op modifies feed vars, the op should not clip.
TEST(Prune, recurrrent_op_2) {
f::ProgramDesc program;
f::BlockDesc *block = program.MutableBlock(0);
f::BlockDesc *sub_block = program.AppendBlock(*block);
AddOp("one_two", {{"input", {"a"}}}, {{"output", {"b", "c"}}},
f::AttributeMap{}, block);

std::vector<std::string> state_var_name(1, "y");
AddOp("recurrent", {{"input", {"b", "c"}}}, {{"output", {"b1, c1"}}},
{{"ex_states", state_var_name},
{"states", state_var_name},
{"sub_block", sub_block}},
block);

EXPECT_TRUE(sub_block != nullptr);
AddOp("rnn_memory_helper", {{"input", {"x"}}}, {{"output", {"a"}}},
f::AttributeMap{}, sub_block);

f::proto::ProgramDesc *pdesc = program.Proto();
pdesc->mutable_blocks(0)->mutable_ops(1)->set_is_target(true);

f::proto::ProgramDesc pruned;
std::set<std::string> feed_var_names = {"x", "a"};

f::Prune(*pdesc, feed_var_names, &pruned);
EXPECT_EQ(pruned.blocks_size(), 2);
EXPECT_EQ(pruned.blocks(0).ops_size(), 2);
EXPECT_EQ(pruned.blocks(1).ops_size(), 1);
}
35 changes: 18 additions & 17 deletions paddle/fluid/imperative/prepared_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,23 +42,17 @@ static void PrepareData(const platform::Place& place,
for (const auto& var_base : name_pair.second) {
const auto* tensor = GetTensorFromVar(var_base->Var());
if (tensor && tensor->IsInitialized()) {
auto tmp_place = tensor->place();

// TODO(jiabin): Support transform data layout when we Verify it on more
// tests
if (!(tmp_place == place)) {
auto kernel_type_for_var = op.GetKernelTypeForVar(
name_pair.first, *tensor, expected_kernel_key);
if (!NeedTransform(kernel_type_for_var, expected_kernel_key)) {
continue;
} else {
VLOG(3) << "Transform Variable " << var_base->Name() << " from "
<< kernel_type_for_var << " to " << expected_kernel_key;
framework::Tensor out;
TransformData(expected_kernel_key, kernel_type_for_var, *tensor,
&out);
SetTensorToVariable(var_base->Var(), out, var_base->MutableVar());
}
auto kernel_type_for_var = op.GetKernelTypeForVar(
name_pair.first, *tensor, expected_kernel_key);
if (!NeedTransform(kernel_type_for_var, expected_kernel_key)) {
continue;
} else {
VLOG(3) << "Transform Variable " << var_base->Name() << " from "
<< kernel_type_for_var << " to " << expected_kernel_key;
framework::Tensor out;
TransformData(expected_kernel_key, kernel_type_for_var, *tensor,
&out);
SetTensorToVariable(var_base->Var(), out, var_base->MutableVar());
}
}
}
Expand Down Expand Up @@ -93,6 +87,13 @@ PreparedOp PrepareOpImpl(const NameVarMap<VarType>& ins,
auto& kernels = kernels_iter->second;

framework::RuntimeContext ctx({}, {});
#ifdef PADDLE_WITH_MKLDNN
// MKLDNN variant of code reads attributes in some of GetKernelTypeForVar and
// GetKernelType functions, so we need to copy the attributes there.
// Const qualifier of Attrs had to be discarded to overwrite it.
auto& mutable_op_attrs = const_cast<framework::AttributeMap&>(op.Attrs());
mutable_op_attrs = attrs;
#endif
auto expected_kernel_key =
op.GetExpectedKernelType(DygraphExecutionContext<VarType>(
op, framework::Scope(), *dev_ctx, ctx, ins, outs, attrs));
Expand Down
16 changes: 14 additions & 2 deletions paddle/fluid/imperative/tests/test_prepare_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ TEST(test_prepare_op, test_prepare_data) {
}
#endif

TEST(test_prepare_op, test_prepare_data_same_place) {
void TestPrepareDataSamePlace(framework::AttributeMap attr_map) {
std::shared_ptr<imperative::VarBase> vin(
new imperative::VarBase(false, "vin"));
std::shared_ptr<imperative::VarBase> vout(
Expand All @@ -198,7 +198,6 @@ TEST(test_prepare_op, test_prepare_data_same_place) {
var_pair out_pair = var_pair("Out", vb_vector(1, vout));
imperative::NameVarBaseMap ins = {x_pair};
imperative::NameVarBaseMap outs = {out_pair};
framework::AttributeMap attr_map;
const std::string op_type = "relu";
const auto& info = framework::OpInfoMap::Instance().Get(op_type);
if (info.Checker()) info.Checker()->Check(&attr_map);
Expand All @@ -222,8 +221,21 @@ TEST(test_prepare_op, test_prepare_data_same_place) {
}
}
}

TEST(test_prepare_op, test_prepare_data_same_place) {
TestPrepareDataSamePlace({});
}

#ifdef PADDLE_WITH_MKLDNN
TEST(test_prepare_op, test_prepare_data_cpu_mkldnn) {
TestPrepareDataSamePlace({{"use_mkldnn", true}});
}
#endif
} // namespace imperative
} // namespace paddle

USE_OP(split);
USE_OP(relu);
#ifdef PADDLE_WITH_MKLDNN
USE_OP_DEVICE_KERNEL(relu, MKLDNN);
#endif
7 changes: 6 additions & 1 deletion paddle/fluid/inference/tensorrt/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,12 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector<T>& shape, std::string input,
} else if (shape.size() == 3UL) {
return nvinfer1::Dims3(shape[0], shape[1], shape[2]);
}
return nvinfer1::Dims4(shape[0], shape[1], 1, 1);
nvinfer1::Dims dims;
dims.nbDims = shape.size();
for (size_t i = 0; i < shape.size(); i++) {
dims.d[i] = shape[i];
}
return dims;
}
}
} // NOLINT
Expand Down
Loading

0 comments on commit 2390b43

Please sign in to comment.