Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions caffe2/ideep/operators/operator_fallback_ideep.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,14 @@ class C10_EXPORT IDEEPFallbackOp final : public IDEEPOperator {
}
} else {
VLOG(1) << "Input " << i << " is not ideep::tensor. Skipping copy.";
// Note(jiayq): This removes a const but conceptually
// local_input_blobs will only be used as const blob input for the
// base op so we are still fine.
local_input_blobs_[i]->ShareExternal(
const_cast<void *>(OperatorBase::Inputs()[i]->GetRaw()),
OperatorBase::Inputs()[i]->meta());
if (OperatorBase::Inputs()[i]->GetRaw() != local_input_blobs_[i]->GetRaw()) {
// Note(jiayq): This removes a const but conceptually
// local_input_blobs will only be used as const blob input for the
// base op so we are still fine.
local_input_blobs_[i]->ShareExternal(
const_cast<void *>(OperatorBase::Inputs()[i]->GetRaw()),
OperatorBase::Inputs()[i]->meta());
}
input_share_[i] = true;
}
}
Expand Down Expand Up @@ -150,8 +152,13 @@ class C10_EXPORT IDEEPFallbackOp final : public IDEEPOperator {
} else {
VLOG(2) << "Output " << base_def_.output(i) << " as CPUTensor";
Blob* dst = OperatorBase::OutputBlob(i);
dst->Reset(new Tensor(CPU));
BlobSetTensor(dst, src.Alias());
if (output_inplace_[i]) {
auto dtensor = BlobGetMutableTensor(dst, CPU);
dtensor->CopyFrom(src);
} else {
dst->Reset(new Tensor(CPU));
BlobSetTensor(dst, src.Alias());
}
}
}
return true;
Expand Down