Skip to content

Commit

Permalink
Update on "[ONNX] Use parameter values in onnx shape inference (#49706)"
Browse files Browse the repository at this point in the history
Adds an additional run of onnx shape inference after constant folding, since initializer may have changed and affected shape inference.

Differential Revision: [D26050881](https://our.internmc.facebook.com/intern/diff/D26050881)

[ghstack-poisoned]
  • Loading branch information
BowenBao committed Jan 26, 2021
2 parents 78dcad9 + a51b9a8 commit adddcc2
Show file tree
Hide file tree
Showing 98 changed files with 2,985 additions and 572 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ jobs:
# FunctionsManual.cpp is excluded to keep this diff clean. It will be fixed
# in a follow up PR.
# /torch/csrc/generic/*.cpp is excluded because those files aren't actually built.
# deploy/interpreter files are excluded due to using macros and other techniquies
# that are not easily converted to accepted c++
python tools/clang_tidy.py \
--verbose \
--paths torch/csrc/ \
Expand All @@ -186,6 +188,10 @@ jobs:
-g"-torch/csrc/autograd/FunctionsManual.cpp" \
-g"-torch/csrc/generic/*.cpp" \
-g"-torch/csrc/jit/codegen/cuda/runtime/*" \
-g"-torch/csrc/deploy/interpreter/interpreter.cpp" \
-g"-torch/csrc/deploy/interpreter/interpreter.h" \
-g"-torch/csrc/deploy/interpreter/interpreter_impl.h" \
-g"-torch/csrc/deploy/interpreter/test_main.cpp" \
"$@" > ${GITHUB_WORKSPACE}/clang-tidy-output.txt
cat ${GITHUB_WORKSPACE}/clang-tidy-output.txt
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ torch/csrc/autograd/generated/*
torch/testing/_internal/generated/annotated_fn_args.py
torch/testing/_internal/data/*.pt
torch/csrc/cudnn/cuDNN.cpp
torch/csrc/deploy/interpreter/cpython
torch/csrc/deploy/interpreter/frozen
torch/csrc/deploy/interpreter/third_party/typing_extensions.py
torch/csrc/generated
torch/csrc/generic/TensorMethods.cpp
torch/csrc/jit/generated/*
Expand Down
11 changes: 11 additions & 0 deletions .jenkins/pytorch/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,17 @@ if [[ "$BUILD_ENVIRONMENT" == *-mobile-code-analysis* ]]; then
exec "$(dirname "${BASH_SOURCE[0]}")/build-mobile-code-analysis.sh" "$@"
fi

if [[ "$BUILD_ENVIRONMENT" == *linux-xenial-cuda10.2-cudnn7-py3-gcc7* ]]; then
# Enabling DEPLOY build (embedded torch python interpreter, experimental)
# only on one config for now, can expand later
export USE_DEPLOY=ON

# Deploy feature builds cpython. It requires these packages.
# TODO move this to dockerfile?
sudo apt-get -qq update
sudo apt-get -qq install libffi-dev libbz2-dev libreadline-dev libncurses5-dev libncursesw5-dev libgdbm-dev libsqlite3-dev uuid-dev tk-dev
fi

echo "Python version:"
python --version

Expand Down
8 changes: 8 additions & 0 deletions .jenkins/pytorch/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,11 @@ test_vec256() {
fi
}

test_torch_deploy() {
SIMPLE_MODEL_PATH=torch/csrc/deploy/example/simple.pt LIBINTERPRETER_PATH=build/lib/libinterpreter.so build/bin/interpreter_test
assert_git_not_dirty
}

if ! [[ "${BUILD_ENVIRONMENT}" == *libtorch* || "${BUILD_ENVIRONMENT}" == *-bazel-* ]]; then
(cd test && python -c "import torch; print(torch.__config__.show())")
(cd test && python -c "import torch; print(torch.__config__.parallel_info())")
Expand All @@ -371,6 +376,9 @@ elif [[ "${BUILD_ENVIRONMENT}" == *libtorch* ]]; then
# TODO: run some C++ tests
echo "no-op at the moment"
elif [[ "${BUILD_ENVIRONMENT}" == *-test1 || "${JOB_BASE_NAME}" == *-test1 ]]; then
if [[ "${BUILD_ENVIRONMENT}" == pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7-test1 ]]; then
test_torch_deploy
fi
install_torchvision
test_python_shard1
elif [[ "${BUILD_ENVIRONMENT}" == *-test2 || "${JOB_BASE_NAME}" == *-test2 ]]; then
Expand Down
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -919,3 +919,8 @@ endif()

include(cmake/Summary.cmake)
caffe2_print_configuration_summary()

# ---[ Torch Deploy
if(USE_DEPLOY)
add_subdirectory(torch/csrc/deploy)
endif()
64 changes: 58 additions & 6 deletions aten/src/ATen/native/BatchLinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,22 @@ extern "C" void ssyevd_(char *jobz, char *uplo, int *n, float *a, int *lda, floa
// geev
extern "C" void dgeev_(char *jobvl, char *jobvr, int *n, double *a, int *lda, double *wr, double *wi, double* vl, int *ldvl, double *vr, int *ldvr, double *work, int *lwork, int *info);
extern "C" void sgeev_(char *jobvl, char *jobvr, int *n, float *a, int *lda, float *wr, float *wi, float* vl, int *ldvl, float *vr, int *ldvr, float *work, int *lwork, int *info);
extern "C" void cgeev_(char *jobvl, char *jobvr, int *n,
std::complex<float> *a, int *lda,
std::complex<float> *w,
std::complex<float> *vl, int *ldvl,
std::complex<float> *vr, int *ldvr,
std::complex<float> *work, int *lwork,
float *rwork,
int *info);
extern "C" void zgeev_(char *jobvl, char *jobvr, int *n,
std::complex<double> *a, int *lda,
std::complex<double> *w,
std::complex<double> *vl, int *ldvl,
std::complex<double> *vr, int *ldvr,
std::complex<double> *work, int *lwork,
double *rwork,
int *info);

// gesdd
extern "C" void zgesdd_(char *jobz, int *m, int *n, std::complex<double> *a, int *lda,
Expand Down Expand Up @@ -307,14 +323,44 @@ template<> void lapackSyevd<float>(char jobz, char uplo, int n, float *a, int ld
ssyevd_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, iwork, &liwork, info);
}

template<> void lapackEig<double>(char jobvl, char jobvr, int n, double *a, int lda, double *wr, double *wi, double* vl, int ldvl, double *vr, int ldvr, double *work, int lwork, int *info) {
template<> void lapackEig<double>(char jobvl, char jobvr, int n, double *a, int lda, double *w, double* vl, int ldvl, double *vr, int ldvr, double *work, int lwork, double *rwork, int *info) {
// lapack [sd]geev wants to separate output arrays: wr and wi for the real
// and imaginary parts
double *wr = w;
double *wi = w + n;
(void)rwork; // unused
dgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr, &ldvr, work, &lwork, info);
}

template<> void lapackEig<float>(char jobvl, char jobvr, int n, float *a, int lda, float *wr, float *wi, float* vl, int ldvl, float *vr, int ldvr, float *work, int lwork, int *info) {
template<> void lapackEig<float>(char jobvl, char jobvr, int n, float *a, int lda, float *w, float* vl, int ldvl, float *vr, int ldvr, float *work, int lwork, float *rwork, int *info) {
// lapack [sd]geev wants to separate output arrays: wr and wi for the real
// and imaginary parts
float *wr = w;
float *wi = w + n;
(void)rwork; // unused
sgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr, &ldvr, work, &lwork, info);
}

template<> void lapackEig<c10::complex<double>, double>(char jobvl, char jobvr, int n, c10::complex<double> *a, int lda, c10::complex<double> *w, c10::complex<double> *vl, int ldvl, c10::complex<double> *vr, int ldvr, c10::complex<double> *work, int lwork, double *rwork, int *info) {
zgeev_(&jobvl, &jobvr, &n,
reinterpret_cast<std::complex<double>*>(a), &lda,
reinterpret_cast<std::complex<double>*>(w),
reinterpret_cast<std::complex<double>*>(vl), &ldvl,
reinterpret_cast<std::complex<double>*>(vr), &ldvr,
reinterpret_cast<std::complex<double>*>(work), &lwork,
rwork, info);
}

template<> void lapackEig<c10::complex<float>, float>(char jobvl, char jobvr, int n, c10::complex<float> *a, int lda, c10::complex<float> *w, c10::complex<float> *vl, int ldvl, c10::complex<float> *vr, int ldvr, c10::complex<float> *work, int lwork, float *rwork, int *info) {
cgeev_(&jobvl, &jobvr, &n,
reinterpret_cast<std::complex<float>*>(a), &lda,
reinterpret_cast<std::complex<float>*>(w),
reinterpret_cast<std::complex<float>*>(vl), &ldvl,
reinterpret_cast<std::complex<float>*>(vr), &ldvr,
reinterpret_cast<std::complex<float>*>(work), &lwork,
rwork, info);
}

template<> void lapackSvd<c10::complex<double>, double>(char jobz, int m, int n, c10::complex<double> *a, int lda,
double *s, c10::complex<double> *u, int ldu, c10::complex<double> *vt, int ldvt, c10::complex<double> *work, int lwork, double *rwork, int *iwork, int *info) {
zgesdd_(&jobz, &m, &n, reinterpret_cast<std::complex<double>*>(a), &lda, s, reinterpret_cast<std::complex<double>*>(u), &ldu,
Expand Down Expand Up @@ -1441,7 +1487,11 @@ std::tuple<Tensor&, Tensor&> eig_out(Tensor& e, Tensor& v, const Tensor& self, b
TORCH_CHECK(v.dtype() == self.dtype(), "Expected 'v' to have dtype ", self.dtype(), " but got ", v.dtype());
int64_t n = self.size(-1);

at::native::resize_output(e, {n, 2});
if (isComplexType(at::typeMetaToScalarType(self.dtype()))) {
at::native::resize_output(e, {n});
} else {
at::native::resize_output(e, {n, 2});
}
if (eigenvectors) {
at::native::resize_output(v, self.sizes());
}
Expand Down Expand Up @@ -1566,6 +1616,8 @@ std::tuple<Tensor, Tensor, Tensor> _svd_helper_cpu(const Tensor& self, bool some
VT_working_copy.zero_();
}
// so far we have computed VT, but torch.svd returns V instead. Adjust accordingly.
// Note that the 'apply_svd' routine returns VT = V^T (for real inputs) or VT = V^H (for complex inputs), not V.
VT_working_copy = VT_working_copy.conj();
VT_working_copy.transpose_(-2, -1);
return std::make_tuple(U_working_copy, S_working_copy, VT_working_copy);
}
Expand Down Expand Up @@ -1596,8 +1648,8 @@ std::tuple<Tensor&, Tensor&, Tensor&> svd_out(Tensor& U, Tensor& S, Tensor& V,
1. the 2nd parameter is bool some=True, which if effectively the opposite
of full_matrices=True
2. svd returns V, while linalg.svd returns VT. To accommodate the
difference, we transpose() V upon return
2. svd returns V, while linalg.svd returns VT = V^T (for real inputs) or VT = V^H (for complex inputs).
To accommodate the difference, we transpose() and conj() V upon return
*/

std::tuple<Tensor, Tensor, Tensor> linalg_svd(const Tensor& self, bool full_matrices, bool compute_uv) {
Expand All @@ -1608,7 +1660,7 @@ std::tuple<Tensor, Tensor, Tensor> linalg_svd(const Tensor& self, bool full_matr
Tensor U, S, V;
std::tie(U, S, V) = at::_svd_helper(self, some, compute_uv);
if (compute_uv) {
Tensor VT = V.transpose(-2, -1);
Tensor VT = V.conj().transpose(-2, -1);
return std::make_tuple(U, S, VT);
} else {
Tensor empty_U = at::empty({0}, self.options());
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/BatchLinearAlgebra.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ namespace at { namespace native {
// Define per-batch functions to be used in the implementation of batched
// linear algebra operations

template<class scalar_t>
void lapackEig(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *wr, scalar_t *wi, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, int *info);
template<class scalar_t, class value_t=scalar_t>
void lapackEig(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *w, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, value_t *rwork, int *info);

template<class scalar_t>
void lapackOrgqr(int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info);
Expand Down
36 changes: 28 additions & 8 deletions aten/src/ATen/native/BatchLinearAlgebraKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <ATen/Dispatch.h>
#include <ATen/native/BatchLinearAlgebra.h>
#include <ATen/native/LinearAlgebraUtils.h>
#include <ATen/native/cpu/zmath.h>

#include <TH/TH.h> // for USE_LAPACK

Expand All @@ -15,29 +16,38 @@ void apply_eig(const Tensor& self, bool eigenvectors, Tensor& vals_, Tensor& vec
TORCH_CHECK(false, "Calling torch.eig on a CPU tensor requires compiling ",
"PyTorch with LAPACK. Please use PyTorch built with LAPACK support.");
#else
using value_t = typename c10::scalar_value_type<scalar_t>::type;

char jobvr = eigenvectors ? 'V' : 'N';
int64_t n = self.size(-1);
auto self_data = self.data_ptr<scalar_t>();

auto vals_data = vals_.data_ptr<scalar_t>();
scalar_t* wr = vals_data;
scalar_t* wi = vals_data + n;

scalar_t* vecs_data = eigenvectors ? vecs_.data_ptr<scalar_t>() : nullptr;
int ldvr = eigenvectors ? n : 1;

Tensor rwork;
value_t* rwork_data = nullptr;
if (self.is_complex()) {
ScalarType real_dtype = toValueType(typeMetaToScalarType(self.dtype()));
rwork = at::empty({n*2}, self.options().dtype(real_dtype));
rwork_data = rwork.data_ptr<value_t>();
}

if (n > 0) {
// call lapackEig once to get the optimal size for work data
scalar_t wkopt;
int info;
lapackEig<scalar_t>('N', jobvr, n, self_data, n, wr, wi,
nullptr, 1, vecs_data, ldvr, &wkopt, -1, &info);
int lwork = static_cast<int>(wkopt);
lapackEig<scalar_t, value_t>('N', jobvr, n, self_data, n, wr,
nullptr, 1, vecs_data, ldvr, &wkopt, -1, rwork_data, &info);
int lwork = static_cast<int>(real_impl<scalar_t, value_t>(wkopt));

// call again to do the actual work
Tensor work = at::empty({lwork}, self.dtype());
lapackEig<scalar_t>('N', jobvr, n, self_data, n, wr, wi,
nullptr, 1, vecs_data, ldvr, work.data_ptr<scalar_t>(), lwork, &info);
lapackEig<scalar_t, value_t>('N', jobvr, n, self_data, n, wr,
nullptr, 1, vecs_data, ldvr, work.data_ptr<scalar_t>(), lwork, rwork_data, &info);
*info_ptr = info;
}
#endif
Expand All @@ -55,13 +65,23 @@ std::tuple<Tensor, Tensor> eig_kernel_impl(const Tensor& self, bool& eigenvector
self_.copy_(self);

auto options = self.options().memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT);
Tensor vals_ = at::empty_strided({n, 2}, {1, n}, options);

// the API is slightly different for the complex vs real case: if the input
// is complex, eigenvals will be a vector of complex. If the input is real,
// eigenvals will be a (n, 2) matrix containing the real and imaginary parts
// in each column
Tensor vals_;
if (self.is_complex()) {
vals_ = at::empty({n}, options);
} else {
vals_ = at::empty_strided({n, 2}, {1, n}, options);
}
Tensor vecs_ = eigenvectors
? at::empty_strided({n, n}, {1, n}, options)
: Tensor();

int64_t info;
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "eig_cpu", [&]{
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "eig_cpu", [&]{
apply_eig<scalar_t>(self_, eigenvectors, vals_, vecs_, &info);
});
singleCheckErrors(info, "eig_cpu");
Expand Down
12 changes: 5 additions & 7 deletions aten/src/ATen/native/LinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,16 +147,14 @@ Tensor linalg_pinv(const Tensor& input, const Tensor& rcond, bool hermitian) {

// If not Hermitian use singular value decomposition, else use eigenvalue decomposition
if (!hermitian) {
// until https://github.com/pytorch/pytorch/issues/45821 is resolved
// svd() returns conjugated V for complex-valued input
Tensor U, S, V_conj;
Tensor U, S, V;
// TODO: replace input.svd with linalg_svd
std::tie(U, S, V_conj) = input.svd();
// using linalg_svd breaks pytorch/xla, see https://github.com/pytorch/xla/issues/2755
std::tie(U, S, V) = input.svd();
Tensor max_val = at::narrow(S, /*dim=*/-1, /*start=*/0, /*length=*/1); // singular values are sorted in descending order
Tensor S_pseudoinv = at::where(S > (rcond.unsqueeze(-1) * max_val), S.reciprocal(), at::zeros({}, S.options())).to(input.dtype());
// computes V @ diag(S_pseudoinv) @ U.T.conj()
// TODO: replace V_conj.conj() -> V once https://github.com/pytorch/pytorch/issues/45821 is resolved
return at::matmul(V_conj.conj() * S_pseudoinv.unsqueeze(-2), U.conj().transpose(-2, -1));
// computes V @ diag(S_pseudoinv) @ U.conj().T
return at::matmul(V * S_pseudoinv.unsqueeze(-2), U.conj().transpose(-2, -1));
} else {
Tensor S, U;
std::tie(S, U) = at::linalg_eigh(input);
Expand Down

0 comments on commit adddcc2

Please sign in to comment.