Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No white space included in tokens sent back by Llama2 in streaming mode #332

Open
3 of 4 tasks
jfpichlme opened this issue Feb 7, 2024 · 22 comments
Open
3 of 4 tasks
Assignees
Labels
bug Something isn't working triaged Issue has been triaged by maintainers

Comments

@jfpichlme
Copy link

System Info

  • DGX H100
  • TensorrtLlm 0.7.1

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

1. Set up LLama2 (7b, 13b, 70b) in streaming mode:

model_config:

name: "tensorrt_llm"
backend: "tensorrtllm"
max_batch_size: 300

model_transaction_policy {
  decoupled: True
}

dynamic_batching {
}

input [
  {
    name: "input_ids"
    data_type: TYPE_INT32
    dims: [ -1 ]
    allow_ragged_batch: true
  },
  {
    name: "input_lengths"
    data_type: TYPE_INT32
    dims: [ 1 ]
    reshape: { shape: [ ] }
  },
  {
    name: "request_output_len"
    data_type: TYPE_INT32
    dims: [ 1 ]
  },
  {
    name: "draft_input_ids"
    data_type: TYPE_INT32
    dims: [ -1 ]
    optional: true
    allow_ragged_batch: true
  },
  {
    name: "end_id"
    data_type: TYPE_INT32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "pad_id"
    data_type: TYPE_INT32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "stop_words_list"
    data_type: TYPE_INT32
    dims: [ 2, -1 ]
    optional: true
    allow_ragged_batch: true
  },
  {
    name: "bad_words_list"
    data_type: TYPE_INT32
    dims: [ 2, -1 ]
    optional: true
    allow_ragged_batch: true
  },
  {
    name: "embedding_bias"
    data_type: TYPE_FP32
    dims: [ -1 ]
    optional: true
    allow_ragged_batch: true
  },
  {
    name: "beam_width"
    data_type: TYPE_INT32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "temperature"
    data_type: TYPE_FP32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "runtime_top_k"
    data_type: TYPE_INT32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "runtime_top_p"
    data_type: TYPE_FP32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "len_penalty"
    data_type: TYPE_FP32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "repetition_penalty"
    data_type: TYPE_FP32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "min_length"
    data_type: TYPE_INT32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "presence_penalty"
    data_type: TYPE_FP32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "random_seed"
    data_type: TYPE_UINT64
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "return_log_probs"
    data_type: TYPE_BOOL
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "stop"
    data_type: TYPE_BOOL
    dims: [ 1 ]
    optional: true
  },
  {
    name: "streaming"
    data_type: TYPE_BOOL
    dims: [ 1 ]
    optional: true
  },
  {
    name: "prompt_embedding_table"
    data_type: TYPE_FP16
    dims: [ -1, -1 ]
    optional: true
    allow_ragged_batch: true
  },
  {
    name: "prompt_vocab_size"
    data_type: TYPE_INT32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  }
]
output [
  {
    name: "output_ids"
    data_type: TYPE_INT32
    dims: [ -1, -1 ]
  },
  {
    name: "sequence_length"
    data_type: TYPE_INT32
    dims: [ -1 ]
  },
  {
    name: "cum_log_probs"
    data_type: TYPE_FP32
    dims: [ -1 ]
  },
  {
    name: "output_log_probs"
    data_type: TYPE_FP32
    dims: [ -1, -1 ]
  }
]
instance_group [
  {
    count: 1
    kind : KIND_CPU
  }
]
parameters: {
  key: "max_beam_width"
  value: {
    string_value: "1"
  }
}
parameters: {
  key: "FORCE_CPU_ONLY_INPUT_TENSORS"
  value: {
    string_value: "no"
  }
}
parameters: {
  key: "gpt_model_type"
  value: {
    string_value: "inflight_fused_batching"
  }
}
parameters: {
  key: "gpt_model_path"
  value: {
    string_value: "/Llama2/03_Model_Dir_RT_070224/03_Model_Dir/01_Llama2_70b_TP8_300_STR/tensorrt_llm/1"
  }
}
parameters: {
  key: "max_tokens_in_paged_kv_cache"
  value: {
    string_value: "180000"
  }
}
parameters: {
  key: "max_attention_window_size"
  value: {
    string_value: "max_sequence_length"
  }
}
parameters: {
  key: "batch_scheduler_policy"
  value: {
    string_value: "max_utilization"
  }
}
parameters: {
  key: "max_num_sequences"
  value: {
    string_value: "1000"
  }
}
parameters: {
  key: "enable_trt_overlap"
  value: {
    string_value: "True"
  }
}
parameters: {
  key: "exclude_input_in_output"
  value: {
    string_value: "True"
  }
}
parameters: {
  key: "enable_kv_cache_reuse"
  value: {
    string_value: "False"
  }
}
parameters: {
  key: "normalize_log_probs"
  value: {
    string_value: "True"
  }
}

preprocessing:


name: "preprocessing"
backend: "python"
max_batch_size: 300
input [
    {
        name: "QUERY"
        data_type: TYPE_STRING
        dims: [ -1 ]
    },
    {
        name: "REQUEST_OUTPUT_LEN"
        data_type: TYPE_INT32
        dims: [ -1 ]
    },
    {
        name: "BAD_WORDS_DICT"
        data_type: TYPE_STRING
        dims: [ -1 ]
        optional: true
    },
    {
        name: "STOP_WORDS_DICT"
        data_type: TYPE_STRING
        dims: [ -1 ]
        optional: true
    },
    {
        name: "EMBEDDING_BIAS_WORDS"
        data_type: TYPE_STRING
        dims: [ -1 ]
        optional: true
    },
    {
        name: "EMBEDDING_BIAS_WEIGHTS"
        data_type: TYPE_FP32
        dims: [ -1 ]
        optional: true
    }
]
output [
    {
        name: "INPUT_ID"
        data_type: TYPE_INT32
        dims: [ -1 ]
    },
    {
        name: "REQUEST_INPUT_LEN"
        data_type: TYPE_INT32
        dims: [ 1 ]
    },
    {
        name: "BAD_WORDS_IDS"
        data_type: TYPE_INT32
        dims: [ 2, -1 ]
    },
    {
        name: "STOP_WORDS_IDS"
        data_type: TYPE_INT32
        dims: [ 2, -1 ]
    },
    {
        name: "EMBEDDING_BIAS"
        data_type: TYPE_FP32
        dims: [ -1 ]
    },
    {
        name: "REQUEST_OUTPUT_LEN"
        data_type: TYPE_INT32
        dims: [ -1 ]
    }
]

parameters {
  key: "tokenizer_dir"
  value: {
    string_value: "/Llama2/01_HF_Model_Folder/01_LLama_70B/Llama-2-70b-chat-hf"
  }
}

parameters {
  key: "tokenizer_type"
  value: {
    string_value: "llama"
  }
}

parameters {
  key: "add_special_tokens"
  value: {
    string_value: "False"
  }
}

instance_group [
    {
        count: 1
        kind: KIND_CPU
    }
]

postprocessing:


name: "postprocessing"
backend: "python"
max_batch_size: 300
input [
  {
    name: "TOKENS_BATCH"
    data_type: TYPE_INT32
    dims: [ -1, -1 ]
  },
  {
    name: "SEQUENCE_LENGTH"
    data_type: TYPE_INT32
    dims: [ -1 ]
  },
  {
    name: "CUM_LOG_PROBS"
    data_type: TYPE_FP32
    dims: [ -1 ]
  },
  {
    name: "OUTPUT_LOG_PROBS"
    data_type: TYPE_FP32
    dims: [ -1, -1 ]
  }
]
output [
  {
    name: "OUTPUT"
    data_type: TYPE_STRING
    dims: [ -1 ]
  },
  {
    name: "OUT_CUM_LOG_PROBS"
    data_type: TYPE_FP32
    dims: [ -1 ]
  },
  {
    name: "OUT_OUTPUT_LOG_PROBS"
    data_type: TYPE_FP32
    dims: [ -1, -1 ]
  }
]

parameters {
  key: "tokenizer_dir"
  value: {
    string_value: "/Llama2/01_HF_Model_Folder/01_LLama_70B/Llama-2-70b-chat-hf"
  }
}

parameters {
  key: "tokenizer_type"
  value: {
    string_value: "llama"
  }
}

parameters {
  key: "skip_special_tokens"
  value: {
    string_value: "False"
  }
}

instance_group [
    {
        count: 1
        kind: KIND_CPU
    }
]

ensemble:

name: "ensemble"
platform: "ensemble"
max_batch_size: 300
input [
  {
    name: "text_input"
    data_type: TYPE_STRING
    dims: [ -1 ]
  },
  {
    name: "max_tokens"
    data_type: TYPE_INT32
    dims: [ -1 ]
  },
  {
   name: "bad_words"
   data_type: TYPE_STRING
   dims: [ -1 ]
   optional: true
  },
  {
   name: "stop_words"
   data_type: TYPE_STRING
   dims: [ -1 ]
   optional: true
  },
  {
    name: "end_id"
    data_type: TYPE_INT32
    dims: [ 1 ]
    optional: true
  },
  {
    name: "pad_id"
    data_type: TYPE_INT32
    dims: [ 1 ]
    optional: true
  },
  {
    name: "top_k"
    data_type: TYPE_INT32
    dims: [ 1 ]
    optional: true
  },
  {
    name: "top_p"
    data_type: TYPE_FP32
    dims: [ 1 ]
    optional: true
  },
  {
    name: "temperature"
    data_type: TYPE_FP32
    dims: [ 1 ]
    optional: true
  },
  {
    name: "length_penalty"
    data_type: TYPE_FP32
    dims: [ 1 ]
    optional: true
  },
  {
    name: "repetition_penalty"
    data_type: TYPE_FP32
    dims: [ 1 ]
    optional: true
  },
  {
    name: "min_length"
    data_type: TYPE_INT32
    dims: [ 1 ]
    optional: true
  },
  {
    name: "presence_penalty"
    data_type: TYPE_FP32
    dims: [ 1 ]
    optional: true
  },
  {
    name: "random_seed"
    data_type: TYPE_UINT64
    dims: [ 1 ]
    optional: true
  },
  {
    name: "return_log_probs"
    data_type: TYPE_BOOL
    dims: [ 1 ]
    optional: true
  },
  {
    name: "beam_width"
    data_type: TYPE_INT32
    dims: [ 1 ]
    optional: true
  },
  {
    name: "stream"
    data_type: TYPE_BOOL
    dims: [ 1 ]
    optional: true
  },
  {
    name: "prompt_embedding_table"
    data_type: TYPE_FP16
    dims: [ -1, -1 ]
    optional: true
  },
  {
    name: "prompt_vocab_size"
    data_type: TYPE_INT32
    dims: [ 1 ]
    optional: true
  },
  {
      name: "embedding_bias_words"
      data_type: TYPE_STRING
      dims: [ -1 ]
      optional: true
  },
  {
      name: "embedding_bias_weights"
      data_type: TYPE_FP32
      dims: [ -1 ]
      optional: true
  }
]
output [
  {
    name: "text_output"
    data_type: TYPE_STRING
    dims: [ -1 ]
  },
  {
    name: "cum_log_probs"
    data_type: TYPE_FP32
    dims: [ -1 ]
  },
  {
    name: "output_log_probs"
    data_type: TYPE_FP32
    dims: [ -1, -1 ]
  }
]
ensemble_scheduling {
  step [
    {
      model_name: "preprocessing"
      model_version: -1
      input_map {
        key: "QUERY"
        value: "text_input"
      }
      input_map {
        key: "REQUEST_OUTPUT_LEN"
        value: "max_tokens"
      }
      input_map {
        key: "BAD_WORDS_DICT"
        value: "bad_words"
      }
      input_map {
        key: "STOP_WORDS_DICT"
        value: "stop_words"
      }
      input_map {
        key: "EMBEDDING_BIAS_WORDS"
        value: "embedding_bias_words"
      }
      input_map {
        key: "EMBEDDING_BIAS_WEIGHTS"
        value: "embedding_bias_weights"
      }
      output_map {
        key: "REQUEST_INPUT_LEN"
        value: "_REQUEST_INPUT_LEN"
      }
      output_map {
        key: "INPUT_ID"
        value: "_INPUT_ID"
      }
      output_map {
        key: "REQUEST_OUTPUT_LEN"
        value: "_REQUEST_OUTPUT_LEN"
      }
      output_map {
        key: "STOP_WORDS_IDS"
        value: "_STOP_WORDS_IDS"
      }
      output_map {
        key: "BAD_WORDS_IDS"
        value: "_BAD_WORDS_IDS"
      }
      output_map {
        key: "EMBEDDING_BIAS"
        value: "_EMBEDDING_BIAS"
      }
    },
    {
      model_name: "tensorrt_llm"
      model_version: -1
      input_map {
        key: "input_ids"
        value: "_INPUT_ID"
      }
      input_map {
        key: "input_lengths"
        value: "_REQUEST_INPUT_LEN"
      }
      input_map {
        key: "request_output_len"
        value: "_REQUEST_OUTPUT_LEN"
      }
      input_map {
          key: "end_id"
          value: "end_id"
      }
      input_map {
          key: "pad_id"
          value: "pad_id"
      }
      input_map {
          key: "embedding_bias"
          value: "_EMBEDDING_BIAS"
      }
      input_map {
          key: "runtime_top_k"
          value: "top_k"
      }
      input_map {
          key: "runtime_top_p"
          value: "top_p"
      }
      input_map {
          key: "temperature"
          value: "temperature"
      }
      input_map {
          key: "len_penalty"
          value: "length_penalty"
      }
      input_map {
          key: "repetition_penalty"
          value: "repetition_penalty"
      }
      input_map {
          key: "min_length"
          value: "min_length"
      }
      input_map {
          key: "presence_penalty"
          value: "presence_penalty"
      }
      input_map {
          key: "random_seed"
          value: "random_seed"
      }
      input_map {
          key: "return_log_probs"
          value: "return_log_probs"
      }
      input_map {
          key: "beam_width"
          value: "beam_width"
      }
      input_map {
          key: "streaming"
          value: "stream"
      }
      input_map {
        key: "prompt_embedding_table"
        value: "prompt_embedding_table"
      }
      input_map {
        key: "prompt_vocab_size"
        value: "prompt_vocab_size"
      }
      input_map {
        key: "stop_words_list"
        value: "_STOP_WORDS_IDS"
      }
      input_map {
        key: "bad_words_list"
        value: "_BAD_WORDS_IDS"
      }
      output_map {
        key: "output_ids"
        value: "_TOKENS_BATCH"
      }
      output_map {
        key: "sequence_length"
        value: "_SEQUENCE_LENGTH"
      },
      output_map {
        key: "cum_log_probs"
        value: "_CUM_LOG_PROBS"
      }
      output_map {
        key: "output_log_probs"
        value: "_OUTPUT_LOG_PROBS"
      }
    },
    {
      model_name: "postprocessing"
      model_version: -1
      input_map {
        key: "TOKENS_BATCH"
        value: "_TOKENS_BATCH"
      }
      input_map {
        key: "CUM_LOG_PROBS"
        value: "_CUM_LOG_PROBS"
      }
      input_map {
        key: "OUTPUT_LOG_PROBS"
        value: "_OUTPUT_LOG_PROBS"
      }
      input_map {
        key: "SEQUENCE_LENGTH"
        value: "_SEQUENCE_LENGTH"
      }
      output_map {
        key: "OUTPUT"
        value: "text_output"
      }
      output_map {
        key: "OUT_OUTPUT_LOG_PROBS"
        value: "output_log_probs"
      }
      output_map {
        key: "OUT_CUM_LOG_PROBS"
        value: "cum_log_probs"
      }
    }
  ]
}

2. Use Nvidia client notebook (Install does not work, but downloading langchain_nvidia_trt.llms directly solves the problem)

https://github.com/NVIDIA/GenerativeAIExamples/blob/main/notebooks/01-llm-streaming-client.ipynb

(I have also written my own grpc client which produces the same output)

3. Send inference request via grpc to the triton

Expected behavior

Produce output tokens including whitespace:

The fastest land animal is the cheetah, which can run up to 70 miles per hour(1

actual behavior

##Triton produces output tokens without whitespace:

Thefastestlandanimalisthecheetah,whichcanrunupto70milesperhour(1

additional notes

I am not too sure if this is a bug or that I am missing some flag. Any help is highly appreciated

Model build:

python convert_checkpoint.py --model_dir /Llama2/ \
                              --output_dir /Llama2/03_TensorRT_0102_Model_Dir/01_LLama_7B_TP1/01_Converted_Weights/  \
                              --dtype float16 \
			      --tp_size 8
trtllm-build --checkpoint_dir /Llama2/03_TensorRT_0102_Model_Dir/01_LLama_7B_TP1/01_Converted_Weights/ \
             --output_dir /Llama2/03_TensorRT_0102_Model_Dir/01_Engine_Dir/01_LLama_7B_TP1/02_Build_Engines/ \
             --gpt_attention_plugin float16 \
             --gemm_plugin float16 \
             --remove_input_padding enable \
             --paged_kv_cache enable \
             --enable_xqa enable \
	     --paged_kv_cache enable \
             --max_batch_size 300
@jfpichlme jfpichlme added the bug Something isn't working label Feb 7, 2024
@philross
Copy link

philross commented Feb 8, 2024

We are experiencing the same issue.

@jfpichlme
Copy link
Author

For now I am using a workaround that is probably not ideal. In the postprocessing script (/postprocessing/1/model.py) I changed the _postprocessing function to return the actual token ids.

    def _postprocessing(self, tokens_batch, sequence_lengths):
        outputs = []
        for batch_idx, beam_tokens in enumerate(tokens_batch):
            for beam_idx, tokens in enumerate(beam_tokens):
                seq_len = sequence_lengths[batch_idx][beam_idx]
                output = tokens[:seq_len]
                outputs.append(output)
        return outputs

I collect all the Token Ids on the User side and then decode the entire sequence which produces the correct output.

@enochlev
Copy link

enochlev commented Feb 8, 2024

The tokenizers in transoformers do not support this function automatically when calling decode function

The standard way of going about this is holding tokens in cache until a space is detected, in which everything after the space is put again into cache.

The other suggested method decodes the token_id text instead of the string text to look for a "_" symbol

here is a work around with text using the second method

    def _postprocessing(self, tokens_batch, sequence_lengths):
        outputs = []
        for batch_idx, beam_tokens in enumerate(tokens_batch):
            for beam_idx, tokens in enumerate(beam_tokens):
                seq_len = sequence_lengths[batch_idx][beam_idx]
                output = self.tokenizer.decode(
                    tokens[:seq_len],
                    skip_special_tokens=False)
                
                # for streamming mode
                token_id_string = self.tokenizer.convert_ids_to_tokens(tokens[:seq_len],skip_special_tokens=True)[0]
                if token_id_string[0] == "▁":
                    output = " " + output
                
                outputs.append(output.encode('utf8'))
        return str(output)   

@enochlev
Copy link

enochlev commented Feb 8, 2024

@Shixiaowei02 I can create a PR for this

@byshiue
Copy link
Collaborator

byshiue commented Feb 9, 2024

Have you tried the tensorrt_llm_bls module?

@enochlev
Copy link

enochlev commented Feb 9, 2024

btw @jfpichlme
how did you get tensorrt-LLM working witht he new workflow specivly with the trtllm-build command? which docker command and version of tensortllm_backend did you use?

@jfpichlme
Copy link
Author

Hi enochlev,

Build the container

I have used Option 2 in the tensorrt-llm backend repo to build the docker container:

# Update the submodules
cd tensorrtllm_backend
git lfs install
git submodule update --init --recursive

# Use the Dockerfile to build the backend in a container
# For x86_64
DOCKER_BUILDKIT=1 docker build -t triton_trt_llm -f dockerfile/Dockerfile.trt_llm_backend .

The docker version is: "22.04" and the tensorrt_llm git version is TensorRT-LLM backend (#324).

Build the Models

This process now consists of two steps, first a covert_checkpoint step. Then a build step.

  1. Perform the Conversion step
    The conversion step is done via the following command:
python convert_checkpoint.py --model_dir $Enter hugginface format model dir \
                              --output_dir ./tllm_checkpoint_1gpu_fp16 \
                              --dtype float16 \
			      --tp_size 8
  1. Perform the Build step:
trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_fp16 \
             --output_dir $Put your directory here  \
             --gpt_attention_plugin float16 \
             --gemm_plugin float16 \
             --remove_input_padding enable \
             --paged_kv_cache enable \
             --enable_xqa enable \
	     --paged_kv_cache enable
             --max_batch_size 300

Combine your Model:

Last step is to copy the created engine files to the tensorrt_llm/1/ directory and adapt the config files. You can see the configs of the model in my initial comment.

I hope this helps you. @byshiue I will test the tensorrt_llm_bls module now.

@ekarmazin
Copy link

@jfpichlme any luck with bls + streaming? I have the same problem and for some reason can't make my grpc client to work with bls.

@jfpichlme
Copy link
Author

Hi ekarmazin,
bls + streaming did not work for me. At the moment I am sticking to the proposed solution where I modify the postprocessing script and buffer the token on the user side. To sort of augment streaming (displaying word by word output), I am decoding around 6-10 Token at a time. However, this does not work perfectly all the time.

@ekarmazin
Copy link

@jfpichlme I kind of got it working with BLS, it does proper output with whitespaces now. But I faced an accuracy problems with enabling --use_paged_context_fmha but that is a different issue.

@ekarmazin
Copy link

@byshiue same issue with bls model. Spaces are presented when accumulate tokens are true, and missing when false.

@schetlur-nv
Copy link
Collaborator

@enochlev apologies for the delayed response. Would you still be able to PR the fix you suggested?

@npuichigo
Copy link

Any update on this?

@enochlev
Copy link

I will find some time around work this week and push an update

@plt12138
Copy link

mark

@quwu0820
Copy link

Mark

@charllll
Copy link

charllll commented Apr 8, 2024

Any update?

@Saigut
Copy link

Saigut commented Apr 12, 2024

Mark

2 similar comments
@HermitSun
Copy link

Mark

@yxchia98
Copy link

Mark

@elinx
Copy link

elinx commented Apr 18, 2024

The tokenizers in transoformers do not support this function automatically when calling decode function

The standard way of going about this is holding tokens in cache until a space is detected, in which everything after the space is put again into cache.

The other suggested method decodes the token_id text instead of the string text to look for a "_" symbol

here is a work around with text using the second method

    def _postprocessing(self, tokens_batch, sequence_lengths):
        outputs = []
        for batch_idx, beam_tokens in enumerate(tokens_batch):
            for beam_idx, tokens in enumerate(beam_tokens):
                seq_len = sequence_lengths[batch_idx][beam_idx]
                output = self.tokenizer.decode(
                    tokens[:seq_len],
                    skip_special_tokens=False)
                
                # for streamming mode
                token_id_string = self.tokenizer.convert_ids_to_tokens(tokens[:seq_len],skip_special_tokens=True)[0]
                if token_id_string[0] == "▁":
                    output = " " + output
                
                outputs.append(output.encode('utf8'))
        return str(output)   

@enochlev
crash if the last token is EOS, a quick fix:

                token_id_string = self.tokenizer.convert_ids_to_tokens(tokens[:seq_len], skip_special_tokens=True)
                if len(token_id_string) > 0 and len(token_id_string[0]) > 0 and token_id_string[0][0] == "▁":
                    output = " " + output

@enochlev
Copy link

@elinx Really appreciate catching that...

I just submitted a PR including your suggestion. It worked it my local environment before a submitted the PR, so it has my approval (if that means anything 😁)

Link to PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests