Skip to content

Commit 091c83f

Browse files
committedFeb 28, 2025
adding the test script and correction to the backend
1 parent 31cf035 commit 091c83f

File tree

5 files changed

+306
-12
lines changed

5 files changed

+306
-12
lines changed
 

‎py/torch_tensorrt/dynamo/backend/backends.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,12 @@ def aot_torch_tensorrt_aten_backend(
6969
to_delete = {
7070
key
7171
for key in settings_aot_autograd["decompositions"]
72-
if "transpose" in key._name
72+
if "transpose" in key._name or "detach" in key._name
7373
}
7474

7575
for key in to_delete:
7676
del settings_aot_autograd["decompositions"][key]
7777

78-
remove_detach(gm, settings)
7978
return aot_autograd(
8079
fw_compiler=_pretraced_backend_autograd,
8180
decompositions=settings_aot_autograd["decompositions"],
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import logging
2+
import os
3+
4+
import numpy as np
5+
import tensorrt as trt
6+
import torch
7+
import torch.distributed as dist
8+
from torch.distributed._tensor.device_mesh import init_device_mesh
9+
10+
11+
def set_environment_variables_pytest():
12+
os.environ["WORLD_SIZE"] = str(1)
13+
os.environ["RANK"] = str(0)
14+
os.environ["MASTER_ADDR"] = "127.0.0.1"
15+
os.environ["MASTER_PORT"] = str(29500)
16+
os.environ["USE_TRTLLM_PLUGINS"] = "1"
17+
18+
19+
def find_repo_root(max_depth=10):
20+
dir_path = os.path.dirname(os.path.realpath(__file__))
21+
for i in range(max_depth):
22+
files = os.listdir(dir_path)
23+
if "MODULE.bazel" in files:
24+
return dir_path
25+
else:
26+
dir_path = os.path.dirname(dir_path)
27+
28+
raise RuntimeError("Could not find repo root")
29+
30+
31+
def initialize_logger(rank, logger_file_name):
32+
logger = logging.getLogger()
33+
logger.setLevel(logging.INFO)
34+
fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w")
35+
fh.setLevel(logging.INFO)
36+
logger.addHandler(fh)
37+
return logger
38+
39+
40+
# This is required for env initialization since we use mpirun
41+
def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=29500):
42+
local_rank = int(
43+
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
44+
)
45+
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size))
46+
47+
# Set up environment variable to run with mpirun
48+
os.environ["RANK"] = str(local_rank)
49+
os.environ["WORLD_SIZE"] = str(world_size)
50+
os.environ["MASTER_ADDR"] = "127.0.0.1"
51+
os.environ["MASTER_PORT"] = str(port)
52+
os.environ["TRTLLM_PLUGINS_PATH"] = (
53+
find_repo_root() + "/lib/libnvinfer_plugin_tensorrt_llm.so"
54+
)
55+
56+
# Necessary to assign a device to each rank.
57+
torch.cuda.set_device(local_rank)
58+
59+
# We use nccl backend
60+
dist.init_process_group("nccl")
61+
62+
# set a manual seed for reproducibility
63+
torch.manual_seed(1111)
64+
65+
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,))
66+
rank = device_mesh.get_rank()
67+
assert rank == local_rank
68+
logger = initialize_logger(rank, logger_file_name)
69+
device_id = (
70+
rank % torch.cuda.device_count()
71+
) # Ensure each rank gets a unique device
72+
torch.cuda.set_device(device_id)
73+
74+
return device_mesh, world_size, rank, logger
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import time
2+
3+
import tensorrt as trt
4+
import torch
5+
import torch.nn as nn
6+
import torch_tensorrt
7+
from distributed_utils import initialize_distributed_env
8+
from torch.distributed._tensor import Shard
9+
from torch.distributed.tensor.parallel import (
10+
ColwiseParallel,
11+
RowwiseParallel,
12+
parallelize_module,
13+
)
14+
15+
device_mesh, _world_size, _rank, logger = initialize_distributed_env(
16+
"./tensor_parallel_simple_example"
17+
)
18+
19+
"""
20+
This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
21+
"""
22+
23+
24+
class ToyModel(nn.Module):
25+
"""MLP based model"""
26+
27+
def __init__(self):
28+
super(ToyModel, self).__init__()
29+
self.in_proj = nn.Linear(10, 3200)
30+
self.relu = nn.ReLU()
31+
self.out_proj = nn.Linear(3200, 1600)
32+
self.in_proj2 = nn.Linear(1600, 500)
33+
self.out_proj2 = nn.Linear(500, 100)
34+
35+
def forward(self, x):
36+
x = self.out_proj(self.relu(self.in_proj(x)))
37+
x = self.relu(x)
38+
x = self.out_proj2(self.relu(self.in_proj2(x)))
39+
return x
40+
41+
42+
logger.info(f"Starting PyTorch TP example on rank {_rank}.")
43+
44+
# # create model and move it to GPU - init"cuda"_mesh has already mapped GPU ids.
45+
tp_model = ToyModel().to("cuda")
46+
47+
48+
# Custom parallelization plan for the model
49+
tp_model = parallelize_module(
50+
module=tp_model,
51+
device_mesh=device_mesh,
52+
parallelize_plan={
53+
"in_proj": ColwiseParallel(input_layouts=Shard(0)),
54+
"out_proj": RowwiseParallel(output_layouts=Shard(0)),
55+
"in_proj2": ColwiseParallel(input_layouts=Shard(0)),
56+
"out_proj2": RowwiseParallel(output_layouts=Shard(0)),
57+
},
58+
)
59+
torch.manual_seed(0)
60+
inp = torch.rand(20, 10, device="cuda")
61+
python_result = tp_model(inp)
62+
63+
64+
backend = "torch_tensorrt"
65+
tp_model = torch.compile(
66+
tp_model,
67+
backend=backend,
68+
options={
69+
"truncate_long_and_double": True,
70+
"enabled_precisions": {torch.float32, torch.float16},
71+
"use_python_runtime": True,
72+
"min_block_size": 1,
73+
"use_aot_joint_export": False,
74+
},
75+
dynamic=False,
76+
)
77+
78+
for i in range(10):
79+
# For TP, input needs to be same across all TP ranks.
80+
# Setting the random seed is to mimic the behavior of dataloader.
81+
torch.manual_seed(i)
82+
inp = torch.rand(20, 10, device="cuda")
83+
start = time.time()
84+
output = tp_model(inp)
85+
end = time.time()
86+
if i == 0:
87+
logger.info(f"Compilation time is {end-start}")
88+
assert (
89+
python_result - output
90+
).std() < 0.01, "Compilation result is not correct."
91+
elif _rank == 0:
92+
logger.info(f"Inference time is {end-start}")

‎tests/py/dynamo/distributed/test_nccl_ops.py

+2-10
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,11 @@
33
import torch
44
import torch.distributed as dist
55
import torch.nn as nn
6+
from distributed_utils import set_environment_variables_pytest
67
from parameterized import parameterized
78
from torch.testing._internal.common_utils import run_tests
89

9-
10-
def set_environment_variables():
11-
os.environ["WORLD_SIZE"] = str(1)
12-
os.environ["RANK"] = str(0)
13-
os.environ["MASTER_ADDR"] = "127.0.0.1"
14-
os.environ["MASTER_PORT"] = str(29500)
15-
os.environ["USE_TRTLLM_PLUGINS"] = "1"
16-
17-
18-
set_environment_variables()
10+
set_environment_variables_pytest()
1911
dist.init_process_group(backend="nccl", init_method="env://")
2012
group = dist.new_group(ranks=[0])
2113
group_name = group.group_name
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
#!/bin/bash
2+
3+
check_command() {
4+
command -v "$1" >/dev/null 2>&1
5+
}
6+
7+
ensure_installed() {
8+
local pkg="$1"
9+
if ! check_command "$pkg"; then
10+
echo "$pkg is not installed. Installing $pkg..."
11+
12+
# Determine if sudo is needed
13+
if check_command sudo; then
14+
SUDO="sudo"
15+
else
16+
SUDO=""
17+
fi
18+
19+
# Detect OS and install accordingly
20+
OS="$(uname -s)"
21+
if [[ "$OS" == "Linux" ]]; then
22+
if check_command apt-get; then
23+
$SUDO apt-get update && $SUDO apt-get install -y "$pkg"
24+
fi
25+
else
26+
echo "Unsupported OS: $OS. Please install $pkg manually."
27+
exit 1
28+
fi
29+
else
30+
echo "$pkg is already installed."
31+
fi
32+
}
33+
34+
ensure_mpi_installed() {
35+
local pkg="$1"
36+
if dpkg -l | grep -q "$pkg"; then
37+
echo "$pkg is already installed."
38+
else
39+
echo "$pkg is not installed. Installing $pkg..."
40+
41+
# Determine if sudo is needed
42+
if check_command sudo; then
43+
SUDO="sudo"
44+
else
45+
SUDO=""
46+
fi
47+
48+
# Detect OS and install accordingly
49+
OS="$(uname -s)"
50+
if [[ "$OS" == "Linux" ]]; then
51+
if check_command apt-get; then
52+
$SUDO apt-get update && $SUDO apt-get install -y "$pkg"
53+
fi
54+
else
55+
echo "Unsupported OS: $OS. Please install $pkg manually."
56+
exit 1
57+
fi
58+
fi
59+
}
60+
61+
ensure_pytest_installed(){
62+
if check_command pip; then
63+
echo "pip is installed, installing pytest..."
64+
pip install pytest
65+
else
66+
echo "pip is not installed. Please install pip first."
67+
exit 1
68+
fi
69+
}
70+
71+
echo "Setting up the environment"
72+
73+
OS="$(uname -s)"
74+
ARCH="$(uname -m)"
75+
PYTHON_VERSION="$(python3 -c 'import sys; print(f"cp{sys.version_info.major}{sys.version_info.minor}")')"
76+
77+
78+
#getting the file name for TensorRT-LLM download
79+
if [[ "$OS" == "Linux" && "$ARCH" == "x86_64" && "$PYTHON_VERSION" == "cp312" ]]; then
80+
FILE="tensorrt_llm-0.17.0.post1-cp312-cp312-linux_x86_64.whl"
81+
elif [[ "$OS" == "Linux" && "$ARCH" == "aarch64" && "$PYTHON_VERSION" == "cp312" ]]; then
82+
FILE="tensorrt_llm-0.17.0.post1-cp312-cp312-linux_aarch64.whl"
83+
elif [[ "$OS" == "Linux" && "$ARCH" == "x86_64" && "$PYTHON_VERSION" == "cp310" ]]; then
84+
FILE="tensorrt_llm-0.17.0.post1-cp310-cp310-linux_x86_64.whl"
85+
elif [[ "$OS" == "Linux" && "$ARCH" == "aarch64" && "$PYTHON_VERSION" == "cp310" ]]; then
86+
FILE="tensorrt_llm-0.17.0.post1-cp310-cp310-linux_aarch64.whl"
87+
else:
88+
echo "Unsupported platform: OS=$OS ARCH=$ARCH PYTHON=$PYTHON_VERSION"
89+
exit 1
90+
fi
91+
92+
# Download the selected file
93+
URL="https://pypi.nvidia.com/tensorrt-llm/$FILE"
94+
echo "Downloading $FILE from $URL..."
95+
96+
echo "Downloading ...."
97+
#Installing wget
98+
ensure_installed wget
99+
#Downloading the package
100+
wget "$URL"
101+
echo "Download complete: $FILE"
102+
103+
UNZIP_DIR="tensorrt_llm_unzip"
104+
if [[ ! -d "$UNZIP_DIR" ]]; then
105+
echo "Creating directory: $UNZIP_DIR"
106+
mkdir -p "$UNZIP_DIR"
107+
echo "extracting $FILE to $UNZIP_DIR ..."
108+
#Installing unzip
109+
ensure_installed unzip
110+
#unzip the TensorRT-LLM package
111+
unzip -q "$FILE" -d "$UNZIP_DIR"
112+
echo "Unzip complete"
113+
fi
114+
115+
116+
export TRTLLM_PLUGINS_PATH="$(pwd)/${UNZIP_DIR}/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so"
117+
echo ${TRTLLM_PLUGINS_PATH}
118+
119+
ensure_mpi_installed libmpich-dev
120+
ensure_mpi_installed libopenmpi-dev
121+
122+
run_tests() {
123+
cd ..
124+
export PYTHONPATH=$(pwd) # Set PYTHONPATH to the current directory
125+
echo "Running pytest on distributed/test_nccl_ops.py..."
126+
pytest distributed/test_nccl_ops.py
127+
}
128+
129+
run_mpi_tests(){
130+
cd distributed
131+
echo "Running test_distributed_simple_example with mpirun..."---
132+
mpirun -n 1 --allow-run-as-root python test_distributed_simple_example.py
133+
}
134+
135+
ensure_pytest_installed
136+
run_tests
137+
run_mpi_tests

0 commit comments

Comments
 (0)
Failed to load comments.