Skip to content

Commit

Permalink
Merge pull request #456 from ztxz16/revert-451-master
Browse files Browse the repository at this point in the history
Revert "添加add_special_tokens选项,默认true,支持chatglm模型"
  • Loading branch information
ztxz16 committed May 16, 2024
2 parents a7d1bd6 + e1370d4 commit 7e56deb
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 37 deletions.
1 change: 0 additions & 1 deletion include/fastllm.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ namespace fastllm {
float temperature = 1.0; // 温度参数,一般在0.1 ~ 1.0之间,设大这个参数可以带来结果的多样性
bool output_logits = false; // 是否返回logits
bool enable_hash_id = false; // 给会话添加hash id
bool add_special_tokens = true; // prompt添加special tokens(chatglm模型生效)
std::multiset <int> stop_token_ids;

bool IsSimpleGreedy() const {
Expand Down
8 changes: 2 additions & 6 deletions src/models/basellm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,7 @@ namespace fastllm {
std::vector<float> results;
LastTokensManager tokens(1, generationConfig.last_n);
int promptLen = inputTokens[0].size(), index = 0;
int add_special_tokens = generationConfig.add_special_tokens? 1: 0;
FillLLMInputs(inputTokens, {{"promptLen", promptLen}, {"index", index}, {"add_special_tokens", add_special_tokens}},
inputIds, attentionMask, positionIds);
FillLLMInputs(inputTokens, {{"promptLen", promptLen}, {"index", index}}, inputIds, attentionMask, positionIds);
while (true) {
auto st = std::chrono::system_clock::now();
int ret = Forward(inputIds, attentionMask, positionIds, pastKeyValues, generationConfig, tokens);
Expand Down Expand Up @@ -123,8 +121,7 @@ namespace fastllm {
results.clear();

inputTokens[0] = std::vector<float> {(float)ret};
FillLLMInputs(inputTokens, {{"promptLen", promptLen}, {"index", index}}, {"add_special_tokens", add_special_tokens}
inputIds, attentionMask, positionIds);
FillLLMInputs(inputTokens, {{"promptLen", promptLen}, {"index", index}}, inputIds, attentionMask, positionIds);
if (index == generationConfig.output_token_limit) {
break;
}
Expand Down Expand Up @@ -198,7 +195,6 @@ namespace fastllm {
}
params[0]["index"] = 0;
int index = 0;
params[0]["add_special_tokens"] = generationConfig.add_special_tokens? 1: 0;

LastTokensManager tokensManager (batch, generationConfig.last_n);
std::vector <bool> isEnding = std::vector <bool> (batch, false);
Expand Down
48 changes: 18 additions & 30 deletions src/models/chatglm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -751,19 +751,16 @@ namespace fastllm {

int index = params.find("index")->second;
int promptLen = params.find("promptLen")->second;
bool add_special_tokens = params.find("add_special_tokens")->second == 0? false: true;

if (index == 0) {
if (add_special_tokens) {
for (auto &ids: inputTokens) {
if (GetVersion() == 1) {
ids.push_back(gmask_token_id);
ids.push_back(bos_token_id);
} else if (GetVersion() == 2) {
if (ids.size() < 2 || ids[0] != this->gmask_token_id || ids[1] != this->bos_token_id) {
ids.insert(ids.begin(), this->bos_token_id);
ids.insert(ids.begin(), this->gmask_token_id);
}
for (auto &ids: inputTokens) {
if (GetVersion() == 1) {
ids.push_back(gmask_token_id);
ids.push_back(bos_token_id);
} else if (GetVersion() == 2) {
if (ids.size() < 2 || ids[0] != this->gmask_token_id || ids[1] != this->bos_token_id) {
ids.insert(ids.begin(), this->bos_token_id);
ids.insert(ids.begin(), this->gmask_token_id);
}
}
}
Expand Down Expand Up @@ -812,17 +809,12 @@ namespace fastllm {

int batch = inputTokens.size();
int index = params[0].find("index")->second;
bool add_special_tokens = params[0].find("add_special_tokens")->second == 0? false: true;
int special_tokens_offset = 0;
if (add_special_tokens) {
special_tokens_offset = 2;
}
if (index == 0) {
std::vector<int> seqLens;
seqLens.resize(batch);
int maxLen = 0;
for (int i = 0; i < batch; i++) {
maxLen = std::max(maxLen, (int) inputTokens[i].size() + special_tokens_offset);
maxLen = std::max(maxLen, (int) inputTokens[i].size() + 2);
seqLens[i] = (int) inputTokens[i].size();
}

Expand All @@ -832,15 +824,13 @@ namespace fastllm {
for (int i = 0; i < batch; i++) {
if (GetVersion() == 1) {
auto &tokens = inputTokens[i];
int len = tokens.size(), base = maxLen - special_tokens_offset - len;
int len = tokens.size(), base = maxLen - 2 - len;
for (int j = 0; j < len; j++) {
ids[i * maxLen + base + j] = tokens[j];
}
if (add_special_tokens) {
ids[i * maxLen + base + len] = gmask_token_id;
ids[i * maxLen + base + len + 1] = bos_token_id;
}
len += special_tokens_offset;
ids[i * maxLen + base + len] = gmask_token_id;
ids[i * maxLen + base + len + 1] = bos_token_id;
len += 2;
for (int j = 0; j < len - 1; j++) {
vpids[i * 2 * maxLen + base + j] = j;
}
Expand All @@ -857,15 +847,13 @@ namespace fastllm {
}
} else {
auto &tokens = inputTokens[i];
int len = tokens.size(), base = maxLen - special_tokens_offset - len;
if (add_special_tokens) {
ids[i * maxLen + base] = gmask_token_id;
ids[i * maxLen + base + 1] = bos_token_id;
}
int len = tokens.size(), base = maxLen - 2 - len;
ids[i * maxLen + base] = gmask_token_id;
ids[i * maxLen + base + 1] = bos_token_id;
for (int j = 0; j < len; j++) {
ids[i * maxLen + base + special_tokens_offset + j] = tokens[j];
ids[i * maxLen + base + 2 + j] = tokens[j];
}
len += special_tokens_offset;
len += 2;
for (int j = 0; j < len; j++) {
vpids[i * 2 * maxLen + base + j] = j;
}
Expand Down

0 comments on commit 7e56deb

Please sign in to comment.