diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index de32cabbe6d0..05deee7bd547 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -1,7 +1,9 @@ import copy import dataclasses import operator +from contextlib import ExitStack from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from unittest.mock import patch import torch import torch.fx as fx @@ -503,17 +505,29 @@ def __call__(self, *args) -> Any: entry.input_addresses = input_addresses cudagraph = torch.cuda.CUDAGraph() - # mind-exploding: carefully manage the reference and memory. - with torch.cuda.graph(cudagraph, pool=self.graph_pool): - # `output` is managed by pytorch's cudagraph pool - output = entry.runnable(*args) - if self.is_last_graph: - # by converting it to weak ref, - # the original `output` will immediately be released - # to save memory. It is only safe to do this for - # the last graph, because the output of the last graph - # will not be used by any other cuda graph. - output = weak_ref_tensors(output) + with ExitStack() as stack: + if not self.is_first_graph: + # during every model forward, we will capture + # many pieces of cudagraphs (roughly one per layer). + # running gc again and again across layers will + # make the cudagraph capture very slow. + # therefore, we only run gc for the first graph, + # and disable gc for the rest of the graphs. + stack.enter_context(patch("gc.collect", lambda: None)) + stack.enter_context( + patch("torch.cuda.empty_cache", lambda: None)) + + # mind-exploding: carefully manage the reference and memory. + with torch.cuda.graph(cudagraph, pool=self.graph_pool): + # `output` is managed by pytorch's cudagraph pool + output = entry.runnable(*args) + if self.is_last_graph: + # by converting it to weak ref, + # the original `output` will immediately be released + # to save memory. It is only safe to do this for + # the last graph, because the output of the last graph + # will not be used by any other cuda graph. + output = weak_ref_tensors(output) # here we always use weak ref for the output # to save memory