Skip to content

Commit

Permalink
Fixed svd for empty matrices
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanYashchuk committed Jan 26, 2021
1 parent f7b339d commit a877b8c
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 106 deletions.
44 changes: 21 additions & 23 deletions aten/src/ATen/native/BatchLinearAlgebra.cpp
Expand Up @@ -1490,6 +1490,8 @@ static void apply_svd(Tensor& self, Tensor& U, Tensor& S, Tensor& VT,
int info;
auto m = self.size(-2);
auto n = self.size(-1);
auto lda = std::max<int64_t>(1, m);
auto ldvt = std::max<int64_t>(1, n);
auto mn = std::min(m, n);
Tensor iwork = at::empty({8 * mn}, at::kInt);
auto iwork_data = iwork.data_ptr<int>();
Expand All @@ -1508,7 +1510,7 @@ static void apply_svd(Tensor& self, Tensor& U, Tensor& S, Tensor& VT,
// and (batch_size - 1) calls to allocate and deallocate workspace using at::empty()
int lwork = -1;
scalar_t wkopt;
lapackSvd<scalar_t, value_t>(jobz, m, n, self_data, m, S_data, U_data, m, VT_data, n, &wkopt, lwork, rwork_data, iwork_data, &info);
lapackSvd<scalar_t, value_t>(jobz, m, n, self_data, lda, S_data, U_data, lda, VT_data, ldvt, &wkopt, lwork, rwork_data, iwork_data, &info);
lwork = static_cast<int>(real_impl<scalar_t, value_t>(wkopt));
Tensor work = at::empty({lwork}, self.options());
auto work_data = work.data_ptr<scalar_t>();
Expand All @@ -1520,8 +1522,8 @@ static void apply_svd(Tensor& self, Tensor& U, Tensor& S, Tensor& VT,
scalar_t* VT_working_ptr = &VT_data[i * VT_stride];

// Compute S, U (optionally) and VT (optionally)
lapackSvd<scalar_t, value_t>(jobz, m, n, self_working_ptr, m,
S_working_ptr, U_working_ptr, m, VT_working_ptr, n, work_data, lwork, rwork_data, iwork_data, &info);
lapackSvd<scalar_t, value_t>(jobz, m, n, self_working_ptr, lda,
S_working_ptr, U_working_ptr, lda, VT_working_ptr, ldvt, work_data, lwork, rwork_data, iwork_data, &info);
infos[i] = info;
if (info != 0) {
return;
Expand All @@ -1540,31 +1542,27 @@ std::tuple<Tensor, Tensor, Tensor> _svd_helper_cpu(const Tensor& self, bool some
Tensor U_working_copy, S_working_copy, VT_working_copy;
std::tie(U_working_copy, S_working_copy, VT_working_copy) = _create_U_S_VT(self, some, compute_uv);

if (self.numel() > 0) {
auto self_working_copy = cloneBatchedColumnMajor(self);

AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "svd_cpu", [&]{
apply_svd<scalar_t>(self_working_copy, U_working_copy, S_working_copy, VT_working_copy, jobz, infos);
});
auto self_working_copy = cloneBatchedColumnMajor(self);

if (self.dim() > 2) {
batchCheckErrors(infos, "svd_cpu");
} else {
singleCheckErrors(infos[0], "svd_cpu");
}
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "svd_cpu", [&]{
apply_svd<scalar_t>(self_working_copy, U_working_copy, S_working_copy, VT_working_copy, jobz, infos);
});

if (compute_uv) {
if (some) {
VT_working_copy = VT_working_copy.narrow(-2, 0, k);
}
} else {
VT_working_copy.zero_();
U_working_copy.zero_();
}
if (self.dim() > 2) {
batchCheckErrors(infos, "svd_cpu");
} else {
U_working_copy.zero_();
singleCheckErrors(infos[0], "svd_cpu");
}

if (!compute_uv) {
VT_working_copy.zero_();
U_working_copy.zero_();
}

if (some) {
VT_working_copy = VT_working_copy.narrow(-2, 0, k);
}

// so far we have computed VT, but torch.svd returns V instead. Adjust accordingly.
VT_working_copy.transpose_(-2, -1);
return std::make_tuple(U_working_copy, S_working_copy, VT_working_copy);
Expand Down
77 changes: 37 additions & 40 deletions aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
Expand Up @@ -2158,6 +2158,8 @@ AT_ERROR("svd: MAGMA library not found in "

magma_int_t m = magma_int_cast(self.size(-2), "m");
magma_int_t n = magma_int_cast(self.size(-1), "n");
auto lda = std::max<magma_int_t>(1, m);
auto ldvt = std::max<magma_int_t>(1, n);
auto mn = std::min(m, n);

c10::Storage storage_rwork;
Expand All @@ -2178,7 +2180,7 @@ AT_ERROR("svd: MAGMA library not found in "
// and (batch_size - 1) calls to allocate and deallocate workspace using at::empty()
magma_int_t lwork = -1;
scalar_t wkopt;
magmaSvd<scalar_t, value_t>(jobz, m, n, self_data, m, S_data, U_data, m, VT_data, n, &wkopt, lwork, rwork, iwork, &info);
magmaSvd<scalar_t, value_t>(jobz, m, n, self_data, lda, S_data, U_data, lda, VT_data, ldvt, &wkopt, lwork, rwork, iwork, &info);
lwork = magma_int_cast(real_impl<scalar_t, value_t>(wkopt), "work_size");
scalar_t* work;
ALLOCATE_ARRAY(work, scalar_t, lwork);
Expand All @@ -2190,8 +2192,8 @@ AT_ERROR("svd: MAGMA library not found in "
scalar_t* VT_working_ptr = &VT_data[i * VT_stride];

// Compute S, U (optionally), VT (optionally)
magmaSvd<scalar_t, value_t>(jobz, m, n, self_working_ptr, m,
S_working_ptr, U_working_ptr, m, VT_working_ptr, n, work, lwork, rwork, iwork, &info);
magmaSvd<scalar_t, value_t>(jobz, m, n, self_working_ptr, lda,
S_working_ptr, U_working_ptr, lda, VT_working_ptr, ldvt, work, lwork, rwork, iwork, &info);
infos[i] = info;
if (info != 0) {
return;
Expand All @@ -2210,47 +2212,42 @@ std::tuple<Tensor, Tensor, Tensor> _svd_helper_cuda_legacy(const Tensor& self, b
Tensor U_working_copy, S_working_copy, VT_working_copy;
std::tie(U_working_copy, S_working_copy, VT_working_copy) = _create_U_S_VT(self, some, compute_uv);

if (self.numel() > 0) {
// The input matrix, U, S and VT have to reside in pinned memory.
// Additionally, the input and U have to be in column major format.
// _create_U_S_VT takes care of a part of these requirements (for U, S and VT)
// For the input matrix, this requirements are being taken care of below.
// Specify strides
auto self_col_major_strides = at::detail::defaultStrides(self.sizes());
self_col_major_strides[self.dim() - 2] = 1;
self_col_major_strides[self.dim() - 1] = m;
// Create strided tensor in pinned memory
auto self_working_copy = at::empty_strided(self.sizes(), self_col_major_strides,
at::TensorOptions(at::kCPU).dtype(self.dtype()).pinned_memory(true));
self_working_copy.copy_(self);

AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "svd_cuda", [&] {
apply_svd<scalar_t>(self_working_copy, U_working_copy, S_working_copy, VT_working_copy, jobchar, infos);
});

if (self.dim() > 2) {
batchCheckErrors(infos, "svd_cuda");
} else {
singleCheckErrors(infos[0], "svd_cuda");
}
// The input matrix, U, S and VT have to reside in pinned memory.
// Additionally, the input and U have to be in column major format.
// _create_U_S_VT takes care of a part of these requirements (for U, S and VT)
// For the input matrix, this requirements are being taken care of below.
// Specify strides
auto self_col_major_strides = at::detail::defaultStrides(self.sizes());
self_col_major_strides[self.dim() - 2] = 1;
self_col_major_strides[self.dim() - 1] = m;
// Create strided tensor in pinned memory
auto self_working_copy = at::empty_strided(self.sizes(), self_col_major_strides,
at::TensorOptions(at::kCPU).dtype(self.dtype()).pinned_memory(true));
self_working_copy.copy_(self);

U_working_copy = same_stride_to(U_working_copy, self.options());
S_working_copy = same_stride_to(S_working_copy, S_working_copy.options().device(self.device()));
VT_working_copy = same_stride_to(VT_working_copy, self.options());
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "svd_cuda", [&] {
apply_svd<scalar_t>(self_working_copy, U_working_copy, S_working_copy, VT_working_copy, jobchar, infos);
});

if (compute_uv) {
if (some) {
VT_working_copy = VT_working_copy.narrow(-2, 0, k);
}
} else {
VT_working_copy.zero_();
U_working_copy.zero_();
}
if (self.dim() > 2) {
batchCheckErrors(infos, "svd_cuda");
} else {
U_working_copy = same_stride_to(U_working_copy, self.options()).zero_();
S_working_copy = same_stride_to(S_working_copy, S_working_copy.options().device(self.device()));
VT_working_copy = same_stride_to(VT_working_copy, self.options()).zero_();
singleCheckErrors(infos[0], "svd_cuda");
}

U_working_copy = same_stride_to(U_working_copy, self.options());
S_working_copy = same_stride_to(S_working_copy, S_working_copy.options().device(self.device()));
VT_working_copy = same_stride_to(VT_working_copy, self.options());

if (!compute_uv) {
VT_working_copy.zero_();
U_working_copy.zero_();
}

if (some) {
VT_working_copy = VT_working_copy.narrow(-2, 0, k);
}

// so far we have computed VT, but torch.svd returns V instead. Adjust accordingly.
VT_working_copy.transpose_(-2, -1);
return std::make_tuple(U_working_copy, S_working_copy, VT_working_copy);
Expand Down
48 changes: 25 additions & 23 deletions aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu
Expand Up @@ -165,7 +165,9 @@ inline static void _apply_svd_lib_gesvdj(const Tensor& self, Tensor& U, Tensor&

int batchsize = cuda_int_cast(batchCount(self), "batch size");
int m = cuda_int_cast(self.size(-2), "m");
int n = cuda_int_cast(self.size(-1), "n");
int n = cuda_int_cast(self.size(-1), "n");
int lda = std::max<int>(1, m);
int ldvt = std::max<int>(1, n);

for(int i = 0; i < batchsize; i++){
// gesvdj_params controls the numerical accuracy of cusolver gesvdj iterations on GPU
Expand All @@ -179,12 +181,12 @@ inline static void _apply_svd_lib_gesvdj(const Tensor& self, Tensor& U, Tensor&
at::cuda::solver::gesvdj<scalar_t>(
handle, jobz, /*econ=*/ some ? 1 : 0, m, n,
self_data + i * self_stride,
m,
lda,
S_data + i * S_stride,
U_data + i * U_stride,
m,
lda,
VT_data + i * VT_stride,
n,
ldvt,
infos.data_ptr<int>() + i,
gesvdj_params
);
Expand Down Expand Up @@ -223,7 +225,9 @@ inline static void _apply_svd_lib_gesvdjBatched(const Tensor& self, Tensor& U, T

int batchsize = cuda_int_cast(batchCount(self), "batch size");
int m = cuda_int_cast(self.size(-2), "m");
int n = cuda_int_cast(self.size(-1), "n");
int n = cuda_int_cast(self.size(-1), "n");
int lda = std::max<int>(1, m);
int ldvt = std::max<int>(1, n);

TORCH_INTERNAL_ASSERT(m <= 32 && n <= 32, "gesvdjBatched requires both matrix dimensions not greater than 32, but got "
"m = ", m, " n = ", n);
Expand All @@ -238,7 +242,7 @@ inline static void _apply_svd_lib_gesvdjBatched(const Tensor& self, Tensor& U, T
auto handle = at::cuda::getCurrentCUDASolverDnHandle();
auto jobz = compute_uv ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR;
at::cuda::solver::gesvdjBatched<scalar_t>(
handle, jobz, m, n, self_data, m, S_data, U_data, m, VT_data, n,
handle, jobz, m, n, self_data, lda, S_data, U_data, lda, VT_data, ldvt,
infos.data_ptr<int>(), gesvdj_params, batchsize
);

Expand Down Expand Up @@ -273,25 +277,23 @@ std::tuple<Tensor, Tensor, Tensor> _svd_helper_cuda_lib(const Tensor& self, bool
_create_U_S_VT(self, some, compute_uv, /* svd_use_cusolver = */ true);
// U, S, V working copies are already column majored now

if (self.numel() > 0) {
// heuristic for using `gesvdjBatched` over `gesvdj`
if (m <= 32 && n <= 32 && batch_size > 1 && (!some || m == n)) {
apply_svd_lib_gesvdjBatched(self, U_working_copy, S_working_copy, VT_working_copy, infos, compute_uv);
} else {
apply_svd_lib_gesvdj(self, U_working_copy, S_working_copy, VT_working_copy, infos, compute_uv, some);
}
// heuristic for using `gesvdjBatched` over `gesvdj`
if (m <= 32 && n <= 32 && batch_size > 1 && (!some || m == n)) {
apply_svd_lib_gesvdjBatched(self, U_working_copy, S_working_copy, VT_working_copy, infos, compute_uv);
} else {
apply_svd_lib_gesvdj(self, U_working_copy, S_working_copy, VT_working_copy, infos, compute_uv, some);
}

// A device-host sync will be performed.
batchCheckErrors(infos, "svd_cuda");
// A device-host sync will be performed.
batchCheckErrors(infos, "svd_cuda");

if (compute_uv) {
if (some) {
VT_working_copy = VT_working_copy.narrow(-2, 0, k);
}
} else {
VT_working_copy.zero_();
U_working_copy.zero_();
}
if (!compute_uv) {
VT_working_copy.zero_();
U_working_copy.zero_();
}

if (some) {
VT_working_copy = VT_working_copy.narrow(-2, 0, k);
}

// so far we have computed VT, but torch.svd returns V instead. Adjust accordingly.
Expand Down
43 changes: 23 additions & 20 deletions test/test_linalg.py
Expand Up @@ -1768,27 +1768,30 @@ def run_test(dims, some, compute_uv):

# test non-contiguous
x = torch.randn(*dims, dtype=dtype, device=device)
n_dim = len(dims)
# Reverse the batch dimensions and the matrix dimensions and then concat them
x = x.permute(tuple(range(n_dim - 3, -1, -1)) + (n_dim - 1, n_dim - 2))
assert not x.is_contiguous(), "x is intentionally non-contiguous"
resu, ress, resv = torch.svd(x, some=some, compute_uv=compute_uv)
if compute_uv:
if some:
x_recon = torch.matmul(resu, torch.matmul(ress.diag_embed(), resv.transpose(-2, -1)))
self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T')
if x.numel() > 0:
n_dim = len(dims)
# Reverse the batch dimensions and the matrix dimensions and then concat them
x = x.permute(tuple(range(n_dim - 3, -1, -1)) + (n_dim - 1, n_dim - 2))
assert not x.is_contiguous(), "x is intentionally non-contiguous"
resu, ress, resv = torch.svd(x, some=some, compute_uv=compute_uv)
if compute_uv:
if some:
x_recon = torch.matmul(resu, torch.matmul(ress.diag_embed(), resv.transpose(-2, -1)))
self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T')
else:
narrow_u = resu[..., :min(*dims[-2:])]
narrow_v = resv[..., :min(*dims[-2:])]
x_recon = torch.matmul(narrow_u, torch.matmul(ress.diag_embed(), narrow_v.transpose(-2, -1)))
self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T')
else:
narrow_u = resu[..., :min(*dims[-2:])]
narrow_v = resv[..., :min(*dims[-2:])]
x_recon = torch.matmul(narrow_u, torch.matmul(ress.diag_embed(), narrow_v.transpose(-2, -1)))
self.assertEqual(x, x_recon, atol=1e-8, rtol=0, msg='Incorrect reconstruction using U @ diag(S) @ V.T')
else:
_, singvals, _ = torch.svd(x, compute_uv=True)
self.assertEqual(singvals, ress, msg='Singular values mismatch')
self.assertEqual(resu, torch.zeros_like(resu), msg='U not zero')
self.assertEqual(resv, torch.zeros_like(resv), msg='V not zero')

shapes = [(3, 3), (5, 3, 3), (7, 5, 3, 3), # square matrices
_, singvals, _ = torch.svd(x, compute_uv=True)
self.assertEqual(singvals, ress, msg='Singular values mismatch')
self.assertEqual(resu, torch.zeros_like(resu), msg='U not zero')
self.assertEqual(resv, torch.zeros_like(resv), msg='V not zero')

shapes = [(0, 0), (5, 0), (0, 5), # empty matrices
(0, 0, 0), (0, 5, 5), (0, 5, 3), # zero batch dimension
(3, 3), (5, 3, 3), (7, 5, 3, 3), # square matrices
(7, 3), (5, 7, 3), (7, 5, 7, 3), # fat matrices
(3, 7), (5, 3, 7), (7, 5, 3, 7)] # thin matrices
for dims, some, compute_uv in product(shapes, [True, False], [True, False]):
Expand Down

0 comments on commit a877b8c

Please sign in to comment.