Skip to content

Commit

Permalink
Merge pull request #362 from wangyumu/add-stop-tokens-ids
Browse files Browse the repository at this point in the history
Fixes #360 add stop_token_ids
  • Loading branch information
ztxz16 committed Nov 8, 2023
2 parents f97fff8 + 171e3c1 commit 737a58a
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 18 deletions.
2 changes: 1 addition & 1 deletion include/fastllm.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ namespace fastllm {
float temperature = 1.0; // 温度参数,一般在0.1 ~ 1.0之间,设大这个参数可以带来结果的多样性
bool output_logits = false; // 是否返回logits
bool enable_hash_id = false; // 给会话添加hash id

std::multiset <int> stop_token_ids;

bool IsSimpleGreedy() const {
if (fabs(repeat_penalty - 1) > 1e-8) {
Expand Down
11 changes: 11 additions & 0 deletions src/models/basellm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,11 @@ namespace fastllm {
inputTokens[i] = std::vector <float> {(float)ret[i]};
if (ret[i] == eos_token_id) {
isEnding[i] = true;
} else {
auto itStopTk = generationConfig.stop_token_ids.find(ret[i]);
if (itStopTk != generationConfig.stop_token_ids.end()) {
isEnding[i] = true;
}
}
if (isEnding[i]) {
curStrings.push_back("");
Expand Down Expand Up @@ -659,6 +664,12 @@ printf("tot = %d\n", tot);
if (curRet == model->eos_token_id) {
it.second->isEnding = true;
} else {
auto itStopTk = it.second->generationConfig.stop_token_ids.find(curRet);
if (itStopTk != it.second->generationConfig.stop_token_ids.end()) {
it.second->isEnding = true;
}
}
if (it.second->isEnding == false) {
it.second->currentTokens = std::vector<int>{curRet};
it.second->resultTokenQueue.push(curRet);
it.second->tokens.Push(curRet);
Expand Down
6 changes: 6 additions & 0 deletions src/models/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,12 @@ namespace fastllm {
if (curRet == model->eos_token_id) {
it.second->isEnding = true;
} else {
auto itStopTk = it.second->generationConfig.stop_token_ids.find(curRet);
if (itStopTk != it.second->generationConfig.stop_token_ids.end()) {
it.second->isEnding = true;
}
}
if (it.second->isEnding == false) {
it.second->currentTokens = std::vector<int>{curRet};
it.second->resultTokenQueue.push(curRet);
it.second->tokens.Push(curRet);
Expand Down
51 changes: 36 additions & 15 deletions tools/fastllm_pytools/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@

fastllm_lib.launch_response_llm_model.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_void_p,
ctypes.c_int, ctypes.c_bool, ctypes.c_float, ctypes.c_int,
ctypes.c_float, ctypes.c_float, ctypes.c_bool]
ctypes.c_float, ctypes.c_float, ctypes.c_bool,
ctypes.c_int, ctypes.POINTER(ctypes.c_int)]
fastllm_lib.launch_response_llm_model.restype = ctypes.c_int

fastllm_lib.fetch_response_llm_model.argtypes = [ctypes.c_int, ctypes.c_int]
Expand All @@ -39,7 +40,8 @@

fastllm_lib.launch_response_str_llm_model.argtype = [ctypes.c_int, ctypes.c_char_p,
ctypes.c_int, ctypes.c_bool, ctypes.c_float, ctypes.c_int,
ctypes.c_float, ctypes.c_float, ctypes.c_bool]
ctypes.c_float, ctypes.c_float, ctypes.c_bool,
ctypes.c_int, ctypes.POINTER(ctypes.c_int)]
fastllm_lib.launch_response_str_llm_model.restype = ctypes.c_int

fastllm_lib.fetch_response_str_llm_model.argtypes = [ctypes.c_int, ctypes.c_int]
Expand Down Expand Up @@ -201,19 +203,29 @@ def tokenizer_decode_token(self, token_id: int) -> bytes:
break
return buffer_bytes[:result_len]

def stop_token_ctypes(self, stop_token_ids):
if stop_token_ids is None:
return 0, None
else:
return ctypes.c_int(len(stop_token_ids)), (ctypes.c_int * len(stop_token_ids))(*stop_token_ids)

def response_logits(self,
query: str,
history: List[Tuple[str, str]] = None,
tokenizer = None) -> str:
tokenizer = None,
stop_token_ids: List[int] = None,
) -> str:
prompt = query if self.direct_query else self.get_prompt(query, history);
stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids)
if (tokenizer == None):
handle = fastllm_lib.launch_response_str_llm_model(self.model, prompt.encode(),
ctypes.c_int(1), ctypes.c_bool(False), ctypes.c_float(1), ctypes.c_int(1),
ctypes.c_float(1), ctypes.c_float(1), ctypes.c_bool(True));
ctypes.c_float(1), ctypes.c_float(1), ctypes.c_bool(True),
stop_token_len, stop_token_list);
else:
input = tokenizer.encode(prompt);
handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
1, False, 1, 1, 1, 1, True);
1, False, 1, 1, 1, 1, True, stop_token_len, stop_token_list);
vocab_size = fastllm_lib.get_tokenizer_vocab_size(self.model);
logits = list(range(vocab_size))
array = (ctypes.c_float * (vocab_size * 4))(*logits);
Expand All @@ -226,7 +238,8 @@ def response_logits(self,
def response(self,
query: str,
history: List[Tuple[str, str]] = None,
max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0) -> str:
max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0,
stop_token_ids: List[int] = None) -> str:
ret = "";
for i in self.stream_response(query = query,
history = history,
Expand All @@ -235,19 +248,22 @@ def response(self,
top_p = top_p, top_k = top_k,
temperature = temperature,
repeat_penalty = repeat_penalty,
one_by_one = True):
one_by_one = True,
stop_token_ids = stop_token_ids):
ret += i;
return ret;

def stream_response(self,
query: str,
history: List[Tuple[str, str]] = None,
max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0,
one_by_one = True):
one_by_one = True, stop_token_ids: List[int] = None):
prompt = query if self.direct_query else self.get_prompt(query, history);
stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids);
handle = fastllm_lib.launch_response_str_llm_model(self.model, prompt.encode(),
ctypes.c_int(max_length), ctypes.c_bool(do_sample), ctypes.c_float(top_p), ctypes.c_int(top_k),
ctypes.c_float(temperature), ctypes.c_float(repeat_penalty), ctypes.c_bool(False));
ctypes.c_float(temperature), ctypes.c_float(repeat_penalty), ctypes.c_bool(False),
stop_token_len, stop_token_list);
res = "";
ret = b'';
fail_cnt = 0;
Expand Down Expand Up @@ -275,12 +291,15 @@ def stream_response(self,
def stream_response_raw(self,
input_tokens: List[int],
max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0,
one_by_one = True
one_by_one = True,
stop_token_ids: List[int] = None
):
stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids)
handle = fastllm_lib.launch_response_llm_model(self.model, len(input_tokens),
(ctypes.c_int * len(input_tokens))(*input_tokens),
ctypes.c_int(max_length), ctypes.c_bool(do_sample), ctypes.c_float(top_p), ctypes.c_int(top_k),
ctypes.c_float(temperature), ctypes.c_float(repeat_penalty), ctypes.c_bool(False))
ctypes.c_float(temperature), ctypes.c_float(repeat_penalty), ctypes.c_bool(False),
stop_token_len, stop_token_list)

# 可能遇到长尾char需要多个token才能够生成,所以只返回bytes,string.decode策略交给外部
# 方便统计输出token数量,和控制不完整utf8时候解码的逻辑
Expand All @@ -300,14 +319,15 @@ def stream_response_raw(self,
yield total_bytes

def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192,
do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0, **kwargs):
do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0, stop_token_ids: List[int] = None, **kwargs):
if (not(history)):
history = [];
prompt = query if self.direct_query else self.get_prompt(query, history);
input = tokenizer.encode(prompt);
stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids)
handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
max_length, do_sample, top_p, top_k, temperature, repeat_penalty,
False);
False, stop_token_len, stop_token_list);

result = [];
while True:
Expand All @@ -321,14 +341,15 @@ def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max

def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, past_key_values = None,
max_length: int = 8192, do_sample = True, top_p = 0.8, top_k = 1, temperature = 1.0, repeat_penalty = 1.0,
return_past_key_values = False, **kwargs) -> str:
return_past_key_values = False, stop_token_ids: List[int] = None, **kwargs) -> str:
if (not(history)):
history = [];
prompt = query if self.direct_query else self.get_prompt(query, history);
input = tokenizer.encode(prompt);
stop_token_len, stop_token_list = self.stop_token_ctypes(stop_token_ids)
handle = fastllm_lib.launch_response_llm_model(self.model, len(input), (ctypes.c_int * len(input))(*input),
max_length, do_sample, top_p, top_k, temperature, repeat_penalty,
False);
False, stop_token_len, stop_token_list);
tokens = [];
while True:
cur = fastllm_lib.fetch_response_llm_model(self.model, handle);
Expand Down
14 changes: 12 additions & 2 deletions tools/src/pytools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,14 +241,19 @@ extern "C" {

DLL_EXPORT int launch_response_str_llm_model(int modelId, char *content,
int max_length, bool do_sample, float top_p, int top_k,
float temperature, float repeat_penalty, bool output_logits) {
float temperature, float repeat_penalty, bool output_logits,
int stop_token_len, int * stop_token_ids) {
auto model = models.GetModel(modelId);
std::vector <int> tokens;
auto v = model->weight.tokenizer.Encode(content);
for (int i = 0; i < v.Count(0); i++) {
tokens.push_back((int)((float*)v.cpuData)[i]);
}
auto config = make_config(max_length, do_sample, top_p, top_k, temperature, repeat_penalty, output_logits);
for(int i = 0; i < stop_token_len; i++ )
{
config.stop_token_ids.insert(stop_token_ids[i]);
}
return model->LaunchResponseTokens(tokens, config);
}

Expand All @@ -261,12 +266,17 @@ extern "C" {

DLL_EXPORT int launch_response_llm_model(int modelId, int len, int *values,
int max_length, bool do_sample, float top_p, int top_k,
float temperature, float repeat_penalty, bool output_logits) {
float temperature, float repeat_penalty, bool output_logits,
int stop_token_len, int * stop_token_ids) {
std::vector <int> input;
for (int i = 0; i < len; i++) {
input.push_back(values[i]);
}
auto config = make_config(max_length, do_sample, top_p, top_k, temperature, repeat_penalty, output_logits);
for(int i = 0; i < stop_token_len; i++ )
{
config.stop_token_ids.insert(stop_token_ids[i]);
}
auto model = models.GetModel(modelId);
return model->LaunchResponseTokens(input, config);
}
Expand Down

0 comments on commit 737a58a

Please sign in to comment.