Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions vllm/device_allocator/cumem.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@

import torch

from vllm.logger import init_logger
from vllm.utils import is_pin_memory_available

logger = init_logger(__name__)


def find_loaded_library(lib_name) -> Optional[str]:
"""
Expand Down Expand Up @@ -165,6 +168,9 @@ def _python_malloc_callback(self, allocation_handle: HandleType) -> None:
py_d_mem = allocation_handle[2]
self.pointer_to_data[py_d_mem] = AllocationData(
allocation_handle, self.current_tag)
logger.debug(
"Allocated %s bytes for %s with address %s from cumem allocator",
allocation_handle[1], self.current_tag, py_d_mem)
return

def _python_free_callback(self, ptr: int) -> HandleType:
Expand All @@ -174,6 +180,9 @@ def _python_free_callback(self, ptr: int) -> HandleType:
data = self.pointer_to_data.pop(ptr)
if data.cpu_backup_tensor is not None:
data.cpu_backup_tensor = None
logger.debug(
"Freed %s bytes for %s with address %s from cumem allocator",
data.handle[1], data.tag, ptr)
return data.handle

def sleep(
Expand All @@ -197,9 +206,14 @@ def sleep(

assert isinstance(offload_tags, tuple)

total_bytes = 0
backup_bytes = 0

for ptr, data in self.pointer_to_data.items():
handle = data.handle
total_bytes += handle[1]
if data.tag in offload_tags:
backup_bytes += handle[1]
size_in_bytes = handle[1]
cpu_backup_tensor = torch.empty(
size_in_bytes,
Expand All @@ -211,6 +225,12 @@ def sleep(
data.cpu_backup_tensor = cpu_backup_tensor
unmap_and_release(handle)

logger.info(
"CuMemAllocator: sleep freed %.2f GiB memory in total, of which "
"%.2f GiB is backed up in CPU and the rest %.2f GiB is discarded "
"directly.", total_bytes / 1024**3, backup_bytes / 1024**3,
(total_bytes - backup_bytes) / 1024**3)

gc.collect()
torch.cuda.empty_cache()

Expand Down Expand Up @@ -267,12 +287,17 @@ def use_memory_pool(self, tag: Optional[str] = None):
# when using pluggable allocator, see
# https://github.com/pytorch/pytorch/issues/145168 .
# if we have some memory allocated and then freed,
# the memory will not be released.
# right now it is fine, because we only use this allocator
# during weight loading and kv cache creation, where we only
# allocate memory.
# TODO: we need to find a way to release the memory,
# i.e. calling torch.cuda.empty_cache()
# the memory will not be released, e.g. in online quantization,
# where the model is created in higher precision, and then
# quantized in lower precision.
# Find all unused allocations and manually release them.
# TODO: we should expose `empty_cache` method in the memory pool.
# TODO: ask for help from PyTorch team to expose this method.
allocations = data[0].snapshot()
for allocation in allocations:
if allocation["allocated_size"] == 0:
handle = self._python_free_callback(allocation["address"])
Comment on lines +298 to +299
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @zou3519 @ngimel I don't really want to be so intrusive to interpret the memory snapshot, but I have no other ways to free the memory pool :(

really hope we can expose empty_cache method in the memory pool from pytorch side.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is on nvidia side, they are not implementing what you want. And as I've said repeatedly, it's not a question of exposing empty_cache.

unmap_and_release(handle)
Comment on lines +296 to +300
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation assumes that any memory block with allocated_size == 0 will have a corresponding entry in self.pointer_to_data. However, it's possible that the free callback was already triggered for an allocation, removing it from self.pointer_to_data, while the memory block is still tracked by the pool. This would lead to a KeyError when _python_free_callback is called, as it internally performs a pop, which could crash the application.

To make the code more robust, you should wrap the calls in a try...except KeyError block to gracefully handle cases where the allocation has already been freed and removed from pointer_to_data.

Suggested change
allocations = data[0].snapshot()
for allocation in allocations:
if allocation["allocated_size"] == 0:
handle = self._python_free_callback(allocation["address"])
unmap_and_release(handle)
allocations = data[0].snapshot()
for allocation in allocations:
if allocation["allocated_size"] == 0:
try:
handle = self._python_free_callback(
allocation["address"])
unmap_and_release(handle)
except KeyError:
# This can happen if the allocation was already freed
# through the normal path, but the memory pool has not
# released the block.
pass

self.current_tag = old_tag

def get_current_usage(self) -> int:
Expand Down