Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#1 from PaddlePaddle/develop
Browse files Browse the repository at this point in the history
fetch remote
  • Loading branch information
BillDior committed Dec 28, 2020
2 parents 53a7bc2 + fdf8c72 commit e9febda
Show file tree
Hide file tree
Showing 67 changed files with 2,809 additions and 403 deletions.
126 changes: 126 additions & 0 deletions docs/develop_guides/add_hardware.md

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions docs/index.rst
Expand Up @@ -105,6 +105,7 @@ Welcome to Paddle-Lite's documentation!
develop_guides/add_operation
develop_guides/add_layout
develop_guides/add_new_pass
develop_guides/add_hardware

.. toctree::
:maxdepth: 1
Expand Down
2 changes: 2 additions & 0 deletions lite/CMakeLists.txt
Expand Up @@ -63,6 +63,7 @@ if (WITH_TESTING)
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL_FOR_UNITTESTS} "mobilenet_v1_int8_for_rockchip_npu.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL_FOR_UNITTESTS} "mobilenet_v1_int8_for_imagination_nna.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL_FOR_UNITTESTS} "fast_rcnn_fluid184.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL_FOR_UNITTESTS} "ocr_rec_quant_mul_lstm_for_arm.tar.gz")
else()
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "GoogleNet_inference.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "step_rnn.tar.gz")
Expand All @@ -81,6 +82,7 @@ if (WITH_TESTING)
# data
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL_FOR_UNITTESTS} "ILSVRC2012_small.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL_FOR_UNITTESTS} "bert_data.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL_FOR_UNITTESTS} "ocr_rec_img_txt.tar.gz")
if (NOT "${LITE_BAIDU_XPU_INTERNAL_URL_FOR_UNITTESTS}" STREQUAL "")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_BAIDU_XPU_INTERNAL_URL_FOR_UNITTESTS} "mmdnn_data.tar.gz")
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_BAIDU_XPU_INTERNAL_URL_FOR_UNITTESTS} "content_dnn_data.tar.gz")
Expand Down
2 changes: 1 addition & 1 deletion lite/api/android/jni/src/com/baidu/paddle/lite/Tensor.java
Expand Up @@ -21,7 +21,7 @@ public class Tensor {

/**
* Java doesn't have pointer. To maintain the life cycle of underneath C++
* PaddlePredictor object, we use a long value to maintain it.
* Tensor object, we use a long value to maintain it.
*/
private long cppTensorPointer;

Expand Down
1 change: 1 addition & 0 deletions lite/api/paddle_use_passes.h
Expand Up @@ -74,3 +74,4 @@ USE_MIR_PASS(__xpu__conv2d_fuse_pass);
USE_MIR_PASS(__xpu__conv2d_link_previous_out_max_pass);
USE_MIR_PASS(__xpu__sfa_head_meanstd_fuse_pass);
USE_MIR_PASS(__xpu__sfa_head_moment_fuse_pass);
USE_MIR_PASS(__xpu__softmax_topk_fuse_pass);
26 changes: 26 additions & 0 deletions lite/api/test_helper.h
Expand Up @@ -22,6 +22,10 @@
#endif
#include <time.h>
#include <cmath>
#include <fstream>
#include <string>
#include <vector>
#include "lite/utils/cp_logging.h"

// for eval
DEFINE_string(model_dir, "", "model dir");
Expand Down Expand Up @@ -74,5 +78,27 @@ double compute_standard_deviation(const T* in,
return sqrt(variance);
}

void ReadTxtFile(const std::string& file_path, float* dest, int num) {
CHECK(!file_path.empty());
CHECK(dest != nullptr);
std::ifstream ifs(file_path);
if (!ifs.is_open()) {
LOG(FATAL) << "open file error:" << file_path;
}
for (int i = 0; i < num; i++) {
ifs >> dest[i];
}
ifs.close();
}

template <typename T>
T ShapeProduction(const std::vector<T>& shape) {
T num = 1;
for (auto i : shape) {
num *= i;
}
return num;
}

} // namespace lite
} // namespace paddle
43 changes: 41 additions & 2 deletions lite/backends/arm/math/activation.cc
Expand Up @@ -720,11 +720,50 @@ void act_hard_swish<float>(const float* din,
float scale,
float offset,
int threads) {
int nums_per_thread = size / threads;
int remain = size - nums_per_thread * threads;
int neon_loop_cnt_dim4 = nums_per_thread >> 2;
int neon_loop_remain_dim4 = nums_per_thread - (neon_loop_cnt_dim4 << 2);

const float* ptr_in = din;
float* ptr_out = dout;
for (int i = 0; i < size; ++i) {
float scale_r = 1. / scale;
float32x4_t scale_v, offset_v, threshold_v, zero;
offset_v = vdupq_n_f32(offset);
scale_v = vdupq_n_f32(scale_r);
zero = vdupq_n_f32(0.);
threshold_v = vdupq_n_f32(threshold);

#pragma omp parallel for
for (int i = 0; i < threads; i++) {
const float* ptr_in_thread = ptr_in + i * nums_per_thread;
float* ptr_out_thread = ptr_out + i * nums_per_thread;
for (int j = 0; j < neon_loop_cnt_dim4; j++) {
float32x4_t in = vld1q_f32(ptr_in_thread);
float32x4_t in_add_offset = vaddq_f32(in, offset_v);
float32x4_t tmp1 = vmaxq_f32(zero, in_add_offset);
float32x4_t tmp2 = vminq_f32(threshold_v, tmp1);
float32x4_t tmp3 = vmulq_f32(scale_v, in);
float32x4_t tmp4 = vmulq_f32(tmp2, tmp3);
vst1q_f32(ptr_out_thread, tmp4);
ptr_in_thread += 4;
ptr_out_thread += 4;
}

for (int j = 0; j < neon_loop_remain_dim4; j++) {
ptr_out_thread[0] =
std::min(std::max(0.f, ptr_in_thread[0] + offset), threshold) *
ptr_in_thread[0] * scale_r;
ptr_in_thread++;
ptr_out_thread++;
}
}

ptr_out = dout + threads * nums_per_thread;
ptr_in = din + threads * nums_per_thread;
for (int i = 0; i < remain; i++) {
ptr_out[0] = std::min(std::max(0.f, ptr_in[0] + offset), threshold) *
ptr_in[0] / scale;
ptr_in[0] * scale_r;
ptr_in++;
ptr_out++;
}
Expand Down
49 changes: 49 additions & 0 deletions lite/backends/arm/math/col_im_transform.cc
Expand Up @@ -74,6 +74,55 @@ void col2im<float>(const float* data_col,
}
}

template <>
void col2im<int32_t>(const int32_t* data_col,
const int channels,
const int height,
const int width,
const int kernel_h,
const int kernel_w,
const int pad_h0,
const int pad_h1,
const int pad_w0,
const int pad_w1,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
int32_t* data_im) {
memset(data_im, 0, height * width * channels * sizeof(int32_t));
const int output_h =
(height + pad_h0 + pad_h1 - (dilation_h * (kernel_h - 1) + 1)) /
stride_h +
1;
const int output_w =
(width + pad_w0 + pad_w1 - (dilation_w * (kernel_w - 1) + 1)) / stride_w +
1;
const int channel_size = height * width;
for (int channel = channels; channel--; data_im += channel_size) {
for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
int input_row = -pad_h0 + kernel_row * dilation_h;
for (int output_rows = output_h; output_rows; output_rows--) {
if (!is_a_ge_zero_and_a_lt_b(input_row, height)) {
data_col += output_w;
} else {
int input_col = -pad_w0 + kernel_col * dilation_w;
for (int output_col = output_w; output_col; output_col--) {
if (is_a_ge_zero_and_a_lt_b(input_col, width)) {
data_im[input_row * width + input_col] += *data_col;
}
data_col++;
input_col += stride_w;
}
}
input_row += stride_h;
}
}
}
}
}

} // namespace math
} // namespace arm
} // namespace lite
Expand Down

0 comments on commit e9febda

Please sign in to comment.