Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.model_management import get_free_memory

class ModelType(Enum):
EPS = 1
Expand Down Expand Up @@ -297,8 +298,15 @@ def load_model_weights(self, sd, unet_prefix=""):
if k.startswith(unet_prefix):
to_load[k[len(unet_prefix):]] = sd.pop(k)

free_cpu_memory = get_free_memory(torch.device("cpu"))
logging.debug(f"load model weights start, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB")
logging.debug(f"model destination device {next(self.diffusion_model.parameters()).device}")
to_load = self.model_config.process_unet_state_dict(to_load)
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
logging.debug(f"load model {self.model_config} weights process end")
# replace tensor with mmap tensor by assign
m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=True)
free_cpu_memory = get_free_memory(torch.device("cpu"))
logging.debug(f"load model {self.model_config} weights end, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB")
if len(m) > 0:
logging.warning("unet missing: {}".format(m))

Expand Down
77 changes: 67 additions & 10 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@
import platform
import weakref
import gc
import os

from functools import lru_cache

@lru_cache(maxsize=1)
def get_mmap_mem_threshold_gb():
mmap_mem_threshold_gb = int(os.environ.get("MMAP_MEM_THRESHOLD_GB", "0"))
logging.debug(f"MMAP_MEM_THRESHOLD_GB: {mmap_mem_threshold_gb}")
return mmap_mem_threshold_gb

def get_free_disk():
return psutil.disk_usage("/").free

class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
Expand Down Expand Up @@ -519,16 +531,58 @@ def should_reload_model(self, force_patch_weights=False):
return False

def model_unload(self, memory_to_free=None, unpatch_weights=True):
if memory_to_free is not None:
if memory_to_free < self.model.loaded_size():
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
if freed >= memory_to_free:
return False
self.model.detach(unpatch_weights)
self.model_finalizer.detach()
self.model_finalizer = None
self.real_model = None
return True
logging.debug(f"model_unload: {self.model.model.__class__.__name__}")
logging.debug(f"memory_to_free: {memory_to_free/(1024*1024*1024)} GB")
logging.debug(f"unpatch_weights: {unpatch_weights}")
logging.debug(f"loaded_size: {self.model.loaded_size()/(1024*1024*1024)} GB")
logging.debug(f"offload_device: {self.model.offload_device}")

# if available_memory < reserved_memory:
# logging.warning(f"Not enough cpu memory to unload. Available: {available_memory/(1024*1024*1024)} GB, Reserved: {reserved_memory/(1024*1024*1024)} GB")
# return False
# else:
# offload_memory = available_memory - reserved_memory
#
# if offload_memory < memory_to_free:
# memory_to_free = offload_memory
# logging.info(f"Not enough cpu memory to unload. Available: {available_memory/(1024*1024*1024)} GB, Reserved: {reserved_memory/(1024*1024*1024)} GB, Offload: {offload_memory/(1024*1024*1024)} GB")
# logging.info(f"Set memory_to_free to {memory_to_free/(1024*1024*1024)} GB")

if memory_to_free is None:
# free the full model
memory_to_free = self.model.loaded_size()

available_memory = get_free_memory(self.model.offload_device)
logging.debug(f"before unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB")

mmap_mem_threshold = get_mmap_mem_threshold_gb() * 1024 * 1024 * 1024 # this is reserved memory for other system usage
if memory_to_free > available_memory - mmap_mem_threshold or memory_to_free < self.model.loaded_size():
partially_unload = True
else:
partially_unload = False

if partially_unload:
logging.debug("Do partially unload")
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
logging.debug(f"partially_unload freed vram: {freed/(1024*1024*1024)} GB")
if freed < memory_to_free:
logging.warning(f"Partially unload not enough memory, freed {freed/(1024*1024*1024)} GB, memory_to_free {memory_to_free/(1024*1024*1024)} GB")
else:
logging.debug("Do full unload")
self.model.detach(unpatch_weights)
logging.debug("Do full unload done")
self.model_finalizer.detach()
self.model_finalizer = None
self.real_model = None

available_memory = get_free_memory(self.model.offload_device)
logging.debug(f"after unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB")

if partially_unload:
return False
else:
return True


def model_use_more_vram(self, extra_memory, force_patch_weights=False):
return self.model.partially_load(self.device, extra_memory, force_patch_weights=force_patch_weights)
Expand Down Expand Up @@ -577,6 +631,7 @@ def minimum_inference_memory():
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()

def free_memory(memory_required, device, keep_loaded=[]):
logging.debug("start to free mem")
cleanup_models_gc()
unloaded_model = []
can_unload = []
Expand Down Expand Up @@ -614,6 +669,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
return unloaded_models

def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
logging.debug(f"start to load models")
cleanup_models_gc()
global vram_state

Expand All @@ -635,6 +691,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
models_to_load = []

for x in models:
logging.debug(f"start loading model to vram: {x.model.__class__.__name__}")
loaded_model = LoadedModel(x)
try:
loaded_model_index = current_loaded_models.index(loaded_model)
Expand Down
145 changes: 143 additions & 2 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
from typing import Callable, Optional

import torch
import os
import tempfile
import weakref
import gc

import comfy.float
import comfy.hooks
Expand All @@ -36,6 +40,129 @@
import comfy.utils
from comfy.comfy_types import UnetWrapperFunction
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
from comfy.model_management import get_free_memory, get_mmap_mem_threshold_gb, get_free_disk

def need_mmap() -> bool:
free_cpu_mem = get_free_memory(torch.device("cpu"))
mmap_mem_threshold_gb = get_mmap_mem_threshold_gb()
if free_cpu_mem < mmap_mem_threshold_gb * 1024 * 1024 * 1024:
logging.debug(f"Enabling mmap, current free cpu memory {free_cpu_mem/(1024*1024*1024)} GB < {mmap_mem_threshold_gb} GB")
return True
return False

def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor:
"""
Convert a tensor to a memory-mapped CPU tensor using PyTorch's native mmap support.
"""
# Move to CPU if needed
if t.is_cuda:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

考虑到 comfyui 可能要适配其它的硬件,比如摩尔的 musa,(摩尔上的显卡tensor,t.is_cuda 返回 False,t.is_musa 才是 True)。

这里的条件是否改成

if not t._is_cpu

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok,晚点改个兼容 musa 的情况

cpu_tensor = t.cpu()
else:
cpu_tensor = t

# Create temporary file
if filename is None:
temp_file = tempfile.mktemp(suffix='.pt', prefix='comfy_mmap_')
else:
temp_file = filename

# Save tensor to file
torch.save(cpu_tensor, temp_file)

# If we created a CPU copy from CUDA, delete it to free memory
if t.is_cuda:
del cpu_tensor
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()

# Load with mmap - this doesn't load all data into RAM
mmap_tensor = torch.load(temp_file, map_location='cpu', mmap=True, weights_only=False)

# Register cleanup callback - will be called when tensor is garbage collected
def _cleanup():
try:
if os.path.exists(temp_file):
os.remove(temp_file)
logging.debug(f"Cleaned up mmap file: {temp_file}")
except Exception:
pass

weakref.finalize(mmap_tensor, _cleanup)

# # Save original 'to' method
# original_to = mmap_tensor.to

# # Create custom 'to' method that cleans up file when moving to CUDA
# def custom_to(*args, **kwargs):
# # Determine target device
# target_device = None
# if len(args) > 0:
# if isinstance(args[0], torch.device):
# target_device = args[0]
# elif isinstance(args[0], str):
# target_device = torch.device(args[0])
# if 'device' in kwargs:
# target_device = kwargs['device']
# if isinstance(target_device, str):
# target_device = torch.device(target_device)
#
# # Call original 'to' method first to move data
# result = original_to(*args, **kwargs)
#
# # NOTE: Cleanup disabled to avoid blocking model load performance
# # If moved to CUDA, cleanup the mmap file after the move
# if target_device is not None and target_device.type == 'cuda':
# _cleanup()
#
# return result

# # Replace the 'to' method
# mmap_tensor.to = custom_to

return mmap_tensor

def model_to_mmap(model: torch.nn.Module):
"""Convert all parameters and buffers to memory-mapped tensors

This function mimics PyTorch's Module.to() behavior but converts
tensors to memory-mapped format instead, using _apply() method.

Reference: https://github.com/pytorch/pytorch/blob/0fabc3ba44823f257e70ce397d989c8de5e362c1/torch/nn/modules/module.py#L1244

Note: For Parameters, we modify .data in-place because
MemoryMappedTensor cannot be wrapped in torch.nn.Parameter.
For buffers, _apply() will automatically update the reference.

Args:
model: PyTorch module to convert

Returns:
The same model with all tensors converted to memory-mapped format
"""
free_cpu_mem = get_free_memory(torch.device("cpu"))
logging.debug(f"Converting model {model.__class__.__name__} to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB")

def convert_fn(t):
"""Convert function for _apply()

- For Parameters: modify .data and return the Parameter object
- For buffers (plain Tensors): return new MemoryMappedTensor
"""
if isinstance(t, torch.nn.Parameter):
# For parameters, modify data in-place and return the parameter
if isinstance(t.data, torch.Tensor):
t.data = to_mmap(t.data)
return t
elif isinstance(t, torch.Tensor):
# For buffers (plain tensors), return the converted tensor
return to_mmap(t)
return t

new_model = model._apply(convert_fn)
free_cpu_mem = get_free_memory(torch.device("cpu"))
logging.debug(f"Model {model.__class__.__name__} converted to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB")
return new_model


def string_to_seed(data):
Expand Down Expand Up @@ -486,6 +613,7 @@ def get_model_object(self, name: str) -> torch.nn.Module:
return comfy.utils.get_attr(self.model, name)

def model_patches_to(self, device):
# TODO(sf): to mmap
to = self.model_options["transformer_options"]
if "patches" in to:
patches = to["patches"]
Expand Down Expand Up @@ -782,9 +910,15 @@ def unpatch_model(self, device_to=None, unpatch_weights=True):
self.model.current_weight_patches_uuid = None
self.backup.clear()


if device_to is not None:
self.model.to(device_to)
if need_mmap():
# offload to mmap
model_to_mmap(self.model)
else:
self.model.to(device_to)
self.model.device = device_to

self.model.model_loaded_weight_memory = 0

for m in self.model.modules():
Expand Down Expand Up @@ -837,7 +971,14 @@ def partially_unload(self, device_to, memory_to_free=0):
bias_key = "{}.bias".format(n)
if move_weight:
cast_weight = self.force_cast_weights
m.to(device_to)
if need_mmap():
if get_free_disk() < module_mem:
logging.warning(f"Not enough disk space to offload {n} to mmap, current free disk space {get_free_disk()/(1024*1024*1024)} GB < {module_mem/(1024*1024*1024)} GB")
break
# offload to mmap
model_to_mmap(m)
else:
m.to(device_to)
module_mem += move_weight_functions(m, device_to)
if lowvram_possible:
if weight_key in self.patches:
Expand Down
1 change: 1 addition & 0 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,6 +1321,7 @@ def load_diffusion_model_state_dict(sd, model_options={}):
logging.warning("{} {}".format(diffusers_keys[k], k))

offload_device = model_management.unet_offload_device()
logging.debug(f"loader load model to offload device: {offload_device}")
unet_weight_dtype = list(model_config.supported_inference_dtypes)
if model_config.scaled_fp8 is not None:
weight_dtype = None
Expand Down
3 changes: 3 additions & 0 deletions comfy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
metadata = None
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
try:
if not DISABLE_MMAP:
logging.debug(f"load_torch_file of safetensors into mmap True")
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
sd = {}
for k in f.keys():
Expand All @@ -80,6 +82,7 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
else:
torch_args = {}
if MMAP_TORCH_FILES:
logging.debug(f"load_torch_file of torch state dict into mmap True")
torch_args["mmap"] = True

if safe_load or ALWAYS_SAFE_LOAD:
Expand Down
Loading
Loading