diff --git a/comfy/model_base.py b/comfy/model_base.py index 8274c7dea192..d6ef644dd42b 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -60,6 +60,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher +from comfy.model_management import get_free_memory class ModelType(Enum): EPS = 1 @@ -297,8 +298,15 @@ def load_model_weights(self, sd, unet_prefix=""): if k.startswith(unet_prefix): to_load[k[len(unet_prefix):]] = sd.pop(k) + free_cpu_memory = get_free_memory(torch.device("cpu")) + logging.debug(f"load model weights start, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB") + logging.debug(f"model destination device {next(self.diffusion_model.parameters()).device}") to_load = self.model_config.process_unet_state_dict(to_load) - m, u = self.diffusion_model.load_state_dict(to_load, strict=False) + logging.debug(f"load model {self.model_config} weights process end") + # replace tensor with mmap tensor by assign + m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=True) + free_cpu_memory = get_free_memory(torch.device("cpu")) + logging.debug(f"load model {self.model_config} weights end, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB") if len(m) > 0: logging.warning("unet missing: {}".format(m)) diff --git a/comfy/model_management.py b/comfy/model_management.py index 79d6ff9d441d..a2ad5db2aac7 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -26,6 +26,18 @@ import platform import weakref import gc +import os + +from functools import lru_cache + +@lru_cache(maxsize=1) +def get_mmap_mem_threshold_gb(): + mmap_mem_threshold_gb = int(os.environ.get("MMAP_MEM_THRESHOLD_GB", "0")) + logging.debug(f"MMAP_MEM_THRESHOLD_GB: {mmap_mem_threshold_gb}") + return mmap_mem_threshold_gb + +def get_free_disk(): + return psutil.disk_usage("/").free class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram @@ -519,16 +531,58 @@ def should_reload_model(self, force_patch_weights=False): return False def model_unload(self, memory_to_free=None, unpatch_weights=True): - if memory_to_free is not None: - if memory_to_free < self.model.loaded_size(): - freed = self.model.partially_unload(self.model.offload_device, memory_to_free) - if freed >= memory_to_free: - return False - self.model.detach(unpatch_weights) - self.model_finalizer.detach() - self.model_finalizer = None - self.real_model = None - return True + logging.debug(f"model_unload: {self.model.model.__class__.__name__}") + logging.debug(f"memory_to_free: {memory_to_free/(1024*1024*1024)} GB") + logging.debug(f"unpatch_weights: {unpatch_weights}") + logging.debug(f"loaded_size: {self.model.loaded_size()/(1024*1024*1024)} GB") + logging.debug(f"offload_device: {self.model.offload_device}") + + # if available_memory < reserved_memory: + # logging.warning(f"Not enough cpu memory to unload. Available: {available_memory/(1024*1024*1024)} GB, Reserved: {reserved_memory/(1024*1024*1024)} GB") + # return False + # else: + # offload_memory = available_memory - reserved_memory + # + # if offload_memory < memory_to_free: + # memory_to_free = offload_memory + # logging.info(f"Not enough cpu memory to unload. Available: {available_memory/(1024*1024*1024)} GB, Reserved: {reserved_memory/(1024*1024*1024)} GB, Offload: {offload_memory/(1024*1024*1024)} GB") + # logging.info(f"Set memory_to_free to {memory_to_free/(1024*1024*1024)} GB") + + if memory_to_free is None: + # free the full model + memory_to_free = self.model.loaded_size() + + available_memory = get_free_memory(self.model.offload_device) + logging.debug(f"before unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB") + + mmap_mem_threshold = get_mmap_mem_threshold_gb() * 1024 * 1024 * 1024 # this is reserved memory for other system usage + if memory_to_free > available_memory - mmap_mem_threshold or memory_to_free < self.model.loaded_size(): + partially_unload = True + else: + partially_unload = False + + if partially_unload: + logging.debug("Do partially unload") + freed = self.model.partially_unload(self.model.offload_device, memory_to_free) + logging.debug(f"partially_unload freed vram: {freed/(1024*1024*1024)} GB") + if freed < memory_to_free: + logging.warning(f"Partially unload not enough memory, freed {freed/(1024*1024*1024)} GB, memory_to_free {memory_to_free/(1024*1024*1024)} GB") + else: + logging.debug("Do full unload") + self.model.detach(unpatch_weights) + logging.debug("Do full unload done") + self.model_finalizer.detach() + self.model_finalizer = None + self.real_model = None + + available_memory = get_free_memory(self.model.offload_device) + logging.debug(f"after unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB") + + if partially_unload: + return False + else: + return True + def model_use_more_vram(self, extra_memory, force_patch_weights=False): return self.model.partially_load(self.device, extra_memory, force_patch_weights=force_patch_weights) @@ -577,6 +631,7 @@ def minimum_inference_memory(): return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory() def free_memory(memory_required, device, keep_loaded=[]): + logging.debug("start to free mem") cleanup_models_gc() unloaded_model = [] can_unload = [] @@ -614,6 +669,7 @@ def free_memory(memory_required, device, keep_loaded=[]): return unloaded_models def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False): + logging.debug(f"start to load models") cleanup_models_gc() global vram_state @@ -635,6 +691,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu models_to_load = [] for x in models: + logging.debug(f"start loading model to vram: {x.model.__class__.__name__}") loaded_model = LoadedModel(x) try: loaded_model_index = current_loaded_models.index(loaded_model) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index c0b68fb8cff7..361f15e5b9c7 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -27,6 +27,10 @@ from typing import Callable, Optional import torch +import os +import tempfile +import weakref +import gc import comfy.float import comfy.hooks @@ -36,6 +40,129 @@ import comfy.utils from comfy.comfy_types import UnetWrapperFunction from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP +from comfy.model_management import get_free_memory, get_mmap_mem_threshold_gb, get_free_disk + +def need_mmap() -> bool: + free_cpu_mem = get_free_memory(torch.device("cpu")) + mmap_mem_threshold_gb = get_mmap_mem_threshold_gb() + if free_cpu_mem < mmap_mem_threshold_gb * 1024 * 1024 * 1024: + logging.debug(f"Enabling mmap, current free cpu memory {free_cpu_mem/(1024*1024*1024)} GB < {mmap_mem_threshold_gb} GB") + return True + return False + +def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor: + """ + Convert a tensor to a memory-mapped CPU tensor using PyTorch's native mmap support. + """ + # Move to CPU if needed + if t.is_cuda: + cpu_tensor = t.cpu() + else: + cpu_tensor = t + + # Create temporary file + if filename is None: + temp_file = tempfile.mktemp(suffix='.pt', prefix='comfy_mmap_') + else: + temp_file = filename + + # Save tensor to file + torch.save(cpu_tensor, temp_file) + + # If we created a CPU copy from CUDA, delete it to free memory + if t.is_cuda: + del cpu_tensor + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Load with mmap - this doesn't load all data into RAM + mmap_tensor = torch.load(temp_file, map_location='cpu', mmap=True, weights_only=False) + + # Register cleanup callback - will be called when tensor is garbage collected + def _cleanup(): + try: + if os.path.exists(temp_file): + os.remove(temp_file) + logging.debug(f"Cleaned up mmap file: {temp_file}") + except Exception: + pass + + weakref.finalize(mmap_tensor, _cleanup) + + # # Save original 'to' method + # original_to = mmap_tensor.to + + # # Create custom 'to' method that cleans up file when moving to CUDA + # def custom_to(*args, **kwargs): + # # Determine target device + # target_device = None + # if len(args) > 0: + # if isinstance(args[0], torch.device): + # target_device = args[0] + # elif isinstance(args[0], str): + # target_device = torch.device(args[0]) + # if 'device' in kwargs: + # target_device = kwargs['device'] + # if isinstance(target_device, str): + # target_device = torch.device(target_device) + # + # # Call original 'to' method first to move data + # result = original_to(*args, **kwargs) + # + # # NOTE: Cleanup disabled to avoid blocking model load performance + # # If moved to CUDA, cleanup the mmap file after the move + # if target_device is not None and target_device.type == 'cuda': + # _cleanup() + # + # return result + + # # Replace the 'to' method + # mmap_tensor.to = custom_to + + return mmap_tensor + +def model_to_mmap(model: torch.nn.Module): + """Convert all parameters and buffers to memory-mapped tensors + + This function mimics PyTorch's Module.to() behavior but converts + tensors to memory-mapped format instead, using _apply() method. + + Reference: https://github.com/pytorch/pytorch/blob/0fabc3ba44823f257e70ce397d989c8de5e362c1/torch/nn/modules/module.py#L1244 + + Note: For Parameters, we modify .data in-place because + MemoryMappedTensor cannot be wrapped in torch.nn.Parameter. + For buffers, _apply() will automatically update the reference. + + Args: + model: PyTorch module to convert + + Returns: + The same model with all tensors converted to memory-mapped format + """ + free_cpu_mem = get_free_memory(torch.device("cpu")) + logging.debug(f"Converting model {model.__class__.__name__} to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB") + + def convert_fn(t): + """Convert function for _apply() + + - For Parameters: modify .data and return the Parameter object + - For buffers (plain Tensors): return new MemoryMappedTensor + """ + if isinstance(t, torch.nn.Parameter): + # For parameters, modify data in-place and return the parameter + if isinstance(t.data, torch.Tensor): + t.data = to_mmap(t.data) + return t + elif isinstance(t, torch.Tensor): + # For buffers (plain tensors), return the converted tensor + return to_mmap(t) + return t + + new_model = model._apply(convert_fn) + free_cpu_mem = get_free_memory(torch.device("cpu")) + logging.debug(f"Model {model.__class__.__name__} converted to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB") + return new_model def string_to_seed(data): @@ -486,6 +613,7 @@ def get_model_object(self, name: str) -> torch.nn.Module: return comfy.utils.get_attr(self.model, name) def model_patches_to(self, device): + # TODO(sf): to mmap to = self.model_options["transformer_options"] if "patches" in to: patches = to["patches"] @@ -782,9 +910,15 @@ def unpatch_model(self, device_to=None, unpatch_weights=True): self.model.current_weight_patches_uuid = None self.backup.clear() + if device_to is not None: - self.model.to(device_to) + if need_mmap(): + # offload to mmap + model_to_mmap(self.model) + else: + self.model.to(device_to) self.model.device = device_to + self.model.model_loaded_weight_memory = 0 for m in self.model.modules(): @@ -837,7 +971,14 @@ def partially_unload(self, device_to, memory_to_free=0): bias_key = "{}.bias".format(n) if move_weight: cast_weight = self.force_cast_weights - m.to(device_to) + if need_mmap(): + if get_free_disk() < module_mem: + logging.warning(f"Not enough disk space to offload {n} to mmap, current free disk space {get_free_disk()/(1024*1024*1024)} GB < {module_mem/(1024*1024*1024)} GB") + break + # offload to mmap + model_to_mmap(m) + else: + m.to(device_to) module_mem += move_weight_functions(m, device_to) if lowvram_possible: if weight_key in self.patches: diff --git a/comfy/sd.py b/comfy/sd.py index 28bee248dae1..3651da5e7552 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1321,6 +1321,7 @@ def load_diffusion_model_state_dict(sd, model_options={}): logging.warning("{} {}".format(diffusers_keys[k], k)) offload_device = model_management.unet_offload_device() + logging.debug(f"loader load model to offload device: {offload_device}") unet_weight_dtype = list(model_config.supported_inference_dtypes) if model_config.scaled_fp8 is not None: weight_dtype = None diff --git a/comfy/utils.py b/comfy/utils.py index 0fd03f165b7c..be6ab759655e 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -60,6 +60,8 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): metadata = None if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): try: + if not DISABLE_MMAP: + logging.debug(f"load_torch_file of safetensors into mmap True") with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f: sd = {} for k in f.keys(): @@ -80,6 +82,7 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): else: torch_args = {} if MMAP_TORCH_FILES: + logging.debug(f"load_torch_file of torch state dict into mmap True") torch_args["mmap"] = True if safe_load or ALWAYS_SAFE_LOAD: diff --git a/tests/execution/test_model_mmap.py b/tests/execution/test_model_mmap.py new file mode 100644 index 000000000000..7a608c9316b1 --- /dev/null +++ b/tests/execution/test_model_mmap.py @@ -0,0 +1,282 @@ +import pytest +import torch +import torch.nn as nn +import psutil +import os +import gc +import tempfile +from comfy.model_patcher import model_to_mmap, to_mmap + + +class LargeModel(nn.Module): + """A simple model with large parameters for testing memory mapping""" + + def __init__(self, size_gb=10): + super().__init__() + # Calculate number of float32 elements needed for target size + # 1 GB = 1024^3 bytes, float32 = 4 bytes + bytes_per_gb = 1024 * 1024 * 1024 + elements_per_gb = bytes_per_gb // 4 # float32 is 4 bytes + total_elements = int(size_gb * elements_per_gb) + + # Create a large linear layer + # Split into multiple layers to avoid single tensor size limits + self.layers = nn.ModuleList() + elements_per_layer = 500 * 1024 * 1024 # 500M elements per layer (~2GB) + num_layers = (total_elements + elements_per_layer - 1) // elements_per_layer + + for i in range(num_layers): + if i == num_layers - 1: + # Last layer gets the remaining elements + remaining = total_elements - (i * elements_per_layer) + in_features = int(remaining ** 0.5) + out_features = (remaining + in_features - 1) // in_features + else: + in_features = int(elements_per_layer ** 0.5) + out_features = (elements_per_layer + in_features - 1) // in_features + + # Create layer without bias to control size precisely + self.layers.append(nn.Linear(in_features, out_features, bias=False)) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + + +def get_process_memory_gb(): + """Get current process memory usage in GB""" + process = psutil.Process(os.getpid()) + mem_info = process.memory_info() + return mem_info.rss / (1024 ** 3) # Convert to GB + + +def get_model_size_gb(model): + """Calculate model size in GB""" + total_size = 0 + for param in model.parameters(): + total_size += param.nelement() * param.element_size() + for buffer in model.buffers(): + total_size += buffer.nelement() * buffer.element_size() + return total_size / (1024 ** 3) + + +def test_model_to_mmap_memory_efficiency(): + """Test that model_to_mmap reduces memory usage for a 10GB model to less than 1GB + + The typical use case is: + 1. Load a large model on CUDA + 2. Convert to mmap to offload from GPU to disk-backed memory + 3. This frees GPU memory and reduces CPU RAM usage + """ + + # Check if CUDA is available + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available, skipping test") + + # Force garbage collection before starting + gc.collect() + torch.cuda.empty_cache() + + # Record initial memory + initial_cpu_memory = get_process_memory_gb() + initial_gpu_memory = torch.cuda.memory_allocated() / (1024 ** 3) + print(f"\nInitial CPU memory: {initial_cpu_memory:.2f} GB") + print(f"Initial GPU memory: {initial_gpu_memory:.2f} GB") + + # Create a 10GB model + print("Creating 10GB model...") + model = LargeModel(size_gb=10) + + # Verify model size + model_size = get_model_size_gb(model) + print(f"Model size: {model_size:.2f} GB") + assert model_size >= 9.5, f"Model size {model_size:.2f} GB is less than expected 10 GB" + + # Move model to CUDA + print("Moving model to CUDA...") + model = model.cuda() + torch.cuda.synchronize() + + # Memory after moving to CUDA + cpu_after_cuda = get_process_memory_gb() + gpu_after_cuda = torch.cuda.memory_allocated() / (1024 ** 3) + print(f"CPU memory after moving to CUDA: {cpu_after_cuda:.2f} GB") + print(f"GPU memory after moving to CUDA: {gpu_after_cuda:.2f} GB") + + # Convert to mmap (this should move model from GPU to disk-backed memory) + # Note: model_to_mmap modifies the model in-place via _apply() + # so model and model_mmap will be the same object + print("Converting model to mmap...") + model_mmap = model_to_mmap(model) + + # Verify that model and model_mmap are the same object (in-place modification) + assert model is model_mmap, "model_to_mmap should modify the model in-place" + + # Force garbage collection and clear CUDA cache + # The original CUDA tensors should be automatically freed when replaced + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + + # Memory after mmap conversion + cpu_after_mmap = get_process_memory_gb() + gpu_after_mmap = torch.cuda.memory_allocated() / (1024 ** 3) + print(f"CPU memory after mmap: {cpu_after_mmap:.2f} GB") + print(f"GPU memory after mmap: {gpu_after_mmap:.2f} GB") + + # Calculate memory changes from CUDA state (the baseline we're converting from) + cpu_increase = cpu_after_mmap - cpu_after_cuda + gpu_decrease = gpu_after_cuda - gpu_after_mmap # Should be positive (freed) + print(f"\nCPU memory increase from CUDA: {cpu_increase:.2f} GB") + print(f"GPU memory freed: {gpu_decrease:.2f} GB") + + # Verify that CPU memory usage increase is less than 1GB + # The mmap should use disk-backed storage, keeping CPU RAM usage low + # We use 1.5 GB threshold to account for overhead + assert cpu_increase < 1.5, ( + f"CPU memory increase after mmap ({cpu_increase:.2f} GB) should be less than 1.5 GB. " + f"CUDA state: {cpu_after_cuda:.2f} GB, After mmap: {cpu_after_mmap:.2f} GB" + ) + + # Verify that GPU memory has been freed + # We expect at least 9 GB to be freed (original 10GB model with some tolerance) + assert gpu_decrease > 9.0, ( + f"GPU memory should be freed after mmap. " + f"Freed: {gpu_decrease:.2f} GB (from {gpu_after_cuda:.2f} to {gpu_after_mmap:.2f} GB), expected > 9 GB" + ) + + # Verify the model is still functional (basic sanity check) + assert model_mmap is not None + assert len(list(model_mmap.parameters())) > 0 + + print(f"\n✓ Test passed!") + print(f" CPU memory increase: {cpu_increase:.2f} GB < 1.5 GB") + print(f" GPU memory freed: {gpu_decrease:.2f} GB > 9.0 GB") + print(f" Model successfully offloaded from GPU to disk-backed memory") + + # Cleanup (model and model_mmap are the same object) + del model, model_mmap + gc.collect() + torch.cuda.empty_cache() + + +def test_to_mmap_cuda_cycle(): + """Test CUDA -> mmap -> CUDA cycle + + This test verifies: + 1. CUDA tensor can be converted to mmap tensor + 2. CPU memory increase is minimal when using mmap (< 0.1 GB) + 3. GPU memory is freed when converting to mmap + 4. mmap tensor can be moved back to CUDA + 5. Data remains consistent throughout the cycle + 6. mmap file is automatically cleaned up via garbage collection + """ + + # Check if CUDA is available + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available, skipping test") + + # Force garbage collection + gc.collect() + torch.cuda.empty_cache() + + print("\nTest: CUDA -> mmap -> CUDA cycle") + + # Record initial CPU memory + initial_cpu_memory = get_process_memory_gb() + print(f"Initial CPU memory: {initial_cpu_memory:.2f} GB") + + # Step 1: Create a CUDA tensor + print("\n1. Creating CUDA tensor...") + original_data = torch.randn(5000, 5000).cuda() + original_sum = original_data.sum().item() + print(f" Shape: {original_data.shape}") + print(f" Device: {original_data.device}") + print(f" Sum: {original_sum:.2f}") + + # Record GPU and CPU memory after CUDA allocation + cpu_after_cuda = get_process_memory_gb() + gpu_before_mmap = torch.cuda.memory_allocated() / (1024 ** 3) + print(f" GPU memory: {gpu_before_mmap:.2f} GB") + print(f" CPU memory: {cpu_after_cuda:.2f} GB") + + # Step 2: Convert to mmap tensor + print("\n2. Converting to mmap tensor...") + mmap_tensor = to_mmap(original_data) + del original_data + gc.collect() + torch.cuda.empty_cache() + + print(f" Device: {mmap_tensor.device}") + print(f" Sum: {mmap_tensor.sum().item():.2f}") + + # Verify GPU memory is freed + gpu_after_mmap = torch.cuda.memory_allocated() / (1024 ** 3) + cpu_after_mmap = get_process_memory_gb() + print(f" GPU memory freed: {gpu_before_mmap - gpu_after_mmap:.2f} GB") + print(f" CPU memory: {cpu_after_mmap:.2f} GB") + + # Verify GPU memory is freed + assert gpu_after_mmap < 0.1, f"GPU memory should be freed, but {gpu_after_mmap:.2f} GB still allocated" + + # Verify CPU memory increase is minimal (should be close to 0 due to mmap) + cpu_increase = cpu_after_mmap - cpu_after_cuda + print(f" CPU memory increase: {cpu_increase:.2f} GB") + assert cpu_increase < 0.1, f"CPU memory should increase minimally, but increased by {cpu_increase:.2f} GB" + + # Get the temp file path (we'll check if it gets cleaned up) + # The file should exist at this point + temp_files_before = len([f for f in os.listdir(tempfile.gettempdir()) if f.startswith('comfy_mmap_')]) + print(f" Temp mmap files exist: {temp_files_before}") + + # Step 3: Move back to CUDA + print("\n3. Moving back to CUDA...") + cuda_tensor = mmap_tensor.to('cuda') + torch.cuda.synchronize() + + print(f" Device: {cuda_tensor.device}") + final_sum = cuda_tensor.sum().item() + print(f" Sum: {final_sum:.2f}") + + # Verify GPU memory is used again + gpu_after_cuda = torch.cuda.memory_allocated() / (1024 ** 3) + print(f" GPU memory: {gpu_after_cuda:.2f} GB") + + # Step 4: Verify data consistency + print("\n4. Verifying data consistency...") + sum_diff = abs(original_sum - final_sum) + print(f" Original sum: {original_sum:.2f}") + print(f" Final sum: {final_sum:.2f}") + print(f" Difference: {sum_diff:.6f}") + assert sum_diff < 0.01, f"Data should be consistent, but difference is {sum_diff:.6f}" + + # Step 5: Verify file cleanup (delayed until garbage collection) + print("\n5. Verifying file cleanup...") + # Delete the mmap tensor reference to trigger garbage collection + del mmap_tensor + gc.collect() + import time + time.sleep(0.1) # Give OS time to clean up + temp_files_after = len([f for f in os.listdir(tempfile.gettempdir()) if f.startswith('comfy_mmap_')]) + print(f" Temp mmap files after GC: {temp_files_after}") + # File should be cleaned up after garbage collection + assert temp_files_after <= temp_files_before, "mmap file should be cleaned up after garbage collection" + + print("\n✓ Test passed!") + print(" CUDA -> mmap -> CUDA cycle works correctly") + print(f" CPU memory increase: {cpu_increase:.2f} GB < 0.1 GB (mmap efficiency)") + print(" Data consistency maintained") + print(" File cleanup successful (via garbage collection)") + + # Cleanup + del cuda_tensor # mmap_tensor already deleted in Step 5 + gc.collect() + torch.cuda.empty_cache() + + +if __name__ == "__main__": + # Run the tests directly + test_model_to_mmap_memory_efficiency() + test_to_mmap_cuda_cycle() +