diff --git a/src/models/chatglm.cpp b/src/models/chatglm.cpp index dccffa56..312b35dd 100644 --- a/src/models/chatglm.cpp +++ b/src/models/chatglm.cpp @@ -757,7 +757,8 @@ namespace fastllm { if (GetVersion() == 1) { positionIds.CopyFrom(Data(DataType::FLOAT32, {2, 1}, {(float) promptLen, (float) (index + 1)})); } else { - positionIds.CopyFrom(Data(DataType::FLOAT32, {2, 1}, {(float) promptLen + index + 1, (float) (index + 1)})); + int gap = add_special_tokens ? 1 : -1; + positionIds.CopyFrom(Data(DataType::FLOAT32, {2, 1}, {(float) promptLen + index + gap, (float) (index + gap)})); } } } diff --git a/tools/src/pytools.cpp b/tools/src/pytools.cpp index e24e9f7e..ed7b79b7 100644 --- a/tools/src/pytools.cpp +++ b/tools/src/pytools.cpp @@ -80,7 +80,7 @@ extern "C" { } DLL_EXPORT fastllm::GenerationConfig make_config(int max_length, bool do_sample, float top_p, int top_k, - float temperature, float repeat_penalty, bool output_logits) { + float temperature, float repeat_penalty, bool output_logits, bool add_special_tokens) { fastllm::GenerationConfig config; config.output_token_limit = max_length; config.temperature = temperature; @@ -90,6 +90,7 @@ extern "C" { config.top_k = top_k; } config.output_logits = output_logits; + config.add_special_tokens = add_special_tokens; return config; } @@ -314,7 +315,7 @@ extern "C" { int max_length, bool do_sample, float top_p, int top_k, float temperature, float repeat_penalty, bool output_logits) { auto model = models.GetModel(modelId); - auto config = make_config(max_length, do_sample, top_p, top_k, temperature, repeat_penalty, output_logits); + auto config = make_config(max_length, do_sample, top_p, top_k, temperature, repeat_penalty, output_logits, true); std::string s = model->Response(content, nullptr, config); return string_to_chars(s); } @@ -329,7 +330,7 @@ extern "C" { for (int i = 0; i < v.Count(0); i++) { tokens.push_back((int)((float*)v.cpuData)[i]); } - auto config = make_config(max_length, do_sample, top_p, top_k, temperature, repeat_penalty, output_logits); + auto config = make_config(max_length, do_sample, top_p, top_k, temperature, repeat_penalty, output_logits, true); for(int i = 0; i < stop_token_len; i++ ) { config.stop_token_ids.insert(stop_token_ids[i]); @@ -358,7 +359,7 @@ extern "C" { for (int i = 0; i < len; i++) { input.push_back(values[i]); } - auto config = make_config(max_length, do_sample, top_p, top_k, temperature, repeat_penalty, output_logits); + auto config = make_config(max_length, do_sample, top_p, top_k, temperature, repeat_penalty, output_logits, false); for(int i = 0; i < stop_token_len; i++ ) { config.stop_token_ids.insert(stop_token_ids[i]);