Skip to content

Commit

Permalink
Fix fallback issues to handle inplace case (pytorch#15726)
Browse files Browse the repository at this point in the history
Summary:
Fix fallback issues to handle inplace case
Pull Request resolved: pytorch#15726

Differential Revision: D13591243

Pulled By: yinghai

fbshipit-source-id: 6897f1daacb36beabcdfc22c39242bbdfdd0e534
  • Loading branch information
gujinghui authored and facebook-github-bot committed Jan 11, 2019
1 parent 0934e8d commit 07ea3e0
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions caffe2/ideep/operators/operator_fallback_ideep.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,14 @@ class C10_EXPORT IDEEPFallbackOp final : public IDEEPOperator {
}
} else {
VLOG(1) << "Input " << i << " is not ideep::tensor. Skipping copy.";
// Note(jiayq): This removes a const but conceptually
// local_input_blobs will only be used as const blob input for the
// base op so we are still fine.
local_input_blobs_[i]->ShareExternal(
const_cast<void *>(OperatorBase::Inputs()[i]->GetRaw()),
OperatorBase::Inputs()[i]->meta());
if (OperatorBase::Inputs()[i]->GetRaw() != local_input_blobs_[i]->GetRaw()) {
// Note(jiayq): This removes a const but conceptually
// local_input_blobs will only be used as const blob input for the
// base op so we are still fine.
local_input_blobs_[i]->ShareExternal(
const_cast<void *>(OperatorBase::Inputs()[i]->GetRaw()),
OperatorBase::Inputs()[i]->meta());
}
input_share_[i] = true;
}
}
Expand Down Expand Up @@ -150,8 +152,13 @@ class C10_EXPORT IDEEPFallbackOp final : public IDEEPOperator {
} else {
VLOG(2) << "Output " << base_def_.output(i) << " as CPUTensor";
Blob* dst = OperatorBase::OutputBlob(i);
dst->Reset(new Tensor(CPU));
BlobSetTensor(dst, src.Alias());
if (output_inplace_[i]) {
auto dtensor = BlobGetMutableTensor(dst, CPU);
dtensor->CopyFrom(src);
} else {
dst->Reset(new Tensor(CPU));
BlobSetTensor(dst, src.Alias());
}
}
}
return true;
Expand Down

0 comments on commit 07ea3e0

Please sign in to comment.