Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement FP8 functionality #2763

Merged
merged 147 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from 121 commits
Commits
Show all changes
147 commits
Select commit Hold shift + click to select a range
9ad87ac
chore: Upgrade to TRT 10.0
peri044 Mar 12, 2024
a655c9a
chore: updates to trt api
peri044 Mar 12, 2024
cd86660
feat: Add save API for torch-trt compiled models
peri044 Mar 14, 2024
31285e5
feat: Add FP8 support including dtype and converters
peri044 Mar 5, 2024
7c9c646
chore: minor fixes
peri044 Mar 15, 2024
4eabeb0
Merge branch 'main' into trt_10
peri044 Mar 15, 2024
a320e56
Merge branch 'trt_10' into fp8_trt10
peri044 Mar 15, 2024
3ece71b
chore: resolve merge conflicts
peri044 Mar 15, 2024
eab0dba
chore: Fix save failures
peri044 Mar 18, 2024
b191d62
chore: update to 2.3 rc build
peri044 Mar 18, 2024
ce606fe
chore: rebase with release/2.3 branch
peri044 Mar 19, 2024
8674a3c
chore: minor fixes
peri044 Mar 19, 2024
f4e8fe9
chore: remove duplicate bert test case
peri044 Mar 20, 2024
4ae6ab9
chore: remove comments
peri044 Mar 20, 2024
fff1b80
chore: Upgrade to TRT 10.0
peri044 Mar 12, 2024
39ca77d
chore: more fixes
peri044 Mar 21, 2024
5431ee3
chore: update trt version
peri044 Mar 25, 2024
0c03de5
chore: more updates
peri044 Mar 26, 2024
982dbd2
parent f39e89e3964bc3d6ea3a6989b1e4099e1bb3e6dd
peri044 Mar 25, 2024
1ae46e9
chore: more updates
peri044 Mar 27, 2024
ae87fba
chore: rebase with save
peri044 Mar 27, 2024
beb5920
chore: Update versions
peri044 Mar 27, 2024
f0068c6
chore: update tensorrt version in CI
peri044 Mar 27, 2024
39261b9
chore: more updates
peri044 Mar 27, 2024
3753150
chore: more fixes
peri044 Apr 2, 2024
16a191c
Merge branch 'release/2.3' into trt_10
peri044 Apr 2, 2024
c355766
chore: remove NvUtils.h
peri044 Apr 2, 2024
2d237dc
chore: more updates
peri044 Apr 2, 2024
e4b4429
chore: change lib64 to lib in rhel BUILD file
peri044 Apr 2, 2024
fa4fb9c
chore: more updates
peri044 Apr 2, 2024
e11eb60
chore: fix TRT version
peri044 Apr 2, 2024
092feb2
chore: more updates
peri044 Apr 2, 2024
09ecf26
fix shape bug in bitwise ops
zewenli98 Apr 3, 2024
85e04c5
chore: update to rhel9
peri044 Apr 3, 2024
6a3664e
Merge branch 'trt_10' of github.com:pytorch/TensorRT into trt_10
peri044 Apr 3, 2024
41229d6
chore: change trt version
peri044 Apr 3, 2024
9d7a656
fix test bug and add more tests
zewenli98 Apr 3, 2024
5e911a9
chore: delete mirror of rules_pkg
peri044 Apr 3, 2024
dae0eb2
chore: fix conv test
peri044 Apr 3, 2024
2a32b13
Merge branch 'trt_10' of github.com:pytorch/TensorRT into trt_10
peri044 Apr 3, 2024
4676cd2
chore: fix trt version range
peri044 Apr 4, 2024
88efe8e
chore: fix trt rangfe
peri044 Apr 4, 2024
f9b40e6
chore: minor fix
peri044 Apr 4, 2024
b86aec2
chore: update rules_pkg
peri044 Apr 4, 2024
6630281
chore: minor fixes
peri044 Apr 4, 2024
fca55fe
chore: expt
peri044 Apr 4, 2024
1ca01e7
chore: update WORKSPACE tmpl
peri044 Apr 5, 2024
cdf5d07
chore: rebase with 2.3
peri044 Apr 5, 2024
6ffb85e
chore: fix
peri044 Apr 5, 2024
76af510
chore: remove cudnn dep
peri044 Apr 6, 2024
f9cf75a
chore: fix
peri044 Apr 6, 2024
33ba8b2
chore: updates
peri044 Apr 8, 2024
923377c
chore: update post-build script
peri044 Apr 9, 2024
89f04db
chore: remove trt dep
peri044 Apr 9, 2024
7620acc
chore: updates
peri044 Apr 9, 2024
62332fb
chore: set ld_library path in post script
peri044 Apr 9, 2024
96a8bf6
chore: updates
peri044 Apr 9, 2024
041f6a3
chore: updates
peri044 Apr 9, 2024
83e9a0b
chore: disable smoke test
peri044 Apr 9, 2024
e8529b0
chore: updates
peri044 Apr 9, 2024
1357112
chore: updates
peri044 Apr 9, 2024
608a6d2
chore: updates
peri044 Apr 10, 2024
1b34b32
chore: updates
peri044 Apr 10, 2024
89cb55a
chore: updates
peri044 Apr 10, 2024
4323e36
chore: updates
peri044 Apr 10, 2024
60b3e51
chore: update hw_compat
peri044 Apr 10, 2024
05627cd
chore: updates
peri044 Apr 12, 2024
d16585f
chore: update streams
peri044 Apr 12, 2024
16088e6
chore: updates
peri044 Apr 12, 2024
3d149ef
chore: updates
peri044 Apr 13, 2024
3addcae
chore: updates
peri044 Apr 13, 2024
b0e92d8
chore: update hw_compat.ts
peri044 Apr 15, 2024
d285d27
fix dynamic shape bugs for test_binary_ops_aten
zewenli98 Apr 15, 2024
d78a846
chore: revert layer_norm test
peri044 Apr 16, 2024
ba8a424
chore: rebase
peri044 Apr 16, 2024
097d887
Merge branch 'trt_10' of github.com:pytorch/TensorRT into trt_10
zewenli98 Apr 16, 2024
faaa0fa
chore: rebase with trt_10
peri044 Apr 17, 2024
68aab70
chore: updates
peri044 Apr 17, 2024
ffe7a52
chore: rebase with release/2.3
peri044 Apr 18, 2024
bac409a
chore: rebase
peri044 Apr 18, 2024
38642bb
chore: updates
peri044 Apr 18, 2024
c70c6dc
Merge branch 'trt_10' into fp8_trt10
peri044 Apr 18, 2024
ba286bd
chore: add fp8 test
peri044 Apr 18, 2024
d15dd72
chore: updates
peri044 Apr 19, 2024
dda88ee
Merge branch 'trt_10' into fp8_trt10
peri044 Apr 19, 2024
dee9aa0
chore: updates
peri044 Apr 19, 2024
fc6078b
Merge branch 'trt_10' into fp8_trt10
peri044 Apr 19, 2024
c05d675
chore: update stream in python runtime
peri044 Apr 19, 2024
2329657
chore: update hw_compat.ts
peri044 Apr 19, 2024
b8a8709
chore: updates
peri044 Apr 20, 2024
44778e1
chore: updates
peri044 Apr 20, 2024
0dbbcd7
chore: updates
peri044 Apr 20, 2024
55e4a1b
Merge branch 'trt_10' into fp8_trt10
peri044 Apr 22, 2024
89c3d76
chore: updates
peri044 Apr 22, 2024
bd70ef8
Merge branch 'trt_10' into fp8_trt10
peri044 Apr 22, 2024
3956749
chore: updates
peri044 Apr 23, 2024
0a2337b
chore: updates
peri044 Apr 23, 2024
358255d
chore: updates
peri044 Apr 23, 2024
dad5399
chore: rebase
peri044 Apr 23, 2024
e3e1d85
chore: rebase
peri044 May 7, 2024
7e717d6
chore: updates
peri044 May 14, 2024
c6d2f2a
chore: update to modelopt
peri044 May 14, 2024
ceec39d
chore: updates
peri044 May 14, 2024
a7e566b
chore: updates
peri044 May 14, 2024
707b10a
chore: updates
peri044 May 15, 2024
22066c5
chore: minor fix
peri044 May 15, 2024
6eed383
chore: fixes
peri044 May 15, 2024
ff231b5
chore: fixes
peri044 May 16, 2024
2f167c6
chore: updates
peri044 May 16, 2024
367eaf0
chore: updates
peri044 May 16, 2024
8cb6b91
chore: updates
peri044 May 16, 2024
4d38368
chore: updates
peri044 May 16, 2024
ee54da6
chore: updates
peri044 May 17, 2024
f4ccd62
chore: updates
peri044 May 17, 2024
681a6d1
chore: fixes
peri044 May 17, 2024
44071aa
chore: updates
peri044 May 17, 2024
3f6999d
chore: updates
peri044 May 17, 2024
5de9325
chore: updates
peri044 May 17, 2024
c677ef9
refactor vgg16 with fp8 and ptq example
zewenli98 May 21, 2024
f0b8d47
fix bugs
zewenli98 May 22, 2024
3ce9bed
chore: rebase
peri044 May 22, 2024
beb888d
chore: updates
peri044 May 23, 2024
e7989a0
chore: address review comments
peri044 May 23, 2024
96fd462
chore: updates
peri044 May 23, 2024
4030344
chore: updates
peri044 May 24, 2024
ad9d825
chore: updates
peri044 May 24, 2024
0059c1c
Update build-test-windows.yml
narendasan May 24, 2024
f98abd6
Update build-test-linux.yml
narendasan May 24, 2024
0d2021d
chore: updates
peri044 May 24, 2024
1940267
chore: updates
peri044 May 27, 2024
5814402
chore: disable all lower_linear tests
peri044 May 27, 2024
338a92b
chore: updates
peri044 May 27, 2024
59d0bd0
chore: fixes
peri044 May 27, 2024
020fe63
chore: updates
peri044 May 27, 2024
3f8297e
chore: updates
peri044 May 27, 2024
5ce0ee1
chore: updates
peri044 May 27, 2024
65c5c3e
chore: updates
peri044 May 28, 2024
d99989d
chore: updates
peri044 May 28, 2024
99dfbdc
chore: updates
peri044 May 28, 2024
ad996a5
chore: updates
peri044 May 28, 2024
6ada351
chore: updates
peri044 May 28, 2024
88fd7ee
chore: fixes
peri044 May 28, 2024
2511095
chore: updates
peri044 May 28, 2024
5346a45
chore: updates
peri044 May 29, 2024
d284b8f
chore: updates
peri044 May 29, 2024
c71c017
chore: updates
peri044 May 29, 2024
a983064
chore: updates
peri044 May 29, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 9 additions & 1 deletion .github/workflows/build-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ jobs:
package-name: torch_tensorrt
pre-script: packaging/pre_build_script.sh
post-script: packaging/post_build_script.sh
smoke-test-script: packaging/smoke_test_script.sh
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3
with:
job-name: tests-py-torchscript-fe
Expand Down Expand Up @@ -103,6 +104,7 @@ jobs:
package-name: torch_tensorrt
pre-script: packaging/pre_build_script.sh
post-script: packaging/post_build_script.sh
smoke-test-script: packaging/smoke_test_script.sh
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3
with:
job-name: tests-py-dynamo-converters
Expand Down Expand Up @@ -131,6 +133,7 @@ jobs:
package-name: torch_tensorrt
pre-script: packaging/pre_build_script.sh
post-script: packaging/post_build_script.sh
smoke-test-script: packaging/smoke_test_script.sh
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3
with:
job-name: tests-py-dynamo-fe
Expand Down Expand Up @@ -160,6 +163,7 @@ jobs:
package-name: torch_tensorrt
pre-script: packaging/pre_build_script.sh
post-script: packaging/post_build_script.sh
smoke-test-script: packaging/smoke_test_script.sh
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3
with:
job-name: tests-py-dynamo-serde
Expand Down Expand Up @@ -188,6 +192,7 @@ jobs:
package-name: torch_tensorrt
pre-script: packaging/pre_build_script.sh
post-script: packaging/post_build_script.sh
smoke-test-script: packaging/smoke_test_script.sh
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3
with:
job-name: tests-py-torch-compile-be
Expand Down Expand Up @@ -218,6 +223,7 @@ jobs:
package-name: torch_tensorrt
pre-script: packaging/pre_build_script.sh
post-script: packaging/post_build_script.sh
smoke-test-script: packaging/smoke_test_script.sh
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3
with:
job-name: tests-py-dynamo-core
Expand Down Expand Up @@ -247,7 +253,9 @@ jobs:
- repository: pytorch/tensorrt
package-name: torch_tensorrt
pre-script: packaging/pre_build_script.sh
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main
post-script: packaging/post_build_script.sh
smoke-test-script: packaging/smoke_test_script.sh
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3
with:
job-name: tests-py-core
repository: "pytorch/tensorrt"
Expand Down
1 change: 1 addition & 0 deletions docsrc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ Tutorials
tutorials/_rendered_examples/dynamo/torch_compile_transformers_example
tutorials/_rendered_examples/dynamo/torch_compile_advanced_usage
tutorials/_rendered_examples/dynamo/torch_compile_stable_diffusion
tutorials/_rendered_examples/dynamo/vgg16_fp8_ptq

Python API Documenation
------------------------
Expand Down
1 change: 1 addition & 0 deletions examples/dynamo/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ a number of ways you can leverage this backend to accelerate inference.
* :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile``
* :ref:`torch_compile_advanced_usage`: Advanced usage including making a custom backend to use directly with the ``torch.compile`` API
* :ref:`torch_compile_stable_diffusion`: Compiling a Stable Diffusion model using ``torch.compile``
* :ref:`vgg16_fp8_ptq`: Compiling a VGG16 model with FP8 and PTQ using ``torch.compile``
251 changes: 251 additions & 0 deletions examples/dynamo/vgg16_fp8_ptq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
"""
.. _vgg16_fp8_ptq:

Torch Compile VGG16 with FP8 and PTQ
======================================================

This script is intended as a sample of the Torch-TensorRT workflow with `torch.compile` on a VGG16 model with FP8 and PTQ.
"""

# %%
# Imports and Model Definition
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

import argparse

import modelopt.torch.quantization as mtq
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_tensorrt as torchtrt
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from modelopt.torch.quantization.utils import export_torch_mode


class VGG(nn.Module):
def __init__(self, layer_spec, num_classes=1000, init_weights=False):
super(VGG, self).__init__()

layers = []
in_channels = 3
for l in layer_spec:
if l == "pool":
layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
else:
layers += [
nn.Conv2d(in_channels, l, kernel_size=3, padding=1),
nn.BatchNorm2d(l),
nn.ReLU(),
]
in_channels = l

self.features = nn.Sequential(*layers)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.classifier = nn.Sequential(
nn.Linear(512 * 1 * 1, 4096),
nn.ReLU(),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(),
nn.Dropout(),
nn.Linear(4096, num_classes),
)
if init_weights:
self._initialize_weights()

def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)

def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x


def vgg16(num_classes=1000, init_weights=False):
vgg16_cfg = [
64,
64,
"pool",
128,
128,
"pool",
256,
256,
256,
"pool",
512,
512,
512,
"pool",
512,
512,
512,
"pool",
]
return VGG(vgg16_cfg, num_classes, init_weights)


PARSER = argparse.ArgumentParser(
description="Load pre-trained VGG model and then tune with FP8 and PTQ"
)
PARSER.add_argument(
"--ckpt", type=str, required=True, help="Path to the pre-trained checkpoint"
)
PARSER.add_argument(
"--batch-size",
default=128,
type=int,
help="Batch size for tuning the model with PTQ and FP8",
)

args = PARSER.parse_args()

model = vgg16(num_classes=10, init_weights=False)
model = model.cuda()

# %%
# Load the pre-trained model weights
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

ckpt = torch.load(args.ckpt)
weights = ckpt["model_state_dict"]

if torch.cuda.device_count() > 1:
from collections import OrderedDict

new_state_dict = OrderedDict()
for k, v in weights.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
weights = new_state_dict

model.load_state_dict(weights)
# Don't forget to set the model to evaluation mode!
model.eval()

# %%
# Load training dataset and define loss function for PTQ
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

training_dataset = datasets.CIFAR10(
root="./data",
train=True,
download=True,
transform=transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
),
)
training_dataloader = torch.utils.data.DataLoader(
training_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2
)

data = iter(training_dataloader)
images, _ = next(data)

crit = nn.CrossEntropyLoss()

# %%
# Define Calibration Loop for quantization
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^


def calibrate_loop(model):
# calibrate over the training dataset
total = 0
correct = 0
loss = 0.0
for data, labels in training_dataloader:
data, labels = data.cuda(), labels.cuda(non_blocking=True)
out = model(data)
loss += crit(out, labels)
preds = torch.max(out, 1)[1]
total += labels.size(0)
correct += (preds == labels).sum().item()

print("PTQ Loss: {:.5f} Acc: {:.2f}%".format(loss / total, 100 * correct / total))


# %%
# Tune the pre-trained model with FP8 and PTQ
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

quant_cfg = mtq.FP8_DEFAULT_CFG
# PTQ with in-place replacement to quantized modules
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
# model has FP8 qdq nodes at this point

# %%
# Inference
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Load the testing dataset
testing_dataset = datasets.CIFAR10(
root="./data",
train=False,
download=True,
transform=transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
),
)

testing_dataloader = torch.utils.data.DataLoader(
testing_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2
)

with torch.no_grad():
with export_torch_mode():
# Compile the model with Torch-TensorRT Dynamo backend
input_tensor = images.cuda()
exp_program = torch.export.export(model, (input_tensor,))
trt_model = torchtrt.dynamo.compile(
exp_program,
inputs=[input_tensor],
enabled_precisions={torch.float8_e4m3fn},
min_block_size=1,
debug=False,
)

# Inference compiled Torch-TensorRT model over the testing dataset
total = 0
correct = 0
loss = 0.0
class_probs = []
class_preds = []
model.eval()
for data, labels in testing_dataloader:
data, labels = data.cuda(), labels.cuda(non_blocking=True)
out = model(data)
loss += crit(out, labels)
preds = torch.max(out, 1)[1]
class_probs.append([F.softmax(i, dim=0) for i in out])
class_preds.append(preds)
total += labels.size(0)
correct += (preds == labels).sum().item()

test_probs = torch.cat([torch.stack(batch) for batch in class_probs])
test_preds = torch.cat(class_preds)
test_loss = loss / total
test_acc = correct / total
print("Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc))
2 changes: 2 additions & 0 deletions examples/int8/training/vgg16/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@ nvidia-pyindex
--extra-index-url https://pypi.nvidia.com
pytorch-quantization
tqdm
nvidia-modelopt
--extra-index-url https://pypi.nvidia.com
3 changes: 0 additions & 3 deletions py/torch_tensorrt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from torch_tensorrt._version import ( # noqa: F401
__cuda_version__,
__cudnn_version__,
__tensorrt_version__,
__version__,
)
Expand Down Expand Up @@ -40,11 +39,9 @@ def _find_lib(name: str, paths: List[str]) -> str:
import tensorrt # noqa: F401
except ImportError:
cuda_version = _parse_semver(__cuda_version__)
cudnn_version = _parse_semver(__cudnn_version__)
tensorrt_version = _parse_semver(__tensorrt_version__)

CUDA_MAJOR = cuda_version["major"]
CUDNN_MAJOR = cudnn_version["major"]
TENSORRT_MAJOR = tensorrt_version["major"]

if sys.platform.startswith("win"):
Expand Down