Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port cholesky_inverse to ATen #50269

Closed
wants to merge 50 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
da9a254
Ported cholesky_inverse to ATen; tests pass
IvanYashchuk Dec 21, 2020
3c95d86
Renamed self -> input
IvanYashchuk Dec 21, 2020
4486c37
Make _out variant to be the primary one
IvanYashchuk Dec 21, 2020
2afc8ec
Merge remote-tracking branch 'upstream/master' into port-cholesky-inv…
IvanYashchuk Jan 7, 2021
c06c5a7
Corrected comments
IvanYashchuk Jan 7, 2021
b9da168
Added other dtypes to test_cholesky_inverse
IvanYashchuk Jan 7, 2021
39a7dbf
Use CPU, CUDA instead of DefaultBackend
IvanYashchuk Jan 7, 2021
06f02e1
Added path for batched cuda to use cholesky_solve
IvanYashchuk Jan 7, 2021
d9bf828
Moved cholesky_solve workaround to cholesky_inverse_out
IvanYashchuk Jan 7, 2021
84732f7
Added missing .conj
IvanYashchuk Jan 7, 2021
5cfdf52
Fixed checking for cuda device
IvanYashchuk Jan 7, 2021
69429c9
Add missing upper to cholesky_solve_out call
IvanYashchuk Jan 7, 2021
634f53f
Fix lda argument for cholesky_solve
IvanYashchuk Jan 7, 2021
68de787
Updated tests
IvanYashchuk Jan 7, 2021
b7b482f
Removed unused import
IvanYashchuk Jan 7, 2021
747fd71
Added op-based tests for cholesky_inverse
IvanYashchuk Jan 8, 2021
78ef953
Use c10 full dispatcher
IvanYashchuk Jan 8, 2021
db1602c
Merge remote-tracking branch 'upstream/master' into port-cholesky-inv…
IvanYashchuk Jan 8, 2021
f1ed1bc
Allow non column major out
IvanYashchuk Jan 8, 2021
962a302
Fix mypy failure
IvanYashchuk Jan 8, 2021
962aef1
Changed AT_ERROR -> TORCH_CHECK; modified error message for batched
IvanYashchuk Jan 11, 2021
0ae5113
Moved conj_wrapper to header
IvanYashchuk Jan 12, 2021
2641056
Added cpu kernel for matrix reflection and conjugation
IvanYashchuk Jan 12, 2021
ab48313
Added gpu kernel for matrix reflection and conjugation
IvanYashchuk Jan 12, 2021
8011adb
Merge remote-tracking branch 'upstream/master' into port-cholesky-inv…
IvanYashchuk Jan 12, 2021
4ecd7c5
Use DEFINE/DECLARE_DISPATCH for helper function instead of native_fun…
IvanYashchuk Jan 15, 2021
dd79d85
Added a comment why copy is needed
IvanYashchuk Jan 15, 2021
ae1acd1
Use at::native::resize_output
IvanYashchuk Jan 15, 2021
ec5e678
Allocate infos with the required shape, don't use resize
IvanYashchuk Jan 15, 2021
ccf1440
Fix comment
IvanYashchuk Jan 15, 2021
cbe7e38
Added comments on different code paths for out variant
IvanYashchuk Jan 15, 2021
e6b98cf
Removed cholesky_inverse test; it is replaced by OpInfo-based tests now
IvanYashchuk Jan 15, 2021
fc8c561
Added an option to skip scripted jit
IvanYashchuk Jan 15, 2021
927298a
Added comment that infos must be on CPU for magma's single matrix cho…
IvanYashchuk Jan 15, 2021
fded2b5
Use batched cholesky_solve for single matrix cholesky_inverse because…
IvanYashchuk Jan 15, 2021
853f126
Specialized cuda kernel for symmetrization is not needed anymore
IvanYashchuk Jan 15, 2021
9735cd9
Removed magmaCholeskyInverse
IvanYashchuk Jan 15, 2021
494c912
Merge remote-tracking branch 'upstream/master' into port-cholesky-inv…
IvanYashchuk Jan 15, 2021
aeba6d4
Revert changes to UnaryComplexKernels
IvanYashchuk Jan 15, 2021
edcb2f5
apply_cholesky_inverse support for batched inputs on CUDA using apply…
IvanYashchuk Jan 19, 2021
28e81bc
Use test_complex_grad=False
IvanYashchuk Jan 20, 2021
11bf2ec
Revert adding skipping scripted jit
IvanYashchuk Jan 20, 2021
a4b750a
Merge remote-tracking branch 'upstream/master' into port-cholesky-inv…
IvanYashchuk Jan 20, 2021
081a206
flake8
IvanYashchuk Jan 20, 2021
1dde747
Merge branch 'master' into port-cholesky-inverse
IvanYashchuk Jan 22, 2021
a7bdecf
Removed use_c10_dispatcher: full; it is now default
IvanYashchuk Jan 22, 2021
703d490
Added check_batched_gradgrad=False
IvanYashchuk Jan 22, 2021
3807d28
Updated the comment on infos' device
IvanYashchuk Jan 27, 2021
682f80a
Added test case for non-invertible input
IvanYashchuk Jan 27, 2021
f117e84
Merge remote-tracking branch 'upstream/master' into port-cholesky-inv…
IvanYashchuk Jan 27, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
43 changes: 0 additions & 43 deletions aten/src/ATen/LegacyTHFunctionsCPU.cpp
Expand Up @@ -686,49 +686,6 @@ std::tuple<Tensor,Tensor> _th_gels(const Tensor & self, const Tensor & A) {
}
return std::tuple<Tensor, Tensor>(res1, res2);
}
Tensor & _th_potri_out(Tensor & output, const Tensor & self, bool upper) {
// DeviceGuard omitted
auto dispatch_scalar_type = infer_scalar_type(self);

switch (dispatch_scalar_type) {
case ScalarType::Double: {
auto output_ = checked_dense_tensor_unwrap(output, "output", 0, "_th_potri_out", false, DeviceType::CPU, dispatch_scalar_type);
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_potri_out", false, DeviceType::CPU, dispatch_scalar_type);
THDoubleTensor_potri(output_, self_, upper);
break;
}
case ScalarType::Float: {
auto output_ = checked_dense_tensor_unwrap(output, "output", 0, "_th_potri_out", false, DeviceType::CPU, dispatch_scalar_type);
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_potri_out", false, DeviceType::CPU, dispatch_scalar_type);
THFloatTensor_potri(output_, self_, upper);
break;
}
default:
AT_ERROR("_th_potri_out not supported on CPUType for ", dispatch_scalar_type);
}
return output;
}
Tensor _th_potri(const Tensor & self, bool upper) {
// DeviceGuard omitted
auto dispatch_scalar_type = infer_scalar_type(self);
auto output_ = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CPU, scalarTypeToTypeMeta(dispatch_scalar_type)).release();
auto output = Tensor(c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>::reclaim(output_));
switch (dispatch_scalar_type) {
case ScalarType::Double: {
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_potri", false, DeviceType::CPU, dispatch_scalar_type);
THDoubleTensor_potri(output_, self_, upper);
break;
}
case ScalarType::Float: {
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_potri", false, DeviceType::CPU, dispatch_scalar_type);
THFloatTensor_potri(output_, self_, upper);
break;
}
default:
AT_ERROR("_th_potri not supported on CPUType for ", dispatch_scalar_type);
}
return output;
}
std::tuple<Tensor &,Tensor &> _th_geqrf_out(Tensor & res1, Tensor & res2, const Tensor & self) {
// DeviceGuard omitted
auto dispatch_scalar_type = infer_scalar_type(self);
Expand Down
2 changes: 0 additions & 2 deletions aten/src/ATen/LegacyTHFunctionsCPU.h
Expand Up @@ -38,8 +38,6 @@ Tensor & _th_histc_out(Tensor & result, const Tensor & self, int64_t bins, Scala
Tensor _th_histc(const Tensor & self, int64_t bins, Scalar min, Scalar max);
std::tuple<Tensor &,Tensor &> _th_gels_out(Tensor & res1, Tensor & res2, const Tensor & self, const Tensor & A);
std::tuple<Tensor,Tensor> _th_gels(const Tensor & self, const Tensor & A);
Tensor & _th_potri_out(Tensor & output, const Tensor & self, bool upper);
Tensor _th_potri(const Tensor & self, bool upper);
std::tuple<Tensor &,Tensor &> _th_geqrf_out(Tensor & res1, Tensor & res2, const Tensor & self);
std::tuple<Tensor,Tensor> _th_geqrf(const Tensor & self);
Tensor & _th_ormqr_out(Tensor & result, const Tensor & self, const Tensor & input2, const Tensor & input3, bool left, bool transpose);
Expand Down
43 changes: 0 additions & 43 deletions aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp
Expand Up @@ -1222,49 +1222,6 @@ std::tuple<Tensor,Tensor> _th_gels(const Tensor & self, const Tensor & A) {
}
return std::tuple<Tensor, Tensor>(res1, res2);
}
Tensor & _th_potri_out(Tensor & output, const Tensor & self, bool upper) {
// DeviceGuard omitted
auto dispatch_scalar_type = infer_scalar_type(self);

switch (dispatch_scalar_type) {
case ScalarType::Double: {
auto output_ = checked_dense_tensor_unwrap(output, "output", 0, "_th_potri_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_potri_out", false, DeviceType::CUDA, dispatch_scalar_type);
THCudaDoubleTensor_potri(globalContext().getTHCState(), output_, self_, upper);
break;
}
case ScalarType::Float: {
auto output_ = checked_dense_tensor_unwrap(output, "output", 0, "_th_potri_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_potri_out", false, DeviceType::CUDA, dispatch_scalar_type);
THCudaTensor_potri(globalContext().getTHCState(), output_, self_, upper);
break;
}
default:
AT_ERROR("_th_potri_out not supported on CUDAType for ", dispatch_scalar_type);
}
return output;
}
Tensor _th_potri(const Tensor & self, bool upper) {
// DeviceGuard omitted
auto dispatch_scalar_type = infer_scalar_type(self);
auto output_ = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(dispatch_scalar_type)).release();
auto output = Tensor(c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>::reclaim(output_));
switch (dispatch_scalar_type) {
case ScalarType::Double: {
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_potri", false, DeviceType::CUDA, dispatch_scalar_type);
THCudaDoubleTensor_potri(globalContext().getTHCState(), output_, self_, upper);
break;
}
case ScalarType::Float: {
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_potri", false, DeviceType::CUDA, dispatch_scalar_type);
THCudaTensor_potri(globalContext().getTHCState(), output_, self_, upper);
break;
}
default:
AT_ERROR("_th_potri not supported on CUDAType for ", dispatch_scalar_type);
}
return output;
}
std::tuple<Tensor &,Tensor &> _th_geqrf_out(Tensor & res1, Tensor & res2, const Tensor & self) {
// DeviceGuard omitted
auto dispatch_scalar_type = infer_scalar_type(self);
Expand Down
100 changes: 97 additions & 3 deletions aten/src/ATen/native/BatchLinearAlgebra.cpp
Expand Up @@ -49,6 +49,12 @@ extern "C" void cpotrf_(char *uplo, int *n, std::complex<float> *a, int *lda, in
extern "C" void dpotrf_(char *uplo, int *n, double *a, int *lda, int *info);
extern "C" void spotrf_(char *uplo, int *n, float *a, int *lda, int *info);

// potri
extern "C" void zpotri_(char *uplo, int *n, std::complex<double> *a, int *lda, int *info);
extern "C" void cpotri_(char *uplo, int *n, std::complex<float> *a, int *lda, int *info);
extern "C" void dpotri_(char *uplo, int *n, double *a, int *lda, int *info);
extern "C" void spotri_(char *uplo, int *n, float *a, int *lda, int *info);

// trtrs
extern "C" void ztrtrs_(char *uplo, char *trans, char *diag, int *n, int *nrhs, std::complex<double> *a, int *lda, std::complex<double> *b, int *ldb, int *info);
extern "C" void ctrtrs_(char *uplo, char *trans, char *diag, int *n, int *nrhs, std::complex<float> *a, int *lda, std::complex<float> *b, int *ldb, int *info);
Expand Down Expand Up @@ -237,6 +243,22 @@ template<> void lapackCholesky<float>(char uplo, int n, float *a, int lda, int *
spotrf_(&uplo, &n, a, &lda, info);
}

template<> void lapackCholeskyInverse<c10::complex<double>>(char uplo, int n, c10::complex<double> *a, int lda, int *info) {
zpotri_(&uplo, &n, reinterpret_cast<std::complex<double>*>(a), &lda, info);
}

template<> void lapackCholeskyInverse<c10::complex<float>>(char uplo, int n, c10::complex<float> *a, int lda, int *info) {
cpotri_(&uplo, &n, reinterpret_cast<std::complex<float>*>(a), &lda, info);
}

template<> void lapackCholeskyInverse<double>(char uplo, int n, double *a, int lda, int *info) {
dpotri_(&uplo, &n, a, &lda, info);
}

template<> void lapackCholeskyInverse<float>(char uplo, int n, float *a, int lda, int *info) {
spotri_(&uplo, &n, a, &lda, info);
}

template<> void lapackTriangularSolve<c10::complex<double>>(char uplo, char trans, char diag, int n, int nrhs, c10::complex<double> *a, int lda, c10::complex<double> *b, int ldb, int *info) {
ztrtrs_(&uplo, &trans, &diag, &n, &nrhs, reinterpret_cast<std::complex<double>*>(a), &lda, reinterpret_cast<std::complex<double>*>(b), &ldb, info);
}
Expand Down Expand Up @@ -411,7 +433,7 @@ Computes the solution to a system of linear equations
where A is an n-by-n matrix and X and B are n-by-nrhs matrices.
Note that B is required to be a matrix, the usual, vector case, is obtained with nrhs = 1.
Above description is for non-batched input, the batched input is also supported.
This is an in-place routine, content of both A and b are overriden.
This is an in-place routine, content of both A and b are overwritten.
'infos' is an int Tensor containing error codes for each matrix in the batched input.
For more information see LAPACK's documentation for GESV routine.
*/
Expand Down Expand Up @@ -480,7 +502,7 @@ std::tuple<Tensor&,Tensor&> solve_out(Tensor& solution, Tensor& lu, const Tensor
// This is a type dispatching helper function for 'apply_solve'
Tensor& _linalg_solve_out_helper_cpu(Tensor& result, Tensor& input, Tensor& infos) {
// 'result' and 'input' should be in column major order (it should be checked before calling this function)
// the content of 'result', 'input' and 'infos' is overriden by 'apply_solve'
// the content of 'result', 'input' and 'infos' is overwritten by 'apply_solve'
// 'result' should contain data of 'other' tensor (right-hand-side of the linear system of equations)
// 'input' should contain data of original 'input' tensor (left-hand-side of the linear system of equations)
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(result.scalar_type(), "linalg_solve_out_cpu", [&]{
Expand Down Expand Up @@ -861,6 +883,78 @@ Tensor& linalg_cholesky_out(Tensor &result, const Tensor &self) {
return result;
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ cholesky_inverse ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

DEFINE_DISPATCH(cholesky_inverse_stub);

Tensor& cholesky_inverse_out_info(Tensor& result, Tensor& infos, const Tensor& input, bool upper) {
TORCH_INTERNAL_ASSERT(input.dim() >= 2);
TORCH_INTERNAL_ASSERT(input.size(-1) == input.size(-2));

TORCH_INTERNAL_ASSERT(result.scalar_type() == input.scalar_type());
TORCH_INTERNAL_ASSERT(result.device() == input.device());

TORCH_INTERNAL_ASSERT(infos.scalar_type() == at::kInt);
TORCH_INTERNAL_ASSERT(infos.device() == at::kCPU);
TORCH_INTERNAL_ASSERT(infos.numel() == std::max<int64_t>(1, batchCount(input)));

// if result has no elements we can modify it
if (result.numel() == 0) {
mruberry marked this conversation as resolved.
Show resolved Hide resolved
at::native::resize_as_(result, input.transpose(-2, -1), MemoryFormat::Contiguous);
result.transpose_(-2, -1);
}

// result tensor must be in batched column major order (Fortran contiguous)
TORCH_INTERNAL_ASSERT(result.transpose(-2, -1).is_contiguous());
TORCH_INTERNAL_ASSERT(result.sizes().equals(input.sizes()));

// cholesky_inverse_stub (apply_cholesky_inverse) performs calculations in-place and result must be a copy of input
result.copy_(input);
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved

// infos must be contiguous
TORCH_INTERNAL_ASSERT(infos.is_contiguous());
infos.fill_(0);

result = cholesky_inverse_stub(result.device().type(), result, infos, upper);
return result;
}

Tensor& cholesky_inverse_out(const Tensor &input, bool upper, Tensor &result) {
squareCheckInputs(input);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is correct because input should be strictly 2-dimensional

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TORCH_CHECK(result.scalar_type() == input.scalar_type(),
"result dtype ", result.scalar_type(), " does not match input dtype ", input.scalar_type());
TORCH_CHECK(result.device() == input.device(),
"result device ", result.device(), " does not match input device ", input.device());

// MAGMA requires 'infos' to reside in CPU memory, therefore we create 'infos' only on CPU for now.
auto infos = at::zeros({std::max<int64_t>(1, batchCount(input))}, input.options().dtype(kInt).device(kCPU));

// if result is not empty and not in batched column major format we have to allocate a temporary tensor
if (result.numel() != 0 && !result.transpose(-2, -1).is_contiguous()) {
Tensor result_tmp = at::empty({0}, input.options());
result_tmp = cholesky_inverse_out_info(result_tmp, infos, input, upper);
at::native::resize_output(result, result_tmp.sizes());
result.copy_(result_tmp);
} else {
// use result's memory directly
result = cholesky_inverse_out_info(result, infos, input, upper);
}

// Now check LAPACK/MAGMA error codes
if (result.dim() > 2) {
batchCheckErrors(infos, "cholesky_inverse");
} else {
singleCheckErrors(infos.item().toInt(), "cholesky_inverse");
}
return result;
}

Tensor cholesky_inverse(const Tensor &input, bool upper) {
Tensor result = at::empty({0}, input.options());
result = at::cholesky_inverse_out(result, input, upper);
return result;
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

template<typename scalar_t>
Expand Down Expand Up @@ -1230,7 +1324,7 @@ Tensor orgqr(const Tensor& input, const Tensor& tau) {
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ syevd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

// This function computes eigenvalues 'w' and eigenvectors 'v' of the input that is stored initially in 'v'
// The computation is done in-place: 'v' stores the input and will be overriden, 'w' should be an allocated empty array
// The computation is done in-place: 'v' stores the input and will be overwritten, 'w' should be an allocated empty array
// compute_v controls whether eigenvectors should be computed
// uplo_str controls the portion of input matrix to consider in computations, allowed values are "u", "U", "l", "L"
// infos is used to store information for possible checks for error
Expand Down
7 changes: 7 additions & 0 deletions aten/src/ATen/native/BatchLinearAlgebra.h
Expand Up @@ -14,6 +14,9 @@ namespace at { namespace native {
// Define per-batch functions to be used in the implementation of batched
// linear algebra operations

template<class scalar_t>
void lapackCholeskyInverse(char uplo, int n, scalar_t *a, int lda, int *info);

template<class scalar_t, class value_t=scalar_t>
void lapackEig(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *w, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, value_t *rwork, int *info);

Expand All @@ -22,6 +25,10 @@ void lapackOrgqr(int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scala

#endif

using cholesky_inverse_fn = Tensor& (*)(Tensor& /*result*/, Tensor& /*infos*/, bool /*upper*/);

DECLARE_DISPATCH(cholesky_inverse_fn, cholesky_inverse_stub);

using eig_fn = std::tuple<Tensor, Tensor> (*)(const Tensor&, bool&);

DECLARE_DISPATCH(eig_fn, eig_stub);
Expand Down
77 changes: 77 additions & 0 deletions aten/src/ATen/native/BatchLinearAlgebraKernel.cpp
Expand Up @@ -10,6 +10,79 @@ namespace at { namespace native {

namespace {

/*
Copies the lower (or upper) triangle of the square matrix to the other half and conjugates it.
This operation is performed in-place.
*/
template <typename scalar_t>
void apply_reflect_conj_tri_single(scalar_t* self, int64_t n, int64_t stride, bool upper) {
std::function<void(int64_t, int64_t)> loop = [](int64_t, int64_t){};
if (upper) {
loop = [&](int64_t start, int64_t end) {
for (int64_t i = start; i < end; i++) {
for (int64_t j = i + 1; j < n; j++) {
self[i * stride + j] = conj_impl(self[j * stride + i]);
}
}
};
} else {
loop = [&](int64_t start, int64_t end) {
for (int64_t i = start; i < end; i++) {
for (int64_t j = 0; j < i; j++) {
self[i * stride + j] = conj_impl(self[j * stride + i]);
}
}
};
}
// For small matrices OpenMP overhead is too large
if (n < 256) {
loop(0, n);
} else {
at::parallel_for(0, n, 0, loop);
}
}

/*
Computes the inverse of a symmetric (Hermitian) positive-definite matrix n-by-n matrix 'input' using the Cholesky factorization
This is an in-place routine, content of 'input' is overwritten.
'infos' is an int Tensor containing error codes for each matrix in the batched input.
For more information see LAPACK's documentation for POTRI routine.
*/
template <typename scalar_t>
void apply_cholesky_inverse(Tensor& input, Tensor& infos, bool upper) {
#ifndef USE_LAPACK
TORCH_CHECK(false, "cholesky_inverse: LAPACK library not found in compilation");
#else
char uplo = upper ? 'U' : 'L';

auto input_data = input.data_ptr<scalar_t>();
auto infos_data = infos.data_ptr<int>();
auto input_matrix_stride = matrixStride(input);
auto batch_size = batchCount(input);
auto n = input.size(-2);
auto lda = std::max<int64_t>(1, n);

for (int64_t i = 0; i < batch_size; i++) {
scalar_t* input_working_ptr = &input_data[i * input_matrix_stride];
int* info_working_ptr = &infos_data[i];
lapackCholeskyInverse<scalar_t>(uplo, n, input_working_ptr, lda, info_working_ptr);
// LAPACK writes to only upper/lower part of the matrix leaving the other side unchanged
apply_reflect_conj_tri_single<scalar_t>(input_working_ptr, n, lda, upper);
}
#endif
}

// This is a type dispatching helper function for 'apply_cholesky_inverse'
Tensor& cholesky_inverse_kernel_impl(Tensor& result, Tensor& infos, bool upper) {
// This function calculates the inverse matrix in-place
// result should be in column major order and contain matrices to invert
// the content of result is overwritten by 'apply_cholesky_inverse'
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(result.scalar_type(), "cholesky_inverse_out_cpu", [&]{
apply_cholesky_inverse<scalar_t>(result, infos, upper);
});
return result;
}

template <typename scalar_t>
void apply_eig(const Tensor& self, bool eigenvectors, Tensor& vals_, Tensor& vecs_, int64_t* info_ptr) {
#ifndef USE_LAPACK
Expand Down Expand Up @@ -98,6 +171,10 @@ Tensor& orgqr_kernel_impl(Tensor& result, const Tensor& tau, Tensor& infos, int6

} // anonymous namespace

REGISTER_ARCH_DISPATCH(cholesky_inverse_stub, DEFAULT, &cholesky_inverse_kernel_impl);
REGISTER_AVX_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl);
REGISTER_AVX2_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl);

REGISTER_ARCH_DISPATCH(eig_stub, DEFAULT, &eig_kernel_impl);
REGISTER_AVX_DISPATCH(eig_stub, &eig_kernel_impl);
REGISTER_AVX2_DISPATCH(eig_stub, &eig_kernel_impl);
Expand Down