Skip to content

Commit

Permalink
Merge pull request #338 from TylunasLi/bug_fix_att_opt
Browse files Browse the repository at this point in the history
兼容ChatGLM-6B最初版本,并修正ChatGLM2-6B的prompt构造问题
  • Loading branch information
ztxz16 committed Oct 7, 2023
2 parents 80a3917 + f7b17f9 commit baf2404
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 32 deletions.
10 changes: 7 additions & 3 deletions include/models/chatglm.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ namespace fastllm {
public:
ChatGLMModel (); // 构造函数

virtual void InitParams(); // 初始化参数信息

// 推理
virtual int Forward(
virtual int Forward(
const Data &inputIds,
const Data &attentionMask,
const Data &positionIds,
Expand Down Expand Up @@ -56,7 +58,7 @@ namespace fastllm {
const std::vector <std::map <std::string, int> > &params,
Data &inputIds, Data &attentionMask, Data &positionIds);

virtual void WarmUp(); // 预热
virtual void WarmUp(); // 预热

virtual std::string MakeInput(const std::string &history, int round, const std::string &input); // 根据历史信息和当前输入生成prompt

Expand All @@ -66,7 +68,9 @@ namespace fastllm {

void UpdateSinCos(float rope);
private:
virtual void CausalMask(Data &data, int start) {}; // 因果mask?
virtual void CausalMask(Data &data, int start) {}; // 因果mask?

int gmask_token_id;

float rope = 1.0f;
};
Expand Down
56 changes: 29 additions & 27 deletions src/models/chatglm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,31 @@ namespace fastllm {
ChatGLMModel::ChatGLMModel() {
this->model_type = "chatglm";

this->bos_token_id = 130004;
this->eos_token_id = 130005;
this->bos_token_id = 130004; // V1 后期版本 bos token,可通过 config.json 覆盖
this->eos_token_id = 130005; // V1 后期版本 eos token,可通过 config.json 覆盖
this->gmask_token_id= 150001; // V1最初版本, 150528 tokens,部分 config.json 没有 gmask_token_id,因此取默认值。

this->rope = -1.0;
this->UpdateSinCos(1.0f);
weight.embeddingNames.insert("transformer.word_embeddings.weight");
weight.embeddingNames.insert("transformer.embedding.word_embeddings.weight");
}

void ChatGLMModel::InitParams() {
basellm::InitParams();
if (GetVersion() == 1) {
if (this->weight.dicts.find("gmask_token_id") != this->weight.dicts.end()) {
this->gmask_token_id = atoi(this->weight.dicts["gmask_token_id"].c_str());
}
} else if (GetVersion() == 2) {
this->gmask_token_id = 64790;
this->bos_token_id = 64792;
}
if (this->weight.dicts.find("rope_ratio") != this->weight.dicts.end()) {
UpdateSinCos(atof(this->weight.dicts["rope_ratio"].c_str()));
}
}

int ChatGLMModel::Forward(const fastllm::Data &inputIds, const fastllm::Data &attentionMask,
const fastllm::Data &positionIds, std::vector<std::pair<Data, Data>> &pastKeyValues,
const GenerationConfig &generationConfig, const LastTokensManager &lastTokens,
Expand All @@ -86,9 +102,6 @@ namespace fastllm {
const GenerationConfig &generationConfig,
const LastTokensManager &lastTokens,
std::vector <std::vector <float>*> *retLogits) {
if (this->weight.dicts.find("rope_ratio") != this->weight.dicts.end()) {
UpdateSinCos(atof(this->weight.dicts["rope_ratio"].c_str()));
}
int maxLen = inputIds.dims[1];
Data inputEmbeddings;
Data attenInput;
Expand Down Expand Up @@ -328,9 +341,6 @@ namespace fastllm {
const std::vector <GenerationConfig> &generationConfigs,
const LastTokensManager &lastTokens,
std::vector <std::vector <float>*> *retLogits) {
if (this->weight.dicts.find("rope_ratio") != this->weight.dicts.end()) {
UpdateSinCos(atof(this->weight.dicts["rope_ratio"].c_str()));
}
int seqLen = inputIds.dims[1];
sinData.ToDevice(DataDevice::CUDA);
cosData.ToDevice(DataDevice::CUDA);
Expand Down Expand Up @@ -712,8 +722,6 @@ namespace fastllm {
attentionMask.ToDevice(DataDevice::CPU);
positionIds.ToDevice(DataDevice::CPU);

int gmask_token_id = this->weight.dicts.find("gmask_token_id") != this->weight.dicts.end() ?
atoi(this->weight.dicts["gmask_token_id"].c_str()) : 130001;
int index = params.find("index")->second;
int promptLen = params.find("promptLen")->second;

Expand All @@ -723,9 +731,9 @@ namespace fastllm {
ids.push_back(gmask_token_id);
ids.push_back(bos_token_id);
} else if (GetVersion() == 2) {
if (ids.size() < 2 || ids[0] != 64790 || ids[1] != 64792) {
ids.insert(ids.begin(), 64792);
ids.insert(ids.begin(), 64790);
if (ids.size() < 2 || ids[0] != this->gmask_token_id || ids[1] != this->bos_token_id) {
ids.insert(ids.begin(), this->bos_token_id);
ids.insert(ids.begin(), this->gmask_token_id);
}
}
}
Expand Down Expand Up @@ -775,8 +783,6 @@ namespace fastllm {
int batch = inputTokens.size();
int index = params[0].find("index")->second;
if (index == 0) {
int gmask_token_id = this->weight.dicts.find("gmask_token_id") != this->weight.dicts.end() ?
atoi(this->weight.dicts["gmask_token_id"].c_str()) : 130001;
std::vector<int> seqLens;
seqLens.resize(batch);
int maxLen = 0;
Expand Down Expand Up @@ -815,8 +821,8 @@ namespace fastllm {
} else {
auto &tokens = inputTokens[i];
int len = tokens.size(), base = maxLen - 2 - len;
ids[i * maxLen + base] = 64790;
ids[i * maxLen + base + 1] = 64792;
ids[i * maxLen + base] = gmask_token_id;
ids[i * maxLen + base + 1] = bos_token_id;
for (int j = 0; j < len; j++) {
ids[i * maxLen + base + 2 + j] = tokens[j];
}
Expand Down Expand Up @@ -889,28 +895,24 @@ namespace fastllm {
}

std::string ChatGLMModel::MakeInput(const std::string &history, int round, const std::string &input) {
if (GetVersion() == 2)
round++;
if (round == 0 && GetVersion() == 1) {
return input;
} else {
#if defined(_WIN32) or defined(_WIN64)
std::vector <uint8_t> vask = {233, 151, 174, 239, 188, 154, 0};
std::vector <uint8_t> vans = {231, 173, 148, 239, 188, 154, 0};
std::string sask = (char*)vask.data();
std::string sans = (char*)vans.data();
return (history + ("[Round " + std::to_string(round) + "]\n\n" + sask + input + "\n\n" + sans));
return history + ("[Round " + std::to_string(round) + u8"]\n\n问:" + input + u8"\n\n答:");
#else
return history + ("[Round " + std::to_string(round) + "]\n\n问:" + input + "\n\n答:");
#endif
}
}

std::string ChatGLMModel::MakeHistory(const std::string &history, int round, const std::string &input, const std::string &output) {
if (GetVersion() == 2)
round++;
#if defined(_WIN32) or defined(_WIN64)
std::vector <uint8_t> vask = {233, 151, 174, 239, 188, 154, 0};
std::vector <uint8_t> vans = {231, 173, 148, 239, 188, 154, 0};
std::string sask = (char*)vask.data();
std::string sans = (char*)vans.data();
return (history + ("[Round " + std::to_string(round) + "]\n\n" + sask + input + "\n\n" + sans + output + "\n"));
return (history + ("[Round " + std::to_string(round) + u8"]\n\n问:" + input + u8"\n\n答:" + output + "\n"));
#else
return (history + ("[Round " + std::to_string(round) + "]\n\n问:" + input + "\n\n答:" + output + "\n\n"));
#endif
Expand Down
4 changes: 2 additions & 2 deletions tools/fastllm_pytools/torch2flm.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,10 @@ def tofile(exportPath,
for v in vocab.keys():
if (modelInfo['model_type'] == "qwen"):
s = v
elif (modelInfo["model_type"] == "moss"):
s = [(ord(c) if c not in tokenizer.byte_decoder else tokenizer.byte_decoder[c]) for c in v]
else:
s = v.encode()
if (modelInfo["model_type"] == "moss"):
s = [(ord(c) if c not in tokenizer.byte_decoder else tokenizer.byte_decoder[c]) for c in v]
fo.write(struct.pack('i', len(s)))
for c in s:
fo.write(struct.pack('i', c))
Expand Down

0 comments on commit baf2404

Please sign in to comment.