Skip to content

Commit

Permalink
Fix eigvals_op (PaddlePaddle#12)
Browse files Browse the repository at this point in the history
* dymf tmp

* add dymf tmp

* local test change

* pull thread pool

* fix conflict

* delete unuse log

* local change for mirrow 0

* fix dymf

* code clean

* fix code clean

* code clean

* code clean

* fix dymf

* fix dymf

* add endpass optimize

* clean code

* fix endpass optimize

* fix

* fix

* fix eigvals_op

* merge pre-stable

* merge pre-stable

Co-authored-by: yaoxuefeng6 <yaoxuefeng@baidu.com>
Co-authored-by: Thunderbrook <a754913769@163.com>
  • Loading branch information
3 people committed Jun 24, 2022
1 parent 022b54e commit 7134e26
Showing 1 changed file with 26 additions and 26 deletions.
52 changes: 26 additions & 26 deletions paddle/fluid/operators/eigvals_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,14 @@ LapackEigvals(const framework::ExecutionContext& ctx, const Tensor& input,
w.mutable_data<T>(phi::make_ddim({n_dim << 1}), ctx.GetPlace());

int64_t work_mem = work->memory_size();
// int64_t required_work_mem = 3 * n_dim * sizeof(T);
// PADDLE_ENFORCE_GE(
// work_mem, 3 * n_dim * sizeof(T),
// platform::errors::InvalidArgument(
// "The memory size of the work tensor in LapackEigvals function "
// "should be at least %" PRId64 " bytes, "
// "but received work\'s memory size = %" PRId64 " bytes.",
// required_work_mem, work_mem));
int64_t required_work_mem = 3 * n_dim * sizeof(T);
PADDLE_ENFORCE_GE(
work_mem, 3 * n_dim * sizeof(T),
platform::errors::InvalidArgument(
"The memory size of the work tensor in LapackEigvals function "
"should be at least %" PRId64 " bytes, "
"but received work\'s memory size = %" PRId64 " bytes.",
required_work_mem, work_mem));

int info = 0;
phi::funcs::lapackEig<T>('N', 'N', static_cast<int>(n_dim),
Expand Down Expand Up @@ -134,24 +134,24 @@ LapackEigvals(const framework::ExecutionContext& ctx, const Tensor& input,

int64_t work_mem = work->memory_size();
int64_t n_dim = input.dims()[1];
// int64_t required_work_mem = 3 * n_dim * sizeof(T);
// PADDLE_ENFORCE_GE(
// work_mem, 3 * n_dim * sizeof(T),
// platform::errors::InvalidArgument(
// "The memory size of the work tensor in LapackEigvals function "
// "should be at least %" PRId64 " bytes, "
// "but received work\'s memory size = %" PRId64 " bytes.",
// required_work_mem, work_mem));

// int64_t rwork_mem = rwork->memory_size();
// int64_t required_rwork_mem = (n_dim << 1) * sizeof(phi::dtype::Real<T>);
// PADDLE_ENFORCE_GE(
// rwork_mem, required_rwork_mem,
// platform::errors::InvalidArgument(
// "The memory size of the rwork tensor in LapackEigvals function "
// "should be at least %" PRId64 " bytes, "
// "but received rwork\'s memory size = %" PRId64 " bytes.",
// required_rwork_mem, rwork_mem));
int64_t required_work_mem = 3 * n_dim * sizeof(T);
PADDLE_ENFORCE_GE(
work_mem, 3 * n_dim * sizeof(T),
platform::errors::InvalidArgument(
"The memory size of the work tensor in LapackEigvals function "
"should be at least %" PRId64 " bytes, "
"but received work\'s memory size = %" PRId64 " bytes.",
required_work_mem, work_mem));

int64_t rwork_mem = rwork->memory_size();
int64_t required_rwork_mem = (n_dim << 1) * sizeof(phi::dtype::Real<T>);
PADDLE_ENFORCE_GE(
rwork_mem, required_rwork_mem,
platform::errors::InvalidArgument(
"The memory size of the rwork tensor in LapackEigvals function "
"should be at least %" PRId64 " bytes, "
"but received rwork\'s memory size = %" PRId64 " bytes.",
required_rwork_mem, rwork_mem));

int info = 0;
phi::funcs::lapackEig<T, phi::dtype::Real<T>>(
Expand Down

0 comments on commit 7134e26

Please sign in to comment.