Skip to content

Commit

Permalink
add test for matmul tile (PaddlePaddle#74)
Browse files Browse the repository at this point in the history
  • Loading branch information
Superjomn committed Mar 9, 2020
1 parent 0484b4b commit d24d1b9
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 4 deletions.
87 changes: 87 additions & 0 deletions cinn/backends/codegen_c_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,93 @@ void matmul(const struct cinn_buffer_t *_A, const struct cinn_buffer_t *_B, stru
ASSERT_EQ(Trim(tgt), Trim(out));
}

// This matches output of competitor.
TEST(CodeGenC, matmul_tile) {
using namespace ir;
const int M = 100;
const int K = 200;
const int N = 500;
const int bn = 32;
Placeholder<float> A("A", {M, K});
Placeholder<float> B("B", {K, N});

// C = A * B
lang::Buffer C_buf(Float(32));
Var k(K, "k");

Tensor C_init = Compute(
{M, N}, [&](Var i, Var j) { return Expr(0.f); }, "C_init");

Tensor C = Compute(
{M, N}, [&](Var i, Var j) { return lang::Sum(A(i, k) * B(k, j), k); }, "C", k);
C->Bind(C_buf);
C_init->Bind(C_buf);
// C_init->stage()->ComputeAt(C->stage(), 1);

{
poly::Iterator i_outer, i_inner, j_outer, j_inner;
std::tie(i_outer, i_inner, j_outer, j_inner) = C_init->stage()->Tile(0, 1, bn, bn);
C_init->stage()->Reorder({i_outer, j_outer, i_inner, j_inner});
}

{
poly::Iterator i_outer, i_inner, j_outer, j_inner, k_outer, k_inner;
std::tie(i_outer, i_inner, j_outer, j_inner) = C->stage()->Tile(0, 1, bn, bn);
std::tie(k_outer, k_inner) = C->stage()->Split(poly::Iterator("k"), 4);
C->stage()->Reorder({i_outer, j_outer, i_inner, j_inner, k_outer, k_inner});
}

C_init->stage()->ComputeAt(C->stage(), 3);

// Code gen
auto funcs = Lower("matmul", {A, B, C_init, C});
ASSERT_EQ(funcs.size(), 1UL);

Target target;
target.arch = Target::Arch ::X86;
target.bits = Target::Bit ::k32;
target.os = Target::OS ::Linux;

Module module("module1", target);
module.Append(funcs.front());
module.Append(C_buf);

CodeGenC codegen(target);
auto out = codegen.Compile(module, CodeGenC::OutputKind::CImpl);
std::cout << "codegen C:" << std::endl << out << std::endl;

auto target_out = R"ROC(
#include <cinn_runtime.h>
#include <stdio.h>
cinn_buffer_t* _C = cinn_buffer_t::new_((cinn_device_kind_t)(0)/*target*/, cinn_float32_t());
void matmul(const struct cinn_buffer_t *_A, const struct cinn_buffer_t *_B, struct cinn_buffer_t *_C)
{
cinn_buffer_malloc((void*)(0), _C);
const float* A = (const float*)(cinn_buffer_get_data_const_handle(_A));
const float* B = (const float*)(cinn_buffer_get_data_const_handle(_B));
float* C = (float*)(cinn_buffer_get_data_handle(_C));
float* C_init = (float*)(cinn_buffer_get_data_handle(_C));
for (int32_t i_outer = 0; (i_outer <= 3); i_outer += 1){
for (int32_t j_outer = 0; (j_outer <= 15); j_outer += 1){
for (int32_t i_inner = 0; (i_inner <= min(31, ((-32 * i_outer) + 99))); i_inner += 1){
for (int32_t j_inner = 0; (j_inner <= min(31, ((-32 * j_outer) + 499))); j_inner += 1){
C_init[((((32 * i_outer) + i_inner) * 500) + ((32 * j_outer) + j_inner))] = 0;
for (int32_t k_outer = 0; (k_outer <= 49); k_outer += 1){
for (int32_t k_inner = 0; (k_inner <= 3); k_inner += 1){
C[((((32 * i_outer) + i_inner) * 500) + ((32 * j_outer) + j_inner))] = (C[((((32 * i_outer) + i_inner) * 500) + ((32 * j_outer) + j_inner))] + (A[((((32 * i_outer) + i_inner) * 200) + ((4 * k_outer) + k_inner))] * B[((((4 * k_outer) + k_inner) * 500) + ((32 * j_outer) + j_inner))]));
};
};
};
};
};
};
}
)ROC";

ASSERT_EQ(Trim(target_out), Trim(out));
}

TEST(CodeGenC, matmul_with_packed) {
const int M = 100;
const int K = 20;
Expand Down
8 changes: 4 additions & 4 deletions cinn/poly/stage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,16 +177,16 @@ bool ComputeAtRelation::IsCompatible(Stage *self) {
CHECK(!self->domain().is_null());
CHECK(!stage->domain().is_null());

CHECK_LE(level, isl_set_dim(self->domain().get(), isl_dim_set));
CHECK_LE(level, isl_set_dim(stage->domain().get(), isl_dim_set));
CHECK_LE(level, isl_set_dim(self->transformed_domain().get(), isl_dim_set));
CHECK_LE(level, isl_set_dim(stage->transformed_domain().get(), isl_dim_set));

std::vector<int> selected_dims;
for (int i = 0; i <= level; i++) {
selected_dims.push_back(i);
}

auto stage_partial_set = SetGetDims(stage->domain(), selected_dims);
auto self_partial_set = SetGetDims(self->domain(), selected_dims);
auto stage_partial_set = SetGetDims(stage->transformed_domain(), selected_dims);
auto self_partial_set = SetGetDims(self->transformed_domain(), selected_dims);

stage_partial_set = isl::manage(isl_set_set_tuple_name(stage_partial_set.release(), ""));
self_partial_set = isl::manage(isl_set_set_tuple_name(self_partial_set.release(), ""));
Expand Down

0 comments on commit d24d1b9

Please sign in to comment.