Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip][pt1][tensor] BlobGetMutableTensor returns Tensor #14136

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions c10/util/typeid.cpp
Expand Up @@ -71,7 +71,7 @@ CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(25, detail::_guard_long_unique<long>);
CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(
26,
detail::_guard_long_unique<std::vector<long>>);

CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(27, _CaffeHighestPreallocatedTypeId)
// 27 is TensorImplPtr see tensor.cc
CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(28, _CaffeHighestPreallocatedTypeId)

This comment was marked as off-topic.


} // namespace caffe2
4 changes: 2 additions & 2 deletions c10/util/typeid.h
Expand Up @@ -605,6 +605,6 @@ CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(25, detail::_guard_long_unique<long>)
CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(
26,
detail::_guard_long_unique<std::vector<long>>)

CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(27, _CaffeHighestPreallocatedTypeId)
// 27 is TensorImplPtr see tensor.h
CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(28, _CaffeHighestPreallocatedTypeId)
} // namespace caffe2
46 changes: 39 additions & 7 deletions caffe2/core/blob.h
Expand Up @@ -10,11 +10,14 @@

#include <ATen/core/blob.h>
#include <c10/util/typeid.h>
#include <ATen/core/intrusive_ptr.h>
#include "caffe2/core/logging.h"
#include "caffe2/core/tensor.h"

namespace caffe2 {

using TensorImplPtr = c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>;

inline bool BlobIsTensorType(const Blob& blob, DeviceType device_type) {
bool is_match = blob.meta().Match<Tensor>();
if (!is_match) {
Expand All @@ -24,10 +27,40 @@ inline bool BlobIsTensorType(const Blob& blob, DeviceType device_type) {
return tensor && *tensor && tensor->GetDeviceType() == device_type;
}

inline bool XBlobIsTensorType(const Blob& blob, DeviceType device_type) {
jerryzh168 marked this conversation as resolved.
Show resolved Hide resolved
if (!blob.meta().Match<TensorImplPtr>()) {
return false;
}
const auto& tensor_impl_ptr = blob.Get<TensorImplPtr>();
return tensor_impl_ptr && tensor_impl_ptr->device_type() == device_type;
}

inline Tensor* BlobSetTensor(Blob* blob, const Tensor& tensor) {
return blob->Reset<Tensor>(new Tensor(tensor));
}

inline Tensor
XBlobGetMutableTensor(Blob* blob, at::IntList dims, at::TensorOptions options) {
auto* tensor_impl_ptr = blob->GetMutableOrNull<TensorImplPtr>();
// Create a new Tensor(TensorImpl) when either the stored object is not TensorImplPtr
// or data type does not match or device type does not match
if (!tensor_impl_ptr || (*tensor_impl_ptr)->dtype() != options.dtype()
|| (*tensor_impl_ptr).get()->GetDevice() != options.device()) {
VLOG(1) << "Create new mutable object " << TypeMeta::TypeName<TensorImplPtr>()
<< " dims: " << dims << " options: " << options;
return Tensor(*blob->Reset<TensorImplPtr>(new TensorImplPtr(caffe2::empty(dims, options).getTensorImpl())));
} else {
auto& tensor_impl = *tensor_impl_ptr;
if (tensor_impl->sizes() != dims) {
// Resize when the dims doesn't match
tensor_impl->Resize(dims);
}
tensor_impl.get()->raw_mutable_data(tensor_impl->dtype());
}
return Tensor(*tensor_impl_ptr);
}

// need to keep both for clangr codemod
inline Tensor*
BlobGetMutableTensor(Blob* blob, at::IntList dims, at::TensorOptions options) {
if (blob->IsType<Tensor>()) {
Expand All @@ -40,9 +73,8 @@ BlobGetMutableTensor(Blob* blob, at::IntList dims, at::TensorOptions options) {
}
if (tensor->dtype() == options.dtype()) {
tensor->raw_mutable_data();
} else {
// create a new Tensor when the data_type doesn't match
return BlobSetTensor(blob, caffe2::empty(dims, options));
} else { // create a new Tensor when the data_type doesn't match
return blob->Reset<Tensor>(new Tensor(caffe2::empty(dims, options)));
}
return tensor;
}
Expand All @@ -53,23 +85,23 @@ BlobGetMutableTensor(Blob* blob, at::IntList dims, at::TensorOptions options) {
VLOG(1) << "Create new mutable object " << TypeMeta::TypeName<Tensor>()
<< " dims: " << dims;
// << " options: " << options; (operator<< for Options is in at:: now)
return BlobSetTensor(blob, caffe2::empty(dims, options));
// TODO: Blob store Tensor directly?
return blob->Reset<Tensor>(new Tensor(caffe2::empty(dims, options)));
}


inline Tensor* BlobGetMutableTensor(Blob* blob, DeviceType device_type) {
if (blob->IsType<Tensor>()) {
Tensor* tensor = blob->GetMutable<Tensor>();
if (*tensor && tensor->GetDeviceType() == device_type) {
return tensor;
}
}

// if we're here, then either Blob didn't hold a Tensor
// or that Tensor had the wrong DeviceType.
VLOG(1) << "Create new mutable object " << TypeMeta::TypeName<Tensor>()
<< " DeviceType:" << device_type;

return BlobSetTensor(blob, Tensor(device_type));
return blob->Reset<Tensor>(new Tensor(device_type));
}

} // namespace caffe2
Expand Down
19 changes: 18 additions & 1 deletion caffe2/core/operator.h
Expand Up @@ -127,6 +127,14 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
return BlobGetMutableTensor(outputs_.at(idx), type);
}

inline Tensor
XOutputTensor(int idx, at::IntList dims, at::TensorOptions options) {
CAFFE_ENFORCE_WITH_CALLER(
options.device_opt() != c10::nullopt,
"device must be provided in option.");
return XBlobGetMutableTensor(outputs_.at(idx), dims, options);
}

inline Tensor*
OutputTensor(int idx, at::IntList dims, at::TensorOptions options) {
CAFFE_ENFORCE_WITH_CALLER(
Expand Down Expand Up @@ -495,14 +503,23 @@ class Operator : public OperatorBase {
return OperatorBase::template Input<Tensor>(idx, type);
}

inline Tensor* Output(int idx, at::IntList dims, at::TensorOptions options) {
Tensor XOutput(int idx, at::IntList dims, at::TensorOptions options) {
if (options.device_opt() == c10::nullopt) {
return OperatorBase::XOutputTensor(
idx, dims, at::TensorOptions(options).device(context_.device()));
jerryzh168 marked this conversation as resolved.
Show resolved Hide resolved
}
return OperatorBase::XOutputTensor(idx, dims, options);
}

Tensor* Output(int idx, at::IntList dims, at::TensorOptions options) {
if (options.device_opt() == c10::nullopt) {
return OperatorBase::OutputTensor(
idx, dims, at::TensorOptions(options).device(context_.device()));
}
return OperatorBase::OutputTensor(idx, dims, options);
}


inline Tensor* Output(int idx, DeviceType type = Context::GetDeviceType()) {
return OperatorBase::template Output<Tensor>(idx, type);
}
Expand Down
2 changes: 2 additions & 0 deletions caffe2/core/tensor.cc
Expand Up @@ -5,6 +5,8 @@
namespace caffe2 {

CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(12, Tensor);
using TensorImplPtr = c10::intrusive_ptr<at::TensorImpl, at::UndefinedTensorImpl>;

This comment was marked as off-topic.

This comment was marked as off-topic.

CAFFE_DEFINE_PREALLOCATED_KNOWN_TYPE(27, TensorImplPtr);

TensorPrinter::TensorPrinter(
const std::string& tensor_name,
Expand Down
11 changes: 11 additions & 0 deletions caffe2/core/tensor.h
Expand Up @@ -32,10 +32,18 @@ class CAFFE2_API Tensor final {
return impl_.defined();
}

Tensor(TensorImplPtr ptr) {
impl_ = ptr;
}
jerryzh168 marked this conversation as resolved.
Show resolved Hide resolved

TensorImpl* unsafeGetTensorImpl() const {
return impl_.get();
}

TensorImplPtr getTensorImpl() {
jerryzh168 marked this conversation as resolved.
Show resolved Hide resolved
return impl_;
}

/**
* @brief Creates a tensor of the given device type.
*
Expand Down Expand Up @@ -450,6 +458,9 @@ CAFFE2_API void ReinitializeAndCopyFrom(

CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(12, Tensor)

using TensorImplPtr = c10::intrusive_ptr<at::TensorImpl, at::UndefinedTensorImpl>;
CAFFE_DECLARE_PREALLOCATED_KNOWN_TYPE(27, TensorImplPtr)

using TensorCPU = Tensor;

constexpr int k_limit_default_ = 1000;
Expand Down