Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix][Model] Refactor OLMo model to support new HF format in transformers 4.40.0 #4324

Merged
merged 6 commits into from
Apr 25, 2024

Conversation

Isotr0py
Copy link
Contributor

@Isotr0py Isotr0py commented Apr 24, 2024

FILL IN THE PR DESCRIPTION HERE

FIX #4310

This PR will add support for OLMo models in HF format (OLMo-1B-hf and OLMo-1.7-7B-hf etc)

Note that support for OLMo model in original format (OLMo-1B and OLMo-7B etc) will be deprecated, since we remove hf_olmo from dependency.

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the PR!

@natolambert
Copy link

Thanks @Isotr0py -- I was getting blocked by this on something fun!

@natolambert
Copy link

natolambert commented Apr 24, 2024

FYI @WoosukKwon and @Isotr0py -- I think there's a bug in this? I just tried to run the code and I get this error on how the classes are registered.

Was with this model https://huggingface.co/allenai/OLMo-1.7-7B-hf, and another "being trained"

ValueError: Model architectures ['OlmoForCausalLM'] are not supported for now. Supported architectures: ['AquilaModel', 'AquilaForCausalLM', 'BaiChuanForCausalLM', 'BaichuanForCausalLM', 'BloomForCausalLM', 'ChatGLMModel', 'ChatGLMForConditionalGeneration', 'CohereForCausalLM', 'DbrxForCausalLM', 'DeciLMForCausalLM', 'DeepseekForCausalLM', 'FalconForCausalLM', 'GemmaForCausalLM', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTJForCausalLM', 'GPTNeoXForCausalLM', 'InternLMForCausalLM', 'InternLM2ForCausalLM', 'JAISLMHeadModel', 'LlamaForCausalLM', 'LlavaForConditionalGeneration', 'LLaMAForCausalLM', 'MistralForCausalLM', 'MixtralForCausalLM', 'QuantMixtralForCausalLM', 'MptForCausalLM', 'MPTForCausalLM', 'MiniCPMForCausalLM', 'OLMoForCausalLM', 'OPTForCausalLM', 'OrionForCausalLM', 'PhiForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'RWForCausalLM', 'StableLMEpochForCausalLM', 'StableLmForCausalLM', 'Starcoder2ForCausalLM', 'XverseForCausalLM']

UPDATE: Resolved

@WoosukKwon
Copy link
Collaborator

WoosukKwon commented Apr 24, 2024

@natolambert Could you share your model name (or the HF repo name)? The architecture should be OLMoForCausalLM instead of OlmoForCausalLM.

@natolambert
Copy link

natolambert commented Apr 24, 2024

EDIT: re-running it, maybe I had cached weights. Yup it works!

@WoosukKwon was just adding it via edit: https://huggingface.co/allenai/OLMo-1.7-7B-hf

python -m vllm.entrypoints.openai.api_server --model allenai/OLMo-1.7-7B-hf --tensor-parallel-size 2

@natolambert
Copy link

There's a separate issue, but I think this is specific to our unreleased model (it at least starts the loading process). I'll investigate when it's a higher priority.

(RayWorkerWrapper pid=37267) INFO 04-24 12:38:36 utils.py:129] reading GPU P2P access cache from /home/nathanl/.config/vllm/gpu_p2p_access_cache_for_0,1.json
ERROR 04-24 12:38:49 worker_base.py:157] Error executing method load_model. This might cause deadlock in distributed execution.
ERROR 04-24 12:38:49 worker_base.py:157] Traceback (most recent call last):
ERROR 04-24 12:38:49 worker_base.py:157]   File "/opt/miniconda3/lib/python3.10/site-packages/vllm/worker/worker_base.py", line 149, in execute_method
ERROR 04-24 12:38:49 worker_base.py:157]     return executor(*args, **kwargs)
ERROR 04-24 12:38:49 worker_base.py:157]   File "/opt/miniconda3/lib/python3.10/site-packages/vllm/worker/worker.py", line 117, in load_model
ERROR 04-24 12:38:49 worker_base.py:157]     self.model_runner.load_model()
ERROR 04-24 12:38:49 worker_base.py:157]   File "/opt/miniconda3/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 162, in load_model
ERROR 04-24 12:38:49 worker_base.py:157]     self.model = get_model(
ERROR 04-24 12:38:49 worker_base.py:157]   File "/opt/miniconda3/lib/python3.10/site-packages/vllm/model_executor/model_loader/__init__.py", line 19, in get_model
ERROR 04-24 12:38:49 worker_base.py:157]     return loader.load_model(model_config=model_config,
ERROR 04-24 12:38:49 worker_base.py:157]   File "/opt/miniconda3/lib/python3.10/site-packages/vllm/model_executor/model_loader/loader.py", line 224, in load_model
ERROR 04-24 12:38:49 worker_base.py:157]     model.load_weights(
ERROR 04-24 12:38:49 worker_base.py:157]   File "/opt/miniconda3/lib/python3.10/site-packages/vllm/model_executor/models/olmo.py", line 348, in load_weights
ERROR 04-24 12:38:49 worker_base.py:157]     weight_loader(param, loaded_weight, shard_id)
ERROR 04-24 12:38:49 worker_base.py:157]   File "/opt/miniconda3/lib/python3.10/site-packages/vllm/model_executor/layers/linear.py", line 501, in weight_loader
ERROR 04-24 12:38:49 worker_base.py:157]     loaded_weight = loaded_weight.narrow(output_dim, start_idx,
ERROR 04-24 12:38:49 worker_base.py:157] RuntimeError: start (0) + length (4096) exceeds dimension size (1024).
Traceback (most recent call last):
  File "/opt/miniconda3/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/opt/miniconda3/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/opt/miniconda3/lib/python3.10/site-packages/vllm/entrypoints/openai/api_server.py", line 159, in <module>
    engine = AsyncLLMEngine.from_engine_args(
  File "/opt/miniconda3/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 361, in from_engine_args
    engine = cls(
  File "/opt/miniconda3/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 319, in __init__
    self.engine = self._init_engine(*args, **kwargs)
  File "/opt/miniconda3/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 437, in _init_engine
    return engine_class(*args, **kwargs)
  File "/opt/miniconda3/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 148, in __init__
    self.model_executor = executor_class(
  File "/opt/miniconda3/lib/python3.10/site-packages/vllm/executor/ray_gpu_executor.py", line 382, in __init__
    super().__init__(*args, **kwargs)
  File "/opt/miniconda3/lib/python3.10/site-packages/vllm/executor/executor_base.py", line 41, in __init__
    self._init_executor()
  File "/opt/miniconda3/lib/python3.10/site-packages/vllm/executor/ray_gpu_executor.py", line 45, in _init_executor
    self._init_workers_ray(placement_group)
  File "/opt/miniconda3/lib/python3.10/site-packages/vllm/executor/ray_gpu_executor.py", line 182, in _init_workers_ray
    self._run_workers(
  File "/opt/miniconda3/lib/python3.10/site-packages/vllm/executor/ray_gpu_executor.py", line 318, in _run_workers
    driver_worker_output = self.driver_worker.execute_method(
  File "/opt/miniconda3/lib/python3.10/site-packages/vllm/worker/worker_base.py", line 158, in execute_method
    raise e
  File "/opt/miniconda3/lib/python3.10/site-packages/vllm/worker/worker_base.py", line 149, in execute_method
    return executor(*args, **kwargs)
  File "/opt/miniconda3/lib/python3.10/site-packages/vllm/worker/worker.py", line 117, in load_model
    self.model_runner.load_model()
  File "/opt/miniconda3/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 162, in load_model
    self.model = get_model(
  File "/opt/miniconda3/lib/python3.10/site-packages/vllm/model_executor/model_loader/__init__.py", line 19, in get_model
    return loader.load_model(model_config=model_config,
  File "/opt/miniconda3/lib/python3.10/site-packages/vllm/model_executor/model_loader/loader.py", line 224, in load_model
    model.load_weights(
  File "/opt/miniconda3/lib/python3.10/site-packages/vllm/model_executor/models/olmo.py", line 348, in load_weights
    weight_loader(param, loaded_weight, shard_id)
  File "/opt/miniconda3/lib/python3.10/site-packages/vllm/model_executor/layers/linear.py", line 501, in weight_loader
    loaded_weight = loaded_weight.narrow(output_dim, start_idx,
RuntimeError: start (0) + length (4096) exceeds dimension size (1024).
(RayWorkerWrapper pid=37267) ERROR 04-24 12:38:49 worker_base.py:157] Error executing method load_model. This might cause deadlock in distributed execution.
(RayWorkerWrapper pid=37267) ERROR 04-24 12:38:49 worker_base.py:157] Traceback (most recent call last):
(RayWorkerWrapper pid=37267) ERROR 04-24 12:38:49 worker_base.py:157]   File "/opt/miniconda3/lib/python3.10/site-packages/vllm/worker/worker_base.py", line 149, in execute_method
(RayWorkerWrapper pid=37267) ERROR 04-24 12:38:49 worker_base.py:157]     return executor(*args, **kwargs)
(RayWorkerWrapper pid=37267) ERROR 04-24 12:38:49 worker_base.py:157]   File "/opt/miniconda3/lib/python3.10/site-packages/vllm/worker/worker.py", line 117, in load_model
(RayWorkerWrapper pid=37267) ERROR 04-24 12:38:49 worker_base.py:157]     self.model_runner.load_model()
(RayWorkerWrapper pid=37267) ERROR 04-24 12:38:49 worker_base.py:157]   File "/opt/miniconda3/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 162, in load_model
(RayWorkerWrapper pid=37267) ERROR 04-24 12:38:49 worker_base.py:157]     self.model = get_model(
(RayWorkerWrapper pid=37267) ERROR 04-24 12:38:49 worker_base.py:157]   File "/opt/miniconda3/lib/python3.10/site-packages/vllm/model_executor/model_loader/__init__.py", line 19, in get_model
(RayWorkerWrapper pid=37267) ERROR 04-24 12:38:49 worker_base.py:157]     return loader.load_model(model_config=model_config,
(RayWorkerWrapper pid=37267) ERROR 04-24 12:38:49 worker_base.py:157]   File "/opt/miniconda3/lib/python3.10/site-packages/vllm/model_executor/model_loader/loader.py", line 224, in load_model
(RayWorkerWrapper pid=37267) ERROR 04-24 12:38:49 worker_base.py:157]     model.load_weights(
(RayWorkerWrapper pid=37267) ERROR 04-24 12:38:49 worker_base.py:157]   File "/opt/miniconda3/lib/python3.10/site-packages/vllm/model_executor/models/olmo.py", line 348, in load_weights
(RayWorkerWrapper pid=37267) ERROR 04-24 12:38:49 worker_base.py:157]     weight_loader(param, loaded_weight, shard_id)
(RayWorkerWrapper pid=37267) ERROR 04-24 12:38:49 worker_base.py:157]   File "/opt/miniconda3/lib/python3.10/site-packages/vllm/model_executor/layers/linear.py", line 501, in weight_loader
(RayWorkerWrapper pid=37267) ERROR 04-24 12:38:49 worker_base.py:157]     loaded_weight = loaded_weight.narrow(output_dim, start_idx,
(RayWorkerWrapper pid=37267) ERROR 04-24 12:38:49 worker_base.py:157] IndexError: start out of range (expected to be in range of [-1024, 1024], but got 4096)

@2015aroras
Copy link

We had typos in the architecture field of our HF config.json. In all the -hf repos (e.g. OLMo-1B-hf), the architecture was supposed to be OlmoForCausalLM, not OLMoForCausalLM. I have fixed it in HF Hub now, sorry for the inconvenience.

@WoosukKwon
Copy link
Collaborator

@Isotr0py The model outputs gibberish:

$ python examples/llm_engine_example.py --model allenai/OLMo-1.7-7B-hf
...
RequestOutput(request_id=3, prompt='It is only with the heart that one can see rightly', prompt_token_ids=[1147, 310, 760, 342, 253, 2798, 326, 581, 476, 923, 35155], prompt_logprobs=None, outputs=[CompletionOutput(index=0, text='', token_ids=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], cumulative_logprob=-173.21343994140625, logprobs=None, finish_reason=length, stop_reason=None), CompletionOutput(index=1, text='|||IP_ADDRESS|||', token_ids=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0], cumulative_logprob=-173.21343994140625, logprobs=None, finish_reason=length, stop_reason=None), CompletionOutput(index=2, text='#', token_ids=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4], cumulative_logprob=-173.21343994140625, logprobs=None, finish_reason=length, stop_reason=None)], finished=True, metrics=RequestMetrics(arrival_time=1714011286.3906262, last_token_time=1714011286.8017027, first_scheduled_time=1714011286.3912067, first_token_time=1714011286.44032, time_in_queue=0.0005805492401123047, finished_time=1714011286.8016992), lora_request=None)

Could you take a look?

@Isotr0py
Copy link
Contributor Author

Isotr0py commented Apr 25, 2024

@WoosukKwon Oops, I forgot to remove original lm_head code which broke the model lm_head without tie_word_embeddings.
The 7B model should work now.

examples/offline_inference.py:

$ python examples/offline_inference.py
INFO 04-25 13:55:05 llm_engine.py:98] Initializing an LLM engine (v0.4.1) with config: model='/data/LLM-model/OLMo-1.7-7B-hf', speculative_config=None, tokenizer='/data/LLM-model/OLMo-1.7-7B-hf', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.float16, max_seq_len=4096, download_dir=None, load_format=auto, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cpu, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), seed=0)
WARNING 04-25 13:55:05 cpu_executor.py:128] float16 is not supported on CPU, casting to bfloat16.
WARNING 04-25 13:55:05 cpu_executor.py:131] CUDA graph is not supported on CPU, fallback to the eager mode.
WARNING 04-25 13:55:05 cpu_executor.py:159] Environment variable VLLM_CPU_KVCACHE_SPACE (GB) for CPU backend is not set, using 4 by default.
INFO 04-25 13:55:07 selector.py:43] Using Torch SDPA backend.
INFO 04-25 13:56:57 cpu_executor.py:72] # CPU blocks: 512
Processed prompts:   0%|                                                                                                                                                                   | 0/4 [00:00<?, ?it/s]INFO 04-25 13:56:58 pynccl_utils.py:17] Failed to import NCCL library: NCCL only supports CUDA and ROCm backends.
INFO 04-25 13:56:58 pynccl_utils.py:18] It is expected if you are not running on NVIDIA GPUs.
Processed prompts: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:27<00:00,  6.82s/it]
Prompt: 'Hello, my name is', Generated text: ' Renee Dokot and I am a native of the Bahamas.'
Prompt: 'The president of the United States is', Generated text: ' not just the leader of the USA but also of the free world." to Russian'
Prompt: 'The capital of France is', Generated text: ' Paris, which is the capital of the European Union. The Eiffel Tower'
Prompt: 'The future of AI is', Generated text: ' open source. If you want to build a new AI-powered application, you'

examples/llm_engine_example.py:

$ python examples/llm_engine_example.py --model /data/LLM-model/OLMo-1.7-7B-hf/
...
RequestOutput(request_id=3, prompt='It is only with the heart that one can see rightly', prompt_token_ids=[1147, 310, 760, 342, 253, 2798, 326, 581, 476, 923, 35155], prompt_logprobs=None, outputs=[CompletionOutput(index=0, text='; what is essential is invisible to the eye."\n\n- Antoine de', token_ids=[28, 752, 310, 5667, 310, 20624, 281, 253, 5130, 449, 187, 187, 14, 47255, 460, 372], cumulative_logprob=-8.499438464641571, logprobs=None, finish_reason=length, stop_reason=None), CompletionOutput(index=1, text='; what is essential is invisible to the eye."\n\nAntoine de', token_ids=[28, 752, 310, 5667, 310, 20624, 281, 253, 5130, 449, 187, 187, 1145, 936, 460, 372], cumulative_logprob=-8.654532624408603, logprobs=None, finish_reason=length, stop_reason=None), CompletionOutput(index=2, text='; what is essential is invisible to the eye."\n\nAns:\n', token_ids=[28, 752, 310, 5667, 310, 20624, 281, 253, 5130, 449, 187, 187, 1145, 84, 27, 187], cumulative_logprob=-9.909873271360993, logprobs=None, finish_reason=length, stop_reason=None)], finished=True, metrics=RequestMetrics(arrival_time=1714023940.599021, last_token_time=1714024001.6509027, first_scheduled_time=1714023940.6001248, first_token_time=1714023945.974483, time_in_queue=0.0011038780212402344, finished_time=1714024001.6508915), lora_request=None)

@WoosukKwon
Copy link
Collaborator

@Isotr0py Thanks for the fix! It works now!

@WoosukKwon WoosukKwon merged commit fbf152d into vllm-project:main Apr 25, 2024
48 checks passed
@WoosukKwon
Copy link
Collaborator

@natolambert Does your internal model have the same architecture as OLMo? If so, could you try again the updated main since we fixed the code a bit since your comment.

@Isotr0py Isotr0py deleted the olmo branch April 26, 2024 05:00
robertgshaw2-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request Apr 26, 2024
…formers 4.40.0 (vllm-project#4324)

Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
@natolambert
Copy link

@WoosukKwon nope, but I probably won't look into why until closer to release, unless you've seen stuff like this before

INFO 04-28 11:12:27 utils.py:129] reading GPU P2P access cache from /home/nathanl/.config/vllm/gpu_p2p_access_cache_for_0,1.json
(RayWorkerWrapper pid=7436) INFO 04-28 11:12:27 utils.py:129] reading GPU P2P access cache from /home/nathanl/.config/vllm/gpu_p2p_access_cache_for_0,1.json
ERROR 04-28 11:12:40 worker_base.py:147] Error executing method load_model. This might cause deadlock in distributed execution.
ERROR 04-28 11:12:40 worker_base.py:147] Traceback (most recent call last):
ERROR 04-28 11:12:40 worker_base.py:147]   File "/opt/miniconda3/lib/python3.10/site-packages/vllm/worker/worker_base.py", line 139, in execute_method
ERROR 04-28 11:12:40 worker_base.py:147]     return executor(*args, **kwargs)
ERROR 04-28 11:12:40 worker_base.py:147]   File "/opt/miniconda3/lib/python3.10/site-packages/vllm/worker/worker.py", line 117, in load_model
ERROR 04-28 11:12:40 worker_base.py:147]     self.model_runner.load_model()
ERROR 04-28 11:12:40 worker_base.py:147]   File "/opt/miniconda3/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 161, in load_model
ERROR 04-28 11:12:40 worker_base.py:147]     self.model = get_model(
ERROR 04-28 11:12:40 worker_base.py:147]   File "/opt/miniconda3/lib/python3.10/site-packages/vllm/model_executor/model_loader/__init__.py", line 19, in get_model
ERROR 04-28 11:12:40 worker_base.py:147]     return loader.load_model(model_config=model_config,
ERROR 04-28 11:12:40 worker_base.py:147]   File "/opt/miniconda3/lib/python3.10/site-packages/vllm/model_executor/model_loader/loader.py", line 223, in load_model
ERROR 04-28 11:12:40 worker_base.py:147]     model.load_weights(
ERROR 04-28 11:12:40 worker_base.py:147]   File "/opt/miniconda3/lib/python3.10/site-packages/vllm/model_executor/models/olmo.py", line 347, in load_weights
ERROR 04-28 11:12:40 worker_base.py:147]     weight_loader(param, loaded_weight, shard_id)
ERROR 04-28 11:12:40 worker_base.py:147]   File "/opt/miniconda3/lib/python3.10/site-packages/vllm/model_executor/layers/linear.py", line 517, in weight_loader
ERROR 04-28 11:12:40 worker_base.py:147]     loaded_weight = loaded_weight.narrow(output_dim, start_idx,
ERROR 04-28 11:12:40 worker_base.py:147] RuntimeError: start (0) + length (4096) exceeds dimension size (1024).

z103cb pushed a commit to z103cb/opendatahub_vllm that referenced this pull request May 7, 2024
…formers 4.40.0 (vllm-project#4324)

Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
mawong-amd pushed a commit to ROCm/vllm that referenced this pull request Jun 3, 2024
…formers 4.40.0 (vllm-project#4324)

Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: Olmo model deployment exception due to key conflicts between hf_olmo and transformers 4.40.0
4 participants