-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[sleep mode] save memory for on-the-fly quantization #24731
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -16,8 +16,11 @@ | |||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
import torch | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
from vllm.logger import init_logger | ||||||||||||||||||||||||||||||||||||
from vllm.utils import is_pin_memory_available | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
logger = init_logger(__name__) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
def find_loaded_library(lib_name) -> Optional[str]: | ||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||
|
@@ -165,6 +168,9 @@ def _python_malloc_callback(self, allocation_handle: HandleType) -> None: | |||||||||||||||||||||||||||||||||||
py_d_mem = allocation_handle[2] | ||||||||||||||||||||||||||||||||||||
self.pointer_to_data[py_d_mem] = AllocationData( | ||||||||||||||||||||||||||||||||||||
allocation_handle, self.current_tag) | ||||||||||||||||||||||||||||||||||||
logger.debug( | ||||||||||||||||||||||||||||||||||||
"Allocated %s bytes for %s with address %s from cumem allocator", | ||||||||||||||||||||||||||||||||||||
allocation_handle[1], self.current_tag, py_d_mem) | ||||||||||||||||||||||||||||||||||||
return | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
def _python_free_callback(self, ptr: int) -> HandleType: | ||||||||||||||||||||||||||||||||||||
|
@@ -174,6 +180,9 @@ def _python_free_callback(self, ptr: int) -> HandleType: | |||||||||||||||||||||||||||||||||||
data = self.pointer_to_data.pop(ptr) | ||||||||||||||||||||||||||||||||||||
if data.cpu_backup_tensor is not None: | ||||||||||||||||||||||||||||||||||||
data.cpu_backup_tensor = None | ||||||||||||||||||||||||||||||||||||
logger.debug( | ||||||||||||||||||||||||||||||||||||
"Freed %s bytes for %s with address %s from cumem allocator", | ||||||||||||||||||||||||||||||||||||
data.handle[1], data.tag, ptr) | ||||||||||||||||||||||||||||||||||||
return data.handle | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
def sleep( | ||||||||||||||||||||||||||||||||||||
|
@@ -197,9 +206,14 @@ def sleep( | |||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
assert isinstance(offload_tags, tuple) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
total_bytes = 0 | ||||||||||||||||||||||||||||||||||||
backup_bytes = 0 | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
for ptr, data in self.pointer_to_data.items(): | ||||||||||||||||||||||||||||||||||||
handle = data.handle | ||||||||||||||||||||||||||||||||||||
total_bytes += handle[1] | ||||||||||||||||||||||||||||||||||||
if data.tag in offload_tags: | ||||||||||||||||||||||||||||||||||||
backup_bytes += handle[1] | ||||||||||||||||||||||||||||||||||||
size_in_bytes = handle[1] | ||||||||||||||||||||||||||||||||||||
cpu_backup_tensor = torch.empty( | ||||||||||||||||||||||||||||||||||||
size_in_bytes, | ||||||||||||||||||||||||||||||||||||
|
@@ -211,6 +225,12 @@ def sleep( | |||||||||||||||||||||||||||||||||||
data.cpu_backup_tensor = cpu_backup_tensor | ||||||||||||||||||||||||||||||||||||
unmap_and_release(handle) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
logger.info( | ||||||||||||||||||||||||||||||||||||
"CuMemAllocator: sleep freed %.2f GiB memory in total, of which " | ||||||||||||||||||||||||||||||||||||
"%.2f GiB is backed up in CPU and the rest %.2f GiB is discarded " | ||||||||||||||||||||||||||||||||||||
"directly.", total_bytes / 1024**3, backup_bytes / 1024**3, | ||||||||||||||||||||||||||||||||||||
(total_bytes - backup_bytes) / 1024**3) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
gc.collect() | ||||||||||||||||||||||||||||||||||||
torch.cuda.empty_cache() | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
|
@@ -267,12 +287,17 @@ def use_memory_pool(self, tag: Optional[str] = None): | |||||||||||||||||||||||||||||||||||
# when using pluggable allocator, see | ||||||||||||||||||||||||||||||||||||
# https://github.com/pytorch/pytorch/issues/145168 . | ||||||||||||||||||||||||||||||||||||
# if we have some memory allocated and then freed, | ||||||||||||||||||||||||||||||||||||
# the memory will not be released. | ||||||||||||||||||||||||||||||||||||
# right now it is fine, because we only use this allocator | ||||||||||||||||||||||||||||||||||||
# during weight loading and kv cache creation, where we only | ||||||||||||||||||||||||||||||||||||
# allocate memory. | ||||||||||||||||||||||||||||||||||||
# TODO: we need to find a way to release the memory, | ||||||||||||||||||||||||||||||||||||
# i.e. calling torch.cuda.empty_cache() | ||||||||||||||||||||||||||||||||||||
# the memory will not be released, e.g. in online quantization, | ||||||||||||||||||||||||||||||||||||
# where the model is created in higher precision, and then | ||||||||||||||||||||||||||||||||||||
# quantized in lower precision. | ||||||||||||||||||||||||||||||||||||
# Find all unused allocations and manually release them. | ||||||||||||||||||||||||||||||||||||
# TODO: we should expose `empty_cache` method in the memory pool. | ||||||||||||||||||||||||||||||||||||
# TODO: ask for help from PyTorch team to expose this method. | ||||||||||||||||||||||||||||||||||||
allocations = data[0].snapshot() | ||||||||||||||||||||||||||||||||||||
for allocation in allocations: | ||||||||||||||||||||||||||||||||||||
if allocation["allocated_size"] == 0: | ||||||||||||||||||||||||||||||||||||
handle = self._python_free_callback(allocation["address"]) | ||||||||||||||||||||||||||||||||||||
unmap_and_release(handle) | ||||||||||||||||||||||||||||||||||||
Comment on lines
+296
to
+300
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current implementation assumes that any memory block with To make the code more robust, you should wrap the calls in a
Suggested change
|
||||||||||||||||||||||||||||||||||||
self.current_tag = old_tag | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
def get_current_usage(self) -> int: | ||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @zou3519 @ngimel I don't really want to be so intrusive to interpret the memory snapshot, but I have no other ways to free the memory pool :(
really hope we can expose
empty_cache
method in the memory pool from pytorch side.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem is on nvidia side, they are not implementing what you want. And as I've said repeatedly, it's not a question of exposing
empty_cache
.