Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Add support for GPT-J #226

Merged
merged 15 commits into from
Jul 9, 2023
Merged

Conversation

AndreSlavescu
Copy link
Contributor

reference to issue #198

@WoosukKwon WoosukKwon self-requested a review June 23, 2023 21:11
@WoosukKwon
Copy link
Collaborator

@AndreSlavescu Awesome! Thanks for your contribution. Is this PR ready for review? Otherwise, please ping me when you are ready. Thanks again!

@silvacarl2
Copy link

can you merge this change? so we can test it out with our fine tuned gpt-j model?

8-)

@WoosukKwon
Copy link
Collaborator

@AndreSlavescu What's going on with the PR? If you are not able to continue it, no worries, I can take it. Please let us know if you have any question.

@AndreSlavescu
Copy link
Contributor Author

@WoosukKwon Hi sorry for the delayed reply, had a busy schedule this past week. I won't have much time to continue this coming week, so please continue on it if you'd like.
Thanks!

@ri938
Copy link
Contributor

ri938 commented Jun 28, 2023

Is it just waiting for review or requires additional work? Is it expected to be working (if so I can use it now).

@WoosukKwon
Copy link
Collaborator

@ri938 This PR is not ready yet. I'll take this over and finish the PR soon.

@WoosukKwon
Copy link
Collaborator

WoosukKwon commented Jul 7, 2023

The PR is currently blocked because GPT-J's rotary embedding requires a new kernel (IIUC, it's different from GPT-NeoX's rotary embedding). I will address it this weekend. Turns out that this is not a problem.

@WoosukKwon
Copy link
Collaborator

@zhuohan123 This PR is ready for review. Please take a look at it.

@WoosukKwon WoosukKwon changed the title GPT-J model [Model] Add support for GPT-J Jul 8, 2023
Copy link
Collaborator

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Left some minor comments.

vllm/model_executor/layers/sampler.py Show resolved Hide resolved
vllm/model_executor/models/gpt_j.py Outdated Show resolved Hide resolved
vllm/model_executor/models/gpt_j.py Outdated Show resolved Hide resolved
@WoosukKwon WoosukKwon merged commit c894836 into vllm-project:main Jul 9, 2023
2 checks passed
@WoosukKwon
Copy link
Collaborator

@silvacarl2 @ri938 We'v just merged this PR. Please install vLLM from source and try it out!

@silvacarl2
Copy link

Cool will do!!

@silvacarl2
Copy link

got this error:

python offline_inference.py
INFO 07-09 10:47:10 llm_engine.py:59] Initializing an LLM engine with config: model='EleutherAI/gpt-j-6b', dtype=torch.float16, use_dummy_weights=False, download_dir=None, use_np_weights=False, tensor_parallel_size=1, seed=0)
Traceback (most recent call last):
File "offline_inference.py", line 14, in
llm = LLM(model="EleutherAI/gpt-j-6b")
File "/home/silvacarl/.local/lib/python3.8/site-packages/vllm/entrypoints/llm.py", line 55, in init
self.llm_engine = LLMEngine.from_engine_args(engine_args)
File "/home/silvacarl/.local/lib/python3.8/site-packages/vllm/engine/llm_engine.py", line 151, in from_engine_args
engine = cls(*engine_configs, distributed_init_method, devices,
File "/home/silvacarl/.local/lib/python3.8/site-packages/vllm/engine/llm_engine.py", line 93, in init
worker = worker_cls(
File "/home/silvacarl/.local/lib/python3.8/site-packages/vllm/worker/worker.py", line 45, in init
self.model = get_model(model_config)
File "/home/silvacarl/.local/lib/python3.8/site-packages/vllm/model_executor/model_loader.py", line 34, in get_model
model_class = _get_model_architecture(model_config.hf_config)
File "/home/silvacarl/.local/lib/python3.8/site-packages/vllm/model_executor/model_loader.py", line 27, in _get_model_architecture
raise ValueError(
ValueError: Model architectures ['GPTJForCausalLM'] are not supported for now. Supported architectures: ['GPT2LMHeadModel', 'GPTNeoXForCausalLM', 'LlamaForCausalLM', 'OPTForCausalLM']

@silvacarl2
Copy link

same with gpt-neo:

python offline_inference.py
Downloading (…)lve/main/config.json: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.46k/1.46k [00:00<00:00, 1.09MB/s]
INFO 07-09 10:48:12 llm_engine.py:59] Initializing an LLM engine with config: model='EleutherAI/gpt-neo-2.7B', dtype=torch.float16, use_dummy_weights=False, download_dir=None, use_np_weights=False, tensor_parallel_size=1, seed=0)
Downloading (…)okenizer_config.json: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:00<00:00, 173kB/s]
Downloading (…)olve/main/vocab.json: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 798k/798k [00:00<00:00, 5.96MB/s]
Downloading (…)olve/main/merges.txt: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 456k/456k [00:00<00:00, 38.1MB/s]
Downloading (…)cial_tokens_map.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 90.0/90.0 [00:00<00:00, 168kB/s]
Traceback (most recent call last):
File "offline_inference.py", line 14, in
llm = LLM(model="EleutherAI/gpt-neo-2.7B")
File "/home/silvacarl/.local/lib/python3.8/site-packages/vllm/entrypoints/llm.py", line 55, in init
self.llm_engine = LLMEngine.from_engine_args(engine_args)
File "/home/silvacarl/.local/lib/python3.8/site-packages/vllm/engine/llm_engine.py", line 151, in from_engine_args
engine = cls(*engine_configs, distributed_init_method, devices,
File "/home/silvacarl/.local/lib/python3.8/site-packages/vllm/engine/llm_engine.py", line 93, in init
worker = worker_cls(
File "/home/silvacarl/.local/lib/python3.8/site-packages/vllm/worker/worker.py", line 45, in init
self.model = get_model(model_config)
File "/home/silvacarl/.local/lib/python3.8/site-packages/vllm/model_executor/model_loader.py", line 34, in get_model
model_class = _get_model_architecture(model_config.hf_config)
File "/home/silvacarl/.local/lib/python3.8/site-packages/vllm/model_executor/model_loader.py", line 27, in _get_model_architecture
raise ValueError(
ValueError: Model architectures ['GPTNeoForCausalLM'] are not supported for now. Supported architectures: ['GPT2LMHeadModel', 'GPTNeoXForCausalLM', 'LlamaForCausalLM', 'OPTForCausalLM']

@WoosukKwon
Copy link
Collaborator

@silvacarl2 Could you check again if you installed the latest vLLM from source?

BTW, GPTNeo is not supported yet.

@silvacarl2
Copy link

NP, trying out others

@zhuohan123 zhuohan123 mentioned this pull request Jul 12, 2023
@leegohi04517
Copy link
Contributor

install vllm from source. i encountered this problem:
(generator38) fsuser@recau5mvammeirzd3:~/chat_generator$ python -m vllm.entrypoints.openai.api_server \

--model PygmalionAI/pygmalion-6b
--host 0.0.0.0
INFO 07-20 03:51:12 llm_engine.py:60] Initializing an LLM engine with config: model='PygmalionAI/pygmalion-6b', tokenizer='PygmalionAI/pygmalion-6b', tokenizer_mode=auto, trust_remote_code=False, dtype=torch.float16, use_dummy_weights=False, download_dir=None, use_np_weights=False, tensor_parallel_size=1, seed=0)
Traceback (most recent call last):
File "/home/fsuser/anaconda3/envs/generator38/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/fsuser/anaconda3/envs/generator38/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/fsuser/vllm/vllm/entrypoints/openai/api_server.py", line 583, in
engine = AsyncLLMEngine.from_engine_args(engine_args)
File "/home/fsuser/vllm/vllm/engine/async_llm_engine.py", line 232, in from_engine_args
engine = cls(engine_args.worker_use_ray,
File "/home/fsuser/vllm/vllm/engine/async_llm_engine.py", line 55, in init
self.engine = engine_class(*args, **kwargs)
File "/home/fsuser/vllm/vllm/engine/llm_engine.py", line 99, in init
worker = worker_cls(
File "/home/fsuser/vllm/vllm/worker/worker.py", line 45, in init
self.model = get_model(model_config)
File "/home/fsuser/vllm/vllm/model_executor/model_loader.py", line 43, in get_model
model = model_class(model_config.hf_config)
File "/home/fsuser/vllm/vllm/model_executor/models/gpt_j.py", line 192, in init
self.transformer = GPTJModel(config)
File "/home/fsuser/vllm/vllm/model_executor/models/gpt_j.py", line 157, in init
[GPTJBlock(config) for _ in range(config.n_layer)])
File "/home/fsuser/vllm/vllm/model_executor/models/gpt_j.py", line 157, in
[GPTJBlock(config) for _ in range(config.n_layer)])
File "/home/fsuser/vllm/vllm/model_executor/models/gpt_j.py", line 122, in init
self.attn = GPTJAttention(config)
File "/home/fsuser/vllm/vllm/model_executor/models/gpt_j.py", line 68, in init
assert config.rotary
File "/home/fsuser/anaconda3/envs/generator38/lib/python3.8/site-packages/transformers/configuration_utils.py", line 260, in getattribute
return super().getattribute(key)
AttributeError: 'GPTJConfig' object has no attribute 'rotary'

hongxiayang pushed a commit to hongxiayang/vllm that referenced this pull request Feb 13, 2024
Co-authored-by: woWoosuk Kwon <woosuk.kwon@berkeley.edu>
sjchoi1 pushed a commit to casys-kaist-internal/vllm that referenced this pull request May 7, 2024
Co-authored-by: woWoosuk Kwon <woosuk.kwon@berkeley.edu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants