Skip to content

Commit

Permalink
Update on "[Gradient Compression] Add a random generator to PowerSGD …
Browse files Browse the repository at this point in the history
…state for initializing low-rank matrix Q"

Previously the random seed is the length of input tensor, which is not guaranteed to be the different for different batches. Now initialize a random generator in PowerSGD state, and use this generator to create a random seed to randomize the low-rank tensor Q at every step.

Therefore, the initial tensor Q should be the same across all the replicas at the same step, but different at different steps.

'torch.manual_seed' is used in the same way as https://github.com/epfml/powersgd/blob/master/gradient_reducers.py#L675

Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202

Differential Revision: [D25191589](https://our.internmc.facebook.com/intern/diff/D25191589/)

[ghstack-poisoned]
  • Loading branch information
wayi committed Nov 30, 2020
2 parents 563914d + 5bb2a87 commit a44241f
Show file tree
Hide file tree
Showing 57 changed files with 16,140 additions and 15,552 deletions.
23 changes: 9 additions & 14 deletions CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
# This is a comment.
# Each line is a file pattern followed by one or more owners.

/docs/cpp @goldsborough @ebetica @yf225 @glaringlee
/torch/csrc/api/ @ebetica @goldsborough @yf225 @glaringlee
/test/cpp/api/ @ebetica @goldsborough @yf225 @glaringlee
/torch/utils/cpp_extension.py @goldsborough @fmassa @soumith @ezyang
/docs/cpp @glaringlee
/torch/csrc/api/ @glaringlee
/test/cpp/api/ @glaringlee
/torch/utils/cpp_extension.py @fmassa @soumith @ezyang

# Not there to strictly require the approval, but to be tagged as a reviewer
# on the PRs to push them into a high priority inbox.
/torch/csrc/api/data/ @apaszke
/torch/csrc/autograd/ @apaszke @albanD
/torch/csrc/jit/ @apaszke
/torch/nn/ @apaszke
/torch/autograd/ @apaszke @albanD
/torch/jit/ @apaszke
/torch/utils/data/ @apaszke
/torch/csrc/autograd/ @albanD
/torch/autograd/ @albanD

# Tensorpipe RPC Agent.
/torch/csrc/distributed/rpc/tensorpipe_agent.cpp @jiayisuse @osalpekar @lw @beauby
Expand All @@ -23,9 +18,9 @@
# Distributed package
# This list is mostly if you'd like to be tagged as reviewer, feel free to add
# or remove yourself from it.
/torch/lib/c10d/ @pietern @mrshenli @zhaojuanmao @pritamdamania87 @rohan-varma @mingzhe09088
/torch/csrc/distributed/ @pietern @mrshenli @zhaojuanmao @pritamdamania87 @rohan-varma @mingzhe09088
/torch/distributed/ @apaszke @pietern @mrshenli @zhaojuanmao @pritamdamania87 @rohan-varma @mingzhe09088
/torch/lib/c10d/ @mrshenli @zhaojuanmao @pritamdamania87 @rohan-varma @mingzhe09088
/torch/csrc/distributed/ @mrshenli @zhaojuanmao @pritamdamania87 @rohan-varma @mingzhe09088
/torch/distributed/ @mrshenli @zhaojuanmao @pritamdamania87 @rohan-varma @mingzhe09088

# Distributed tests
# This list is mostly if you'd like to be tagged as reviewer, feel free to add
Expand Down
2 changes: 0 additions & 2 deletions aten/src/ATen/LegacyTHFunctionsCUDA.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ Tensor & _th_cross_kernel_out(Tensor & result, const Tensor & self, const Tensor
Tensor _th_cross_kernel(const Tensor & self, const Tensor & other, int64_t dim);
std::tuple<Tensor &,Tensor &> _th_gels_out(Tensor & res1, Tensor & res2, const Tensor & self, const Tensor & A);
std::tuple<Tensor,Tensor> _th_gels(const Tensor & self, const Tensor & A);
std::tuple<Tensor &,Tensor &> _th_eig_out(Tensor & res1, Tensor & res2, const Tensor & self, bool eigenvectors);
std::tuple<Tensor,Tensor> _th_eig(const Tensor & self, bool eigenvectors);
Tensor & _th_potri_out(Tensor & output, const Tensor & self, bool upper);
Tensor _th_potri(const Tensor & self, bool upper);
std::tuple<Tensor &,Tensor &> _th_geqrf_out(Tensor & res1, Tensor & res2, const Tensor & self);
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/core/aten_interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,7 @@ _(aten, pixel_shuffle) \
_(aten, poisson) \
_(aten, polygamma) \
_(aten, pow) \
_(aten, float_power) \
_(aten, prelu) \
_(aten, prelu_backward) \
_(aten, prod) \
Expand Down
47 changes: 0 additions & 47 deletions aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1588,53 +1588,6 @@ std::tuple<Tensor,Tensor> _th_gels(const Tensor & self, const Tensor & A) {
}
return std::tuple<Tensor, Tensor>(res1, res2);
}
std::tuple<Tensor &,Tensor &> _th_eig_out(Tensor & res1, Tensor & res2, const Tensor & self, bool eigenvectors) {
// DeviceGuard omitted
auto dispatch_scalar_type = infer_scalar_type(self);

switch (dispatch_scalar_type) {
case ScalarType::Double: {
auto res1_ = checked_dense_tensor_unwrap(res1, "res1", 0, "_th_eig_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto res2_ = checked_dense_tensor_unwrap(res2, "res2", 0, "_th_eig_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_eig_out", false, DeviceType::CUDA, dispatch_scalar_type);
THCudaDoubleTensor_geev(globalContext().getTHCState(), res1_, res2_, self_, eigenvectors);
break;
}
case ScalarType::Float: {
auto res1_ = checked_dense_tensor_unwrap(res1, "res1", 0, "_th_eig_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto res2_ = checked_dense_tensor_unwrap(res2, "res2", 0, "_th_eig_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_eig_out", false, DeviceType::CUDA, dispatch_scalar_type);
THCudaTensor_geev(globalContext().getTHCState(), res1_, res2_, self_, eigenvectors);
break;
}
default:
AT_ERROR("_th_eig_out not supported on CUDAType for ", dispatch_scalar_type);
}
return std::tuple<Tensor &, Tensor &>(res1, res2);
}
std::tuple<Tensor,Tensor> _th_eig(const Tensor & self, bool eigenvectors) {
// DeviceGuard omitted
auto dispatch_scalar_type = infer_scalar_type(self);
auto res1_ = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(dispatch_scalar_type)).release();
auto res1 = Tensor(c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>::reclaim(res1_));
auto res2_ = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(dispatch_scalar_type)).release();
auto res2 = Tensor(c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>::reclaim(res2_));
switch (dispatch_scalar_type) {
case ScalarType::Double: {
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_eig", false, DeviceType::CUDA, dispatch_scalar_type);
THCudaDoubleTensor_geev(globalContext().getTHCState(), res1_, res2_, self_, eigenvectors);
break;
}
case ScalarType::Float: {
auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_eig", false, DeviceType::CUDA, dispatch_scalar_type);
THCudaTensor_geev(globalContext().getTHCState(), res1_, res2_, self_, eigenvectors);
break;
}
default:
AT_ERROR("_th_eig not supported on CUDAType for ", dispatch_scalar_type);
}
return std::tuple<Tensor, Tensor>(res1, res2);
}
Tensor & _th_potri_out(Tensor & output, const Tensor & self, bool upper) {
// DeviceGuard omitted
auto dispatch_scalar_type = infer_scalar_type(self);
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/LinearAlgebraUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ static inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A, c
" but each b matrix is ", self.size(-2), " by ", self.size(-1));
}

// Validates input shapes for operations on batches of square matrices (inverse, cholesky, symeig)
// Validates input shapes for operations on batches of square matrices (inverse, cholesky, symeig, eig)
static inline void squareCheckInputs(const Tensor& self) {
TORCH_CHECK(self.dim() >= 2, "Tensor of matrices must have at least 2 dimensions. ");
TORCH_CHECK(self.size(-1) == self.size(-2),
Expand Down Expand Up @@ -135,7 +135,7 @@ static inline void singleCheckErrors(int64_t info, const char* name, bool allow_
} else if (info > 0) {
if (strstr(name, "svd")) {
AT_ERROR(name, ": the updating process of SBDSDC did not converge (error: ", info, ")");
} else if (strstr(name, "symeig")) {
} else if (strstr(name, "eig")) { // this catches both "eig" and "symeig"
AT_ERROR(name, ": the algorithm failed to converge; ", info,
" off-diagonal elements of an intermediate tridiagonal form did not converge to zero.");
} else if (!allow_singular) {
Expand Down
42 changes: 42 additions & 0 deletions aten/src/ATen/native/Pow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,48 @@ Tensor pow(Scalar base, const Tensor& exp) {
return native::pow_out(result, base, exp);
}

Tensor& float_power_out(Tensor& result, const Tensor& base, const Tensor& exp) {
auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ?
at::kComplexDouble : at::kDouble;
TORCH_CHECK(result.scalar_type() == dtype,
"output type ", result.scalar_type(), "is not the desired output type ", dtype);

return at::pow_out(result, base.to(dtype), exp.to(dtype));
}

Tensor& float_power_out(Tensor& result, const Tensor& base, Scalar exp) {
return at::float_power_out(result, base, c10::scalar_to_tensor(exp, base.device()));
}

Tensor& float_power_out(Tensor& result, Scalar base, const Tensor& exp) {
return at::float_power_out(result, c10::scalar_to_tensor(base, exp.device()), exp);
}

Tensor float_power(const Tensor& base, const Tensor& exp) {
auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ? at::kComplexDouble : at::kDouble;
return at::pow(base.to(dtype), exp.to(dtype));
}

Tensor float_power(const Tensor& base, Scalar exp) {
return at::float_power(base, c10::scalar_to_tensor(exp, base.device()));
}

Tensor float_power(Scalar base, const Tensor& exp) {
return at::float_power(c10::scalar_to_tensor(base, exp.device()), exp);
}

Tensor& float_power_(Tensor& base, const Tensor& exp) {
auto dtype = (at::isComplexType(base.scalar_type()) || at::isComplexType(exp.scalar_type())) ? at::kComplexDouble : at::kDouble;
TORCH_CHECK(base.scalar_type() == dtype,
"self tensor type ", base.scalar_type(), "is not the desired type ", dtype);

return base.pow_(exp.to(dtype));
}

Tensor& float_power_(Tensor& base, Scalar exp) {
return base.float_power_(c10::scalar_to_tensor(exp, base.device()));
}

} // namespace native

} // namespace at
2 changes: 1 addition & 1 deletion aten/src/ATen/native/Resize.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace at { namespace native {
// Issues a warning if the output tensor has one or more elements and
// needs resizing
// NOTE: In the future the warning will become an error
void resize_output(Tensor& output, IntArrayRef shape);
CAFFE2_API void resize_output(Tensor& output, IntArrayRef shape);

// These functions are called by native::resize_ as well as (legacy) TH resize.
// They are not in TH/THTensor.cpp because the at namespace is easier
Expand Down
12 changes: 6 additions & 6 deletions aten/src/ATen/native/UnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ Tensor& deg2rad_out(Tensor& result, const Tensor& self) {
Tensor deg2rad(const Tensor& self) { return unary_op_impl(self, at::deg2rad_out); }
Tensor& deg2rad_(Tensor& self) { return unary_op_impl_(self, at::deg2rad_out); }

Tensor& asin_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, asin_stub); }
Tensor asin(const Tensor& self) { return unary_op_impl(self, at::asin_out); }
Tensor& asin_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, asin_stub); }
Tensor asin(const Tensor& self) { return unary_op_impl_float(self, asin_stub); }
Tensor& asin_(Tensor& self) { return unary_op_impl_(self, at::asin_out); }

// arcsin, alias of asin
Expand Down Expand Up @@ -258,12 +258,12 @@ Tensor& expm1_out(Tensor& result, const Tensor& self) { return unary_op_impl_out
Tensor expm1(const Tensor& self) { return unary_op_impl(self, at::expm1_out); }
Tensor& expm1_(Tensor& self) { return unary_op_impl_(self, at::expm1_out); }

Tensor& erf_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, erf_stub); }
Tensor erf(const Tensor& self) { return unary_op_impl(self, at::erf_out); }
Tensor& erf_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, erf_stub); }
Tensor erf(const Tensor& self) { return unary_op_impl_float(self, erf_stub); }
Tensor& erf_(Tensor& self) { return unary_op_impl_(self, at::erf_out); }

Tensor& erfc_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, erfc_stub); }
Tensor erfc(const Tensor& self) { return unary_op_impl(self, at::erfc_out); }
Tensor& erfc_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, erfc_stub); }
Tensor erfc(const Tensor& self) { return unary_op_impl_float(self, erfc_stub); }
Tensor& erfc_(Tensor& self) { return unary_op_impl_(self, at::erfc_out); }

Tensor& frac_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, frac_stub); }
Expand Down
149 changes: 148 additions & 1 deletion aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <ATen/native/LinearAlgebraUtils.h>
#include <ATen/native/cuda/MiscUtils.h>
#include <ATen/native/Resize.h>
#include <ATen/native/cuda/BatchLinearAlgebraLib.h>
#include <ATen/native/cpu/zmath.h>

Expand Down Expand Up @@ -123,6 +124,12 @@ void magmaSymeig(
value_t* w, scalar_t* wA, magma_int_t ldwa, scalar_t* work, magma_int_t lwork, value_t* rwork,
magma_int_t lrwork, magma_int_t* iwork, magma_int_t liwork, magma_int_t* info);

template<class scalar_t>
void magmaEig(
magma_vec_t jobvl, magma_vec_t jobvr, magma_int_t n, scalar_t *A, magma_int_t lda,
scalar_t *wr, scalar_t *wi, scalar_t *VL, magma_int_t ldvl,
scalar_t *VR, magma_int_t ldvr, scalar_t *work, magma_int_t lwork, magma_int_t *info);

template<class scalar_t, class value_t=scalar_t>
void magmaSvd(
magma_vec_t jobz, magma_int_t m, magma_int_t n, scalar_t* A,
Expand Down Expand Up @@ -925,6 +932,26 @@ void magmaSymeig<c10::complex<float>, float>(
ldwa, reinterpret_cast<magmaFloatComplex*>(work), lwork, rwork, lrwork, iwork, liwork, info);
AT_CUDA_CHECK(cudaGetLastError());
}

template<>
void magmaEig<double>(
magma_vec_t jobvl, magma_vec_t jobvr, magma_int_t n, double *A, magma_int_t lda,
double *wr, double *wi, double *VL, magma_int_t ldvl,
double *VR, magma_int_t ldvr, double *work, magma_int_t lwork, magma_int_t *info) {
MagmaStreamSyncGuard guard;
magma_dgeev(jobvl, jobvr, n, A, lda, wr, wi, VL, ldvl, VR, ldvr, work, lwork, info);
AT_CUDA_CHECK(cudaGetLastError());
}

template<>
void magmaEig<float>(
magma_vec_t jobvl, magma_vec_t jobvr, magma_int_t n, float *A, magma_int_t lda,
float *wr, float *wi, float *VL, magma_int_t ldvl,
float *VR, magma_int_t ldvr, float *work, magma_int_t lwork, magma_int_t *info) {
MagmaStreamSyncGuard guard;
magma_sgeev(jobvl, jobvr, n, A, lda, wr, wi, VL, ldvl, VR, ldvr, work, lwork, info);
AT_CUDA_CHECK(cudaGetLastError());
}

template<>
void magmaSvd<double>(
Expand Down Expand Up @@ -1783,6 +1810,126 @@ std::tuple<Tensor, Tensor> _symeig_helper_cuda(const Tensor& self, bool eigenvec
}
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ eig ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

// magmaEig uses a hybrid CPU-GPU algorithm, which takes and return CPU
// memory. So, we accept a GPU tensor, copy it to CPU memory, and later copy
// the returned values from CPU to GPU. See also magmaSymeig, which uses a
// similar approach.

template <typename scalar_t>
static void apply_eig(const Tensor& self, bool eigenvectors, Tensor& out_eigvals, Tensor& out_eigvecs,
int64_t *info_ptr) {
#ifndef USE_MAGMA
TORCH_CHECK(false, "Calling torch.eig on a CUDA tensor requires compiling PyTorch with MAGMA. "
"Either transfer the tensor to the CPU before calling torch.eig or recompile with MAGMA.");
#else
TORCH_INTERNAL_ASSERT(self.device() == at::kCPU, "Internal error: apply_eig needs a CPU tensor");
magma_vec_t jobvr = eigenvectors ? MagmaVec : MagmaNoVec;
magma_int_t n = magma_int_cast(self.size(-1), "n");
auto self_data = self.data_ptr<scalar_t>();

auto out_eigvals_data = out_eigvals.data_ptr<scalar_t>();
scalar_t *wr = out_eigvals_data;
scalar_t *wi = out_eigvals_data+n;

scalar_t *vr_data = NULL;
magma_int_t ldvr = 1;
if (jobvr == MagmaVec)
{
vr_data = out_eigvecs.data_ptr<scalar_t>();
ldvr = n;
}

if (n > 0) {
// call magmaEig once to get the optimal size of work_data
scalar_t wkopt;
magma_int_t info;
magmaEig<scalar_t>(MagmaNoVec, jobvr, n, self_data, n, wr, wi, NULL, 1, vr_data, ldvr, &wkopt, -1, &info);
magma_int_t lwork = (magma_int_t) wkopt;

// call it a 2nd time to to the actual work
scalar_t *work_data = nullptr;
ALLOCATE_ARRAY(work_data, scalar_t, lwork);
magmaEig<scalar_t>(MagmaNoVec, jobvr, n, self_data, n, wr, wi, NULL, 1, vr_data, ldvr, work_data, lwork, &info);
*info_ptr = info;
}
#endif
}

/*
* Internal helper; like eig_cuda but:
* 1. assume that self is a square matrix of side "n"
* 2. return CPU tensors (because this is what magmaEig returns), which will be copied to GPU memory
* by the caller
*/
static std::tuple<Tensor,Tensor> eig_cuda_helper(const Tensor& self, int64_t n, bool eigenvectors) {
// copy self to pinned CPU memory
auto self_working_copy = at::empty_strided(
{n, n}, // square matrix
{1, n}, // column-ordered, as magmaEig expects
at::TensorOptions(at::kCPU).dtype(self.dtype()).pinned_memory(true));
self_working_copy.copy_(self);

// tensors holding the results. We use empty_strided to make them column-ordered
auto options = self.options().device(at::kCPU).memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto out_eigvals = at::empty_strided({n, 2}, {1, n}, options);
auto out_eigvecs = eigenvectors
? at::empty_strided({n, n}, {1, n}, options)
: Tensor();

int64_t info;
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "eig_cuda", [&]{
apply_eig<scalar_t>(self_working_copy, eigenvectors, out_eigvals, out_eigvecs, &info);
});
singleCheckErrors(info, "eig_cuda");

return std::tuple<Tensor, Tensor>(out_eigvals, out_eigvecs);
}

std::tuple<Tensor&, Tensor&> eig_cuda_out(Tensor& e, Tensor& v, const Tensor& self, bool eigenvectors) {
TORCH_CHECK(self.dim() == 2, "Expected a two-dimensional input but got ", self.dim(), " dimensions");
TORCH_CHECK(e.dtype() == self.dtype(), "Expected 'e' to have dtype ", self.dtype(), " but got ", e.dtype());
if (eigenvectors)
TORCH_CHECK(v.dtype() == self.dtype(), "Expected 'v' to have dtype ", self.dtype(), " but got ", v.dtype());
squareCheckInputs(self);
int64_t n = self.size(-1);

at::native::resize_output(e, {n, 2});
if (eigenvectors) {
at::native::resize_output(v, self.sizes());
}

// optimization: if self is empty, we can immediately return the empty
// GPU tensors, instead of getting empty CPU tensors from eig_cuda_helper
// and copying them to GPU
if (self.numel() == 0) {
return std::tuple<Tensor&, Tensor&>(e, v);
}

Tensor cpu_vals, cpu_vecs;
std::tie(cpu_vals, cpu_vecs) = eig_cuda_helper(self, n, eigenvectors);
e.copy_(cpu_vals);
if (eigenvectors) {
v.copy_(cpu_vecs);
}
return std::tuple<Tensor&, Tensor&>(e, v);
}

std::tuple<Tensor,Tensor> eig_cuda(const Tensor& self, bool eigenvectors) {
TORCH_CHECK(self.dim() == 2, "Expected a two-dimensional input but got ", self.dim(), " dimensions");
squareCheckInputs(self);
int64_t n = self.size(-1);

Tensor e, v;
e = at::empty({n, 2}, self.options());
if (eigenvectors) {
v = at::empty({n, n}, self.options());
}
eig_cuda_out(e, v, self, eigenvectors);
return std::tuple<Tensor, Tensor>(e, v);
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ syevd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

// This function computes eigenvalues 'w' and eigenvectors 'v' of the tensor 'self'
Expand All @@ -1797,7 +1944,7 @@ std::tuple<Tensor, Tensor> _syevd_helper_cuda(const Tensor& self, bool compute_e
bool upper = uplo == 'U' ? true : false;
return _symeig_helper_cuda(self, compute_eigenvectors, upper);
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ svd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

template<typename scalar_t>
Expand Down

0 comments on commit a44241f

Please sign in to comment.