Skip to content

Commit 1e76d88

Browse files
committed
support output length truncated
1 parent ddcb722 commit 1e76d88

File tree

3 files changed

+13
-2
lines changed

3 files changed

+13
-2
lines changed

csrc/gpu/update_inputs_v2.cu

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ __global__ void update_inputs_kernel_v2(
4444
const int input_ids_stride,
4545
const int end_length) {
4646
int thread_idx = threadIdx.x;
47+
bool output_truncated = false;
4748
// update step_idx and stop_flags
4849
if (thread_idx < max_bsz) {
4950
bool stop_flag = stop_flags[thread_idx];
@@ -52,6 +53,7 @@ __global__ void update_inputs_kernel_v2(
5253
}
5354
if (step_idx[thread_idx] >= max_dec_len[thread_idx]) {
5455
stop_flags[thread_idx] = true;
56+
output_truncated = true;
5557
}
5658
}
5759
__syncthreads();
@@ -61,8 +63,13 @@ __global__ void update_inputs_kernel_v2(
6163
if (seq_lens_this_time[thread_idx] == 0) {
6264
next_tokens[thread_idx] = -1;
6365
} else {
64-
next_tokens[thread_idx] = end_ids[0];
65-
kwargs_next_tokens[thread_idx] = end_ids[0];
66+
if (output_truncated){
67+
next_tokens[thread_idx] = -4; // -4 for truncated output.
68+
kwargs_next_tokens[thread_idx] = -4;
69+
}else{
70+
next_tokens[thread_idx] = end_ids[0];
71+
kwargs_next_tokens[thread_idx] = end_ids[0];
72+
}
6673
}
6774
} else {
6875
kwargs_next_tokens[thread_idx] = next_tokens[thread_idx];

llm/predict/predictor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1258,6 +1258,7 @@ def predict(self, input_texts: list[str], return_tokens=False):
12581258
outputs = self._infer(self.model_inputs)
12591259
outputs = outputs.numpy()
12601260
outputs[outputs == -1] = self.tokenizer.eos_token_id
1261+
outputs[outputs < 0] = self.tokenizer.pad_token_id
12611262
output_token.append(outputs)
12621263
logger.info(f"running spend {time.time() - s_time}")
12631264

@@ -1568,6 +1569,7 @@ def send_task_to_queue(task_id):
15681569
if flag_current_rank_run:
15691570
output_tokens = self.model_inputs["all_token_ids"].numpy()
15701571
output_tokens[output_tokens == -1] = self.tokenizer.eos_token_id
1572+
output_tokens[output_tokens < 0] = self.tokenizer.pad_token_id
15711573
if detokenize:
15721574
outputs = self.tokenizer.batch_decode(
15731575
output_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False
@@ -1786,6 +1788,7 @@ def predict(self, input_texts: list[str], return_tokens=False):
17861788
outputs = self.predictor.run(list(self.model_inputs.values()))[0]
17871789
outputs = outputs.numpy()
17881790
outputs[outputs == -1] = self.tokenizer.eos_token_id
1791+
outputs[outputs < 0] = self.tokenizer.pad_token_id
17891792
output_token.append(outputs)
17901793
logger.info(f"running spend {time.time() - s_time}")
17911794

paddlenlp/trl/llm_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -689,6 +689,7 @@ def read_res(
689689
bsz = int(output_tensor[1, 0])
690690
output_numpy = output_tensor[2 : bsz + 2].numpy()
691691
output_numpy[output_numpy == -1] = tokenizer.eos_token_id
692+
output_numpy[output_numpy < 0] = tokenizer.pad_token_id
692693
outputs.append(output_numpy)
693694
if int(output_tensor[0, 0]) == -1:
694695
break

0 commit comments

Comments
 (0)