Skip to content

Commit

Permalink
Lower _euclidean_dist
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 committed Jan 4, 2023
1 parent 91ffaea commit 2506ed7
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
19 changes: 12 additions & 7 deletions torch_xla/csrc/aten_xla_type.cpp
Expand Up @@ -661,13 +661,6 @@ at::Tensor XLANativeFunctions::addmm(const at::Tensor& self,
/*bias=*/bridge::GetXlaTensor(self)));
}

at::Tensor XLANativeFunctions::affine_grid_generator(const at::Tensor& theta,
at::IntArrayRef size,
bool align_corners) {
return at::functionalization::functionalize_aten_op<ATEN_OP(
affine_grid_generator)>::call(theta, size, align_corners);
}

at::Tensor XLANativeFunctions::alias_copy(const at::Tensor& self) {
TORCH_LAZY_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(
Expand Down Expand Up @@ -3176,6 +3169,13 @@ XLANativeFunctions::native_group_norm(const at::Tensor& input,
// core that call into view operators internally. These are all composite ops
// that LTC can technically re-use / get for free, but we need to
// "functionalize" them to remove the view ops before we can use them.
at::Tensor XLANativeFunctions::affine_grid_generator(const at::Tensor& theta,
at::IntArrayRef size,
bool align_corners) {
return at::functionalization::functionalize_aten_op<ATEN_OP(
affine_grid_generator)>::call(theta, size, align_corners);
}

at::Tensor XLANativeFunctions::block_diag(at::TensorList tensors) {
return at::functionalization::functionalize_aten_op<ATEN_OP(
block_diag)>::call(tensors);
Expand Down Expand Up @@ -3206,6 +3206,11 @@ XLANativeFunctions::convolution_backward(
output_padding, groups, output_mask);
}

at::Tensor XLANativeFunctions::_euclidean_dist(const at::Tensor & x1, const at::Tensor & x2) {
return at::functionalization::functionalize_aten_op<ATEN_OP(
_euclidean_dist)>::call(x1, x2);
}

at::Tensor XLANativeFunctions::new_empty_strided_symint(
const at::Tensor& self, at::SymIntArrayRef size, at::SymIntArrayRef stride,
c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout,
Expand Down
1 change: 1 addition & 0 deletions xla_native_functions.yaml
Expand Up @@ -344,6 +344,7 @@ supported:
- block_diag
- _convolution
- convolution_backward
- _euclidean_dist
- slice_backward
- diagonal_backward
- new_empty_strided
Expand Down

0 comments on commit 2506ed7

Please sign in to comment.