Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 60 additions & 22 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3300,18 +3300,32 @@ XLANativeFunctions::native_group_norm(const at::Tensor& input,
eps);
}

at::Tensor XLANativeFunctions::_cdist_forward(
const at::Tensor& x1, const at::Tensor& x2, double p,
c10::optional<int64_t> compute_mode) {
// compute_mode is ignored because the use_mm_for_euclid_dist lowering
// (compute_mode is 0 or 1) is achieved through composite ops from
// native pytorch.
TORCH_LAZY_FN_COUNTER("xla::");
XLA_CHECK(p >= 0) << "p value for the p-norm distance must be >= 0";
return bridge::AtenFromXlaTensor(tensor_methods::cdist_forward(
bridge::GetXlaTensor(x1), bridge::GetXlaTensor(x2), p));
}

// All of the below ops correspond to CompositeExplicitAutograd kernels from
// core that call into view operators internally. These are all composite ops
// that LTC can technically re-use / get for free, but we need to
// "functionalize" them to remove the view ops before we can use them.
at::Tensor XLANativeFunctions::affine_grid_generator(const at::Tensor& theta,
at::IntArrayRef size,
bool align_corners) {
XLA_CHECK(!xla::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false));
return at::functionalization::functionalize_aten_op<ATEN_OP(
affine_grid_generator)>::call(theta, size, align_corners);
}

at::Tensor XLANativeFunctions::block_diag(at::TensorList tensors) {
XLA_CHECK(!xla::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false));
return at::functionalization::functionalize_aten_op<ATEN_OP(
block_diag)>::call(tensors);
}
Expand All @@ -3322,6 +3336,13 @@ at::Tensor XLANativeFunctions::_convolution(
at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed,
at::IntArrayRef output_padding, int64_t groups, bool benchmark,
bool deterministic, bool cudnn_enabled, bool allow_tf32) {
// See Note: [Disabling functionalization]
if (xla::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false)) {
return at::native::_convolution(input, weight, bias, stride, padding,
dilation, transposed, output_padding,
groups, benchmark, deterministic,
cudnn_enabled, allow_tf32);
}
return at::functionalization::functionalize_aten_op<ATEN_OP(
_convolution)>::call(input, weight, bias, stride, padding, dilation,
transposed, output_padding, groups, benchmark,
Expand All @@ -3337,16 +3358,9 @@ XLANativeFunctions::convolution_backward(
::std::array<bool, 3> output_mask) {
// See Note: [Disabling functionalization]
if (xla::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false)) {
return at::native::call_fallback_fn<
&xla_cpu_fallback, ATEN_OP(convolution_backward)>::call(grad_output,
input, weight,
bias_sizes,
stride, padding,
dilation,
transposed,
output_padding,
groups,
output_mask);
return at::native::convolution_backward(
grad_output, input, weight, bias_sizes, stride, padding, dilation,
transposed, output_padding, groups, output_mask);
}
// TODO (alanwaketan): Let's resuse
// `at::functionalization::functionalize_aten_op` after upstream has solved
Expand Down Expand Up @@ -3377,6 +3391,7 @@ XLANativeFunctions::convolution_backward(
at::Tensor XLANativeFunctions::diag_embed(const at::Tensor& self,
int64_t offset, int64_t dim1,
int64_t dim2) {
XLA_CHECK(!xla::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false));
return at::functionalization::functionalize_aten_op<ATEN_OP(
diag_embed)>::call(self, offset, dim1, dim2);
}
Expand All @@ -3386,6 +3401,11 @@ at::Tensor XLANativeFunctions::embedding_symint(const at::Tensor& weight,
c10::SymInt padding_idx,
bool scale_grad_by_freq,
bool sparse) {
// See Note: [Disabling functionalization]
if (xla::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false)) {
return at::native::embedding_symint(weight, indices, padding_idx,
scale_grad_by_freq, sparse);
}
// TODO: for now route to native, which dispatches supported XLA operations.
// We need to make use of the TPU embedding core here eventually.
return at::functionalization::functionalize_aten_op_symint<ATEN_OP(
Expand All @@ -3395,6 +3415,7 @@ at::Tensor XLANativeFunctions::embedding_symint(const at::Tensor& weight,

at::Tensor XLANativeFunctions::_euclidean_dist(const at::Tensor& x1,
const at::Tensor& x2) {
XLA_CHECK(!xla::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false));
return at::functionalization::functionalize_aten_op<ATEN_OP(
_euclidean_dist)>::call(x1, x2);
}
Expand All @@ -3417,31 +3438,43 @@ at::Tensor XLANativeFunctions::narrow_copy_symint(const at::Tensor& self,
int64_t dim,
c10::SymInt start,
c10::SymInt length) {
XLA_CHECK(!xla::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false));
return at::functionalization::functionalize_aten_op_symint<ATEN_OP(
narrow_copy)>::call(self, dim, start, length);
}

at::Tensor XLANativeFunctions::pixel_shuffle(const at::Tensor& self,
int64_t upscale_factor) {
XLA_CHECK(!xla::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false));
return at::functionalization::functionalize_aten_op<ATEN_OP(
pixel_shuffle)>::call(self, upscale_factor);
}

at::Tensor XLANativeFunctions::pixel_unshuffle(const at::Tensor& self,
int64_t downscale_factor) {
XLA_CHECK(!xla::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false));
return at::functionalization::functionalize_aten_op<ATEN_OP(
pixel_unshuffle)>::call(self, downscale_factor);
}

at::Tensor XLANativeFunctions::reshape_symint(const at::Tensor& self,
c10::SymIntArrayRef shape) {
// See Note: [Disabling functionalization]
if (xla::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false)) {
return at::native::reshape_symint(self, shape);
}
return at::functionalization::functionalize_aten_op_symint<ATEN_OP(
reshape)>::call(self, shape);
}

at::Tensor XLANativeFunctions::select_backward_symint(
const at::Tensor& grad_output, c10::SymIntArrayRef input_sizes, int64_t dim,
c10::SymInt index) {
// See Note: [Disabling functionalization]
if (xla::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false)) {
return at::native::select_backward_symint(grad_output, input_sizes, dim,
index);
}
return at::functionalization::functionalize_aten_op_symint<ATEN_OP(
select_backward)>::call(grad_output, input_sizes, dim, index);
}
Expand All @@ -3468,13 +3501,18 @@ at::Tensor XLANativeFunctions::slice(const at::Tensor& self, int64_t dim,
}

at::Tensor XLANativeFunctions::t(const at::Tensor& self) {
// See Note: [Disabling functionalization]
if (xla::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false)) {
return transpose_copy(self, 0, 1);
}
return at::functionalization::functionalize_aten_op<ATEN_OP(t)>::call(self);
}

at::Tensor XLANativeFunctions::_trilinear(
const at::Tensor& i1, const at::Tensor& i2, const at::Tensor& i3,
at::IntArrayRef expand1, at::IntArrayRef expand2, at::IntArrayRef expand3,
at::IntArrayRef sumdim, int64_t unroll_dim) {
XLA_CHECK(!xla::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false));
return at::functionalization::functionalize_aten_op<ATEN_OP(
_trilinear)>::call(i1, i2, i3, expand1, expand2, expand3, sumdim,
unroll_dim);
Expand All @@ -3483,18 +3521,21 @@ at::Tensor XLANativeFunctions::_trilinear(
at::Tensor XLANativeFunctions::linalg_pinv(
const at::Tensor& self, const c10::optional<at::Tensor>& atol,
const c10::optional<at::Tensor>& rtol, bool hermitian) {
XLA_CHECK(!xla::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false));
return at::functionalization::functionalize_aten_op<ATEN_OP2(
linalg_pinv, atol_rtol_tensor)>::call(self, atol, rtol, hermitian);
}

at::Tensor XLANativeFunctions::mvlgamma(const at::Tensor& self, int64_t p) {
XLA_CHECK(!xla::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false));
return at::functionalization::functionalize_aten_op<ATEN_OP(mvlgamma)>::call(
self, p);
}

at::Tensor XLANativeFunctions::diagonal_backward_symint(
const at::Tensor& grad_output, at::SymIntArrayRef input_sizes,
int64_t offset, int64_t dim1, int64_t dim2) {
XLA_CHECK(!xla::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false));
return at::functionalization::functionalize_aten_op_symint<ATEN_OP(
diagonal_backward)>::call(grad_output, input_sizes, offset, dim1, dim2);
}
Expand All @@ -3503,24 +3544,21 @@ at::Tensor XLANativeFunctions::slice_backward(const at::Tensor& grad_output,
at::IntArrayRef input_sizes,
int64_t dim, int64_t start,
int64_t end, int64_t step) {
// See Note: [Disabling functionalization]
if (xla::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false)) {
return at::native::slice_backward(grad_output, input_sizes, dim, start, end,
step);
}
return at::functionalization::functionalize_aten_op<ATEN_OP(
slice_backward)>::call(grad_output, input_sizes, dim, start, end, step);
}

at::Tensor XLANativeFunctions::_cdist_forward(
const at::Tensor& x1, const at::Tensor& x2, double p,
c10::optional<int64_t> compute_mode) {
// compute_mode is ignored because the use_mm_for_euclid_dist lowering
// (compute_mode is 0 or 1) is achieved through composite ops from
// native pytorch.
TORCH_LAZY_FN_COUNTER("xla::");
XLA_CHECK(p >= 0) << "p value for the p-norm distance must be >= 0";
return bridge::AtenFromXlaTensor(tensor_methods::cdist_forward(
bridge::GetXlaTensor(x1), bridge::GetXlaTensor(x2), p));
}

at::Tensor XLANativeFunctions::permute(const at::Tensor& self,
at::IntArrayRef dims) {
// See Note: [Disabling functionalization]
if (xla::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false)) {
return permute_copy(self, dims);
}
return at::functionalization::functionalize_aten_op<ATEN_OP(permute)>::call(
self, dims);
}
Expand Down