Skip to content

Commit

Permalink
add fp8 vgg16 example
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 committed May 17, 2024
1 parent 5de9325 commit ea1053f
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 0 deletions.
50 changes: 50 additions & 0 deletions examples/int8/training/vgg16/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@
PARSER.add_argument("--drop-ratio", default=0.0, type=float, help="Dropout ratio")
PARSER.add_argument("--momentum", default=0.9, type=float, help="Momentum")
PARSER.add_argument("--weight-decay", default=5e-4, type=float, help="Weight decay")
PARSER.add_argument(
"--fp8-epochs",
default=0,
type=int,
help="Enable FP8 and specify the number of epochs after the regular training to quantize the model to FP8",
)
PARSER.add_argument(
"--ckpt-dir",
default="/tmp/vgg16_ckpts",
Expand Down Expand Up @@ -167,6 +173,50 @@ def main():
ckpt_dir=args.ckpt_dir,
)

if args.fp8_epochs > 0:
print("[PTQ] Quantizing model to FP8...")
import modelopt.torch.quantization as mtq
import torch_tensorrt as torchtrt
from modelopt.torch.quantization.utils import export_torch_mode

def calibrate_loop(model):
# calibrate on a small number of batches
for fp8_ep in range(args.fp8_epochs):
print("Epoch: [%5d / %5d]" % (fp8_ep + 1, args.fp8_epochs))
total = 0
correct = 0
loss = 0.0
for data, labels in training_dataloader:
data, labels = data.cuda(), labels.cuda(non_blocking=True)
out = model(data)
loss += crit(out, labels)
preds = torch.max(out, 1)[1]
total += labels.size(0)
correct += (preds == labels).sum().item()

print(
"Test Loss: {:.5f} Test Acc: {:.2f}%".format(
loss / total, 100 * correct / total
)
)

quant_cfg = mtq.FP8_DEFAULT_CFG
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
# model has FP8 qdq nodes at this point
with torch.no_grad():
with export_torch_mode():
input_tensor = images.cuda()
exp_program = torch.export.export(model, (input_tensor,))
trt_model = torchtrt.dynamo.compile(
exp_program,
inputs=[input_tensor],
enabled_precisions={torch.float8_e4m3fn},
min_block_size=1,
debug=False,
)
outputs_trt = trt_model(input_tensor)
print("TRT outputs:\n", outputs_trt)


def train(model, dataloader, crit, opt, epoch):
model.train()
Expand Down
2 changes: 2 additions & 0 deletions examples/int8/training/vgg16/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@ nvidia-pyindex
--extra-index-url https://pypi.nvidia.com
pytorch-quantization
tqdm
nvidia-modelopt
--extra-index-url https://pypi.nvidia.com

0 comments on commit ea1053f

Please sign in to comment.