diff --git a/include/graph.h b/include/graph.h index 27bf717..f3d3a04 100644 --- a/include/graph.h +++ b/include/graph.h @@ -52,6 +52,7 @@ namespace fastllm { void Update(); void AddTo(ComputeGraphNode &input0, ComputeGraphNode &input1, float alpha = 1.0); // input0 += input1 * alpha + void DataTypeAs(ComputeGraphNode &input, ComputeGraphNode &input1); // 将input的dataType设成和input1一样 void Embedding(ComputeGraphNode &input, ComputeGraphNode &weight, ComputeGraphNode &output); void ExpandHead(ComputeGraphNode &input, int headDim); void FusedAttention(ComputeGraphNode &q, ComputeGraphNode &k, ComputeGraphNode &v, @@ -65,6 +66,7 @@ namespace fastllm { void Silu(ComputeGraphNode &input, ComputeGraphNode &output); void Split(ComputeGraphNode &input, int axis, int start, int end, ComputeGraphNode &output); void SplitLastTokenStates(ComputeGraphNode &input, ComputeGraphNode &output); + void Swiglu(ComputeGraphNode &input, ComputeGraphNode &output); // 以下op用于调试 void Exit(); // 退出 diff --git a/src/graph.cpp b/src/graph.cpp index 1862afe..e6cba90 100644 --- a/src/graph.cpp +++ b/src/graph.cpp @@ -36,6 +36,20 @@ namespace fastllm { auto data = allDatas[op.datas.find("input")->second]; data->ToDevice(DataDevice::CPU); data->Print(); + } else if (op.type == "DataTypeAs") { + auto input = allDatas[op.datas.find("input")->second]; + DataType dataType = allDatas[op.datas.find("input1")->second]->dataType; + if (input->dataType != dataType) { + if (dataType == DataType::FLOAT32) { + excutor.Run("ToFloat32", { + {"input", input} + }, {}, {}); + } else if (dataType == DataType::FLOAT16) { + excutor.Run("ToFloat16", { + {"input", input} + }, {}, {}); + } + } } else if (op.type == "ExpandHeads") { auto data = allDatas[op.datas.find("input")->second]; int headDim = op.intParams.find("headDim")->second; @@ -207,6 +221,14 @@ namespace fastllm { ); } + void ComputeGraph::DataTypeAs(ComputeGraphNode &input, ComputeGraphNode &input1) { + this->ops.push_back ( + ComputeGraphOp("DataTypeAs", + {{"input", input.name}, {"input1", input1.name}}, + {}, {}) + ); + } + void ComputeGraph::MulTo(ComputeGraphNode &input0, ComputeGraphNode &input1) { this->ops.push_back ( ComputeGraphOp("MulTo", @@ -218,7 +240,15 @@ namespace fastllm { void ComputeGraph::Silu(ComputeGraphNode &input, ComputeGraphNode &output) { this->ops.push_back ( ComputeGraphOp("Silu", - {{"input", "w1"}, {"output", "w1"}}, + {{"input", input.name}, {"output", output.name}}, + {}, {}) + ); + } + + void ComputeGraph::Swiglu(ComputeGraphNode &input, ComputeGraphNode &output) { + this->ops.push_back ( + ComputeGraphOp("Swiglu", + {{"input", input.name}, {"output", output.name}}, {}, {}) ); } diff --git a/src/models/basellm.cpp b/src/models/basellm.cpp index d8070e3..d298604 100644 --- a/src/models/basellm.cpp +++ b/src/models/basellm.cpp @@ -896,7 +896,8 @@ printf("len = %d, spend = %f s. tokens / s = %f\n", (int)total, spend, (float)to } else if (dataType == DataType::FLOAT16) { AssertInFastLLM(this->model_struct == "chatglm" || - this->model_struct == "llama", + this->model_struct == "llama" || + this->model_struct == "graph", this->model_struct + " doesn't support float16"); } else { ErrorInFastLLM("SetDataType Error: datatype should be float32 or float16"); diff --git a/src/models/graphllm.cpp b/src/models/graphllm.cpp index 1166af9..1f31a5f 100644 --- a/src/models/graphllm.cpp +++ b/src/models/graphllm.cpp @@ -114,10 +114,12 @@ namespace fastllm { for (auto &it : weight.weight) { weightDicts[it.first] = &it.second; } + Data atype = Data(this->dataType); std::map inputs = { {"inputIds", (Data*)&inputIds}, {"positionIds", (Data*)&positionIds}, {"attentionMask", (Data*)&attentionMask}, + {"atype", (Data*)&atype}, {"sin", &sinData}, {"cos", &cosData} }; for (int i = 0; i < block_cnt; i++) { @@ -250,10 +252,11 @@ namespace fastllm { for (auto &it : model->weight.weight) { wNodes[it.first] = ComputeGraphNode(it.first); } - ComputeGraphNode inputIds("inputIds"), positionIds("positionIds"), attentionMask("attentionMask"), sin("sin"), cos("cos"); + ComputeGraphNode inputIds("inputIds"), positionIds("positionIds"), attentionMask("attentionMask"), atype("atype"), sin("sin"), cos("cos"); ComputeGraphNode hiddenStates("hiddenStates"), attenInput("attenInput"), attenOutput("attenOutput"), attenLastOutput("attenLastOutput"); ComputeGraphNode q("q"), k("k"), v("v"), w1("w1"), w2("w2"), w3("w3"), lastTokensStates("lastTokensStates"), logits("logits"); graph.Embedding(inputIds, wNodes["model.embed_tokens.weight"], hiddenStates); + graph.DataTypeAs(hiddenStates, atype); for (int i = 0; i < model->block_cnt; i++) { std::string pre = "model.layers." + std::to_string(i); ComputeGraphNode pastKey("pastKey_" + std::to_string(i)), pastValue("pastValue_" + std::to_string(i)); @@ -289,6 +292,11 @@ namespace fastllm { model->max_positions = atoi(model->weight.dicts["seq_length"].c_str()); model->rope_base = 10000 * pow(3, ((float)model->rotary_dim / (model->rotary_dim - 2))); model->rope_factor = 1.0; + + model->pre_prompt = ""; + model->user_role = "<_user>"; + model->bot_role = "<_bot>"; + model->history_sep = ""; } std::map > > @@ -331,10 +339,11 @@ namespace fastllm { for (auto &it : model->weight.weight) { wNodes[it.first] = ComputeGraphNode(it.first); } - ComputeGraphNode inputIds("inputIds"), positionIds("positionIds"), attentionMask("attentionMask"), sin("sin"), cos("cos"); + ComputeGraphNode inputIds("inputIds"), positionIds("positionIds"), attentionMask("attentionMask"), atype("atype"), sin("sin"), cos("cos"); ComputeGraphNode hiddenStates("hiddenStates"), attenInput("attenInput"), attenOutput("attenOutput"), attenLastOutput("attenLastOutput"); ComputeGraphNode q("q"), kv("kv"), k("k"), v("v"), w1("w1"), w2("w2"), w3("w3"), lastTokensStates("lastTokensStates"), logits("logits"); graph.Embedding(inputIds, wNodes["transformer.word_embeddings.weight"], hiddenStates); + graph.DataTypeAs(hiddenStates, atype); for (int i = 0; i < model->block_cnt; i++) { std::string pre = "transformer.h." + std::to_string(i); ComputeGraphNode pastKey("pastKey_" + std::to_string(i)), pastValue("pastValue_" + std::to_string(i));