From ea1053fb09efdffa2f2283ea526cda4051d21c99 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Fri, 17 May 2024 16:53:50 -0700 Subject: [PATCH] add fp8 vgg16 example --- examples/int8/training/vgg16/main.py | 50 +++++++++++++++++++ examples/int8/training/vgg16/requirements.txt | 2 + 2 files changed, 52 insertions(+) diff --git a/examples/int8/training/vgg16/main.py b/examples/int8/training/vgg16/main.py index 3db8e9d4dd..3d36d6cb6f 100644 --- a/examples/int8/training/vgg16/main.py +++ b/examples/int8/training/vgg16/main.py @@ -25,6 +25,12 @@ PARSER.add_argument("--drop-ratio", default=0.0, type=float, help="Dropout ratio") PARSER.add_argument("--momentum", default=0.9, type=float, help="Momentum") PARSER.add_argument("--weight-decay", default=5e-4, type=float, help="Weight decay") +PARSER.add_argument( + "--fp8-epochs", + default=0, + type=int, + help="Enable FP8 and specify the number of epochs after the regular training to quantize the model to FP8", +) PARSER.add_argument( "--ckpt-dir", default="/tmp/vgg16_ckpts", @@ -167,6 +173,50 @@ def main(): ckpt_dir=args.ckpt_dir, ) + if args.fp8_epochs > 0: + print("[PTQ] Quantizing model to FP8...") + import modelopt.torch.quantization as mtq + import torch_tensorrt as torchtrt + from modelopt.torch.quantization.utils import export_torch_mode + + def calibrate_loop(model): + # calibrate on a small number of batches + for fp8_ep in range(args.fp8_epochs): + print("Epoch: [%5d / %5d]" % (fp8_ep + 1, args.fp8_epochs)) + total = 0 + correct = 0 + loss = 0.0 + for data, labels in training_dataloader: + data, labels = data.cuda(), labels.cuda(non_blocking=True) + out = model(data) + loss += crit(out, labels) + preds = torch.max(out, 1)[1] + total += labels.size(0) + correct += (preds == labels).sum().item() + + print( + "Test Loss: {:.5f} Test Acc: {:.2f}%".format( + loss / total, 100 * correct / total + ) + ) + + quant_cfg = mtq.FP8_DEFAULT_CFG + mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) + # model has FP8 qdq nodes at this point + with torch.no_grad(): + with export_torch_mode(): + input_tensor = images.cuda() + exp_program = torch.export.export(model, (input_tensor,)) + trt_model = torchtrt.dynamo.compile( + exp_program, + inputs=[input_tensor], + enabled_precisions={torch.float8_e4m3fn}, + min_block_size=1, + debug=False, + ) + outputs_trt = trt_model(input_tensor) + print("TRT outputs:\n", outputs_trt) + def train(model, dataloader, crit, opt, epoch): model.train() diff --git a/examples/int8/training/vgg16/requirements.txt b/examples/int8/training/vgg16/requirements.txt index d02af2c616..3b0b03f5d7 100644 --- a/examples/int8/training/vgg16/requirements.txt +++ b/examples/int8/training/vgg16/requirements.txt @@ -4,3 +4,5 @@ nvidia-pyindex --extra-index-url https://pypi.nvidia.com pytorch-quantization tqdm +nvidia-modelopt +--extra-index-url https://pypi.nvidia.com