Skip to content

Conversation

youkaichao
Copy link
Member

@youkaichao youkaichao commented Sep 12, 2025

Purpose

Solves #19855

Test Plan

import gc

import torch
from vllm import LLM
from vllm.utils import GiB_bytes


def print_current_mem_usage(tag):
    gc.collect()
    torch.cuda.empty_cache()
    free_bytes, total = torch.cuda.mem_get_info()
    print(f"[mem_usage] {tag} | current used: {(total - free_bytes) / GiB_bytes}")


def test_fp8_sleep():
    model_path = "Qwen/Qwen2.5-7B-Instruct"

    model = LLM(
        model=model_path,
        dtype="bfloat16",
        gpu_memory_utilization=0.8,
        trust_remote_code=True,
        enable_sleep_mode=True,
        quantization="fp8",
    )

    print_current_mem_usage("before sleep")
    model.sleep()
    print_current_mem_usage("after sleep")
    model.wake_up(["weights"])
    print_current_mem_usage("after wakeup weights")


if __name__ == "__main__":
    test_fp8_sleep()

Run the code, and the memory taken with quantization is close to 9 GiB, showing the benefit of online quantization.

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to save memory during on-the-fly quantization by manually releasing unused memory blocks from the custom CuMemAllocator pool. The changes introduce logic to iterate through memory pool allocations at the end of a use_memory_pool context and free any blocks that are no longer in use. This is a workaround for a PyTorch issue where torch.cuda.empty_cache() does not work with pluggable allocators. The changes also include additional logging to provide visibility into memory management. My review focuses on improving the robustness of the new memory release logic. I've identified a potential KeyError that could occur and suggested adding error handling to prevent crashes.

Comment on lines +296 to +300
allocations = data[0].snapshot()
for allocation in allocations:
if allocation["allocated_size"] == 0:
handle = self._python_free_callback(allocation["address"])
unmap_and_release(handle)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation assumes that any memory block with allocated_size == 0 will have a corresponding entry in self.pointer_to_data. However, it's possible that the free callback was already triggered for an allocation, removing it from self.pointer_to_data, while the memory block is still tracked by the pool. This would lead to a KeyError when _python_free_callback is called, as it internally performs a pop, which could crash the application.

To make the code more robust, you should wrap the calls in a try...except KeyError block to gracefully handle cases where the allocation has already been freed and removed from pointer_to_data.

Suggested change
allocations = data[0].snapshot()
for allocation in allocations:
if allocation["allocated_size"] == 0:
handle = self._python_free_callback(allocation["address"])
unmap_and_release(handle)
allocations = data[0].snapshot()
for allocation in allocations:
if allocation["allocated_size"] == 0:
try:
handle = self._python_free_callback(
allocation["address"])
unmap_and_release(handle)
except KeyError:
# This can happen if the allocation was already freed
# through the normal path, but the memory pool has not
# released the block.
pass

Comment on lines +298 to +299
if allocation["allocated_size"] == 0:
handle = self._python_free_callback(allocation["address"])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @zou3519 @ngimel I don't really want to be so intrusive to interpret the memory snapshot, but I have no other ways to free the memory pool :(

really hope we can expose empty_cache method in the memory pool from pytorch side.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is on nvidia side, they are not implementing what you want. And as I've said repeatedly, it's not a question of exposing empty_cache.

@youkaichao youkaichao enabled auto-merge (squash) September 12, 2025 09:56
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 12, 2025
@jiaau
Copy link

jiaau commented Sep 12, 2025

Question: Got a segmentation fault when trying to free memory inside the pool context to prevent OOM.
​​Hi @youkaichao ,
​Awesome work!
​I have a question regarding the memory release mechanism. I noticed that the allocated memory seems to be cleaned up when the pool context is exited. However, in my use case, loading a very large model can lead to an Out-Of-Memory (OOM) error before the loading process completes and the context is exited.
​To address this, I tried to call unmap_and_release(handle) inside the pool context to free memory for tensors that are no longer needed. Unfortunately, this attempt resulted in a segmentation fault.
​It seems that the memory can only be safely released upon exiting the context. Could you please explain why this is the case? Is there a specific reason related to the memory pool's design that prevents releasing memory mid-process?
​My goal is to free up memory progressively during the weight loading process to avoid the OOM issue. Any insights or suggestions on how to achieve this would be greatly appreciated.
​Thank you!

@youkaichao youkaichao merged commit fdb09c7 into vllm-project:main Sep 12, 2025
52 checks passed
@youkaichao
Copy link
Member Author

@jiaau your OOM might be caused by memory fragmentation. Right now the custom memory pool does not support empty_cache directly, so the hack in this PR can only be applied when the memory pool will never be used again (which is the case in sleep mode).

I think you will need to have the real empty_cache function, maybe write a new c++ extension to expose the pytorch function to python.

@jiaau
Copy link

jiaau commented Sep 12, 2025

@youkaichao Thanks a lot for your reply and the helpful explanation!
​You were right, my OOM issue was indeed caused by memory fragmentation. I was able to solve it by using the method described in this PR: #23875, specifically by using empty_cache within the disable_memory_pool context.
​I have a follow-up question to better understand the mechanism. You mentioned that "the hack in this PR can only be applied when the memory pool will never be used again". Could you please elaborate on why that is the case? I'm curious to learn more about the technical reasons behind this limitation.
​Thanks again for your insight!

@ngimel
Copy link

ngimel commented Sep 12, 2025

maybe write a new c++ extension to expose the pytorch function to python.

It's not a question of writing extension, you can expose empty_cache with mempool argument in pytorch today. But it won't do what you want without changes to Mempool itself and caching allocator (which, again, you can do, but it requires a pretty carefully written PR).

@vermouth1992
Copy link
Contributor

Can we have one more line to test whether wakeup kv_cache yields expected memory? Thanks!

@youkaichao
Copy link
Member Author

@jiaau one possible solution to your memory fragmentation issue, might be allocate a very large tensor and release it at the beginning, then hopefully later tensors can reuse this buffer.

But anyway, I think for on the fly quantization, you need to have enough memory to hold the bf16 checkpoint first. We don't do per-layer on the fly quantization. If you just want to pay the memory of fp8 checkpoint, you need to convert the checkpoint to fp8 directly.

@youkaichao
Copy link
Member Author

Can we have one more line to test whether wakeup kv_cache yields expected memory? Thanks!

@vermouth1992 Running the example in the description,

Before this PR, I get Available KV cache memory: 9.89 GiB.
After this PR, I get Available KV cache memory: 15.97 GiB.

We can confirm that we have more memory for KV cache now. Testing it in ci would be quite complicated though.

skyloevil pushed a commit to skyloevil/vllm that referenced this pull request Sep 13, 2025
)

Signed-off-by: youkaichao <youkaichao@gmail.com>
dsxsteven pushed a commit to dsxsteven/vllm_splitPR that referenced this pull request Sep 15, 2025
)

Signed-off-by: youkaichao <youkaichao@gmail.com>
@youkaichao youkaichao deleted the quantized_sleep branch September 18, 2025 15:05
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
)

Signed-off-by: youkaichao <youkaichao@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants