From 00ad59700ed09d5678ff3909dd94131b65b41454 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 10 Nov 2024 13:01:52 -0800 Subject: [PATCH 1/6] manage buffer in compiler Signed-off-by: youkaichao --- vllm/compilation/backends.py | 39 +++++++++++++++++++++++++++++++++++- vllm/compilation/config.py | 1 + 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index c3c670422def..a1cdaea3fcf8 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -461,7 +461,44 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: self._called = True - return self.split_gm + if not self.compilation_configs.use_cudagraph or \ + not self.compilation_configs.cudagraph_copy_input_buffers: + return self.split_gm + + # if we need to copy input buffers for cudagraph + from torch._guards import detect_fake_mode + fake_mode = detect_fake_mode() + fake_args = [ + fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t + for t in example_inputs + ] + + # index of tensors that have symbolic shapes (batch size) + sym_tensor_indices = [ + i for i, x in enumerate(fake_args) + if isinstance(x, torch._subclasses.fake_tensor.FakeTensor) + ] + + # keep reference to the tensors that have symbolic shapes + # they have the maximum size among all the tensors + # and we use them as static buffers for cudagraph + tensor_buffers = [example_inputs[x] for x in sym_tensor_indices] + + def copy_and_call(*args): + list_args = list(args) + for i, index in enumerate(sym_tensor_indices): + runtime_tensor = list_args[index] + runtime_shape = runtime_tensor.shape[0] + static_tensor = tensor_buffers[i][:runtime_shape] + + # copy the tensor to the static buffer + static_tensor.copy_(runtime_tensor) + + # replace the tensor in the list_args to the static buffer + list_args[index] = static_tensor + return self.split_gm(*list_args) + + return copy_and_call @dataclasses.dataclass diff --git a/vllm/compilation/config.py b/vllm/compilation/config.py index 72377533140b..07262b5d44e0 100644 --- a/vllm/compilation/config.py +++ b/vllm/compilation/config.py @@ -78,6 +78,7 @@ class CompilationConfig(BaseModel): non_cudagraph_ops: List[str] = Field(default_factory=list) cudagraph_num_of_warmups: int = 0 cudagraph_capture_sizes: Optional[List[int]] = None + cudagraph_copy_input_buffers: bool = True dump_graph_stages: List[str] = Field(default_factory=list) dump_graph_dir: Path = Field(default=Path(".")) From 16fb5af8d52b5acea404f23168f2a0c58b97cfb6 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 10 Nov 2024 13:26:56 -0800 Subject: [PATCH 2/6] change tests Signed-off-by: youkaichao --- tests/compile/piecewise/test_simple.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index d151d62516b0..a40ef4c57c3e 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -94,13 +94,14 @@ def test_simple_piecewise_compile(): with set_compile_context([1, 2]): model(input_buffer) - model(input_buffer[:2]) - model(input_buffer[:1]) + model(torch.randn(2).cuda()) + model(torch.randn(1).cuda()) - input_buffer[:2].zero_() + input = torch.randn(2).cuda() + input.zero_() global global_counter global_counter = 0 - output = model(input_buffer[:2]) + output = model(input) assert global_counter == 2 assert torch.allclose(output.cpu(), torch.tensor([3., 1.])) From c95eea2a7955d89671e9e005cec47e9f3f1f04ad Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 10 Nov 2024 13:34:58 -0800 Subject: [PATCH 3/6] clone and manage buffers Signed-off-by: youkaichao --- vllm/compilation/backends.py | 23 +++++++++++++++-------- vllm/compilation/config.py | 7 ++++++- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index a1cdaea3fcf8..5682faa15806 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -389,6 +389,8 @@ class VllmBackend: returned_callable: Callable # Inductor passes to run on the graph pre-defunctionalization post_grad_passes: Sequence[Callable] + sym_tensor_indices: List[int] + input_buffers: List[torch.Tensor] def __init__(self, post_grad_passes: Sequence[Callable] = ()): global global_graph_pool @@ -401,6 +403,9 @@ def __init__(self, post_grad_passes: Sequence[Callable] = ()): self.graph_pool = global_graph_pool self.post_grad_passes = post_grad_passes + self.sym_tensor_indices = [] + self.input_buffers = [] + # `torch.compile` is JIT compiled, so we don't need to # do anything here @@ -462,7 +467,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: self._called = True if not self.compilation_configs.use_cudagraph or \ - not self.compilation_configs.cudagraph_copy_input_buffers: + not self.compilation_configs.cudagraph_copy_inputs: return self.split_gm # if we need to copy input buffers for cudagraph @@ -474,22 +479,24 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: ] # index of tensors that have symbolic shapes (batch size) - sym_tensor_indices = [ + self.sym_tensor_indices = [ i for i, x in enumerate(fake_args) if isinstance(x, torch._subclasses.fake_tensor.FakeTensor) ] - # keep reference to the tensors that have symbolic shapes - # they have the maximum size among all the tensors - # and we use them as static buffers for cudagraph - tensor_buffers = [example_inputs[x] for x in sym_tensor_indices] + # compiler managed cudagraph input buffers + # we assume the first run with symbolic shapes + # has the maximum size among all the tensors + self.input_buffers = [ + example_inputs[x].clone() for x in self.sym_tensor_indices + ] def copy_and_call(*args): list_args = list(args) - for i, index in enumerate(sym_tensor_indices): + for i, index in enumerate(self.sym_tensor_indices): runtime_tensor = list_args[index] runtime_shape = runtime_tensor.shape[0] - static_tensor = tensor_buffers[i][:runtime_shape] + static_tensor = self.input_buffers[i][:runtime_shape] # copy the tensor to the static buffer static_tensor.copy_(runtime_tensor) diff --git a/vllm/compilation/config.py b/vllm/compilation/config.py index 07262b5d44e0..55e2c933dee3 100644 --- a/vllm/compilation/config.py +++ b/vllm/compilation/config.py @@ -32,6 +32,11 @@ class CompilationConfig(BaseModel): It means the first several runs will be treated as warmup runs. Only after that, the execution will be recorded, and the recorded cudagraph will be used for subsequent runs. + - cudagraph_copy_inputs: whether to copy input tensors for + cudagraph. If the caller can guarantee that the same input buffers + are always used, it can set this to False. Otherwise, it should + set this to True, and the compiler will copy the input to an + internally managed buffer. - Inductor compilation: - use_inductor: whether to use inductor compilation. - False: inductor compilation is not used. graph runs in eager. @@ -78,7 +83,7 @@ class CompilationConfig(BaseModel): non_cudagraph_ops: List[str] = Field(default_factory=list) cudagraph_num_of_warmups: int = 0 cudagraph_capture_sizes: Optional[List[int]] = None - cudagraph_copy_input_buffers: bool = True + cudagraph_copy_inputs: bool = True dump_graph_stages: List[str] = Field(default_factory=list) dump_graph_dir: Path = Field(default=Path(".")) From a60006a09f591a4848dc86b55b531824af0bf6f5 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 10 Nov 2024 17:56:15 -0800 Subject: [PATCH 4/6] Update tests/compile/piecewise/test_simple.py Co-authored-by: Woosuk Kwon --- tests/compile/piecewise/test_simple.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index a40ef4c57c3e..286be82acaa2 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -97,8 +97,7 @@ def test_simple_piecewise_compile(): model(torch.randn(2).cuda()) model(torch.randn(1).cuda()) - input = torch.randn(2).cuda() - input.zero_() + input = torch.zeros(2).cuda() global global_counter global_counter = 0 output = model(input) From 470bff5b4ffd8e297fe5afbc217221b5ae80656c Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 10 Nov 2024 17:57:26 -0800 Subject: [PATCH 5/6] rename Signed-off-by: youkaichao --- tests/compile/piecewise/test_simple.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index 286be82acaa2..fcfe80d8e404 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -80,7 +80,7 @@ def test_simple_piecewise_compile(): config = os.path.join(directory, "piecewise_compilation_config.json") os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config - input_buffer = torch.randn(100).cuda() + inputs = torch.randn(100).cuda() with compilation_counter.expect( num_graphs_seen=1, # one graph for the model @@ -92,7 +92,7 @@ def test_simple_piecewise_compile(): ): with set_compile_context([1, 2]): - model(input_buffer) + model(inputs) model(torch.randn(2).cuda()) model(torch.randn(1).cuda()) From 232aa5aa3e217231a5b2f6b2587cc302b7ef61ef Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 11 Nov 2024 10:25:51 -0800 Subject: [PATCH 6/6] default to false Signed-off-by: youkaichao --- tests/compile/piecewise/piecewise_compilation_config.json | 3 ++- vllm/compilation/config.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/compile/piecewise/piecewise_compilation_config.json b/tests/compile/piecewise/piecewise_compilation_config.json index 03d077b76f62..798a34e8dd92 100644 --- a/tests/compile/piecewise/piecewise_compilation_config.json +++ b/tests/compile/piecewise/piecewise_compilation_config.json @@ -1,4 +1,5 @@ { "use_cudagraph": true, - "non_cudagraph_ops": ["silly.attention"] + "non_cudagraph_ops": ["silly.attention"], + "cudagraph_copy_inputs": true } \ No newline at end of file diff --git a/vllm/compilation/config.py b/vllm/compilation/config.py index 55e2c933dee3..3e663505c627 100644 --- a/vllm/compilation/config.py +++ b/vllm/compilation/config.py @@ -36,7 +36,7 @@ class CompilationConfig(BaseModel): cudagraph. If the caller can guarantee that the same input buffers are always used, it can set this to False. Otherwise, it should set this to True, and the compiler will copy the input to an - internally managed buffer. + internally managed buffer. Default is False. - Inductor compilation: - use_inductor: whether to use inductor compilation. - False: inductor compilation is not used. graph runs in eager. @@ -83,7 +83,7 @@ class CompilationConfig(BaseModel): non_cudagraph_ops: List[str] = Field(default_factory=list) cudagraph_num_of_warmups: int = 0 cudagraph_capture_sizes: Optional[List[int]] = None - cudagraph_copy_inputs: bool = True + cudagraph_copy_inputs: bool = False dump_graph_stages: List[str] = Field(default_factory=list) dump_graph_dir: Path = Field(default=Path("."))