Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions scripts/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class ArgTemplate(string.Template):
_XPARSER = lark.Lark(
_GRAMMAR, parser='lalr', propagate_positions=True, keep_all_tokens=True)

# _FN_WHITELIST/_FN_FULL_OVERRIDE/_FN_BLACKLIST takes either name or mapsig.
# _FN_FULL_OVERRIDE/_FN_BLACKLIST takes either name or mapsig.
_FN_BLACKLIST = set([])

# List of non-leaf ops we want to override both forward + backward.
Expand All @@ -90,12 +90,6 @@ class ArgTemplate(string.Template):
'max_pool3d(Tensor, IntArrayRef, IntArrayRef, IntArrayRef, IntArrayRef, bool) -> Tensor',
])

# List of non-leaf ops we want to override forward.
# TODO(#1362, #1364)
_FN_WHITELIST = _FN_FULL_OVERRIDE | set([
'copy_(Tensor, Tensor, bool) -> Tensor',
])

_FN_BLACKLIST_REGEX = [
# ATEN functions
r'[^(]*cudnn',
Expand Down Expand Up @@ -1033,7 +1027,7 @@ def generate_registrations(fgens, overrides):

# XLA is only able to override leaf ops and whitelisted non-leaf ops.
def is_overrideable(fgen):
return fgen.leaf or fgen.mapsig in _FN_WHITELIST or fgen.func in _FN_WHITELIST
return fgen.leaf or fgen.mapsig in _FN_FULL_OVERRIDE or fgen.func in _FN_FULL_OVERRIDE


def generate_functions(fgens):
Expand Down
27 changes: 22 additions & 5 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ TEST_F(AtenXlaTensorTest, TestTo) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::copy_", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::_copy_from", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestIsFloatingPoint) {
Expand Down Expand Up @@ -2571,7 +2571,7 @@ TEST_F(AtenXlaTensorTest, TestZerosLikeOptions) {

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::empty", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::copy_", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::_copy_from", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestZeros) {
Expand Down Expand Up @@ -2626,7 +2626,7 @@ TEST_F(AtenXlaTensorTest, TestOnesLikeOptions) {

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::empty", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::copy_", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::_copy_from", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestFull) {
Expand Down Expand Up @@ -2670,7 +2670,7 @@ TEST_F(AtenXlaTensorTest, TestFullLikeOptions) {

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::empty", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::copy_", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::_copy_from", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestARange) {
Expand Down Expand Up @@ -3075,6 +3075,23 @@ TEST_F(AtenXlaTensorTest, TestEinsumOuter) {
ExpectCounterChanged("xla::view", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestEinsumOuterBackward) {
torch::Tensor a =
torch::rand({5}, torch::TensorOptions(torch::kFloat).requires_grad(true));
torch::Tensor b =
torch::rand({5}, torch::TensorOptions(torch::kFloat).requires_grad(true));
std::string equation = "i,j->ij";
auto testfn = [&](const std::vector<torch::Tensor>& inputs) -> torch::Tensor {
return torch::einsum(equation, inputs);
};
ForEachDevice([&](const torch::Device& device) {
TestBackward({a, b}, device, testfn, /*rtol=*/1e-3, /*atol=*/1e-4);
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::view", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestEinsumBatchMatMul) {
torch::Tensor a = torch::rand({3, 2, 5}, torch::TensorOptions(torch::kFloat));
torch::Tensor b = torch::rand({3, 5, 4}, torch::TensorOptions(torch::kFloat));
Expand Down Expand Up @@ -7100,7 +7117,7 @@ TEST_F(AtenXlaTensorTest, TestContiguous) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::copy_", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::_copy_from", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestSqueezeAll) {
Expand Down
39 changes: 16 additions & 23 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,22 @@ at::Tensor AtenXlaType::_adaptive_avg_pool2d_backward(
at::Tensor AtenXlaType::_copy_from(const at::Tensor& self,
const at::Tensor& dst, bool non_blocking) {
XLA_FN_COUNTER("xla::");
copy_(const_cast<at::Tensor&>(dst), self, non_blocking);
auto dst_tensor = bridge::TryGetXlaTensor(dst);
auto self_tensor = bridge::TryGetXlaTensor(self);
if (!self_tensor) {
static bool sync_update =
xla::sys_util::GetEnvBool("XLA_TENSOR_UPDATE_SYNC", true);
XLA_CHECK(dst_tensor);
dst_tensor->UpdateFromTensor(self, /*sync=*/sync_update);
} else if (!dst_tensor) {
at::Tensor tensor = self_tensor->ToTensor(/*detached=*/true);
at::Tensor typed_tensor =
CopyTensor(tensor, dst.scalar_type(), /*copy=*/false);
dst.resize_as_(typed_tensor).copy_(typed_tensor);
} else {
XLATensor::copy_(*dst_tensor, *self_tensor);
bridge::ReplaceXlaTensor(dst, *dst_tensor);
}
return dst;
}

Expand Down Expand Up @@ -907,28 +922,6 @@ AtenXlaType::convolution_backward_overrideable(
: at::Tensor());
}

at::Tensor& AtenXlaType::copy_(at::Tensor& self, const at::Tensor& src,
bool non_blocking) {
XLA_FN_COUNTER("xla::");
auto self_tensor = bridge::TryGetXlaTensor(self);
auto src_tensor = bridge::TryGetXlaTensor(src);
if (!src_tensor) {
static bool sync_update =
xla::sys_util::GetEnvBool("XLA_TENSOR_UPDATE_SYNC", true);
XLA_CHECK(self_tensor);
self_tensor->UpdateFromTensor(src, /*sync=*/sync_update);
} else if (!self_tensor) {
at::Tensor tensor = src_tensor->ToTensor(/*detached=*/true);
at::Tensor typed_tensor =
CopyTensor(tensor, self.scalar_type(), /*copy=*/false);
self.resize_as_(typed_tensor).copy_(typed_tensor);
} else {
XLATensor::copy_(*self_tensor, *src_tensor);
bridge::ReplaceXlaTensor(self, *self_tensor);
}
return self;
}

at::Tensor AtenXlaType::cos(const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::cos(bridge::GetXlaTensor(self)));
Expand Down
3 changes: 0 additions & 3 deletions torch_xla/csrc/aten_xla_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,6 @@ class AtenXlaType {
at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed,
at::IntArrayRef output_padding, int64_t groups);

static at::Tensor& copy_(at::Tensor& self, const at::Tensor& src,
bool non_blocking);

static at::Tensor cos(const at::Tensor& self);

static at::Tensor& cos_(at::Tensor& self);
Expand Down