diff --git a/tests/compile/piecewise/piecewise_compilation_config.json b/tests/compile/piecewise/piecewise_compilation_config.json index 03d077b76f62..798a34e8dd92 100644 --- a/tests/compile/piecewise/piecewise_compilation_config.json +++ b/tests/compile/piecewise/piecewise_compilation_config.json @@ -1,4 +1,5 @@ { "use_cudagraph": true, - "non_cudagraph_ops": ["silly.attention"] + "non_cudagraph_ops": ["silly.attention"], + "cudagraph_copy_inputs": true } \ No newline at end of file diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index d151d62516b0..fcfe80d8e404 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -80,7 +80,7 @@ def test_simple_piecewise_compile(): config = os.path.join(directory, "piecewise_compilation_config.json") os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config - input_buffer = torch.randn(100).cuda() + inputs = torch.randn(100).cuda() with compilation_counter.expect( num_graphs_seen=1, # one graph for the model @@ -92,15 +92,15 @@ def test_simple_piecewise_compile(): ): with set_compile_context([1, 2]): - model(input_buffer) + model(inputs) - model(input_buffer[:2]) - model(input_buffer[:1]) + model(torch.randn(2).cuda()) + model(torch.randn(1).cuda()) - input_buffer[:2].zero_() + input = torch.zeros(2).cuda() global global_counter global_counter = 0 - output = model(input_buffer[:2]) + output = model(input) assert global_counter == 2 assert torch.allclose(output.cpu(), torch.tensor([3., 1.])) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index c3c670422def..5682faa15806 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -389,6 +389,8 @@ class VllmBackend: returned_callable: Callable # Inductor passes to run on the graph pre-defunctionalization post_grad_passes: Sequence[Callable] + sym_tensor_indices: List[int] + input_buffers: List[torch.Tensor] def __init__(self, post_grad_passes: Sequence[Callable] = ()): global global_graph_pool @@ -401,6 +403,9 @@ def __init__(self, post_grad_passes: Sequence[Callable] = ()): self.graph_pool = global_graph_pool self.post_grad_passes = post_grad_passes + self.sym_tensor_indices = [] + self.input_buffers = [] + # `torch.compile` is JIT compiled, so we don't need to # do anything here @@ -461,7 +466,46 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: self._called = True - return self.split_gm + if not self.compilation_configs.use_cudagraph or \ + not self.compilation_configs.cudagraph_copy_inputs: + return self.split_gm + + # if we need to copy input buffers for cudagraph + from torch._guards import detect_fake_mode + fake_mode = detect_fake_mode() + fake_args = [ + fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t + for t in example_inputs + ] + + # index of tensors that have symbolic shapes (batch size) + self.sym_tensor_indices = [ + i for i, x in enumerate(fake_args) + if isinstance(x, torch._subclasses.fake_tensor.FakeTensor) + ] + + # compiler managed cudagraph input buffers + # we assume the first run with symbolic shapes + # has the maximum size among all the tensors + self.input_buffers = [ + example_inputs[x].clone() for x in self.sym_tensor_indices + ] + + def copy_and_call(*args): + list_args = list(args) + for i, index in enumerate(self.sym_tensor_indices): + runtime_tensor = list_args[index] + runtime_shape = runtime_tensor.shape[0] + static_tensor = self.input_buffers[i][:runtime_shape] + + # copy the tensor to the static buffer + static_tensor.copy_(runtime_tensor) + + # replace the tensor in the list_args to the static buffer + list_args[index] = static_tensor + return self.split_gm(*list_args) + + return copy_and_call @dataclasses.dataclass diff --git a/vllm/compilation/config.py b/vllm/compilation/config.py index 72377533140b..3e663505c627 100644 --- a/vllm/compilation/config.py +++ b/vllm/compilation/config.py @@ -32,6 +32,11 @@ class CompilationConfig(BaseModel): It means the first several runs will be treated as warmup runs. Only after that, the execution will be recorded, and the recorded cudagraph will be used for subsequent runs. + - cudagraph_copy_inputs: whether to copy input tensors for + cudagraph. If the caller can guarantee that the same input buffers + are always used, it can set this to False. Otherwise, it should + set this to True, and the compiler will copy the input to an + internally managed buffer. Default is False. - Inductor compilation: - use_inductor: whether to use inductor compilation. - False: inductor compilation is not used. graph runs in eager. @@ -78,6 +83,7 @@ class CompilationConfig(BaseModel): non_cudagraph_ops: List[str] = Field(default_factory=list) cudagraph_num_of_warmups: int = 0 cudagraph_capture_sizes: Optional[List[int]] = None + cudagraph_copy_inputs: bool = False dump_graph_stages: List[str] = Field(default_factory=list) dump_graph_dir: Path = Field(default=Path("."))