Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions scripts/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,7 @@ class ArgTemplate(string.Template):

_FN_BLACKLIST = set([
'numel',
'ones',
'ones_like',
'result_type',
'zero_',
'zeros',
'zeros_like',
# FIXME: Remove functions below when we switch to override leaf nodes only.
# The function names below might map to multiple function overrloads.
# If the function overload is a leaf node, we must have it in AtenXlaType::
Expand All @@ -99,13 +94,19 @@ class ArgTemplate(string.Template):
'blackman_window',
'empty_like',
'eye',
'full',
'full_like',
'hamming_window',
'hann_window',
'narrow',
'ones',
'ones_like',
'randperm',
'reshape',
'size',
'to',
'zeros',
'zeros_like',
])

_FN_BLACKLIST_REGEX = [
Expand Down
31 changes: 20 additions & 11 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2388,7 +2388,7 @@ TEST_F(AtenXlaTensorTest, TestEmptyLike) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::empty_like", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::empty", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestEmptyLikeOptions) {
Expand All @@ -2402,7 +2402,7 @@ TEST_F(AtenXlaTensorTest, TestEmptyLikeOptions) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::empty_like", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::empty", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestEmpty) {
Expand All @@ -2427,7 +2427,8 @@ TEST_F(AtenXlaTensorTest, TestZerosLike) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::zeros_like", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::empty", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::zero_", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestZerosLikeOptions) {
Expand All @@ -2441,7 +2442,8 @@ TEST_F(AtenXlaTensorTest, TestZerosLikeOptions) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::zeros_like", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::empty", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::copy_", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestZeros) {
Expand All @@ -2453,7 +2455,8 @@ TEST_F(AtenXlaTensorTest, TestZeros) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::zeros", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::empty", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::zero_", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestOnes) {
Expand All @@ -2465,7 +2468,8 @@ TEST_F(AtenXlaTensorTest, TestOnes) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::ones", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::empty", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::fill_", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestOnesLike) {
Expand All @@ -2478,7 +2482,8 @@ TEST_F(AtenXlaTensorTest, TestOnesLike) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::ones_like", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::empty", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::fill_", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestOnesLikeOptions) {
Expand All @@ -2492,7 +2497,8 @@ TEST_F(AtenXlaTensorTest, TestOnesLikeOptions) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::ones_like", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::empty", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::copy_", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestFull) {
Expand All @@ -2505,7 +2511,8 @@ TEST_F(AtenXlaTensorTest, TestFull) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::full", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::empty", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::fill_", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestFullLike) {
Expand All @@ -2518,7 +2525,8 @@ TEST_F(AtenXlaTensorTest, TestFullLike) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::full_like", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::empty", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::fill_", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestFullLikeOptions) {
Expand All @@ -2533,7 +2541,8 @@ TEST_F(AtenXlaTensorTest, TestFullLikeOptions) {
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::full_like", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::empty", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::copy_", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestARange) {
Expand Down
67 changes: 8 additions & 59 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -997,30 +997,26 @@ at::Tensor AtenXlaType::embedding_dense_backward(const at::Tensor& grad_output,
num_weights, padding_idx, scale_grad_by_freq));
}

at::Tensor AtenXlaType::empty(at::IntArrayRef size,
const at::TensorOptions& options,
c10::optional<at::MemoryFormat> memory_format) {
at::Tensor AtenXlaType::empty(
at::IntArrayRef size, const at::TensorOptions& options,
c10::optional<at::MemoryFormat> /* memory_format */) {
XLA_FN_COUNTER("xla::");
// PT empty*() are optimizations to avoid initializing the data when it is
// known it will be completely rewritten. But since for us doing a zero*()
// does not actually end up doing any memory initialization, we use that and
// avoid going to CPU for it. A common PT pattern is indeed doing empty() plus
// s_copy_().
return full(size, 0, options);
}

at::Tensor AtenXlaType::empty_like(
const at::Tensor& self, const at::TensorOptions& options,
c10::optional<at::MemoryFormat> memory_format) {
XLA_FN_COUNTER("xla::");
return full_like(self, 0, options, memory_format);
XlaOptions xla_options(options);
return bridge::AtenFromXlaTensor(
XLATensor::full(XlaHelpers::I64List(size), 0, xla_options.get_device(),
xla_options.get_scalar_type()));
}

at::Tensor AtenXlaType::empty_strided(at::IntArrayRef size,
at::IntArrayRef stride,
const at::TensorOptions& options) {
XLA_FN_COUNTER("xla::");
at::Tensor t = full(size, 0, options);
at::Tensor t = empty(size, options, c10::nullopt);
return as_strided(t, size, stride, /*storage_offset=*/0);
}

Expand Down Expand Up @@ -1205,27 +1201,6 @@ at::Tensor& AtenXlaType::frac_(at::Tensor& self) {
return self;
}

at::Tensor AtenXlaType::full(at::IntArrayRef size, at::Scalar fill_value,
const at::TensorOptions& options) {
XLA_FN_COUNTER("xla::");
XlaOptions xla_options(options);
return bridge::AtenFromXlaTensor(
XLATensor::full(XlaHelpers::I64List(size), fill_value,
xla_options.get_device(), xla_options.get_scalar_type()));
}

at::Tensor AtenXlaType::full_like(
const at::Tensor& self, at::Scalar fill_value,
const at::TensorOptions& options,
c10::optional<at::MemoryFormat> /* memory_format */) {
XLA_FN_COUNTER("xla::");
XLATensor self_tensor = bridge::GetXlaTensor(self);
XlaOptions xla_options(options, self_tensor.GetDevice());
return bridge::AtenFromXlaTensor(
XLATensor::full_like(self_tensor, fill_value, xla_options.get_device(),
xla_options.scalar_type));
}

at::Tensor AtenXlaType::gather(const at::Tensor& self, int64_t dim,
const at::Tensor& index,
bool /* sparse_grad */) {
Expand Down Expand Up @@ -2092,19 +2067,6 @@ at::Tensor AtenXlaType::norm(const at::Tensor& self,
bridge::GetXlaTensor(self), p, c10::nullopt, dim, keepdim));
}

at::Tensor AtenXlaType::ones(at::IntArrayRef size,
const at::TensorOptions& options) {
XLA_FN_COUNTER("xla::");
return full(size, 1, options);
}

at::Tensor AtenXlaType::ones_like(
const at::Tensor& self, const at::TensorOptions& options,
c10::optional<at::MemoryFormat> memory_format) {
XLA_FN_COUNTER("xla::");
return full_like(self, 1, options, memory_format);
}

at::Tensor AtenXlaType::permute(const at::Tensor& self, at::IntArrayRef dims) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::permute(
Expand Down Expand Up @@ -2883,19 +2845,6 @@ at::Tensor& AtenXlaType::zero_(at::Tensor& self) {
return self;
}

at::Tensor AtenXlaType::zeros(at::IntArrayRef size,
const at::TensorOptions& options) {
XLA_FN_COUNTER("xla::");
return full(size, 0, options);
}

at::Tensor AtenXlaType::zeros_like(
const at::Tensor& self, const at::TensorOptions& options,
c10::optional<at::MemoryFormat> memory_format) {
XLA_FN_COUNTER("xla::");
return full_like(self, 0, options, memory_format);
}

void AtenXlaType::InitializeAtenBindings() {
static std::once_flag once;
std::call_once(once, []() { AtenInitialize(); });
Expand Down
25 changes: 0 additions & 25 deletions torch_xla/csrc/aten_xla_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -302,10 +302,6 @@ class AtenXlaType {
const at::TensorOptions& options,
c10::optional<at::MemoryFormat> memory_format);

static at::Tensor empty_like(const at::Tensor& self,
const at::TensorOptions& options,
c10::optional<at::MemoryFormat> memory_format);

static at::Tensor empty_strided(at::IntArrayRef size, at::IntArrayRef stride,
const at::TensorOptions& options);

Expand Down Expand Up @@ -366,13 +362,6 @@ class AtenXlaType {

static at::Tensor& frac_(at::Tensor& self);

static at::Tensor full(at::IntArrayRef size, at::Scalar fill_value,
const at::TensorOptions& options);

static at::Tensor full_like(const at::Tensor& self, at::Scalar fill_value,
const at::TensorOptions& options,
c10::optional<at::MemoryFormat> memory_format);

static at::Tensor gather(const at::Tensor& self, int64_t dim,
const at::Tensor& index, bool sparse_grad);

Expand Down Expand Up @@ -650,13 +639,6 @@ class AtenXlaType {
static at::Tensor norm(const at::Tensor& self, c10::optional<at::Scalar> p,
at::IntArrayRef dim, bool keepdim);

static at::Tensor ones(at::IntArrayRef size,
const at::TensorOptions& options);

static at::Tensor ones_like(const at::Tensor& self,
const at::TensorOptions& options,
c10::optional<at::MemoryFormat> memory_format);

static at::Tensor permute(const at::Tensor& self, at::IntArrayRef dims);

static at::Tensor pow(const at::Tensor& self, at::Scalar exponent);
Expand Down Expand Up @@ -908,13 +890,6 @@ class AtenXlaType {
static at::Tensor view(const at::Tensor& self, at::IntArrayRef size);

static at::Tensor& zero_(at::Tensor& self);

static at::Tensor zeros(at::IntArrayRef size,
const at::TensorOptions& options);

static at::Tensor zeros_like(const at::Tensor& self,
const at::TensorOptions& options,
c10::optional<at::MemoryFormat> memory_format);
};

} // namespace torch_xla