Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix GC bug for LLM class #2882

Merged
merged 3 commits into from
Feb 15, 2024
Merged

[BugFix] Fix GC bug for LLM class #2882

merged 3 commits into from
Feb 15, 2024

Conversation

WoosukKwon
Copy link
Collaborator

@WoosukKwon WoosukKwon commented Feb 15, 2024

This PR fixes a bug that LLM class is not GC-ed even after gc.collect(), which was introduced by #1804 . It turns out that the bug happens because importing punica kernels raises an exception which is not released until the user uses LoRA.

Reproducible script:

import gc
import torch
from vllm import LLM

llm = LLM("facebook/opt-125m", enforce_eager=True)
del llm

gc.collect()
torch.cuda.empty_cache()
print(f"GPU memory usage: {torch.cuda.memory_allocated() / (1024 * 1024 * 1024):.2f} GB")

Copy link
Collaborator

@simon-mo simon-mo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice sleuthing! Can you add a test for this?

@WoosukKwon WoosukKwon merged commit d7afab6 into main Feb 15, 2024
19 checks passed
@WoosukKwon WoosukKwon deleted the fix-punica-import branch February 15, 2024 06:17
@WoosukKwon WoosukKwon mentioned this pull request Feb 15, 2024
5 tasks
xjpang pushed a commit to xjpang/vllm that referenced this pull request Feb 20, 2024
xjpang pushed a commit to xjpang/vllm that referenced this pull request Feb 22, 2024
xjpang pushed a commit to xjpang/vllm that referenced this pull request Mar 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants