Skip to content

Commit

Permalink
fix chatglm
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Jul 11, 2024
1 parent 721497b commit f98f3b9
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
3 changes: 2 additions & 1 deletion src/models/chatglm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,8 @@ namespace fastllm {
if (GetVersion() == 1) {
positionIds.CopyFrom(Data(DataType::FLOAT32, {2, 1}, {(float) promptLen, (float) (index + 1)}));
} else {
positionIds.CopyFrom(Data(DataType::FLOAT32, {2, 1}, {(float) promptLen + index + 1, (float) (index + 1)}));
int gap = add_special_tokens ? 1 : -1;
positionIds.CopyFrom(Data(DataType::FLOAT32, {2, 1}, {(float) promptLen + index + gap, (float) (index + gap)}));
}
}
}
Expand Down
9 changes: 5 additions & 4 deletions tools/src/pytools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ extern "C" {
}

DLL_EXPORT fastllm::GenerationConfig make_config(int max_length, bool do_sample, float top_p, int top_k,
float temperature, float repeat_penalty, bool output_logits) {
float temperature, float repeat_penalty, bool output_logits, bool add_special_tokens) {
fastllm::GenerationConfig config;
config.output_token_limit = max_length;
config.temperature = temperature;
Expand All @@ -90,6 +90,7 @@ extern "C" {
config.top_k = top_k;
}
config.output_logits = output_logits;
config.add_special_tokens = add_special_tokens;
return config;
}

Expand Down Expand Up @@ -314,7 +315,7 @@ extern "C" {
int max_length, bool do_sample, float top_p, int top_k,
float temperature, float repeat_penalty, bool output_logits) {
auto model = models.GetModel(modelId);
auto config = make_config(max_length, do_sample, top_p, top_k, temperature, repeat_penalty, output_logits);
auto config = make_config(max_length, do_sample, top_p, top_k, temperature, repeat_penalty, output_logits, true);
std::string s = model->Response(content, nullptr, config);
return string_to_chars(s);
}
Expand All @@ -329,7 +330,7 @@ extern "C" {
for (int i = 0; i < v.Count(0); i++) {
tokens.push_back((int)((float*)v.cpuData)[i]);
}
auto config = make_config(max_length, do_sample, top_p, top_k, temperature, repeat_penalty, output_logits);
auto config = make_config(max_length, do_sample, top_p, top_k, temperature, repeat_penalty, output_logits, true);
for(int i = 0; i < stop_token_len; i++ )
{
config.stop_token_ids.insert(stop_token_ids[i]);
Expand Down Expand Up @@ -358,7 +359,7 @@ extern "C" {
for (int i = 0; i < len; i++) {
input.push_back(values[i]);
}
auto config = make_config(max_length, do_sample, top_p, top_k, temperature, repeat_penalty, output_logits);
auto config = make_config(max_length, do_sample, top_p, top_k, temperature, repeat_penalty, output_logits, false);
for(int i = 0; i < stop_token_len; i++ )
{
config.stop_token_ids.insert(stop_token_ids[i]);
Expand Down

0 comments on commit f98f3b9

Please sign in to comment.