|
2 | 2 |
|
3 | 3 | import tensorrt as trt
|
4 | 4 | import torch
|
| 5 | +import torch.distributed as dist |
5 | 6 | import torch.nn as nn
|
6 | 7 | import torch_tensorrt
|
7 | 8 | from distributed_utils import initialize_distributed_env
|
|
16 | 17 | "./tensor_parallel_simple_example"
|
17 | 18 | )
|
18 | 19 |
|
19 |
| - |
20 |
| -def compile_tp_model(tp_model, backend): |
21 |
| - compile_options = { |
22 |
| - "truncate_long_and_double": True, |
23 |
| - "enabled_precisions": {torch.float32, torch.float16}, |
24 |
| - "use_python_runtime": True, |
25 |
| - "min_block_size": 1, |
26 |
| - } |
27 |
| - |
28 |
| - try: |
29 |
| - return torch.compile( |
30 |
| - tp_model, backend=backend, options=compile_options, dynamic=None |
31 |
| - ) |
32 |
| - except RuntimeError as e: |
33 |
| - if ( |
34 |
| - "aot_export is not currently supported with traceable tensor subclass" |
35 |
| - in str(e) |
36 |
| - ): |
37 |
| - logger.warning( |
38 |
| - "It is recommended to run the model with use_distributed_mode_trace=True. Running with that option" |
39 |
| - ) |
40 |
| - compile_options["use_distributed_mode_trace"] = True |
41 |
| - return torch.compile( |
42 |
| - tp_model, backend=backend, options=compile_options, dynamic=None |
43 |
| - ) |
44 |
| - else: |
45 |
| - logger.debug("The distributed model fails with the following error") |
46 |
| - raise |
47 |
| - |
48 |
| - |
49 | 20 | """
|
50 | 21 | This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
|
51 | 22 | """
|
@@ -90,20 +61,37 @@ def forward(self, x):
|
90 | 61 | inp = torch.rand(20, 10, device="cuda")
|
91 | 62 | python_result = tp_model(inp)
|
92 | 63 |
|
93 |
| -compile_tp_model(tp_model, backend="torch_tensorrt") |
| 64 | +backend = "torch_tensorrt" |
| 65 | +tp_model = torch.compile( |
| 66 | + tp_model, |
| 67 | + backend=backend, |
| 68 | + options={ |
| 69 | + "truncate_long_and_double": True, |
| 70 | + "enabled_precisions": {torch.float32, torch.float16}, |
| 71 | + "use_python_runtime": True, |
| 72 | + "min_block_size": 1, |
| 73 | + "use_distributed_mode_trace": True, |
| 74 | + }, |
| 75 | + dynamic=None, |
| 76 | +) |
94 | 77 |
|
95 |
| -for i in range(10): |
96 |
| - # For TP, input needs to be same across all TP ranks. |
97 |
| - # Setting the random seed is to mimic the behavior of dataloader. |
98 |
| - torch.manual_seed(i) |
99 |
| - inp = torch.rand(20, 10, device="cuda") |
100 |
| - start = time.time() |
101 |
| - output = tp_model(inp) |
102 |
| - end = time.time() |
103 |
| - if i == 0: |
104 |
| - logger.info(f"Compilation time is {end-start}") |
105 |
| - assert ( |
106 |
| - python_result - output |
107 |
| - ).std() < 0.01, "Compilation result is not correct." |
108 |
| - elif _rank == 0: |
109 |
| - logger.info(f"Inference time is {end-start}") |
| 78 | +try: |
| 79 | + for i in range(10): |
| 80 | + # For TP, input needs to be same across all TP ranks. |
| 81 | + # Setting the random seed is to mimic the behavior of dataloader. |
| 82 | + torch.manual_seed(i) |
| 83 | + inp = torch.rand(20, 10, device="cuda") |
| 84 | + start = time.time() |
| 85 | + output = tp_model(inp) |
| 86 | + end = time.time() |
| 87 | + if i == 0: |
| 88 | + logger.info(f"Compilation time is {end-start}") |
| 89 | + assert ( |
| 90 | + python_result - output |
| 91 | + ).std() < 0.01, "Compilation result is not correct." |
| 92 | + elif _rank == 0: |
| 93 | + logger.info(f"Inference time is {end-start}") |
| 94 | +finally: |
| 95 | + # This cleans up the distributed process group |
| 96 | + if dist.is_initialized(): |
| 97 | + dist.destroy_process_group() |
0 commit comments