Skip to content

Commit

Permalink
补一些算子
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Jun 27, 2024
1 parent 10cfc24 commit b1e6c8e
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 4 deletions.
2 changes: 2 additions & 0 deletions include/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ namespace fastllm {
void Update();

void AddTo(ComputeGraphNode &input0, ComputeGraphNode &input1, float alpha = 1.0); // input0 += input1 * alpha
void DataTypeAs(ComputeGraphNode &input, ComputeGraphNode &input1); // 将input的dataType设成和input1一样
void Embedding(ComputeGraphNode &input, ComputeGraphNode &weight, ComputeGraphNode &output);
void ExpandHead(ComputeGraphNode &input, int headDim);
void FusedAttention(ComputeGraphNode &q, ComputeGraphNode &k, ComputeGraphNode &v,
Expand All @@ -65,6 +66,7 @@ namespace fastllm {
void Silu(ComputeGraphNode &input, ComputeGraphNode &output);
void Split(ComputeGraphNode &input, int axis, int start, int end, ComputeGraphNode &output);
void SplitLastTokenStates(ComputeGraphNode &input, ComputeGraphNode &output);
void Swiglu(ComputeGraphNode &input, ComputeGraphNode &output);

// 以下op用于调试
void Exit(); // 退出
Expand Down
32 changes: 31 additions & 1 deletion src/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,20 @@ namespace fastllm {
auto data = allDatas[op.datas.find("input")->second];
data->ToDevice(DataDevice::CPU);
data->Print();
} else if (op.type == "DataTypeAs") {
auto input = allDatas[op.datas.find("input")->second];
DataType dataType = allDatas[op.datas.find("input1")->second]->dataType;
if (input->dataType != dataType) {
if (dataType == DataType::FLOAT32) {
excutor.Run("ToFloat32", {
{"input", input}
}, {}, {});
} else if (dataType == DataType::FLOAT16) {
excutor.Run("ToFloat16", {
{"input", input}
}, {}, {});
}
}
} else if (op.type == "ExpandHeads") {
auto data = allDatas[op.datas.find("input")->second];
int headDim = op.intParams.find("headDim")->second;
Expand Down Expand Up @@ -207,6 +221,14 @@ namespace fastllm {
);
}

void ComputeGraph::DataTypeAs(ComputeGraphNode &input, ComputeGraphNode &input1) {
this->ops.push_back (
ComputeGraphOp("DataTypeAs",
{{"input", input.name}, {"input1", input1.name}},
{}, {})
);
}

void ComputeGraph::MulTo(ComputeGraphNode &input0, ComputeGraphNode &input1) {
this->ops.push_back (
ComputeGraphOp("MulTo",
Expand All @@ -218,7 +240,15 @@ namespace fastllm {
void ComputeGraph::Silu(ComputeGraphNode &input, ComputeGraphNode &output) {
this->ops.push_back (
ComputeGraphOp("Silu",
{{"input", "w1"}, {"output", "w1"}},
{{"input", input.name}, {"output", output.name}},
{}, {})
);
}

void ComputeGraph::Swiglu(ComputeGraphNode &input, ComputeGraphNode &output) {
this->ops.push_back (
ComputeGraphOp("Swiglu",
{{"input", input.name}, {"output", output.name}},
{}, {})
);
}
Expand Down
3 changes: 2 additions & 1 deletion src/models/basellm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,8 @@ printf("len = %d, spend = %f s. tokens / s = %f\n", (int)total, spend, (float)to

} else if (dataType == DataType::FLOAT16) {
AssertInFastLLM(this->model_struct == "chatglm" ||
this->model_struct == "llama",
this->model_struct == "llama" ||
this->model_struct == "graph",
this->model_struct + " doesn't support float16");
} else {
ErrorInFastLLM("SetDataType Error: datatype should be float32 or float16");
Expand Down
13 changes: 11 additions & 2 deletions src/models/graphllm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,12 @@ namespace fastllm {
for (auto &it : weight.weight) {
weightDicts[it.first] = &it.second;
}
Data atype = Data(this->dataType);
std::map <std::string, Data*> inputs = {
{"inputIds", (Data*)&inputIds},
{"positionIds", (Data*)&positionIds},
{"attentionMask", (Data*)&attentionMask},
{"atype", (Data*)&atype},
{"sin", &sinData}, {"cos", &cosData}
};
for (int i = 0; i < block_cnt; i++) {
Expand Down Expand Up @@ -250,10 +252,11 @@ namespace fastllm {
for (auto &it : model->weight.weight) {
wNodes[it.first] = ComputeGraphNode(it.first);
}
ComputeGraphNode inputIds("inputIds"), positionIds("positionIds"), attentionMask("attentionMask"), sin("sin"), cos("cos");
ComputeGraphNode inputIds("inputIds"), positionIds("positionIds"), attentionMask("attentionMask"), atype("atype"), sin("sin"), cos("cos");
ComputeGraphNode hiddenStates("hiddenStates"), attenInput("attenInput"), attenOutput("attenOutput"), attenLastOutput("attenLastOutput");
ComputeGraphNode q("q"), k("k"), v("v"), w1("w1"), w2("w2"), w3("w3"), lastTokensStates("lastTokensStates"), logits("logits");
graph.Embedding(inputIds, wNodes["model.embed_tokens.weight"], hiddenStates);
graph.DataTypeAs(hiddenStates, atype);
for (int i = 0; i < model->block_cnt; i++) {
std::string pre = "model.layers." + std::to_string(i);
ComputeGraphNode pastKey("pastKey_" + std::to_string(i)), pastValue("pastValue_" + std::to_string(i));
Expand Down Expand Up @@ -289,6 +292,11 @@ namespace fastllm {
model->max_positions = atoi(model->weight.dicts["seq_length"].c_str());
model->rope_base = 10000 * pow(3, ((float)model->rotary_dim / (model->rotary_dim - 2)));
model->rope_factor = 1.0;

model->pre_prompt = "";
model->user_role = "<_user>";
model->bot_role = "<_bot>";
model->history_sep = "";
}

std::map <std::string, std::vector <std::pair <std::string, DataType> > >
Expand Down Expand Up @@ -331,10 +339,11 @@ namespace fastllm {
for (auto &it : model->weight.weight) {
wNodes[it.first] = ComputeGraphNode(it.first);
}
ComputeGraphNode inputIds("inputIds"), positionIds("positionIds"), attentionMask("attentionMask"), sin("sin"), cos("cos");
ComputeGraphNode inputIds("inputIds"), positionIds("positionIds"), attentionMask("attentionMask"), atype("atype"), sin("sin"), cos("cos");
ComputeGraphNode hiddenStates("hiddenStates"), attenInput("attenInput"), attenOutput("attenOutput"), attenLastOutput("attenLastOutput");
ComputeGraphNode q("q"), kv("kv"), k("k"), v("v"), w1("w1"), w2("w2"), w3("w3"), lastTokensStates("lastTokensStates"), logits("logits");
graph.Embedding(inputIds, wNodes["transformer.word_embeddings.weight"], hiddenStates);
graph.DataTypeAs(hiddenStates, atype);
for (int i = 0; i < model->block_cnt; i++) {
std::string pre = "transformer.h." + std::to_string(i);
ComputeGraphNode pastKey("pastKey_" + std::to_string(i)), pastValue("pastValue_" + std::to_string(i));
Expand Down

0 comments on commit b1e6c8e

Please sign in to comment.