Skip to content

UNAVAILABLE: Internal: unexpected error when creating modelInstanceState: [json.exception.parse_error.101] parse error at line 1, column 1: syntax error while parsing value - unexpected end of input; expected '[', '{', or a literal #500

@naphatkps

Description

@naphatkps

System Info

  • GPU Name: NVIDIA GeForce RTX 3080 Ti
  • System Ram: 65GB

Who can help?

@juney-nvidia
@byshiue

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

copy engine from TensorRT to Triton server directory.

docker run -it --rm --gpus all -p 8765:8000 --shm-size=1g \
-v ./all_models:/all_models \
-v ./scripts:/opt/scripts \
nvcr.io/nvidia/tritonserver:23.10-trtllm-python-py3

pip install sentencepiece protobuf

python /opt/scripts/launch_triton_server.py --model_repo /all_models/inflight_batcher_llm --world_size 1

config all config.pbtxt file
/all_models/inflight_batcher_llm/tensorrt_llm/config.pbtxt file

# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#  * Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
#  * Redistributions in binary form must reproduce the above copyright
#    notice, this list of conditions and the following disclaimer in the
#    documentation and/or other materials provided with the distribution.
#  * Neither the name of NVIDIA CORPORATION nor the names of its
#    contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

name: "tensorrt_llm"
backend: "tensorrtllm"
max_batch_size: 4

model_transaction_policy {
  decoupled: false
}

dynamic_batching {
    preferred_batch_size: [ 4 ]
    max_queue_delay_microseconds: 100
}

input [
  {
    name: "input_ids"
    data_type: TYPE_INT32
    dims: [ -1 ]
    allow_ragged_batch: true
  },
  {
    name: "input_lengths"
    data_type: TYPE_INT32
    dims: [ 1 ]
    reshape: { shape: [ ] }
  },
  {
    name: "request_output_len"
    data_type: TYPE_INT32
    dims: [ 1 ]
  },
  {
    name: "draft_input_ids"
    data_type: TYPE_INT32
    dims: [ -1 ]
    optional: true
    allow_ragged_batch: true
  },
  {
    name: "draft_logits"
    data_type: TYPE_FP32
    dims: [ -1, -1 ]
    optional: true
    allow_ragged_batch: true
  },
  {
    name: "draft_acceptance_threshold"
    data_type: TYPE_FP32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "end_id"
    data_type: TYPE_INT32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "pad_id"
    data_type: TYPE_INT32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "stop_words_list"
    data_type: TYPE_INT32
    dims: [ 2, -1 ]
    optional: true
    allow_ragged_batch: true
  },
  {
    name: "bad_words_list"
    data_type: TYPE_INT32
    dims: [ 2, -1 ]
    optional: true
    allow_ragged_batch: true
  },
  {
    name: "embedding_bias"
    data_type: TYPE_FP32
    dims: [ -1 ]
    optional: true
    allow_ragged_batch: true
  },
  {
    name: "beam_width"
    data_type: TYPE_INT32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "temperature"
    data_type: TYPE_FP32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "runtime_top_k"
    data_type: TYPE_INT32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "runtime_top_p"
    data_type: TYPE_FP32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "runtime_top_p_min"
    data_type: TYPE_FP32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "runtime_top_p_decay"
    data_type: TYPE_FP32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "runtime_top_p_reset_ids"
    data_type: TYPE_INT32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "len_penalty"
    data_type: TYPE_FP32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "early_stopping"
    data_type: TYPE_BOOL
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "repetition_penalty"
    data_type: TYPE_FP32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "min_length"
    data_type: TYPE_INT32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "beam_search_diversity_rate"
    data_type: TYPE_FP32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "presence_penalty"
    data_type: TYPE_FP32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "frequency_penalty"
    data_type: TYPE_FP32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "random_seed"
    data_type: TYPE_UINT64
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "return_log_probs"
    data_type: TYPE_BOOL
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "return_context_logits"
    data_type: TYPE_BOOL
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "return_generation_logits"
    data_type: TYPE_BOOL
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  {
    name: "stop"
    data_type: TYPE_BOOL
    dims: [ 1 ]
    optional: true
  },
  {
    name: "streaming"
    data_type: TYPE_BOOL
    dims: [ 1 ]
    optional: true
  },
  {
    name: "prompt_embedding_table"
    data_type: TYPE_FP16
    dims: [ -1, -1 ]
    optional: true
    allow_ragged_batch: true
  },
  {
    name: "prompt_vocab_size"
    data_type: TYPE_INT32
    dims: [ 1 ]
    reshape: { shape: [ ] }
    optional: true
  },
  # the unique task ID for the given LoRA.
  # To perform inference with a specific LoRA for the first time `lora_task_id` `lora_weights` and `lora_config` must all be given.
  # The LoRA will be cached, so that subsequent requests for the same task only require `lora_task_id`.
  # If the cache is full the oldest LoRA will be evicted to make space for new ones.  An error is returned if `lora_task_id` is not cached.
  {
    name: "lora_task_id"
	data_type: TYPE_UINT64
	dims: [ 1 ]
    reshape: { shape: [ ] }
	optional: true
  },
  # weights for a lora adapter shape [ num_lora_modules_layers, D x Hi + Ho x D ]
  # where the last dimension holds the in / out adapter weights for the associated module (e.g. attn_qkv) and model layer
  # each of the in / out tensors are first flattened and then concatenated together in the format above.
  # D=adapter_size (R value), Hi=hidden_size_in, Ho=hidden_size_out.
  {
    name: "lora_weights"
	data_type: TYPE_FP16
	dims: [ -1, -1 ]
	optional: true
	allow_ragged_batch: true
  },
  # module identifier (same size a first dimension of lora_weights)
  # See LoraModule::ModuleType for model id mapping
  #
  # "attn_qkv": 0     # compbined qkv adapter
  # "attn_q": 1       # q adapter
  # "attn_k": 2       # k adapter
  # "attn_v": 3       # v adapter
  # "attn_dense": 4   # adapter for the dense layer in attention
  # "mlp_h_to_4h": 5  # for llama2 adapter for gated mlp layer after attention / RMSNorm: up projection
  # "mlp_4h_to_h": 6  # for llama2 adapter for gated mlp layer after attention / RMSNorm: down projection
  # "mlp_gate": 7     # for llama2 adapter for gated mlp later after attention / RMSNorm: gate
  #
  # last dim holds [ module_id, layer_idx, adapter_size (D aka R value) ]
  {
    name: "lora_config"
	data_type: TYPE_INT32
	dims: [ -1, 3 ]
	optional: true
	allow_ragged_batch: true
  }
]
output [
  {
    name: "output_ids"
    data_type: TYPE_INT32
    dims: [ -1, -1 ]
  },
  {
    name: "sequence_length"
    data_type: TYPE_INT32
    dims: [ -1 ]
  },
  {
    name: "cum_log_probs"
    data_type: TYPE_FP32
    dims: [ -1 ]
  },
  {
    name: "output_log_probs"
    data_type: TYPE_FP32
    dims: [ -1, -1 ]
  },
  {
    name: "context_logits"
    data_type: TYPE_FP32
    dims: [ -1, -1 ]
  },
  {
    name: "generation_logits"
    data_type: TYPE_FP32
    dims: [ -1, -1, -1 ]
  }
]
instance_group [
  {
    count: 1
    kind : KIND_CPU
  }
]
parameters: {
  key: "max_beam_width"
  value: {
    string_value: "1"
  }
}
parameters: {
  key: "FORCE_CPU_ONLY_INPUT_TENSORS"
  value: {
    string_value: "no"
  }
}
parameters: {
  key: "gpt_model_type"
  value: {
    string_value: "inflight_fused_batching"
  }
}
parameters: {
  key: "gpt_model_path"
  value: {
    string_value: "all_models/inflight_batcher_llm/tensorrt_llm/1/1gpu"
  }
}
parameters: {
  key: "max_tokens_in_paged_kv_cache"
  value: {
    string_value: "infinite"
  }
}
parameters: {
  key: "max_attention_window_size"
  value: {
    string_value: "64"
  }
}
parameters: {
  key: "sink_token_length"
  value: {
    string_value: "${sink_token_length}"
  }
}
parameters: {
  key: "batch_scheduler_policy"
  value: {
    string_value: "max_utilization"
  }
}
parameters: {
  key: "kv_cache_free_gpu_mem_fraction"
  value: {
    string_value: "0.9"
  }
}
parameters: {
  key: "kv_cache_host_memory_bytes"
  value: {
    string_value: "${kv_cache_host_memory_bytes}"
  }
}
parameters: {
  key: "kv_cache_onboard_blocks"
  value: {
    string_value: "${kv_cache_onboard_blocks}"
  }
}
# enable_trt_overlap is deprecated and doesn't have any effect on the runtime
# parameters: {
#   key: "enable_trt_overlap"
#   value: {
#     string_value: "${enable_trt_overlap}"
#   }
# }
parameters: {
  key: "exclude_input_in_output"
  value: {
    string_value: "false"
  }
}
parameters: {
  key: "cancellation_check_period_ms"
  value: {
    string_value: "${cancellation_check_period_ms}"
  }
}
parameters: {
  key: "stats_check_period_ms"
  value: {
    string_value: "100"
  }
}
parameters: {
  key: "iter_stats_max_iterations"
  value: {
    string_value: "executor::kDefaultIterStatsMaxIterations"
  }
}
parameters: {
  key: "request_stats_max_iterations"
  value: {
    string_value: "executor::kDefaultRequestStatsMaxIterations"
  }
}
parameters: {
  key: "enable_kv_cache_reuse"
  value: {
    string_value: "true"
  }
}
parameters: {
  key: "normalize_log_probs"
  value: {
    string_value: "true"
  }
}
parameters: {
  key: "enable_chunked_context"
  value: {
    string_value: "false"
  }
}
parameters: {
  key: "gpu_device_ids"
  value: {
    string_value: "1"
  }
}
parameters: {
  key: "lora_cache_optimal_adapter_size"
  value: {
    string_value: "8"
  }
}
parameters: {
  key: "lora_cache_max_adapter_size"
  value: {
    string_value: "64"
  }
}
parameters: {
  key: "lora_cache_gpu_memory_fraction"
  value: {
    string_value: "0.05"
  }
}
parameters: {
  key: "lora_cache_host_memory_bytes"
  value: {
    string_value: "1G"
  }
}
parameters: {
  key: "decoding_mode"
  value: {
    string_value: "top_k"
  }
}
parameters: {
  key: "executor_worker_path"
  value: {
    string_value: "/opt/tritonserver/backends/tensorrtllm/trtllmExecutorWorker"
  }
}
parameters: {
  key: "medusa_choices"
    value: {
      string_value: "${medusa_choices}"
  }
}
parameters: {
  key: "gpu_weights_percent"
    value: {
      string_value: "${gpu_weights_percent}"
  }
}

Expected behavior

LoRA model successfully deployed on triton server.

actual behavior

I0617 07:26:19.481372 2404 pinned_memory_manager.cc:241] Pinned memory pool is created at '0x7fcc56000000' with size 268435456                                                                                                             
I0617 07:26:19.481848 2404 cuda_memory_manager.cc:107] CUDA memory pool is created on device 0 with size 67108864                                                                                                                                                               
I0617 07:26:19.481854 2404 cuda_memory_manager.cc:107] CUDA memory pool is created on device 1 with size 67108864                                                                                                                                                               
W0617 07:26:19.546485 2404 server.cc:238] failed to enable peer access for some device pairs                                                                                                                                                                                    
I0617 07:26:19.549115 2404 model_lifecycle.cc:461] loading: postprocessing:1                                                                                                                                                                                                    
I0617 07:26:19.549150 2404 model_lifecycle.cc:461] loading: preprocessing:1                                                                                                                                                                                                     
I0617 07:26:19.549206 2404 model_lifecycle.cc:461] loading: tensorrt_llm:1                                                                                                                                                                                                      
I0617 07:26:19.549257 2404 model_lifecycle.cc:461] loading: tensorrt_llm_bls:1                                                                                                                                                                                                  
I0617 07:26:19.614582 2404 python_be.cc:2199] TRITONBACKEND_ModelInstanceInitialize: preprocessing_0_0 (CPU device 0)                                                                                                                                                           
I0617 07:26:19.614591 2404 python_be.cc:2199] TRITONBACKEND_ModelInstanceInitialize: postprocessing_0_0 (CPU device 0)                                                                                                                                                          
I0617 07:26:19.614646 2404 python_be.cc:2199] TRITONBACKEND_ModelInstanceInitialize: tensorrt_llm_bls_0_0 (CPU device 0)          
E0617 07:26:19.779599 2404 backend_model.cc:634] ERROR: Failed to create instance: unexpected error when creating modelInstanceState: [json.exception.parse_error.101] parse error at line 1, column 1: syntax error while parsing value - unexpected end of input; expected '[', '{', or a literal
E0617 07:26:19.779664 2404 model_lifecycle.cc:621] failed to load 'tensorrt_llm' version 1: Internal: unexpected error when creating modelInstanceState: [json.exception.parse_error.101] parse error at line 1, column 1: syntax error while parsing value - unexpected end of input; expected '[', '{', or a literal
I0617 07:26:19.779672 2404 model_lifecycle.cc:756] failed to load 'tensorrt_llm'                                                                                                                                                                                                
I0617 07:26:19.817906 2404 model_lifecycle.cc:818] successfully loaded 'tensorrt_llm_bls'                                                                                  
[TensorRT-LLM][WARNING] Don't setup 'add_special_tokens' correctly (set value is ${add_special_tokens}). Set it as True by default.                                              
[TensorRT-LLM][WARNING] Don't setup 'skip_special_tokens' correctly (set value is ${skip_special_tokens}). Set it as True by default.                                            
I0617 07:26:21.746793 2404 model_lifecycle.cc:818] successfully loaded 'postprocessing'                                                                                    
I0617 07:26:21.752482 2404 model_lifecycle.cc:818] successfully loaded 'preprocessing'                                                                                     
E0617 07:26:21.752547 2404 model_repository_manager.cc:563] Invalid argument: ensemble 'ensemble' depends on 'tensorrt_llm' which has no loaded version. Model 'tensorrt_llm' loading failed with error: version 1 is at UNAVAILABLE state: Internal: unexpected error when creating modelInstanceState: [json.exception.parse_error.101] parse error at line 1, co
lumn 1: syntax error while parsing value - unexpected end of input; expected '[', '{', or a literal;                              
I0617 07:26:21.752598 2404 server.cc:592]                                                                                                                                                                                                                                       
+------------------+------+                                                                                                                                                                                                                                                     
| Repository Agent | Path |                                                                                                                                                
+------------------+------+                                                                                                       
+------------------+------+                                                                                                       
                                                                                                                                  
I0617 07:26:21.752628 2404 server.cc:619]                                                                                         
+-------------+--------------------------------------------------------+--------------------------------------------------------+ 
| Backend     | Path                                                   | Config                                                 | 
+-------------+--------------------------------------------------------+--------------------------------------------------------+ 
| python      | /opt/tritonserver/backends/python/libtriton_python.so  | {"cmdline":{"auto-complete-config":"false","backend-di | 
|             |                                                        | rectory":"/opt/tritonserver/backends","min-compute-cap | 
|             |                                                        | ability":"6.000000","shm-region-prefix-name":"prefix0_ | 
|             |                                                        | ","default-max-batch-size":"4"}}                       | 
| tensorrtllm | /opt/tritonserver/backends/tensorrtllm/libtriton_tenso | {"cmdline":{"auto-complete-config":"false","backend-di | 
|             | rrtllm.so                                              | rectory":"/opt/tritonserver/backends","min-compute-cap | 
|             |                                                        | ability":"6.000000","default-max-batch-size":"4"}}     | 
|             |                                                        |                                                        | 
+-------------+--------------------------------------------------------+--------------------------------------------------------+ 
                                                                                                                                  
I0617 07:26:21.752674 2404 server.cc:662]                                                                                         
+------------------+---------+---------------------------------------------------------------------------------------------------+
| Model            | Version | Status                                                                                            |
+------------------+---------+---------------------------------------------------------------------------------------------------+
| postprocessing   | 1       | READY                                                                                             |
| preprocessing    | 1       | READY                                                                                             |
| tensorrt_llm     | 1       | UNAVAILABLE: Internal: unexpected error when creating modelInstanceState: [json.exception.parse_e |
|                  |         | rror.101] parse error at line 1, column 1: syntax error while parsing value - unexpected end of i |
|                  |         | nput; expected '[', '{', or a literal                                                             |
| tensorrt_llm_bls | 1       | READY                                                                                             |
+------------------+---------+---------------------------------------------------------------------------------------------------+
                                                                                                                                  
I0617 07:26:21.780639 2404 metrics.cc:817] Collecting metrics for GPU 0: NVIDIA Graphics Device                                   
I0617 07:26:21.780674 2404 metrics.cc:817] Collecting metrics for GPU 1: NVIDIA GeForce RTX 3080 Ti                               
I0617 07:26:21.780984 2404 metrics.cc:710] Collecting CPU metrics                                                                 
I0617 07:26:21.781148 2404 tritonserver.cc:2458]                                                                                  
+----------------------------------+--------------------------------------------------------------------------------------------+ 
| Option                           | Value                                                                                      | 
+----------------------------------+--------------------------------------------------------------------------------------------+ 
| server_id                        | triton                                                                                     | 
| server_version                   | 2.39.0                                                                                     | 
| server_extensions                | classification sequence model_repository model_repository(unload_dependents) schedule_poli | 
|                                  | cy model_configuration system_shared_memory cuda_shared_memory binary_tensor_data paramete | 
|                                  | rs statistics trace logging                                                                | 
| model_repository_path[0]         | ../../all_models/inflight_batcher_llm                                                      |                                          
| model_control_mode               | MODE_NONE                                                                                  |                                          
| strict_model_config              | 1                                                                                          |                                          
| rate_limit                       | OFF                                                                                        |                                          
| pinned_memory_pool_byte_size     | 268435456                                                                                  |                                          
| cuda_memory_pool_byte_size{0}    | 67108864                                                                                   |                                          
| cuda_memory_pool_byte_size{1}    | 67108864                                                                                   |                                          
| min_supported_compute_capability | 6.0                                                                                        |                                          
| strict_readiness                 | 1                                                                                          |                                          
| exit_timeout                     | 30                                                                                         |                                          
| cache_enabled                    | 0                                                                                          |                                                                                     
+----------------------------------+--------------------------------------------------------------------------------------------+                                                                                     
                                                                                                                                                                                                                      
I0617 07:26:21.781159 2404 server.cc:293] Waiting for in-flight requests to complete.                                                                                                                                 
I0617 07:26:21.781168 2404 server.cc:309] Timeout 30: Found 0 model versions that have in-flight inferences                                                                                                           
I0617 07:26:21.781489 2404 server.cc:324] All models are stopped, unloading models                                                      
I0617 07:26:21.781497 2404 server.cc:331] Timeout 30: Found 3 live models and 0 in-flight non-inference requests                                                                                                      
I0617 07:26:22.781601 2404 server.cc:331] Timeout 29: Found 3 live models and 0 in-flight non-inference requests                                                                                                      
Cleaning up...                                                                                                                                                                                                        
Cleaning up...                                                                                                                                                                                                        
Cleaning up...                                                                                                                                                                                                        
I0617 07:26:22.987543 2404 model_lifecycle.cc:603] successfully unloaded 'tensorrt_llm_bls' version 1                                                                            
I0617 07:26:23.018701 2404 model_lifecycle.cc:603] successfully unloaded 'preprocessing' version 1                                                                               
I0617 07:26:23.020596 2404 model_lifecycle.cc:603] successfully unloaded 'postprocessing' version 1                                     
I0617 07:26:23.783391 2404 server.cc:331] Timeout 28: Found 0 live models and 0 in-flight non-inference requests                                                                                                      
error: creating server: Internal - failed to load all models                                                                            
--------------------------------------------------------------------------                                                              
Primary job  terminated normally, but 1 process returned            
a non-zero exit code. Per user-direction, the job has been aborted.                                                                     
--------------------------------------------------------------------------                                                                                                       
--------------------------------------------------------------------------                                                                                                       
mpirun detected that one or more processes exited with non-zero status, thus causing                                                                                             
the job to be terminated. The first process to do so was:                               

  Process name: [[29838,1],0]                                                           
  Exit code:    1                                                                       
-------------------------------------------------------------------------- 

additional notes

  • This error appears when I configure gpt_model_type correctly.
parameters: {
  key: "gpt_model_type"
  value: {
    string_value: "inflight_batching" # v1 / inflight_batching / inflight_fused_batching
  }
}
  • SeaLLMs/SeaLLM-7B-v2 with LoRA
  • I can run the engine directly, but I encounter an error when deploying the engine on the Triton server.
python ../run.py --max_output_len=256 \
                 --tokenizer_dir ../../model_weights/models--SeaLLMs--SeaLLM-7B-v2/snapshots/f1bd48e0d75365c24a3c5ad006b2d0a0c9dca30f \
                 --engine_dir=./tmp_64/seallm/7B/engine/fp8/1-gpu/ \
                 --input_text "[INST]สวัสดีครับ ผมเครียดมากเรื่องเรียน ผมควรทำอย่างไรดีครับ [/INST]" \
                 --max_attention_window_size=128 \
                 --lora_task_uids 0 \
                 --use_py_session \
                 --temperature 0.2 \
                 --top_k 10 \
                 --top_p 0.8

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions