From 6037b5a4db6f0520a8a3851d876b5a7dc91f6dc5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E5=AE=87=E6=89=AC?= Date: Thu, 4 Jul 2024 10:08:23 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=86=E7=A6=BBgraphllm=E7=9A=84=E5=AE=9E?= =?UTF-8?q?=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CMakeLists.txt | 4 +- src/model.cpp | 4 +- src/models/graph/qwen2.cpp | 82 ++++++++++++++++++++++++++++++ src/models/graph/telechat.cpp | 96 +++++++++++++++++++++++++++++++++++ 4 files changed, 182 insertions(+), 4 deletions(-) create mode 100644 src/models/graph/qwen2.cpp create mode 100644 src/models/graph/telechat.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 0950a14..51f1077 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -44,11 +44,13 @@ endif() message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") +file(GLOB GRAPH_MODEL_FILES "src/models/graph/*.cpp") set(FASTLLM_CXX_SOURCES src/fastllm.cpp src/device.cpp src/model.cpp src/executor.cpp src/template.cpp src/graph.cpp src/devices/cpu/cpudevice.cpp src/devices/cpu/cpudevicebatch.cpp src/models/graphllm.cpp src/models/chatglm.cpp src/models/moss.cpp src/models/llama.cpp src/models/qwen.cpp src/models/basellm.cpp src/models/glm.cpp src/models/minicpm.cpp src/models/internlm2.cpp src/models/bert.cpp src/models/moe.cpp src/models/deepseekv2.cpp - third_party/json11/json11.cpp) + third_party/json11/json11.cpp + ${GRAPH_MODEL_FILES}) include_directories(include) include_directories(include/utils) diff --git a/src/model.cpp b/src/model.cpp index c251767..ce747fc 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -170,10 +170,8 @@ namespace fastllm { model = (basellm*)(new GLMModel()); } else if (modelType == "bert") { model = (basellm*)(new BertModel()); - } else if (modelType == "telechat") { - model = new GraphLLMModel("telechat"); } else { - ErrorInFastLLM("Unknown model type: " + modelType); + model = new GraphLLMModel(modelType); } return model; } diff --git a/src/models/graph/qwen2.cpp b/src/models/graph/qwen2.cpp new file mode 100644 index 0000000..ef176b2 --- /dev/null +++ b/src/models/graph/qwen2.cpp @@ -0,0 +1,82 @@ +#include "graphllm.h" + +namespace fastllm { + class Qwen2GraphModelConfig : GraphLLMModelConfig { + public: + void InitParams(GraphLLMModel *model) { + } + + std::map > > + GetTensorMap(GraphLLMModel *model, const std::vector &tensorNames) { + std::map > > ret; + std::string embeddingName = "model.embed_tokens.weight"; + std::string logitsName = "lm_head.weight"; + std::set linearNames = { + ".self_attn.q_proj.weight", ".self_attn.k_proj.weight", ".self_attn.v_proj.weight", ".self_attn.o_proj.weight", + ".mlp.gate_proj.weight", ".mlp.up_proj.weight", ".mlp.down_proj.weight" + }; + ret[embeddingName].push_back(std::make_pair(embeddingName, DataType::DATA_AUTO_EMBEDDING)); + for (int i = 0; i < model->block_cnt; i++) { + std::string pre = "model.layers." + std::to_string(i); + for (auto &it : linearNames) { + ret[pre + it].push_back(std::make_pair(pre + it, DataType::DATA_AUTO_LINEAR)); + } + } + for (auto &name : tensorNames) { + if (ret[name].size() == 0) { + ret[name].push_back(std::make_pair(name, DataType::DATA_AUTO_NONE)); + } + } + if (ret.find(logitsName) == ret.end()) { + ret[embeddingName].push_back(std::make_pair(logitsName, DataType::DATA_AUTO_LINEAR)); + } else { + ret[logitsName][0].second = DataType::DATA_AUTO_LINEAR; + } + return ret; + } + + void BuildGraph(GraphLLMModel *model) { + auto &graph = *(model->GetGraph()); + std::map wNodes; + for (auto &it : model->weight.weight) { + wNodes[it.first] = ComputeGraphNode(it.first); + } + ComputeGraphNode inputIds("inputIds"), positionIds("positionIds"), attentionMask("attentionMask"), atype("atype"), sin("sin"), cos("cos"), seqLens("seqLens"); + ComputeGraphNode hiddenStates("hiddenStates"), attenInput("attenInput"), attenOutput("attenOutput"), attenLastOutput("attenLastOutput"); + ComputeGraphNode q("q"), k("k"), v("v"), w1("w1"), w2("w2"), w3("w3"), lastTokensStates("lastTokensStates"), logits("logits"); + graph.Embedding(inputIds, wNodes["model.embed_tokens.weight"], hiddenStates); + graph.DataTypeAs(hiddenStates, atype); + for (int i = 0; i < model->block_cnt; i++) { + std::string pre = "model.layers." + std::to_string(i); + ComputeGraphNode pastKey("pastKey_" + std::to_string(i)), pastValue("pastValue_" + std::to_string(i)); + graph.RMSNorm(hiddenStates, wNodes[pre + ".input_layernorm.weight"], model->rms_norm_eps, attenInput); + graph.Linear(attenInput, wNodes[pre + ".self_attn.q_proj.weight"], wNodes[pre + ".self_attn.q_proj.bias"], q); + graph.Linear(attenInput, wNodes[pre + ".self_attn.k_proj.weight"], wNodes[pre + ".self_attn.k_proj.bias"], k); + graph.Linear(attenInput, wNodes[pre + ".self_attn.v_proj.weight"], wNodes[pre + ".self_attn.v_proj.bias"], v); + graph.ExpandHead(q, model->head_dim); + graph.ExpandHead(k, model->head_dim); + graph.ExpandHead(v, model->head_dim); + graph.LlamaRotatePosition2D(q, positionIds, sin, cos, model->rotary_dim); + graph.LlamaRotatePosition2D(k, positionIds, sin, cos, model->rotary_dim); + graph.FusedAttention(q, pastKey, pastValue, k, v, attenInput, attentionMask, attenOutput, seqLens, 1.0 / sqrt(model->head_dim), 0, 128); + graph.Linear(attenOutput, wNodes[pre + ".self_attn.o_proj.weight"], wNodes[pre + ".self_attn.o_proj.bias"], attenLastOutput); + graph.AddTo(hiddenStates, attenLastOutput); + graph.RMSNorm(hiddenStates, wNodes[pre + ".post_attention_layernorm.weight"], model->rms_norm_eps, attenInput); + graph.Linear(attenInput, wNodes[pre + ".mlp.gate_proj.weight"], wNodes[pre + ".mlp.gate_proj.bias"], w1); + graph.Linear(attenInput, wNodes[pre + ".mlp.up_proj.weight"], wNodes[pre + ".mlp.up_proj.bias"], w3); + graph.Silu(w1, w1); + graph.MulTo(w1, w3); + graph.Linear(w1, wNodes[pre + ".mlp.down_proj.weight"], wNodes[pre + ".mlp.down_proj.bias"], w2); + graph.AddTo(hiddenStates, w2); + } + + graph.SplitLastTokenStates(hiddenStates, seqLens, lastTokensStates); + graph.RMSNorm(lastTokensStates, wNodes["model.norm.weight"], model->rms_norm_eps, lastTokensStates); + graph.Linear(lastTokensStates, wNodes["lm_head.weight"], wNodes["lm_head.bias"], logits); + + OptimizeComputeGraph(graph, model->weight); + graph.Update(); + } + }; + REGISTERGRAPHMODELCONFIG(qwen2, Qwen2GraphModelConfig) +} \ No newline at end of file diff --git a/src/models/graph/telechat.cpp b/src/models/graph/telechat.cpp new file mode 100644 index 0000000..bd5f929 --- /dev/null +++ b/src/models/graph/telechat.cpp @@ -0,0 +1,96 @@ +#include "graphllm.h" + +namespace fastllm { + class TeleChatGraphModelConfig : GraphLLMModelConfig { + public: + void InitParams(GraphLLMModel *model) { + model->block_cnt = atoi(model->weight.dicts["n_layer"].c_str()); + model->max_positions = atoi(model->weight.dicts["seq_length"].c_str()); + model->rope_base = 10000 * pow(3, ((float)model->rotary_dim / (model->rotary_dim - 2))); + model->rope_factor = 1.0; + + model->pre_prompt = ""; + model->user_role = "<_user>"; + model->bot_role = "<_bot>"; + model->history_sep = ""; + } + + std::map > > + GetTensorMap(GraphLLMModel *model, const std::vector &tensorNames) { + std::set linearNames = { + ".self_attention.query.weight", ".self_attention.key_value.weight", ".self_attention.dense.weight", + ".mlp.gate_proj.weight", ".mlp.up_proj.weight", ".mlp.down_proj.weight" + }; + std::string embeddingName = "transformer.word_embeddings.weight"; + std::string logitsName = "transformer.lm_head.weight"; + std::map > > ret; + ret[embeddingName].push_back(std::make_pair(embeddingName, DataType::DATA_AUTO_EMBEDDING)); + for (int i = 0; i < model->block_cnt; i++) { + std::string pre = "transformer.h." + std::to_string(i); + for (auto &it : linearNames) { + ret[pre + it].push_back(std::make_pair(pre + it, DataType::DATA_AUTO_LINEAR)); + } + } + for (auto &name : tensorNames) { + if (ret[name].size() == 0) { + ret[name].push_back(std::make_pair(name, DataType::DATA_AUTO_NONE)); + } + } + if (ret.find(logitsName) == ret.end()) { + ret[embeddingName].push_back(std::make_pair(logitsName, DataType::DATA_AUTO_LINEAR)); + } else { + ret[logitsName][0].second = DataType::DATA_AUTO_LINEAR; + } + if (ret.find(logitsName) == ret.end()) { + ret[embeddingName].push_back(std::make_pair(logitsName, DataType::DATA_AUTO_LINEAR)); + } else { + ret[logitsName][0].second = DataType::DATA_AUTO_LINEAR; + } + return ret; + } + + void BuildGraph(GraphLLMModel *model) { + auto &graph = *(model->GetGraph()); + std::map wNodes; + for (auto &it : model->weight.weight) { + wNodes[it.first] = ComputeGraphNode(it.first); + } + ComputeGraphNode inputIds("inputIds"), positionIds("positionIds"), attentionMask("attentionMask"), atype("atype"), sin("sin"), cos("cos"), seqLens("seqLens"); + ComputeGraphNode hiddenStates("hiddenStates"), attenInput("attenInput"), attenOutput("attenOutput"), attenLastOutput("attenLastOutput"); + ComputeGraphNode q("q"), kv("kv"), k("k"), v("v"), w1("w1"), w2("w2"), w3("w3"), lastTokensStates("lastTokensStates"), logits("logits"); + graph.Embedding(inputIds, wNodes["transformer.word_embeddings.weight"], hiddenStates); + graph.DataTypeAs(hiddenStates, atype); + for (int i = 0; i < model->block_cnt; i++) { + std::string pre = "transformer.h." + std::to_string(i); + ComputeGraphNode pastKey("pastKey_" + std::to_string(i)), pastValue("pastValue_" + std::to_string(i)); + graph.RMSNorm(hiddenStates, wNodes[pre + ".input_layernorm.weight"], model->rms_norm_eps, attenInput); + graph.Linear(attenInput, wNodes[pre + ".self_attention.query.weight"], wNodes[pre + ".self_attention.query.bias"], q); + graph.Linear(attenInput, wNodes[pre + ".self_attention.key_value.weight"], wNodes[pre + ".self_attention.key_value.bias"], kv); + graph.ExpandHead(kv, model->head_dim * 2); + graph.Split(kv, -1, 0, model->head_dim, k); + graph.Split(kv, -1, model->head_dim, model->head_dim * 2, v); + graph.ExpandHead(q, model->head_dim); + graph.LlamaRotatePosition2D(q, positionIds, sin, cos, model->rotary_dim); + graph.LlamaRotatePosition2D(k, positionIds, sin, cos, model->rotary_dim); + graph.FusedAttention(q, pastKey, pastValue, k, v, attenInput, attentionMask, attenOutput, seqLens, 1.0 / sqrt(model->head_dim), 0, 128); + graph.Linear(attenOutput, wNodes[pre + ".self_attention.dense.weight"], wNodes[pre + ".self_attention.dense.bias"], attenLastOutput); + graph.AddTo(hiddenStates, attenLastOutput); + graph.RMSNorm(hiddenStates, wNodes[pre + ".post_attention_layernorm.weight"], model->rms_norm_eps, attenInput); + graph.Linear(attenInput, wNodes[pre + ".mlp.gate_proj.weight"], wNodes[pre + ".mlp.gate_proj.bias"], w1); + graph.Linear(attenInput, wNodes[pre + ".mlp.up_proj.weight"], wNodes[pre + ".mlp.up_proj.bias"], w3); + graph.Silu(w1, w1); + graph.MulTo(w1, w3); + graph.Linear(w1, wNodes[pre + ".mlp.down_proj.weight"], wNodes[pre + ".mlp.down_proj.bias"], w2); + graph.AddTo(hiddenStates, w2); + } + + graph.SplitLastTokenStates(hiddenStates, seqLens, lastTokensStates); + graph.RMSNorm(lastTokensStates, wNodes["transformer.ln_f.weight"], model->rms_norm_eps, lastTokensStates); + graph.Linear(lastTokensStates, wNodes["transformer.lm_head.weight"], wNodes["transformer.lm_head.bias"], logits); + + OptimizeComputeGraph(graph, model->weight); + graph.Update(); + } + }; + REGISTERGRAPHMODELCONFIG(telechat, TeleChatGraphModelConfig) +}