diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 361f15e5b9c7..4549db4d4af0 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -24,7 +24,7 @@ import logging import math import uuid -from typing import Callable, Optional +from typing import Callable, Optional, Union import torch import os @@ -50,7 +50,7 @@ def need_mmap() -> bool: return True return False -def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor: +def tensor_to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor: """ Convert a tensor to a memory-mapped CPU tensor using PyTorch's native mmap support. """ @@ -152,11 +152,11 @@ def convert_fn(t): if isinstance(t, torch.nn.Parameter): # For parameters, modify data in-place and return the parameter if isinstance(t.data, torch.Tensor): - t.data = to_mmap(t.data) + t.data = tensor_to_mmap(t.data) return t elif isinstance(t, torch.Tensor): # For buffers (plain tensors), return the converted tensor - return to_mmap(t) + return tensor_to_mmap(t) return t new_model = model._apply(convert_fn) @@ -164,6 +164,14 @@ def convert_fn(t): logging.debug(f"Model {model.__class__.__name__} converted to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB") return new_model +def to_mmap(obj: Union[torch.nn.Module, torch.Tensor]) -> Union[torch.nn.Module, torch.Tensor]: + if isinstance(obj, torch.nn.Module): + return model_to_mmap(obj) + elif isinstance(obj, torch.Tensor): + return tensor_to_mmap(obj) + else: + raise ValueError(f"Unsupported type: {type(obj)}") + def string_to_seed(data): crc = 0xFFFFFFFF @@ -621,18 +629,27 @@ def model_patches_to(self, device): patch_list = patches[name] for i in range(len(patch_list)): if hasattr(patch_list[i], "to"): - patch_list[i] = patch_list[i].to(device) + if need_mmap(): + patch_list[i] = to_mmap(patch_list[i]) + else: + patch_list[i] = patch_list[i].to(device) if "patches_replace" in to: patches = to["patches_replace"] for name in patches: patch_list = patches[name] for k in patch_list: if hasattr(patch_list[k], "to"): - patch_list[k] = patch_list[k].to(device) + if need_mmap(): + patch_list[k] = to_mmap(patch_list[k]) + else: + patch_list[k] = patch_list[k].to(device) if "model_function_wrapper" in self.model_options: wrap_func = self.model_options["model_function_wrapper"] if hasattr(wrap_func, "to"): - self.model_options["model_function_wrapper"] = wrap_func.to(device) + if need_mmap(): + self.model_options["model_function_wrapper"] = to_mmap(wrap_func) + else: + self.model_options["model_function_wrapper"] = wrap_func.to(device) def model_patches_models(self): to = self.model_options["transformer_options"]