Skip to content

Commit

Permalink
Merge pull request #2618 from alibaba/feature/sync
Browse files Browse the repository at this point in the history
[MNN:Sync] Sync Internal 2.7.2
  • Loading branch information
jxt1234 committed Oct 20, 2023
2 parents 1bbb3b1 + 476083a commit 6edd12f
Show file tree
Hide file tree
Showing 163 changed files with 12,818 additions and 700 deletions.
7 changes: 4 additions & 3 deletions 3rd_party/OpenCLHeaders/CL/cl2.hpp
Expand Up @@ -7111,16 +7111,17 @@ class CommandQueue : public detail::Wrapper<cl_command_queue>
size_t num_local_workgroups,
const cl_workgroup_qcom *local_workgroups_array,
cl_uint num_events_in_wait_list,
const cl_event *event_wait_list,
cl_event *event)
const vector<Event>* events = NULL,
Event* event = NULL) const
{
cl_event tmp;
cl_int err = detail::errHandler(
::clEnqueueRecordingQCOM(
object_, recording, num_args, arg_array, num_global_offsets,
global_offset_array, num_global_workgroups, global_workgroup_array,
num_local_workgroups, local_workgroups_array, num_events_in_wait_list,
event_wait_list, &tmp),
(events != NULL && num_events_in_wait_list > 0) ? (cl_event*) &events->front() : NULL,
(event != NULL) ? &tmp : NULL),
__ENQUEUE_READ_BUFFER_ERR);

if (event != NULL && err == CL_SUCCESS)
Expand Down
6 changes: 6 additions & 0 deletions CMakeLists.txt
Expand Up @@ -55,6 +55,7 @@ option(MNN_BUILD_CODEGEN "Build with codegen" OFF)
option(MNN_ENABLE_COVERAGE "Build with coverage enable" OFF)
option(MNN_BUILD_PROTOBUFFER "Build with protobuffer in MNN" ON)
option(MNN_BUILD_OPENCV "Build OpenCV api in MNN." OFF)
option(MNN_BUILD_LLM "Build llm library based MNN." OFF)
option(MNN_INTERNAL "Build with MNN internal features, such as model authentication, metrics logging" OFF)
option(MNN_JNI "Build MNN Jni for java to use" OFF)

Expand Down Expand Up @@ -612,6 +613,11 @@ IF(MNN_BUILD_CODEGEN)
include(${CMAKE_CURRENT_LIST_DIR}/codegen/CMakeLists.txt)
ENDIF()

IF(MNN_BUILD_LLM)
# add_definitions(-DMNN_BUILD_LLM)
include(${CMAKE_CURRENT_LIST_DIR}/llm/CMakeLists.txt)
ENDIF()

# NPU
IF(MNN_NPU)
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/source/backend/hiai/)
Expand Down
2 changes: 2 additions & 0 deletions docs/compile/cmake.md
Expand Up @@ -45,6 +45,7 @@ MNN使用CMake构建项目,CMake中的宏定义列表如下:
| MNN_CUDA_PROFILE | 是否打开CUDA profile工具,默认为`OFF` |
| MNN_CUDA_QUANT | 是否打开CUDA 量化文件编译,默认为`OFF` |
| MNN_CUDA_BF16 | 是否打开CUDA Bf16文件编译,默认为`OFF` |
| MNN_CUDA_TUNE_PARAM | 是否打开CUDA TUNE相关文件编译,目前仅支持安培及以上架构,默认为`OFF` |
| MNN_TENSORRT | 是否构建`TensorRT`后端,默认为`OFF` |
| MNN_COREML | 是否构建`CoreML`后端,默认为`OFF` |
| MNN_NNAPI | 是否构建`NNAPI`后端,默认为`OFF` |
Expand Down Expand Up @@ -82,3 +83,4 @@ MNN使用CMake构建项目,CMake中的宏定义列表如下:
| MNN_OPENCV_BENCH | 构建MNN的OpenCV功能是否开启性能benchmark,默认为`OFF` |
| MNN_VULKAN_IMAGE | 构建MNN的Vulkan后端时采用Image内存模式,以便支持FP16和部分移动端上GPU的加速,默认为`ON` |
| MNN_LOW_MEMORY | 是否支持低内存模式,支持低内存模式使用权值量化模型并设置`low_memory`则会使用计算时反量化,默认为`OFF` |
| MNN_BUILD_LLM | 是否构建基于MNN的llm库和demo,默认为`OFF` |
1 change: 1 addition & 0 deletions express/Executor.cpp
Expand Up @@ -535,6 +535,7 @@ void Executor::_makeCache(const std::vector<EXPRP>& expr, bool forceCPU) {
TensorUtils::getDescribe(tensor.get())->quantAttr.reset(new QuantAttr);
auto quant = TensorUtils::getDescribe(tensor.get())->quantAttr.get();
quant->scale = TensorUtils::getDescribe(srcTensor)->quantAttr.get()->scale;
quant->zero = TensorUtils::getDescribe(srcTensor)->quantAttr.get()->zero;
}

TensorUtils::getDescribe(tensor.get())->index = (int)scheduleInfo.allTensors.size();
Expand Down
13 changes: 8 additions & 5 deletions express/MathOp.cpp
Expand Up @@ -329,6 +329,7 @@ VARP _Asin(VARP x)
{
return _Unary(x, UnaryOpOperation_ASIN);
}

/*Computes acos of x element-wise.
Args:
x: A variable. Must be one of the following types: Halide_Type_Int or Halide_Type_Float
Expand All @@ -344,7 +345,7 @@ VARP _Acos(VARP x)
/*Computes acosh of x element-wise.
Args:
x: A variable. Must be one of the following types: Halide_Type_Int or Halide_Type_Float
Note: The output of atan will lie within the invertible range of tan, i.e (0.0, pi).
Note: The output of atan will lie within (0, +inf). The input lies in [1, +inf)
Returns:
A variable. Has the same type as x.
*/
Expand All @@ -368,7 +369,7 @@ VARP _Asinh(VARP x)
/*Computes atanh of x element-wise.
Args:
x: A variable. Must be one of the following types: Halide_Type_Int or Halide_Type_Float
Note: The output of atan will lie within the invertible range of tan, i.e (0.0, pi).
Note: The input of atanh will lie within (-1, 1). The output of atan will lie within (-inf, +inf).
Returns:
A variable. Has the same type as x.
*/
Expand All @@ -389,6 +390,7 @@ VARP _Cosh(VARP x)
return _Unary(x, UnaryOpOperation_COSH);
}


/*Computes sinh of x element-wise.
Args:
x: A variable. Must be one of the following types: Halide_Type_Int or Halide_Type_Float
Expand All @@ -404,7 +406,7 @@ VARP _Sinh(VARP x)
/*Computes the Gauss error function of `x` element-wise.
Args:
x: A variable. Must be one of the following types: Halide_Type_Int or Halide_Type_Float
Note: The output of atan will lie within the invertible range of tan, i.e (0.0, pi).
Note: The output of atan will lie within (-1.0, 1.0). The input will lie in (-inf, inf)
Returns:
A variable. Has the same type as x.
*/
Expand All @@ -428,7 +430,7 @@ VARP _Erfc(VARP x)
/*Computes the inverse function for erf, for `x` element-wise.
Args:
x: A variable. Must be one of the following types: Halide_Type_Int or Halide_Type_Float
Note: The output of atan will lie within the invertible range of tan, i.e (0.0, pi).
Note: The input of atan will lie within (-1, 1).
Returns:
A variable. Has the same type as x.
*/
Expand Down Expand Up @@ -514,6 +516,7 @@ A variable. Has the same type as x.
VARP _Tanh(VARP x) {
return _Unary(x, UnaryOpOperation_TANH);
}

/*Computes sigmoid of x element-wise.
Args:
x: A variable. Must be one of the following types: Halide_Type_Float
Expand All @@ -524,6 +527,7 @@ VARP _Sigmoid(VARP x) {
return _Unary(x, UnaryOpOperation_SIGMOID);
}


/*Computes ((exponential of x) - 1) element-wise.
Args:
x: A variable. Must be one of the following types: Halide_Type_Float
Expand All @@ -534,7 +538,6 @@ VARP _Expm1(VARP x) {
return _Unary(x, UnaryOpOperation_EXPM1);
}


/*Returns x + y element-wise.
Args:
x: A variable. Must be one of the following types:
Expand Down
2 changes: 1 addition & 1 deletion include/MNN/MNNDefine.h
Expand Up @@ -69,6 +69,6 @@ MNN_ERROR("Check failed: %s ==> %s\n", #success, #log); \
#define STR(x) STR_IMP(x)
#define MNN_VERSION_MAJOR 2
#define MNN_VERSION_MINOR 7
#define MNN_VERSION_PATCH 1
#define MNN_VERSION_PATCH 2
#define MNN_VERSION STR(MNN_VERSION_MAJOR) "." STR(MNN_VERSION_MINOR) "." STR(MNN_VERSION_PATCH)
#endif /* MNNDefine_h */
1 change: 0 additions & 1 deletion include/MNN/expr/MathOp.hpp
Expand Up @@ -61,7 +61,6 @@ MNN_PUBLIC VARP _Atanh(VARP x);
MNN_PUBLIC VARP _Reciprocal(VARP x);
MNN_PUBLIC VARP _Log1p(VARP x);
MNN_PUBLIC VARP _Gelu(VARP x);
//Only one but not in UnaryOPs
MNN_PUBLIC VARP _Tanh(VARP x);
MNN_PUBLIC VARP _Sigmoid(VARP x);
MNN_PUBLIC VARP _Erf(VARP x);
Expand Down
21 changes: 21 additions & 0 deletions llm/CMakeLists.txt
@@ -0,0 +1,21 @@
# include dir
include_directories(${CMAKE_CURRENT_LIST_DIR}/include/)

# source files
FILE(GLOB SRCS ${CMAKE_CURRENT_LIST_DIR}/src/*.cpp)

if (MSVC)
# compile static lib, surrpot Winwows
add_library(llm STATIC ${SRCS})
target_link_libraries(llm ${MNN_DEPS})
else()
# compile dynamic so, support Linux/Mac
add_library(llm SHARED ${SRCS})
set_target_properties(llm PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS TRUE)
target_link_libraries(llm ${MNN_DEPS})
endif()
target_compile_features(llm PRIVATE cxx_std_17)

add_executable(llm_demo ${CMAKE_CURRENT_LIST_DIR}/llm_demo.cpp)
target_compile_features(llm_demo PRIVATE cxx_std_17)
target_link_libraries(llm_demo llm)
133 changes: 133 additions & 0 deletions llm/include/llm.hpp
@@ -0,0 +1,133 @@
//
// llm.hpp
//
// Created by MNN on 2023/08/25.
// ZhaodeWang
//

#ifndef LLM_hpp
#define LLM_hpp

#include <vector>
#include <memory>
#include <string>
#include <unordered_map>
#include <iostream>

#include <MNN/AutoTime.hpp>
#include <MNN/expr/Expr.hpp>
#include <MNN/expr/Module.hpp>
#include <MNN/expr/MathOp.hpp>
#include <MNN/expr/NeuralNetWorkOp.hpp>
#include "tokenizer.hpp"

using namespace MNN;
using namespace Express;

class MNN_PUBLIC Llm {
public:
Llm() {
// default tokenier is senrencepiece
tokenizer_.reset(new Sentencepiece);
}
static Llm* createLLM(const std::string& path);
VARP gen_embedding(const std::vector<int>& input_ids);
void load(const std::string& model_dir);
int forward(const std::vector<int>& input_ids);
std::vector<int> tokenizer_encode(const std::string& input_str);
std::string decode(int id);
std::string response(const std::string& input_str, std::ostream* os = &std::cout);
float load_progress() { return load_progress_; }
void reset();
private:
virtual std::vector<int> tokenizer(const std::string& query) = 0;
virtual VARP gen_attention_mask(int seq_len) = 0;
virtual VARP gen_position_ids(int seq_len) = 0;
virtual bool is_stop(int token_id) = 0;
protected:
// model configs
bool is_single_ = false;
int layer_nums_ = 0;
int hidden_size_ = 4096;
std::vector<int> key_value_shape_ = {};
std::string model_name_ = "";
// gen info
int gen_seq_len_ = 0;
int all_seq_len_ = 0;
int max_seq_len_ = 256;
float load_progress_ = 0.f;
// tokenizer
std::unique_ptr<Tokenizer> tokenizer_;
private:
// MNN Modules
std::shared_ptr<Executor::RuntimeManager> runtime_manager_;
std::vector<std::shared_ptr<Module>> modules_;
std::vector<VARP> past_key_values_;
// model dir
std::string model_dir_;
// tokenizer
std::vector<std::string> word_decoder_;
std::unordered_map<std::string, int> word_encoder_;
};

// some llm models
class Chatglm_6b : public Llm {
public:
Chatglm_6b() {
model_name_ = "Chatglm_6b";
layer_nums_ = 28;
key_value_shape_ = {2, 0, 1, 32, 128};
}
private:
virtual std::vector<int> tokenizer(const std::string& query) override;
virtual VARP gen_attention_mask(int seq_len) override;
virtual VARP gen_position_ids(int seq_len) override;
virtual bool is_stop(int token_id) override;
int context_len_ = 0;
};

class Chatglm2_6b : public Llm {
public:
Chatglm2_6b() {
model_name_ = "Chatglm2_6b";
layer_nums_ = 28;
key_value_shape_ = {2, 0, 1, 2, 128};
}
private:
virtual std::vector<int> tokenizer(const std::string& query) override;
virtual VARP gen_attention_mask(int seq_len) override;
virtual VARP gen_position_ids(int seq_len) override;
virtual bool is_stop(int token_id) override;
};


class Qwen_7b : public Llm {
public:
Qwen_7b() {
model_name_ = "Qwen_7b";
layer_nums_ = 32;
key_value_shape_ = {2, 1, 0, 32, 128};
tokenizer_.reset(new Tiktoken);
}
private:
virtual std::vector<int> tokenizer(const std::string& query) override;
virtual VARP gen_attention_mask(int seq_len) override;
virtual VARP gen_position_ids(int seq_len) override;
virtual bool is_stop(int token_id) override;
};

class Llama2_7b : public Llm {
public:
Llama2_7b() {
model_name_ = "Llama2_7b";
layer_nums_ = 32;
key_value_shape_ = {2, 1, 32, 0, 128};
}
private:
virtual std::vector<int> tokenizer(const std::string& query) override;
virtual VARP gen_attention_mask(int seq_len) override;
virtual VARP gen_position_ids(int seq_len) override;
virtual bool is_stop(int token_id) override;
};

#endif // LLM_hpp
87 changes: 87 additions & 0 deletions llm/include/tokenizer.hpp
@@ -0,0 +1,87 @@
//
// tokenizer.hpp
//
// Created by MNN on 2023/09/25.
// ZhaodeWang
//

#ifndef TOKENIZER_hpp
#define TOKENIZER_hpp

#include <vector>
#include <memory>
#include <string>
#include <unordered_map>
#include <iostream>
#include <string_view>

class Tokenizer {
public:
Tokenizer() = default;
virtual bool load(const std::string& filename) = 0;
virtual std::vector<int> encode(const std::string& str) = 0;
virtual std::string decode(int id) = 0;
};

class Sentencepiece : public Tokenizer {
public:
Sentencepiece() = default;
virtual bool load(const std::string& filename) override;
virtual std::vector<int> encode(const std::string& str) override;
virtual std::string decode(int id) override;
private:
enum ModelType {
UNIGRAM = 1,
BPE = 2,
WORD = 3,
CHAR = 4
};
enum PieceType {
NORMAL = 1,
UNKNOWN = 2,
CONTROL = 3,
USER_DEFINED = 4,
UNUSED = 5,
BYTE = 6
};
struct SentencePiece {
std::string piece;
float score;
PieceType type = PieceType::NORMAL;
};
using EncodeResult = std::vector<std::pair<std::string_view, int>>;
private:
// model train type
ModelType type_ = BPE;
// byte fall back enable
bool byte_fall_back_ = true;
// unknown id.
int unk_id_ = 0;
// pieces from model
std::vector<SentencePiece> sentence_pieces_;
// piece -> id map for normal pieces
std::unordered_map<std::string, int> pieces_;
// piece -> id map for control, unknown, and byte pieces
std::unordered_map<std::string, int> reserved_id_map_;
private:
float get_score(int id) const;
bool is_unused(int id) const;
bool is_control(int id) const;
int piece_to_id(const std::string& w) const;
std::string byte_to_piece(unsigned char c) const;
EncodeResult bpe_encode(std::string_view str, float alpha = 0.f);
};

class Tiktoken : public Tokenizer {
public:
Tiktoken() = default;
virtual bool load(const std::string& filename) override;
virtual std::vector<int> encode(const std::string& str) override;
virtual std::string decode(int id) override;
private:
std::vector<std::string> decoder_;
std::vector<int> tokens_;
std::vector<int> token_ids_;
};

#endif // TOKENIZER_hpp

0 comments on commit 6edd12f

Please sign in to comment.