Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import logging
import math
import uuid
from typing import Callable, Optional
from typing import Callable, Optional, Union

import torch
import os
Expand All @@ -50,7 +50,7 @@ def need_mmap() -> bool:
return True
return False

def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor:
def tensor_to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor:
"""
Convert a tensor to a memory-mapped CPU tensor using PyTorch's native mmap support.
"""
Expand Down Expand Up @@ -152,18 +152,26 @@ def convert_fn(t):
if isinstance(t, torch.nn.Parameter):
# For parameters, modify data in-place and return the parameter
if isinstance(t.data, torch.Tensor):
t.data = to_mmap(t.data)
t.data = tensor_to_mmap(t.data)
return t
elif isinstance(t, torch.Tensor):
# For buffers (plain tensors), return the converted tensor
return to_mmap(t)
return tensor_to_mmap(t)
return t

new_model = model._apply(convert_fn)
free_cpu_mem = get_free_memory(torch.device("cpu"))
logging.debug(f"Model {model.__class__.__name__} converted to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB")
return new_model

def to_mmap(obj: Union[torch.nn.Module, torch.Tensor]) -> Union[torch.nn.Module, torch.Tensor]:
if isinstance(obj, torch.nn.Module):
return model_to_mmap(obj)
elif isinstance(obj, torch.Tensor):
return tensor_to_mmap(obj)
else:
raise ValueError(f"Unsupported type: {type(obj)}")


def string_to_seed(data):
crc = 0xFFFFFFFF
Expand Down Expand Up @@ -621,18 +629,27 @@ def model_patches_to(self, device):
patch_list = patches[name]
for i in range(len(patch_list)):
if hasattr(patch_list[i], "to"):
patch_list[i] = patch_list[i].to(device)
if need_mmap():
patch_list[i] = to_mmap(patch_list[i])
else:
patch_list[i] = patch_list[i].to(device)
if "patches_replace" in to:
patches = to["patches_replace"]
for name in patches:
patch_list = patches[name]
for k in patch_list:
if hasattr(patch_list[k], "to"):
patch_list[k] = patch_list[k].to(device)
if need_mmap():
patch_list[k] = to_mmap(patch_list[k])
else:
patch_list[k] = patch_list[k].to(device)
if "model_function_wrapper" in self.model_options:
wrap_func = self.model_options["model_function_wrapper"]
if hasattr(wrap_func, "to"):
self.model_options["model_function_wrapper"] = wrap_func.to(device)
if need_mmap():
self.model_options["model_function_wrapper"] = to_mmap(wrap_func)
else:
self.model_options["model_function_wrapper"] = wrap_func.to(device)

def model_patches_models(self):
to = self.model_options["transformer_options"]
Expand Down
Loading