Skip to content

Commit

Permalink
fix some bug, test=develop (PaddlePaddle#36888)
Browse files Browse the repository at this point in the history
  • Loading branch information
wanghuancoder authored and piotrekobi committed Nov 3, 2021
1 parent cb9de59 commit 79203ec
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 10 deletions.
7 changes: 4 additions & 3 deletions paddle/fluid/framework/new_executor/interpretercore.cc
Original file line number Diff line number Diff line change
Expand Up @@ -241,13 +241,14 @@ void InterpreterCore::BuildInplace() {
auto& outputs = instr.Outputs();
for (auto& pair : in_to_outs) {
auto iter = inputs.find(pair.first);
if (iter != inputs.end()) {
if (iter != inputs.end() && !iter->second.empty()) {
if (BuildInplaceCheckVarIsOnlyInput(iter->second[0])) {
auto iterout = outputs.find(pair.second);
if (iterout != outputs.end()) {
if (iterout != outputs.end() && !iterout->second.empty()) {
auto invar = global_scope_->Var(iter->second[0]);
auto outvar = global_scope_->Var(iterout->second[0]);
if (invar && outvar) {
if (invar && outvar && invar->IsType<LoDTensor>() &&
outvar->IsType<LoDTensor>()) {
instr.AddInplace(invar, outvar);
VLOG(3) << "inplace " << vec_instruction_[i].OpBase()->Type()
<< " " << global_scope_->GetNameById(iter->second[0])
Expand Down
23 changes: 18 additions & 5 deletions paddle/fluid/framework/new_executor/interpretercore_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ void build_variable_scope(const framework::ProgramDesc& pdesc,
if (nullptr == var_scope->FindVar(var_name)) {
var_scope->AddVar(var_desc->Name(), var_desc);
} else {
auto* var_desc = var_scope->VarDesc(var_name);
if (nullptr == var_desc) {
auto* var_desc_tmp = var_scope->VarDesc(var_name);
if (nullptr == var_desc_tmp) {
VLOG(3) << "update var:" << var_name << " desc from nullptr into "
<< var_desc;
var_scope->VarMetaInfo(var_name).vardesc_ = var_desc;
Expand Down Expand Up @@ -206,9 +206,22 @@ void apply_device_guard(const OperatorBase* op_base,
VLOG(3) << "Switch into CPUPlace by device_guard.";
expected_kernel_key->place_ = platform::CPUPlace();
} else if (op_device.find("gpu") != std::string::npos &&
platform::is_gpu_place(place)) {
VLOG(3) << "Switch into " << place << " by device_guard.";
expected_kernel_key->place_ = place;
(platform::is_gpu_place(place) ||
platform::is_npu_place(place))) {
// when the Op that only has CPUKernel is assigned to GPU, the CPUKernel
// will be executed and a warning will be given at the same time.
if (op_base->SupportGPU()) {
expected_kernel_key->place_ = place;
} else if (op_base->SupportNPU()) {
expected_kernel_key->place_ = place;
} else {
expected_kernel_key->place_ = platform::CPUPlace();
LOG_FIRST_N(WARNING, 1)
<< "Op(" << op_base->Type()
<< ") has no CUDA implementation. It will be assigned to CPUPlace.";
}
VLOG(3) << "Switch into " << expected_kernel_key->place_
<< " by device_guard.";
} else {
PADDLE_THROW(
platform::errors::Fatal("Unsupported current place %s", op_device));
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/framework/new_executor/new_executor_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,15 @@ struct VariableMetaInfo {
// TODO(zhiqiu): Maybe we need to add rwlock for VariableScope?
class VariableScope : public ScopeBase {
public:
VariableScope() {
// for @EMPTY@ variable
var_list_.push_back(nullptr);
name2id_[kEmptyVarName] = 0;
VariableMetaInfo info;
info.var_ref_count_ = 0;
info.vardesc_ = nullptr;
vec_meta_info_.push_back(info);
}
Variable* FindVar(const std::string& name) const {
auto it = name2id_.find(name);
if (it != name2id_.end()) {
Expand Down
28 changes: 26 additions & 2 deletions paddle/fluid/operators/controlflow/fetch_v2_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,35 @@ class FetchV2Op : public framework::OperatorWithKernel {
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
if (!tensor.IsInitialized()) {
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}

framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto *fetch_var = ctx.InputVar("X");
if (fetch_var == nullptr) {
return framework::OpKernelType(framework::proto::VarType::FP32,
platform::CPUPlace());
}

if (fetch_var->IsType<framework::LoDTensor>()) {
auto &src_item = fetch_var->Get<framework::LoDTensor>();
if (!src_item.IsInitialized()) {
return framework::OpKernelType(framework::proto::VarType::FP32,
platform::CPUPlace());
}
} else {
auto &src_item = fetch_var->Get<framework::LoDTensorArray>();
if (src_item.empty() || !src_item[0].IsInitialized()) {
return framework::OpKernelType(framework::proto::VarType::FP32,
platform::CPUPlace());
}
}

return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
platform::CPUPlace());
Expand Down Expand Up @@ -127,6 +150,9 @@ class FetchV2Kernel {

if (fetch_var->IsType<framework::LoDTensor>()) {
auto &src_item = fetch_var->Get<framework::LoDTensor>();
if (!src_item.IsInitialized()) {
return;
}
auto *dst_item = &(BOOST_GET(framework::LoDTensor, fetch_list->at(col)));
bool check_place = platform::is_cpu_place(src_item.place()) ||
platform::is_cuda_pinned_place(src_item.place());
Expand Down Expand Up @@ -173,9 +199,7 @@ class FetchV2OpProtoMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(true);
AddComment(R"DOC(
FetchV2 Operator.
It should not be configured by users directly.
)DOC");
}
};
Expand Down

0 comments on commit 79203ec

Please sign in to comment.