Skip to content

Commit

Permalink
[BE] Use nested namespaces in sparse (#97581)
Browse files Browse the repository at this point in the history
<!--
copilot:summary
-->
### <samp>馃 Generated by Copilot at 59a5205</samp>

This pull request refactors the namespace declarations in several files under `aten/src/ATen/native/sparse` to use a more concise and consistent syntax. This improves the readability and reusability of the sparse tensor operations code.

Also, do not rely on deprecated `TensorBase::data` and instead use `TensorBase::data_ptr`

Pull Request resolved: #97581
Approved by: https://github.com/kit1980, https://github.com/huydhn
  • Loading branch information
malfet authored and pytorchmergebot committed Mar 26, 2023
1 parent 461f088 commit b73e8cd
Show file tree
Hide file tree
Showing 12 changed files with 38 additions and 60 deletions.
6 changes: 2 additions & 4 deletions aten/src/ATen/native/sparse/ParamUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
#include <ATen/ops/empty_like_native.h>
#endif

namespace at {
namespace native {
namespace at::native {

std::pair<Tensor, Tensor> softmax_sparse_input_preprocessing(
const Tensor& input_,
Expand Down Expand Up @@ -57,5 +56,4 @@ std::tuple<Tensor, Tensor, Tensor> softmax_backward_sparse_input_preprocessing(
return std::make_tuple(grad_input, grad, output);
}

} // namespace native
} // namespace at
} // namespace at::native
6 changes: 2 additions & 4 deletions aten/src/ATen/native/sparse/SoftMax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@

#include <map>

namespace at {
namespace native {
namespace at::native {
namespace {

int64_t get_nvalues(const IntArrayRef& sizes, int64_t sparse_dim) {
Expand Down Expand Up @@ -657,5 +656,4 @@ Tensor _sparse_log_softmax(const Tensor& self, Dimname dim, optional<ScalarType>
return at::_sparse_log_softmax(self, dimname_to_position(self, dim), dtype);
}

}
}
} // namespace at::native
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
#include <ATen/native/TensorIterator.h>
#include <ATen/AccumulateType.h>

namespace at {
namespace native {
namespace at::native {

namespace {

Expand Down Expand Up @@ -146,4 +145,4 @@ REGISTER_AVX512_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_interse
REGISTER_AVX2_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel);
REGISTER_VSX_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel);
REGISTER_ZVECTOR_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_cpu_kernel);
}}
} // namespace at::native
6 changes: 2 additions & 4 deletions aten/src/ATen/native/sparse/SparseBlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@

#include <c10/util/MaybeOwned.h>

namespace at {
namespace native {
namespace at::native {

Tensor& addmv_out_sparse_compressed(
const Tensor& self,
Expand Down Expand Up @@ -266,5 +265,4 @@ void sparse_sampled_addmm_check_inputs(

DEFINE_DISPATCH(sampled_addmm_sparse_csr_stub);

} // namespace native
} // namespace at
} // namespace at::native
38 changes: 16 additions & 22 deletions aten/src/ATen/native/sparse/SparseBlasImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@
#include <ATen/Parallel.h>
#endif

namespace at {
namespace native {
namespace sparse {
namespace impl {
namespace at::native::sparse::impl {

Tensor& _compressed_row_strided_mm_out(const Tensor& compressed, const Tensor& strided, Tensor& result) {
const auto compressed_layout = compressed.layout();
Expand Down Expand Up @@ -279,25 +276,25 @@ void addmv_out_sparse_csr(
const Tensor& result) {
auto cont_values = mat.values().contiguous();
if (mat.layout() == kSparseBsr) {
addmv_sparse_bsr(cont_values.template data<scalar_t>(),
mat.crow_indices().template data<idx_t>(),
mat.col_indices().template data_ptr<idx_t>(),
addmv_sparse_bsr(cont_values.data_ptr<scalar_t>(),
mat.crow_indices().data_ptr<idx_t>(),
mat.col_indices().data_ptr<idx_t>(),
mat.size(0),
mat.values().size(1),
mat.values().size(2),
vec.template data<scalar_t>(),
alpha.template to<scalar_t>(),
beta.template to<scalar_t>(),
result.template data<scalar_t>());
vec.data_ptr<scalar_t>(),
alpha.to<scalar_t>(),
beta.to<scalar_t>(),
result.data_ptr<scalar_t>());
} else {
addmv_sparse_csr(cont_values.template data<scalar_t>(),
mat.crow_indices().template data<idx_t>(),
mat.col_indices().template data_ptr<idx_t>(),
addmv_sparse_csr(cont_values.data_ptr<scalar_t>(),
mat.crow_indices().data_ptr<idx_t>(),
mat.col_indices().data_ptr<idx_t>(),
mat.size(0),
vec.template data<scalar_t>(),
alpha.template to<scalar_t>(),
beta.template to<scalar_t>(),
result.template data<scalar_t>());
vec.data_ptr<scalar_t>(),
alpha.to<scalar_t>(),
beta.to<scalar_t>(),
result.data_ptr<scalar_t>());
}
}
} // anonymous namespace
Expand Down Expand Up @@ -378,7 +375,4 @@ void triangular_solve_out_sparse_csr(
}

} // namespace cpu
} // namespace impl
} // namespace sparse
} // namespace native
} // namespace at
} // namespace at::native::sparse::impl
6 changes: 2 additions & 4 deletions aten/src/ATen/native/sparse/SparseCsrTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@
#include <ATen/ops/where.h>
#endif

namespace at {
namespace native {
namespace at::native {

using namespace at::sparse_csr;

Expand Down Expand Up @@ -1125,5 +1124,4 @@ Tensor select_copy_sparse_csr(const Tensor& self, int64_t dim, int64_t index) {
return select_sparse_csr_worker<false, true>(self, dim, index);
}

} // namespace native
} // namespace at
} // namespace at::native
6 changes: 2 additions & 4 deletions aten/src/ATen/native/sparse/SparseFactories.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
#include <ATen/ops/where.h>
#endif

namespace at {
namespace native {
namespace at::native {

DEFINE_DISPATCH(spdiags_kernel_stub);

Expand Down Expand Up @@ -92,5 +91,4 @@ Tensor spdiags(
return result_coo;
}

} // namespace native
} // namespace at
} // namespace at::native
5 changes: 2 additions & 3 deletions aten/src/ATen/native/sparse/SparseMatMul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#include <ATen/ops/empty_like_native.h>
#endif

namespace at { namespace native {
namespace at::native {

using namespace at::sparse;

Expand Down Expand Up @@ -275,5 +275,4 @@ Tensor sparse_sparse_matmul_cpu(const Tensor& mat1_, const Tensor& mat2_) {
}


} // namespace native
} // namespace at
} // namespace at::native
6 changes: 2 additions & 4 deletions aten/src/ATen/native/sparse/SparseTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@
#include <ATen/ops/ones.h>
#endif

namespace at {
namespace native {
namespace at::native {

using namespace at::sparse;

Expand Down Expand Up @@ -851,5 +850,4 @@ Tensor empty_like_sparse_coo(
}
}

} // namespace native
} // namespace at
} // namespace at::native
4 changes: 2 additions & 2 deletions aten/src/ATen/native/sparse/SparseTensorMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@

#include <algorithm>

namespace at { namespace native {
namespace at::native {

using namespace at::sparse;
// --------------------------------------------------------------------
Expand Down Expand Up @@ -2064,4 +2064,4 @@ Tensor& conj_physical_out_sparse(const Tensor& input, Tensor& result) {
return result;
}

}} // namespace at::native
} // namespace at::native
5 changes: 2 additions & 3 deletions aten/src/ATen/native/sparse/SparseUnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@
#include <ATen/ops/trunc_native.h>
#endif

namespace at {
namespace native {
namespace at::native {
namespace {

template <typename Ufunc>
Expand Down Expand Up @@ -261,4 +260,4 @@ Tensor& nan_to_num_sparse_(
return nan_to_num_sparse_out(self, nan, posinf, neginf, self);
}

}} // namespace at::native
} // namespace at::native
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
#include <ATen/native/sparse/ValidateCompressedIndicesCommon.h>
#include <ATen/native/cpu/Loops.h>

namespace at {
namespace native {
namespace at::native {

namespace {

Expand Down Expand Up @@ -43,4 +42,4 @@ void _validate_compressed_sparse_indices_cpu(
is_crow, cidx, idx, cdim, dim, nnz);
}

}}
} //namespace at::native

0 comments on commit b73e8cd

Please sign in to comment.