Skip to content

Commit

Permalink
refine tensor.reshape (PaddlePaddle#194)
Browse files Browse the repository at this point in the history
  • Loading branch information
Superjomn committed Aug 31, 2020
1 parent 18be7b3 commit 59f2b7a
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
11 changes: 11 additions & 0 deletions cinn/ir/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <cstring>

#include "cinn/cinn.h"
#include "cinn/common/arithmatic.h"
#include "cinn/common/cas.h"
#include "cinn/common/common.h"
#include "cinn/common/ir_util.h"
Expand Down Expand Up @@ -410,6 +411,16 @@ ir::Tensor _Tensor_::Reshape(const std::vector<Expr> &shape, poly::StageMap stag
auto n = make_shared<_Tensor_>();
auto selft = Tensor(const_cast<ir::_Tensor_ *>(this));

{
Expr this_num_elements = Expr(1);
for (auto &e : this->shape) this_num_elements = this_num_elements * e;

Expr num_elements = Expr(1);
for (auto &e : shape) num_elements = num_elements * e;

CHECK(MathIsZero(this_num_elements - num_elements)) << "number of elements mismatch";
}

n->name = Context::Global().NewName(name + "_reshape");
n->shape = shape;
n->domain = shape;
Expand Down
11 changes: 4 additions & 7 deletions cinn/ir/tensor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ TEST(Tensor, Reshape) {
auto A1 = A->Reshape({Expr(10), Expr(10), Expr(100)}, stages);
auto B = Compute(A1->shape, [=](Expr i, Expr j, Expr k) { return A1(i, j, k) * 2.f; });

stages[A1]->ShareBufferWith(stages[A]);
stages->InsertLazily(B);

auto func = lang::Lower("fn", stages, {A, B});
Expand Down Expand Up @@ -122,13 +121,10 @@ TEST(Tensor, ReshapeCopied) {
auto A1 = A->ReshapeCopied({Expr(10), Expr(10), Expr(100)}, stages);
auto B = Compute(A1->shape, [=](Expr i, Expr j, Expr k) { return A1(i, j, k) * 2.f; });

stages[A1]->ShareBufferWith(stages[A]);
stages->InsertLazily(B);

auto func = lang::Lower("fn", stages, {A, B});

lang::Module::Builder builder("some_modue", common::DefaultHostTarget());
builder.AddFunction(func);
auto func = lang::Lower("fn", stages, {A, B}, {}, {}, &builder);

backends::CodeGenC codegenc(common::DefaultHostTarget());
codegenc.SetInlineBuiltinCodes(false);
Expand All @@ -139,13 +135,14 @@ TEST(Tensor, ReshapeCopied) {
#include <cinn_runtime.h>
#include <stdio.h>
cinn_buffer_t* _A_copied_2_reshape_3 = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_float32_t(), { 10, 10, 100 }, 32/*align*/);
void fn(void* _args, int32_t num_args)
{
const cinn_buffer_t* _A = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[0]));
cinn_buffer_t* _tensor_4 = cinn_pod_value_to_buffer_p(&(((cinn_pod_value_t*)(_args))[1]));
cinn_buffer_malloc((void*)(0), _tensor_4);
cinn_buffer_malloc((void*)(0), _A);
const float* A_copied_2_reshape_3 = ((const float*)(_A->memory));
cinn_buffer_malloc((void*)(0), _A_copied_2_reshape_3);
const float* A_copied_2_reshape_3 = ((const float*)(_A_copied_2_reshape_3->memory));
float* tensor_4 = ((float*)(_tensor_4->memory));
for (int32_t i = 0; i < 10; i += 1) {
for (int32_t j = 0; j < 10; j += 1) {
Expand Down

0 comments on commit 59f2b7a

Please sign in to comment.