Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hardware][Neuron] Refactor neuron support #3471

Merged
merged 23 commits into from
Mar 22, 2024
Merged

Conversation

zhuohan123
Copy link
Collaborator

@zhuohan123 zhuohan123 commented Mar 18, 2024

Refactor the neuron support. After this refactoring, the neuron support is completely isolated from the GPU support. The development on the GPU side no longer needs to consider its effect on the neuron pass. Also, all the is_neuron flags in the code have been removed, except the one that tells us that we are using neuron device.

Highlights of this change:

  • Isolated NeuronExecutor, NeuronWorker, NeuronModelRunner, and get_neuron_model function.
  • Remove all the unnecessary neuron model definitions. Have a general abstraction of neuron models.

Remaining issues for future PRs:

  • The NeuronModelRunner still has a complicated prepare_inputs function. This should be refactored together with the refactor of prepare_inputs in the future.
  • We probably need more tests for the neuron pass. I'll leave this task to the AWS team (cc @liangfu).

PR Checklist (Click to expand. Please read before submitting.)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

@zhuohan123 zhuohan123 changed the title [WIP][Hardware][Neuron] Refactor neuron support [Hardware][Neuron] Refactor neuron support Mar 19, 2024
@zhuohan123
Copy link
Collaborator Author

@WoosukKwon This PR is ready for review!

cc @liangfu please take a look as well!

Copy link
Contributor

@liangfu liangfu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @zhuohan123 for the proposed change. I like the idea of

  • abstracting execution with executor_class
  • deduplicating llama and mixtral model support

My concern comes from missing input_metadata in the forward function call, where we might need an extra refactoring to support paged attention on neuron backend.

examples/offline_inference_neuron.py Show resolved Hide resolved
vllm/config.py Show resolved Hide resolved
if parallel_config.worker_use_ray or engine_args.engine_use_ray:
device_config = engine_configs[4]

if device_config.device_type == "neuron":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of string comparison, set a property in device_config?

FYI, we might consider some device-specific config with device="neuron config=(a=1, b=2, c=3)"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

device_type is different from device. I think we can safely assume device_type can only be cuda and neuron for now? We can put the config string in other fields of device_config.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, i think we are fine to move on with this.

vllm/model_executor/neuron_model_loader.py Show resolved Hide resolved
@liangfu
Copy link
Contributor

liangfu commented Mar 19, 2024

Tested with tensor_parallel_size=2 for the branch:

INFO 03-19 22:26:56 config.py:446] Custom all-reduce kernels are temporarily disabled due to stability issues. We will re-enable them once the issues are resolved.
INFO 03-19 22:26:56 llm_engine.py:67] Initializing an LLM engine (v0.3.3) with config: model='TinyLlama/TinyLlama-1.1B-Chat-v1.0', tokenizer='TinyLlama/TinyLlama-1.1B-Chat-v1.0', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=128, download_di\
r=None, load_format=auto, tensor_parallel_size=2, disable_custom_all_reduce=True, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=cpu, seed=0)
tokenizer_config.json: 100% 1.29k/1.29k [00:00<00:00, 324kB/s]
tokenizer.model: 100% 500k/500k [00:00<00:00, 12.1MB/s]
tokenizer.json: 100% 1.84M/1.84M [00:00<00:00, 26.8MB/s]
special_tokens_map.json: 100% 551/551 [00:00<00:00, 1.14MB/s]
WARNING 03-19 22:26:57 utils.py:324] Pin memory is not supported on Neuron.
model.safetensors: 100% 2.20G/2.20G [00:08<00:00, 260MB/s]
generation_config.json: 100% 124/124 [00:00<00:00, 33.9kB/s]
2024-03-19 22:27:23.000450:  1486537  INFO ||NEURON_CACHE||: Compile cache path: /var/tmp/neuron-compile-cache
2024-03-19 22:27:23.000452:  1486537  INFO ||NEURON_CC_WRAPPER||: Call compiler with cmd: neuronx-cc compile --target=trn1 --framework=XLA /tmp/ubuntu/neuroncc_compile_workdir/facc01e3-5e6a-4cf8-9d45-a1bca66c03ec/model.MODULE_622c52b3b2f99196d008+2c2d707e.hlo_module.pb --output /tmp/ubuntu/neuroncc_compile_workdir\
/facc01e3-5e6a-4cf8-9d45-a1bca66c03ec/model.MODULE_622c52b3b2f99196d008+2c2d707e.neff --model-type=transformer --auto-cast=none --verbose=35
2024-03-19 22:27:23.000485:  1486538  INFO ||NEURON_CACHE||: Compile cache path: /var/tmp/neuron-compile-cache
2024-03-19 22:27:23.000487:  1486538  INFO ||NEURON_CC_WRAPPER||: Call compiler with cmd: neuronx-cc compile --target=trn1 --framework=XLA /tmp/ubuntu/neuroncc_compile_workdir/73dc4908-e815-4c42-8313-5bd0daa701a6/model.MODULE_067094cf9b96d1dd4d8c+2c2d707e.hlo_module.pb --output /tmp/ubuntu/neuroncc_compile_workdir\
/73dc4908-e815-4c42-8313-5bd0daa701a6/model.MODULE_067094cf9b96d1dd4d8c+2c2d707e.neff --model-type=transformer --auto-cast=none --verbose=35
........
Compiler status PASS
2024-03-19 22:28:33.000124:  1486538  INFO ||NEURON_CACHE||: Compile cache path: /var/tmp/neuron-compile-cache
.
Compiler status PASS
2024-03-19 22:28:45.000561:  1486537  INFO ||NEURON_CACHE||: Compile cache path: /var/tmp/neuron-compile-cache
2024-Mar-19 22:28:45.0977 1486374:1486511 [1] nccl_net_ofi_init:1415 CCOM WARN NET/OFI aws-ofi-nccl initialization failed
2024-Mar-19 22:28:45.0977 1486374:1486511 [1] init.cc:137 CCOM WARN OFI plugin initNet() failed is EFA enabled?
Processed prompts: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 16.53it/s]
Prompt: 'Hello, my name is', Generated text: ' _____. I am a ________. I am passionate about _'
Prompt: 'The president of the United States is', Generated text: ' a man named Donald Trump. 4. Biden/Harris team'
Prompt: 'The capital of France is', Generated text: ' Paris. France has a total area of 552,655'
Prompt: 'The future of AI is', Generated text: ' in collaboration. Here are four examples:\n\n1. IBM Watson: This'

Copy link
Contributor

@liangfu liangfu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@WoosukKwon WoosukKwon self-assigned this Mar 21, 2024
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the great work! This is great refactoring. The code looks much cleaner now.

vllm/worker/worker.py Show resolved Hide resolved
vllm/utils.py Outdated Show resolved Hide resolved
vllm/utils.py Outdated Show resolved Hide resolved
@@ -1,6 +1,6 @@
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed, get_model
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why remove this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are two get_models, one for neuron and one for GPUs. Having get_model here will be confusing.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also we removed the get_model function from the utils. Previously this get_model is a dispatcher for different devices. Now we no longer need the dispatcher.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then should we remove it from all as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah it's already removed. See the latest commit

vllm/utils.py Outdated Show resolved Hide resolved
@zhuohan123 zhuohan123 enabled auto-merge (squash) March 21, 2024 22:48
logger.warning(msg)


def is_pin_memory_available() -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: Does it make sense to use cache for this function? Wondering if it helps performance.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes let me add that

@zhuohan123 zhuohan123 enabled auto-merge (squash) March 22, 2024 00:18
@zhuohan123 zhuohan123 merged commit e90fc21 into main Mar 22, 2024
32 checks passed
@zhuohan123 zhuohan123 deleted the refactor-neuron-support branch April 26, 2024 00:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants