Skip to content

Commit 771e9d2

Browse files
committedMar 21, 2025
Option 1 further modification- recommended set of options in test/examples, including the warning in backend, destroying processes while exiting
1 parent f43d17e commit 771e9d2

File tree

3 files changed

+73
-91
lines changed

3 files changed

+73
-91
lines changed
 

‎examples/distributed_inference/tensor_parallel_simple_example.py

+34-45
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import tensorrt as trt
44
import torch
5+
import torch.distributed as dist
56
import torch.nn as nn
67
import torch_tensorrt
78
from tensor_parallel_initialize_dist import initialize_distributed_env
@@ -21,35 +22,6 @@
2122
"""
2223

2324

24-
def compile_tp_model(tp_model, backend):
25-
compile_options = {
26-
"truncate_long_and_double": True,
27-
"enabled_precisions": {torch.float32, torch.float16},
28-
"use_python_runtime": True,
29-
"min_block_size": 1,
30-
}
31-
32-
try:
33-
return torch.compile(
34-
tp_model, backend=backend, options=compile_options, dynamic=None
35-
)
36-
except RuntimeError as e:
37-
if (
38-
"aot_export is not currently supported with traceable tensor subclass"
39-
in str(e)
40-
):
41-
logger.warning(
42-
"It is recommended to run the model with use_distributed_mode_trace=True. Running with that option"
43-
)
44-
compile_options["use_distributed_mode_trace"] = True
45-
return torch.compile(
46-
tp_model, backend=backend, options=compile_options, dynamic=None
47-
)
48-
else:
49-
logger.debug("The distributed model fails with the following error")
50-
raise
51-
52-
5325
class ToyModel(nn.Module):
5426
"""MLP based model"""
5527

@@ -93,20 +65,37 @@ def forward(self, x):
9365
inp = torch.rand(20, 10, device="cuda")
9466
python_result = tp_model(inp)
9567

96-
compile_tp_model(tp_model, backend="torch_tensorrt")
68+
backend = "torch_tensorrt"
69+
tp_model = torch.compile(
70+
tp_model,
71+
backend=backend,
72+
options={
73+
"truncate_long_and_double": True,
74+
"enabled_precisions": {torch.float32, torch.float16},
75+
"use_python_runtime": True,
76+
"min_block_size": 1,
77+
"use_distributed_mode_trace": True,
78+
},
79+
dynamic=None,
80+
)
9781

98-
for i in range(10):
99-
# For TP, input needs to be same across all TP ranks.
100-
# Setting the random seed is to mimic the behavior of dataloader.
101-
torch.manual_seed(i)
102-
inp = torch.rand(20, 10, device="cuda")
103-
start = time.time()
104-
output = tp_model(inp)
105-
end = time.time()
106-
if i == 0:
107-
logger.info(f"Compilation time is {end-start}")
108-
assert (
109-
python_result - output
110-
).std() < 0.01, "Compilation result is not correct."
111-
elif _rank == 0:
112-
logger.info(f"Inference time is {end-start}")
82+
try:
83+
for i in range(10):
84+
# For TP, input needs to be same across all TP ranks.
85+
# Setting the random seed is to mimic the behavior of dataloader.
86+
torch.manual_seed(i)
87+
inp = torch.rand(20, 10, device="cuda")
88+
start = time.time()
89+
output = tp_model(inp)
90+
end = time.time()
91+
if i == 0:
92+
logger.info(f"Compilation time is {end-start}")
93+
assert (
94+
python_result - output
95+
).std() < 0.01, "Compilation result is not correct."
96+
elif _rank == 0:
97+
logger.info(f"Inference time is {end-start}")
98+
finally:
99+
# This cleans up the distributed process group
100+
if dist.is_initialized():
101+
dist.destroy_process_group()

‎py/torch_tensorrt/dynamo/backend/backends.py

+5
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from torch._dynamo.backends.common import aot_autograd
1111
from torch._dynamo.utils import detect_fake_mode
1212
from torch._functorch.aot_autograd import aot_export_joint_simple
13+
from torch.distributed.tensor import DTensor
1314
from torch_tensorrt.dynamo import CompilationSettings
1415
from torch_tensorrt.dynamo._compiler import compile_module
1516
from torch_tensorrt.dynamo.lowering import (
@@ -79,6 +80,10 @@ def aot_torch_tensorrt_aten_backend(
7980
fw_compiler=_pretraced_backend_autograd,
8081
decompositions=settings_aot_autograd["decompositions"],
8182
)(gm, sample_inputs)
83+
if any(isinstance(tensor, DTensor) for tensor in sample_inputs):
84+
logger.warning(
85+
"It is recommended to run the model with use_distributed_mode_trace = True since there are distributed tensors in the input which is not supported aot_export_joint_simple"
86+
)
8287
return _pretraced_backend(gm, sample_inputs, settings, engine_cache)
8388

8489

‎tests/py/dynamo/distributed/test_distributed_simple_example.py

+34-46
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import tensorrt as trt
44
import torch
5+
import torch.distributed as dist
56
import torch.nn as nn
67
import torch_tensorrt
78
from distributed_utils import initialize_distributed_env
@@ -16,36 +17,6 @@
1617
"./tensor_parallel_simple_example"
1718
)
1819

19-
20-
def compile_tp_model(tp_model, backend):
21-
compile_options = {
22-
"truncate_long_and_double": True,
23-
"enabled_precisions": {torch.float32, torch.float16},
24-
"use_python_runtime": True,
25-
"min_block_size": 1,
26-
}
27-
28-
try:
29-
return torch.compile(
30-
tp_model, backend=backend, options=compile_options, dynamic=None
31-
)
32-
except RuntimeError as e:
33-
if (
34-
"aot_export is not currently supported with traceable tensor subclass"
35-
in str(e)
36-
):
37-
logger.warning(
38-
"It is recommended to run the model with use_distributed_mode_trace=True. Running with that option"
39-
)
40-
compile_options["use_distributed_mode_trace"] = True
41-
return torch.compile(
42-
tp_model, backend=backend, options=compile_options, dynamic=None
43-
)
44-
else:
45-
logger.debug("The distributed model fails with the following error")
46-
raise
47-
48-
4920
"""
5021
This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
5122
"""
@@ -90,20 +61,37 @@ def forward(self, x):
9061
inp = torch.rand(20, 10, device="cuda")
9162
python_result = tp_model(inp)
9263

93-
compile_tp_model(tp_model, backend="torch_tensorrt")
64+
backend = "torch_tensorrt"
65+
tp_model = torch.compile(
66+
tp_model,
67+
backend=backend,
68+
options={
69+
"truncate_long_and_double": True,
70+
"enabled_precisions": {torch.float32, torch.float16},
71+
"use_python_runtime": True,
72+
"min_block_size": 1,
73+
"use_distributed_mode_trace": True,
74+
},
75+
dynamic=None,
76+
)
9477

95-
for i in range(10):
96-
# For TP, input needs to be same across all TP ranks.
97-
# Setting the random seed is to mimic the behavior of dataloader.
98-
torch.manual_seed(i)
99-
inp = torch.rand(20, 10, device="cuda")
100-
start = time.time()
101-
output = tp_model(inp)
102-
end = time.time()
103-
if i == 0:
104-
logger.info(f"Compilation time is {end-start}")
105-
assert (
106-
python_result - output
107-
).std() < 0.01, "Compilation result is not correct."
108-
elif _rank == 0:
109-
logger.info(f"Inference time is {end-start}")
78+
try:
79+
for i in range(10):
80+
# For TP, input needs to be same across all TP ranks.
81+
# Setting the random seed is to mimic the behavior of dataloader.
82+
torch.manual_seed(i)
83+
inp = torch.rand(20, 10, device="cuda")
84+
start = time.time()
85+
output = tp_model(inp)
86+
end = time.time()
87+
if i == 0:
88+
logger.info(f"Compilation time is {end-start}")
89+
assert (
90+
python_result - output
91+
).std() < 0.01, "Compilation result is not correct."
92+
elif _rank == 0:
93+
logger.info(f"Inference time is {end-start}")
94+
finally:
95+
# This cleans up the distributed process group
96+
if dist.is_initialized():
97+
dist.destroy_process_group()

0 commit comments

Comments
 (0)
Failed to load comments.