Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,9 @@ XLATensor XLATensor::Create(
return xtensor;
}

XLATensor XLATensor::Create(ir::Value ir_value, const Device& device,
xla::PrimitiveType logical_element_type) {
XLATensor XLATensor::Create(
ir::Value ir_value, const Device& device,
c10::optional<at::ScalarType> logical_element_type) {
XLATensor xtensor(std::move(ir_value), device, logical_element_type);
TensorsArena::Get()->RegisterTensor(xtensor.data_ptr());
return xtensor;
Expand All @@ -182,7 +183,7 @@ XLATensor::XLATensor(std::shared_ptr<xla::ComputationClient::Data> xla_data,
}

XLATensor::XLATensor(ir::Value ir_value, const Device& device,
xla::PrimitiveType logical_element_type)
c10::optional<at::ScalarType> logical_element_type)
: data_(std::make_shared<Data>(std::move(ir_value), device,
logical_element_type)) {
TryLimitGraphSize();
Expand Down Expand Up @@ -214,7 +215,9 @@ void XLATensor::SetGradient(const XLATensor& grad) {
}

at::ScalarType XLATensor::dtype() const {
return TensorTypeFromXlaType(GetElementType());
return data()->logical_element_type
? *data()->logical_element_type
: TensorTypeFromXlaType(shape().get().element_type());
}

xla::util::MaybeRef<xla::Shape> XLATensor::shape() const {
Expand Down Expand Up @@ -805,15 +808,15 @@ XLATensor XLATensor::DispatchComparisonOp(c10::Symbol kind,
const XLATensor& input,
const at::Scalar& other) {
ir::NodePtr node = ir::ops::ComparisonOp(kind, input.GetIrValue(), other);
return Create(node, input.GetDevice(), xla::PrimitiveType::U8);
return Create(node, input.GetDevice(), at::ScalarType::Byte);
}

XLATensor XLATensor::DispatchComparisonOp(c10::Symbol kind,
const XLATensor& input,
const XLATensor& other) {
ir::NodePtr node =
ir::ops::ComparisonOp(kind, input.GetIrValue(), other.GetIrValue());
return Create(node, input.GetDevice(), xla::PrimitiveType::U8);
return Create(node, input.GetDevice(), at::ScalarType::Byte);
}

XLATensor XLATensor::threshold(const XLATensor& input, float threshold,
Expand Down Expand Up @@ -1147,8 +1150,7 @@ XLATensor XLATensor::log_base(const XLATensor& input, ir::OpKind op,
input.GetDevice());
}

void XLATensor::log_base_(XLATensor& input, ir::OpKind op,
double base) {
void XLATensor::log_base_(XLATensor& input, ir::OpKind op, double base) {
input.SetIrValue(ir::ops::LogBase(input.GetIrValue(), op, base));
}

Expand Down
22 changes: 6 additions & 16 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,6 @@ class XLATensor {
at::ScalarType dtype() const;
xla::util::MaybeRef<xla::Shape> shape() const;

xla::PrimitiveType GetElementType() const {
return data()->logical_element_type ==
xla::PrimitiveType::PRIMITIVE_TYPE_INVALID
? shape().get().element_type()
: data()->logical_element_type;
}

const Device& GetDevice() const;
xla::int64 GetUniqueId() const;

Expand Down Expand Up @@ -597,7 +590,7 @@ class XLATensor {
device(device),
unique_id(GetNextTensorId()) {}
Data(ir::Value ir_value, const Device& device,
xla::PrimitiveType logical_element_type)
c10::optional<at::ScalarType> logical_element_type)
: ir_value(std::move(ir_value)),
device(device),
unique_id(GetNextTensorId()),
Expand All @@ -619,23 +612,20 @@ class XLATensor {
xla::int64 unique_id = 0;
std::shared_ptr<XLATensor> grad;
bool requires_grad = false;
// For some types (U8, S8 etc), the logical type of the tensor doesn't match
// the type of the underlying data.
xla::PrimitiveType logical_element_type =
xla::PrimitiveType::PRIMITIVE_TYPE_INVALID;
c10::optional<at::ScalarType> logical_element_type;
};

XLATensor(const at::Tensor& tensor, const Device& device, bool requires_grad);
XLATensor(std::shared_ptr<xla::ComputationClient::Data> xla_data,
bool requires_grad);
XLATensor(ir::Value ir_value, const Device& device,
xla::PrimitiveType logical_element_type);
c10::optional<at::ScalarType> logical_element_type = c10::nullopt);
XLATensor(std::shared_ptr<View> view, const Device& device);
XLATensor(std::shared_ptr<Data> data);

static XLATensor Create(ir::Value ir_value, const Device& device,
xla::PrimitiveType logical_element_type =
xla::PrimitiveType::PRIMITIVE_TYPE_INVALID);
static XLATensor Create(
ir::Value ir_value, const Device& device,
c10::optional<at::ScalarType> logical_element_type = c10::nullopt);
static XLATensor Create(std::shared_ptr<View> view, const Device& device);

Data* data() const;
Expand Down
3 changes: 1 addition & 2 deletions torch_xla/csrc/tensor_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,7 @@ void XLATensorImpl::SetupSizeProperties() {
}

caffe2::TypeMeta XLATensorImpl::GetTypeMeta(const XLATensor& tensor) {
return c10::scalarTypeToTypeMeta(
TensorTypeFromXlaType(tensor.GetElementType()));
return c10::scalarTypeToTypeMeta(tensor.dtype());
}

c10::Storage XLATensorImpl::GetStorage(const XLATensor& tensor) {
Expand Down