Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 28 additions & 20 deletions aten/src/ATen/native/BinaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ DEFINE_DISPATCH(div_stub);

Tensor& add_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) {
if (other.is_sparse()) {
if (!result.defined()) {
result = at::empty({0}, self.options());
}
if (self.is_sparse()) {
at::_sparse_add_out(result, self, other, alpha);
} else {
Expand All @@ -29,13 +26,18 @@ Tensor& add_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar
}
auto iter = TensorIterator::binary_op(result, self, other);
add_stub(iter->device_type(), *iter, alpha);
result = iter->output();
return result;
}

Tensor add(const Tensor& self, const Tensor& other, Scalar alpha) {
Tensor result;
return native::add_out(result, self, other, alpha);
if (other.is_sparse()) {
result = at::empty({0}, self.options());
return native::add_out(result, self, other, alpha);
}
auto iter = TensorIterator::binary_op(result, self, other);
add_stub(iter->device_type(), *iter, alpha);
return iter->output();
}

Tensor& add_(Tensor& self, const Tensor& other, Scalar alpha) {
Expand All @@ -44,9 +46,6 @@ Tensor& add_(Tensor& self, const Tensor& other, Scalar alpha) {

Tensor& div_out(Tensor& result, const Tensor& self, const Tensor& other) {
if (self.is_sparse()) {
if (!result.defined()) {
result = at::empty({0}, self.options());
}
if (other.dim() != 0) {
AT_ERROR("div(): sparse division only supports division by a scalar ",
"(got shape ", other.sizes(), " for argument 'other')");
Expand All @@ -55,13 +54,18 @@ Tensor& div_out(Tensor& result, const Tensor& self, const Tensor& other) {
}
auto iter = TensorIterator::binary_op(result, self, other);
div_stub(iter->device_type(), *iter);
result = iter->output();
return result;
}

Tensor div(const Tensor& self, const Tensor& other) {
Tensor result;
return native::div_out(result, self, other);
if (self.is_sparse()) {
result = at::empty({0}, self.options());
return native::div_out(result, self, other);
}
auto iter = TensorIterator::binary_op(result, self, other);
div_stub(iter->device_type(), *iter);
return iter->output();
}

Tensor& div_(Tensor& self, const Tensor& other) {
Expand All @@ -70,20 +74,22 @@ Tensor& div_(Tensor& self, const Tensor& other) {

Tensor& mul_out(Tensor& result, const Tensor& self, const Tensor& other) {
if (self.is_sparse() || other.is_sparse()) {
if (!result.defined()) {
result = at::empty({0}, self.options());
}
return at::_sparse_mul_out(result, self, other);
}
auto iter = TensorIterator::binary_op(result, self, other);
mul_stub(iter->device_type(), *iter);
result = iter->output();
return result;
}

Tensor mul(const Tensor& self, const Tensor& other) {
Tensor result;
return native::mul_out(result, self, other);
if (self.is_sparse() || other.is_sparse()) {
result = at::empty({0}, self.options());
return native::mul_out(result, self, other);
}
auto iter = TensorIterator::binary_op(result, self, other);
mul_stub(iter->device_type(), *iter);
return iter->output();
}

Tensor& mul_(Tensor& self, const Tensor& other) {
Expand All @@ -92,9 +98,6 @@ Tensor& mul_(Tensor& self, const Tensor& other) {

Tensor& sub_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) {
if (other.is_sparse()) {
if (!result.defined()) {
result = at::empty({0}, self.options());
}
if (!self.sizes().equals(other.sizes())) {
AT_ERROR("sizes do not match");
}
Expand All @@ -109,13 +112,18 @@ Tensor& sub_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar
}
auto iter = TensorIterator::binary_op(result, self, other);
sub_stub(iter->device_type(), *iter, alpha);
result = iter->output();
return result;
}

Tensor sub(const Tensor& self, const Tensor& other, Scalar alpha) {
Tensor result;
return native::sub_out(result, self, other, alpha);
if (other.is_sparse()) {
result = at::empty({0}, self.options());
return native::sub_out(result, self, other, alpha);
}
auto iter = TensorIterator::binary_op(result, self, other);
sub_stub(iter->device_type(), *iter, alpha);
return iter->output();
}

Tensor& sub_(Tensor& self, const Tensor& other, Scalar alpha) {
Expand Down