diff --git a/examples/int8/training/vgg16/README.md b/examples/int8/training/vgg16/README.md index 5aff4ca116..71d3539de4 100644 --- a/examples/int8/training/vgg16/README.md +++ b/examples/int8/training/vgg16/README.md @@ -53,10 +53,12 @@ Use the exporter script to create a torchscript module you can compile with Torc ### For PTQ ``` -python3 export_ckpt.py +python3 export.py --ckpt --ir torchscript --output vgg.ts ``` -The checkpoint file should be from the original training and not quatization aware fine tuning. THe script should produce a file called `trained_vgg16.jit.pt` +* `--ckpt` : The checkpoint file should be from the original training and not quatization aware fine tuning. +* `--ir` : Options include `torchscript` or `exported_program`. The saved module type is determined by this `ir` flag. +* `--output` : Output file name ### For QAT To export a QAT model, you can run diff --git a/examples/int8/training/vgg16/export.py b/examples/int8/training/vgg16/export.py new file mode 100644 index 0000000000..de1a7b4b1f --- /dev/null +++ b/examples/int8/training/vgg16/export.py @@ -0,0 +1,119 @@ +import argparse +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.data as data +import torchvision.datasets as datasets +import torchvision.models as models +import torchvision.transforms as transforms +from pytorch_quantization import nn as quant_nn +from pytorch_quantization import quant_modules +from vgg16 import vgg16 + + +def test(model, dataloader, crit): + """ + Run the model on a dataset and measure accuracy/loss + """ + total = 0 + correct = 0 + loss = 0.0 + class_probs = [] + class_preds = [] + + with torch.no_grad(): + for data, labels in dataloader: + data, labels = data.cuda(), labels.cuda(non_blocking=True) + out = model(data) + loss += crit(out, labels) + preds = torch.max(out, 1)[1] + class_probs.append([F.softmax(i, dim=0) for i in out]) + class_preds.append(preds) + total += labels.size(0) + correct += (preds == labels).sum().item() + + return loss / total, correct / total + + +def evaluate(model): + """ + Evaluate pre-trained model on CIFAR 10 dataset + """ + testing_dataset = datasets.CIFAR10( + root="./data", + train=False, + download=True, + transform=transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize( + (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) + ), + ] + ), + ) + + testing_dataloader = torch.utils.data.DataLoader( + testing_dataset, batch_size=32, shuffle=False, num_workers=2 + ) + + crit = torch.nn.CrossEntropyLoss() + + test_loss, test_acc = test(model, testing_dataloader, crit) + print("Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc)) + + +def export_model(args): + """ + Evaluate and export the model to Torchscript or exported program + """ + # Define the VGG model + # model = vgg16(num_classes=10, init_weights=False) + model = models.vgg16(weights=None).eval().cuda() + # Load the checkpoint + ckpt = torch.load(args.ckpt) + weights = ckpt["model_state_dict"] + model.load_state_dict(weights) + # Setting eval here causes both JIT and TRT accuracy to tank in LibTorch will follow up with PyTorch Team + # model.eval() + random_inputs = [torch.rand([32, 3, 32, 32]).to("cuda")] + if args.ir == "torchscript": + jit_model = torch.jit.trace(model, random_inputs) + jit_model.eval() + # Evaluating JIT model + evaluate(jit_model) + torch.jit.save(jit_model, args.output) + elif args.ir == "exported_program": + dim_x = torch.export.Dim("dim_x", min=1, max=32) + exp_program = torch.export.export( + model, tuple(random_inputs), dynamic_shapes={"x": {0: dim_x}} + ) + evaluate(exp_program) + torch.export.save(exp_program, args.output) + else: + raise ValueError( + f"Invalid IR {args.ir} provided to export the VGG model. Select among torchscript | exported_program" + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Export trained VGG") + parser.add_argument("--ckpt", type=str, help="Path to saved checkpoint") + parser.add_argument( + "--ir", + type=str, + default="torchscript", + help="IR to determine the output type of exported graph", + ) + parser.add_argument( + "--output", type=str, default="vgg.ts", help="Path to saved checkpoint" + ) + parser.add_argument( + "--qat", + action="store_true", + help="Perform QAT using pytorch-quantization toolkit", + ) + args = parser.parse_args() + export_model(args) diff --git a/examples/int8/training/vgg16/export_ckpt.py b/examples/int8/training/vgg16/export_ckpt.py deleted file mode 100644 index 16f0426811..0000000000 --- a/examples/int8/training/vgg16/export_ckpt.py +++ /dev/null @@ -1,86 +0,0 @@ -import argparse -import os - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.data as data -import torchvision.transforms as transforms -import torchvision.datasets as datasets - -from vgg16 import vgg16 - - -def test(model, dataloader, crit): - global writer - global classes - total = 0 - correct = 0 - loss = 0.0 - class_probs = [] - class_preds = [] - - with torch.no_grad(): - for data, labels in dataloader: - data, labels = data.cuda(), labels.cuda(non_blocking=True) - out = model(data) - loss += crit(out, labels) - preds = torch.max(out, 1)[1] - class_probs.append([F.softmax(i, dim=0) for i in out]) - class_preds.append(preds) - total += labels.size(0) - correct += (preds == labels).sum().item() - - test_probs = torch.cat([torch.stack(batch) for batch in class_probs]) - test_preds = torch.cat(class_preds) - return loss / total, correct / total - - -PARSER = argparse.ArgumentParser(description="Export trained VGG") -PARSER.add_argument("ckpt", type=str, help="Path to saved checkpoint") - -args = PARSER.parse_args() -model = vgg16(num_classes=10, init_weights=False) -model = model.cuda() - -ckpt = torch.load(args.ckpt) -weights = ckpt["model_state_dict"] - -if torch.cuda.device_count() > 1: - from collections import OrderedDict - - new_state_dict = OrderedDict() - for k, v in weights.items(): - name = k[7:] # remove `module.` - new_state_dict[name] = v - weights = new_state_dict - -model.load_state_dict(weights) - -# Setting eval here causes both JIT and TRT accuracy to tank in LibTorch will follow up with PyTorch Team -# model.eval() - -jit_model = torch.jit.trace(model, torch.rand([32, 3, 32, 32]).to("cuda")) -jit_model.eval() - -testing_dataset = datasets.CIFAR10( - root="./data", - train=False, - download=True, - transform=transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - ] - ), -) - -testing_dataloader = torch.utils.data.DataLoader( - testing_dataset, batch_size=32, shuffle=False, num_workers=2 -) - -crit = torch.nn.CrossEntropyLoss() - -test_loss, test_acc = test(jit_model, testing_dataloader, crit) -print("[JIT] Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc)) -torch.jit.save(jit_model, "trained_vgg16.jit.pt") diff --git a/examples/int8/training/vgg16/finetune_qat.py b/examples/int8/training/vgg16/finetune_qat.py index 0414af00de..aa78a4c067 100644 --- a/examples/int8/training/vgg16/finetune_qat.py +++ b/examples/int8/training/vgg16/finetune_qat.py @@ -8,17 +8,15 @@ import torch.nn.functional as F import torch.optim as optim import torch.utils.data as data -import torchvision.transforms as transforms import torchvision.datasets as datasets - -from torch.utils.tensorboard import SummaryWriter - +import torchvision.models as models +import torchvision.transforms as transforms +from pytorch_quantization import calib from pytorch_quantization import nn as quant_nn from pytorch_quantization import quant_modules from pytorch_quantization.tensor_quant import QuantDescriptor -from pytorch_quantization import calib +from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm - from vgg16 import vgg16 PARSER = argparse.ArgumentParser( @@ -231,7 +229,7 @@ def main(): quant_modules.initialize() - model = vgg16(num_classes=num_classes, init_weights=False) + model = vgg16(num_classes=num_classes, init_weights=False).cuda() model = model.cuda() crit = nn.CrossEntropyLoss() diff --git a/examples/int8/training/vgg16/main.py b/examples/int8/training/vgg16/main.py index 3f248a9283..2378119ad3 100644 --- a/examples/int8/training/vgg16/main.py +++ b/examples/int8/training/vgg16/main.py @@ -8,11 +8,10 @@ import torch.nn.functional as F import torch.optim as optim import torch.utils.data as data -import torchvision.transforms as transforms import torchvision.datasets as datasets - +import torchvision.models as models +import torchvision.transforms as transforms from torch.utils.tensorboard import SummaryWriter - from vgg16 import vgg16 PARSER = argparse.ArgumentParser( @@ -125,8 +124,7 @@ def main(): num_classes = len(classes) - model = vgg16(num_classes=num_classes, init_weights=False) - model = model.cuda() + model = vgg16(num_classes=num_classes, init_weights=False).cuda() data = iter(training_dataloader) images, _ = next(data) @@ -233,7 +231,7 @@ def test(model, dataloader, crit, epoch): test_preds = torch.cat(class_preds) for i in range(len(classes)): add_pr_curve_tensorboard(i, test_probs, test_preds, epoch) - # print(loss, total, correct, total) + return loss / total, correct / total diff --git a/examples/int8/training/vgg16/test_qat.py b/examples/int8/training/vgg16/test_qat.py deleted file mode 100644 index d38d36f3fc..0000000000 --- a/examples/int8/training/vgg16/test_qat.py +++ /dev/null @@ -1,103 +0,0 @@ -import argparse -import os - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.data as data -import torchvision.transforms as transforms -import torchvision.datasets as datasets - -from vgg16 import vgg16 - -from pytorch_quantization import quant_modules -from pytorch_quantization import nn as quant_nn - - -def test(model, dataloader, crit): - global writer - global classes - total = 0 - correct = 0 - loss = 0.0 - class_probs = [] - class_preds = [] - - with torch.no_grad(): - for data, labels in dataloader: - data, labels = data.cuda(), labels.cuda(non_blocking=True) - out = model(data) - loss += crit(out, labels) - preds = torch.max(out, 1)[1] - class_probs.append([F.softmax(i, dim=0) for i in out]) - class_preds.append(preds) - total += labels.size(0) - correct += (preds == labels).sum().item() - - test_probs = torch.cat([torch.stack(batch) for batch in class_probs]) - test_preds = torch.cat(class_preds) - return loss / total, correct / total - - -PARSER = argparse.ArgumentParser(description="Export trained VGG") -PARSER.add_argument("ckpt", type=str, help="Path to saved checkpoint") -PARSER.add_argument( - "--enable_qat", - action="store_true", - help="Enable quantization aware training. This is recommended to perform on a pre-trained model.", -) - -args = PARSER.parse_args() - -quant_modules.initialize() -model = vgg16(num_classes=10, init_weights=False) -model = model.cuda() - -ckpt = torch.load(args.ckpt) -weights = ckpt["model_state_dict"] - -model.load_state_dict(weights) -model.eval() - -testing_dataset = datasets.CIFAR10( - root="./data", - train=False, - download=True, - transform=transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - ] - ), -) - -testing_dataloader = torch.utils.data.DataLoader( - testing_dataset, batch_size=32, shuffle=False, num_workers=2 -) - -crit = torch.nn.CrossEntropyLoss() - -# -quant_nn.TensorQuantizer.use_fb_fake_quant = True -with torch.no_grad(): - data = iter(testing_dataloader) - images, _ = data.next() - jit_model = torch.jit.trace(model, images.to("cuda")) - torch.jit.save(jit_model, "trained_vgg16_qat.jit.pt") - -test_loss, test_acc = test(jit_model, testing_dataloader, crit) -print("[JIT] Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc)) - -import torch_tensorrt as torchtrt - -compile_settings = { - "inputs": [torchtrt.Input([1, 3, 32, 32])], - "enabled_precisions": {torch.float, torch.half, torch.int8}, # Run with FP16 -} -new_mod = torch.jit.load("trained_vgg16_qat.jit.pt") -trt_ts_module = torchtrt.compile(new_mod, **compile_settings) -testing_dataloader = torch.utils.data.DataLoader( - testing_dataset, batch_size=1, shuffle=False, num_workers=2 -) -test_loss, test_acc = test(trt_ts_module, testing_dataloader, crit) -print("[TRTorch] Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc)) diff --git a/py/torch_tensorrt/__init__.py b/py/torch_tensorrt/__init__.py index c015bd89db..422cc02321 100644 --- a/py/torch_tensorrt/__init__.py +++ b/py/torch_tensorrt/__init__.py @@ -85,10 +85,10 @@ def _find_lib(name: str, paths: List[str]) -> str: from torch_tensorrt._Device import Device # noqa: F401 from torch_tensorrt._enums import * # noqa: F403 from torch_tensorrt._Input import Input # noqa: F401 -from torch_tensorrt.logging import * -from torch_tensorrt.ptq import * from torch_tensorrt._utils import * # noqa: F403 -from torch_tensorrt._utils import sanitized_torch_version +from torch_tensorrt._utils import sanitized_torch_version # noqa: F401 +from torch_tensorrt.logging import * # noqa: F403 +from torch_tensorrt.ptq import * # noqa: F403 if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"): from torch_tensorrt import dynamo # noqa: F401 diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index d31be8a413..54dca9581a 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -36,6 +36,7 @@ ) from torch_tensorrt.dynamo.lowering import apply_lowering_passes from torch_tensorrt.dynamo.utils import ( + build_calibrator, get_torch_inputs, prepare_inputs, set_log_level, @@ -62,7 +63,7 @@ def compile( dla_sram_size: int = 1048576, dla_local_dram_size: int = 1073741824, dla_global_dram_size: int = 536870912, - calibrator: object = None, + calibrator: Any = None, truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE, require_full_compilation: bool = REQUIRE_FULL_COMPILATION, min_block_size: int = MIN_BLOCK_SIZE, @@ -155,8 +156,12 @@ def compile( logger.debug("Lowered Input graph: " + str(gm.graph)) enabled_precisions = set(enabled_precisions) - if ( + torch.int8 in enabled_precisions + or torch_tensorrt.dtype.int8 in enabled_precisions + ): + precision = torch.int8 + elif ( torch.float16 in enabled_precisions or torch_tensorrt.dtype.half in enabled_precisions ): @@ -192,6 +197,7 @@ def compile( "use_fast_partitioner": use_fast_partitioner, "enable_experimental_decompositions": enable_experimental_decompositions, "require_full_compilation": require_full_compilation, + "calibrator": build_calibrator(calibrator), } settings = CompilationSettings(**compilation_options) diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 103b5f7792..01fe04dbdd 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -15,6 +15,7 @@ USE_FAST_PARTITIONER = True ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False REQUIRE_FULL_COMPILATION = False +CALIBRATOR = None def default_device() -> Device: diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index c9f4534cb8..8252c5bbee 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -1,9 +1,10 @@ from dataclasses import dataclass, field -from typing import Optional, Set +from typing import Any, Optional, Set import torch from torch_tensorrt._Device import Device from torch_tensorrt.dynamo._defaults import ( + CALIBRATOR, DEBUG, ENABLE_EXPERIMENTAL_DECOMPOSITIONS, MAX_AUX_STREAMS, @@ -46,6 +47,7 @@ class CompilationSettings: device (Device): GPU to compile the model on require_full_compilation (bool): Whether to require the graph is fully compiled in TensorRT. Only applicable for `ir="dynamo"`; has no effect for `torch.compile` path + calibrator (ptq.DataLoaderCalibrator | ptq.CacheCalibrator): Calibrator used for INT8 calibration of the model """ precision: torch.dtype = PRECISION @@ -63,3 +65,4 @@ class CompilationSettings: enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS device: Device = field(default_factory=default_device) require_full_compilation: bool = REQUIRE_FULL_COMPILATION + calibrator: Optional[Any] = CALIBRATOR diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 1fa2806181..d0a9d03d5f 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -92,6 +92,7 @@ def _pretraced_backend( torchtrt_inputs = prepare_inputs( sample_inputs, disable_memory_format_check=True ) + trt_compiled = compile_module( gm, torchtrt_inputs, diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 0f1c3b0c42..c597f0bd2d 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -131,6 +131,7 @@ def run( max_aux_streams: Optional[int] = None, version_compatible: bool = False, optimization_level: Optional[int] = None, + calibrator: Optional[Any] = None, ) -> TRTInterpreterResult: """ Build TensorRT engine with some configs. @@ -171,6 +172,8 @@ def run( build_engine_start_time = datetime.now() builder_config = self.builder.create_builder_config() + if calibrator: + builder_config.int8_calibrator = calibrator if workspace_size != 0: builder_config.set_memory_pool_limit( diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index 1cdea63680..8cff29bc07 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -65,6 +65,7 @@ def convert_module( max_aux_streams=settings.max_aux_streams, version_compatible=settings.version_compatible, optimization_level=settings.optimization_level, + calibrator=settings.calibrator, ) if settings.use_python_runtime: diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 26de1fcb27..84d97c0553 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -4,12 +4,14 @@ from dataclasses import fields, replace from typing import Any, Callable, Dict, Optional, Sequence, Union +import tensorrt as trt import torch import torch_tensorrt from torch_tensorrt._Device import Device from torch_tensorrt._Input import Input from torch_tensorrt.dynamo._defaults import PRECISION from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.ptq import * # noqa: F403 from packaging import version @@ -220,8 +222,12 @@ def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings: # TODO: Remove once Dynamo precisions refactoring is complete if "enabled_precisions" in kwargs: enabled_precisions = kwargs["enabled_precisions"] - if ( + torch.int8 in enabled_precisions + or torch_tensorrt.dtype.int8 in enabled_precisions + ): + settings.precision = torch.int8 + elif ( torch.float16 in enabled_precisions or torch_tensorrt.dtype.half in enabled_precisions ): @@ -252,6 +258,9 @@ def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings: "If this is incorrect, please specify an input device, via the device keyword." ) + if "calibrator" in kwargs: + settings.calibrator = build_calibrator(kwargs["calibrator"]) + # Ignore and warn about require_full_compilation flag if settings.require_full_compilation: logger.warning( @@ -265,6 +274,91 @@ def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings: return settings +def build_calibrator(calibrator: Union[DataLoaderCalibrator | CacheCalibrator]) -> Any: + if not calibrator: + return None + if not isinstance(calibrator, DataLoaderCalibrator) or isinstance( + calibrator, CacheCalibrator + ): + raise AssertionError( + f"Invalid calibrator of type {type(calibrator)} provided. Only calibrator of type DataLoaderCalibrator or CacheCalibrator is supported" + ) + algo_type = calibrator.algo_type + cache_file = calibrator.cache_file + attribute_mapping = {} + if isinstance(calibrator, DataLoaderCalibrator): + dataloader = calibrator.dataloader + use_cache = calibrator.use_cache + device = calibrator.device + if not isinstance(dataloader, torch.utils.data.DataLoader): + logger.error( + f"Dataloader type: {type(dataloader)} is not a valid instance of torch.utils.data.DataLoader" + ) + + if not cache_file: + if use_cache: + logger.info(f"Using existing cache_file {cache_file} for calibration") + else: + logger.info("Overwriting existing calibration cache file.") + else: + if use_cache: + logger.error( + "Input cache file is None but use_cache is set to True in INT8 mode." + ) + + # Define attributes and member functions for the calibrator class + attribute_mapping = { + "data_loader": dataloader, + "current_batch_idx": 0, + "batch_size": dataloader.batch_size, + "cache_file": cache_file, + "dataset_iterator": iter(dataloader), + "device": device, + "use_cache": use_cache, + "get_batch_size": get_batch_size, + "get_batch": get_cache_mode_batch if use_cache else get_batch, + "read_calibration_cache": read_calibration_cache, + "write_calibration_cache": write_calibration_cache, + } + elif isinstance(calibrator, CacheCalibrator): + attribute_mapping = { + "use_cache": True, + "cache_file": cache_file, + "get_batch_size": get_batch_size, + "get_batch": get_cache_mode_batch, + "read_calibration_cache": read_calibration_cache, + "write_calibration_cache": write_calibration_cache, + } + + # Using type metaclass to construct calibrator class based on algorithm type + if algo_type == CalibrationAlgo.ENTROPY_CALIBRATION: + calib_ec = type( + "Int8EntropyCalibrator", (trt.IInt8EntropyCalibrator,), attribute_mapping + )() + return calib_ec + elif algo_type == CalibrationAlgo.ENTROPY_CALIBRATION_2: + calib_ec2 = type( + "Int8EntropyCalibrator2", + (trt.IInt8EntropyCalibrator2,), + attribute_mapping, + )() + return calib_ec2 + elif algo_type == CalibrationAlgo.LEGACY_CALIBRATION: + calib_lc = type( + "Int8LegacyCalibrator", (trt.IInt8LegacyCalibrator,), attribute_mapping + )() + return calib_lc + elif algo_type == CalibrationAlgo.MINMAX_CALIBRATION: + calib_mmc = type( + "Int8MinMaxCalibrator", (trt.IInt8MinMaxCalibrator,), attribute_mapping + )() + return calib_mmc + else: + raise ValueError( + "Invalid calibration algorithm type. Please select among ENTROPY_CALIBRATION, ENTROPY_CALIBRATION, LEGACY_CALIBRATION or MINMAX_CALIBRATION" + ) + + def req_torch_version(min_torch_version: str = "2.dev") -> Callable[..., Any]: """ Create a decorator which verifies the Torch version installed diff --git a/py/torch_tensorrt/ptq.py b/py/torch_tensorrt/ptq.py index 5d13ab9108..6da15a40b4 100644 --- a/py/torch_tensorrt/ptq.py +++ b/py/torch_tensorrt/ptq.py @@ -1,24 +1,21 @@ -import sys -from typing import Any, List, Optional - -if sys.version_info >= (3, 11): - from typing import Self -else: - from typing_extensions import Self - import os +from dataclasses import dataclass from enum import Enum +from typing import Any, List, Optional +import tensorrt as trt import torch -from torch_tensorrt import _C -from torch_tensorrt.logging import Level, log - -class CalibrationAlgo(Enum): - ENTROPY_CALIBRATION = _C.CalibrationAlgo.ENTROPY_CALIBRATION - ENTROPY_CALIBRATION_2 = _C.CalibrationAlgo.ENTROPY_CALIBRATION_2 - LEGACY_CALIBRATION = _C.CalibrationAlgo.LEGACY_CALIBRATION - MINMAX_CALIBRATION = _C.CalibrationAlgo.MINMAX_CALIBRATION +__all__ = [ + "get_cache_mode_batch", + "get_batch_size", + "get_batch", + "read_calibration_cache", + "write_calibration_cache", + "CalibrationAlgo", + "DataLoaderCalibrator", + "CacheCalibrator", +] def get_cache_mode_batch(self: object) -> None: @@ -64,14 +61,15 @@ def write_calibration_cache(self: object, cache: bytes) -> None: return -# deepcopy (which involves pickling) is performed on the compile_spec internally during compilation. -# We register this __reduce__ function for pickler to identity the calibrator object returned by DataLoaderCalibrator during deepcopy. -# This should be the object's local name relative to the module https://docs.python.org/3/library/pickle.html#object.__reduce__ -def __reduce__(self: object) -> str: - return self.__class__.__name__ +class CalibrationAlgo(Enum): + ENTROPY_CALIBRATION = trt.CalibrationAlgoType.ENTROPY_CALIBRATION + ENTROPY_CALIBRATION_2 = trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2 + LEGACY_CALIBRATION = trt.CalibrationAlgoType.LEGACY_CALIBRATION + MINMAX_CALIBRATION = trt.CalibrationAlgoType.MINMAX_CALIBRATION -class DataLoaderCalibrator(object): +@dataclass +class DataLoaderCalibrator: """ Constructs a calibrator class in TensorRT and uses pytorch dataloader to load/preproces data which is passed during calibration. @@ -83,84 +81,14 @@ class DataLoaderCalibrator(object): device: device on which calibration data is copied to. """ - def __init__(self, **kwargs: Any): - pass - - def __new__(cls, *args: Any, **kwargs: Any) -> Self: - dataloader = args[0] - algo_type = kwargs.get("algo_type", CalibrationAlgo.ENTROPY_CALIBRATION_2) - cache_file = kwargs.get("cache_file", None) - use_cache = kwargs.get("use_cache", False) - device = kwargs.get("device", torch.device("cuda:0")) - - if not isinstance(dataloader, torch.utils.data.DataLoader): - log( - Level.Error, - "Dataloader : {} is not a valid instance of torch.utils.data.DataLoader".format( - dataloader - ), - ) - - if not cache_file: - if use_cache: - log( - Level.Debug, - "Using existing cache_file {} for calibration".format(cache_file), - ) - else: - log(Level.Debug, "Overwriting existing calibration cache file.") - else: - if use_cache: - log( - Level.Error, - "Input cache file is None but use_cache is set to True in INT8 mode.", - ) - - # Define attributes and member functions for the calibrator class - attribute_mapping = { - "data_loader": dataloader, - "current_batch_idx": 0, - "batch_size": dataloader.batch_size, - "dataset_iterator": iter(dataloader), - "cache_file": cache_file, - "device": device, - "use_cache": use_cache, - "get_batch_size": get_batch_size, - "get_batch": get_cache_mode_batch if use_cache else get_batch, - "read_calibration_cache": read_calibration_cache, - "write_calibration_cache": write_calibration_cache, - "__reduce__": __reduce__, # used when you deepcopy the DataLoaderCalibrator object - } - - # Using type metaclass to construct calibrator class based on algorithm type - if algo_type == CalibrationAlgo.ENTROPY_CALIBRATION: - calib_ec: Self = type( - "Int8EntropyCalibrator", (_C.IInt8EntropyCalibrator,), attribute_mapping - )() - return calib_ec - elif algo_type == CalibrationAlgo.ENTROPY_CALIBRATION_2: - calib_ec2: Self = type( - "Int8EntropyCalibrator2", - (_C.IInt8EntropyCalibrator2,), - attribute_mapping, - )() - return calib_ec2 - elif algo_type == CalibrationAlgo.LEGACY_CALIBRATION: - calib_lc: Self = type( - "Int8LegacyCalibrator", (_C.IInt8LegacyCalibrator,), attribute_mapping - )() - return calib_lc - elif algo_type == CalibrationAlgo.MINMAX_CALIBRATION: - calib_mmc: Self = type( - "Int8MinMaxCalibrator", (_C.IInt8MinMaxCalibrator,), attribute_mapping - )() - return calib_mmc - else: - raise ValueError( - "Invalid calibration algorithm type. Please select among ENTROPY_CALIBRATION, ENTROPY_CALIBRATION, LEGACY_CALIBRATION or MINMAX_CALIBRATION" - ) + dataloader: torch.utils.data.DataLoader + algo_type: CalibrationAlgo = CalibrationAlgo.ENTROPY_CALIBRATION_2 + cache_file: str = "" + use_cache: bool = False + device: torch.device = torch.device("cuda:0") +@dataclass class CacheCalibrator(object): """ Constructs a calibrator class in TensorRT which directly uses pre-existing cache file for calibration. @@ -169,52 +97,5 @@ class CacheCalibrator(object): algo_type: choice of calibration algorithm. """ - def __init__(self, **kwargs: Any): - pass - - def __new__(cls, *args: Any, **kwargs: Any) -> Self: - cache_file = args[0] - algo_type = kwargs.get("algo_type", CalibrationAlgo.ENTROPY_CALIBRATION_2) - - if os.path.isfile(cache_file): - log( - Level.Debug, - "Using existing cache_file {} for calibration".format(cache_file), - ) - else: - log(Level.Error, "Invalid calibration cache file.") - - # Define attributes and member functions for the calibrator class - attribute_mapping = { - "use_cache": True, - "cache_file": cache_file, - "get_batch_size": get_batch_size, - "get_batch": get_cache_mode_batch, - "read_calibration_cache": read_calibration_cache, - "write_calibration_cache": write_calibration_cache, - } - # Using type metaclass to construct calibrator class based on algorithm type - if algo_type == CalibrationAlgo.ENTROPY_CALIBRATION: - calib_ec: Self = type( - "DataLoaderCalibrator", (_C.IInt8EntropyCalibrator,), attribute_mapping - )() - return calib_ec - elif algo_type == CalibrationAlgo.ENTROPY_CALIBRATION_2: - calib_ec2: Self = type( - "DataLoaderCalibrator", (_C.IInt8MinMaxCalibrator,), attribute_mapping - )() - return calib_ec2 - elif algo_type == CalibrationAlgo.LEGACY_CALIBRATION: - calib_lc: Self = type( - "DataLoaderCalibrator", (_C.IInt8LegacyCalibrator,), attribute_mapping - )() - return calib_lc - elif algo_type == CalibrationAlgo.MINMAX_CALIBRATION: - calib_mmc: Self = type( - "DataLoaderCalibrator", (_C.IInt8MinMaxCalibrator,), attribute_mapping - )() - return calib_mmc - else: - raise ValueError( - "Invalid calibration algorithm type. Please select among ENTROPY_CALIBRATION, ENTROPY_CALIBRATION, LEGACY_CALIBRATION or MINMAX_CALIBRATION" - ) + algo_type: CalibrationAlgo = CalibrationAlgo.ENTROPY_CALIBRATION_2 + cache_file: str = "" diff --git a/py/torch_tensorrt/ts/_compile_spec.py b/py/torch_tensorrt/ts/_compile_spec.py index b9a84152e1..2087194a41 100644 --- a/py/torch_tensorrt/ts/_compile_spec.py +++ b/py/torch_tensorrt/ts/_compile_spec.py @@ -1,18 +1,19 @@ from __future__ import annotations +import os from copy import deepcopy from typing import Any, Dict, List, Optional, Set +import tensorrt as trt import torch import torch_tensorrt._C.ts as _ts_C from torch_tensorrt import _C, _enums from torch_tensorrt._Device import Device from torch_tensorrt._Input import Input from torch_tensorrt.logging import Level, log +from torch_tensorrt.ptq import * # noqa: F403 from torch_tensorrt.ts._Input import TorchScriptInput -import tensorrt as trt - def _internal_input_to_torch_class_input(i: _C.Input) -> torch.classes.tensorrt._Input: clone = torch.classes.tensorrt._Input() @@ -75,6 +76,112 @@ def _parse_enabled_precisions(precisions: Any) -> Set[_enums.dtype]: return parsed_precisions +# deepcopy (which involves pickling) is performed on the compile_spec internally during compilation. +# We register this __reduce__ function for pickler to identity the calibrator object returned by DataLoaderCalibrator during deepcopy. +# This should be the object's local name relative to the module https://docs.python.org/3/library/pickle.html#object.__reduce__ +def __reduce__(self: object) -> str: + return self.__class__.__name__ + + +def _build_calibrator(calibrator: Any) -> Any: + if not calibrator: + return None + if not isinstance(calibrator, DataLoaderCalibrator) or isinstance( + calibrator, CacheCalibrator + ): + raise AssertionError( + f"Invalid calibrator of type {type(calibrator)} provided. Only calibrator of type DataLoaderCalibrator or CacheCalibrator is supported" + ) + algo_type = calibrator.algo_type + cache_file = calibrator.cache_file + attribute_mapping = {} + if isinstance(calibrator, DataLoaderCalibrator): + dataloader = calibrator.dataloader + use_cache = calibrator.use_cache + device = calibrator.device + if not isinstance(dataloader, torch.utils.data.DataLoader): + log( + Level.Error, + "Dataloader : {} is not a valid instance of torch.utils.data.DataLoader".format( + dataloader + ), + ) + if not cache_file: + if use_cache: + log( + Level.Debug, + "Using existing cache_file {} for calibration".format(cache_file), + ) + else: + log(Level.Debug, "Overwriting existing calibration cache file.") + else: + if use_cache: + log( + Level.Error, + "Input cache file is None but use_cache is set to True in INT8 mode.", + ) + + # Define attributes and member functions for the calibrator class + attribute_mapping = { + "data_loader": dataloader, + "current_batch_idx": 0, + "batch_size": dataloader.batch_size, + "dataset_iterator": iter(dataloader), + "cache_file": cache_file, + "device": device, + "use_cache": use_cache, + "get_batch_size": get_batch_size, + "get_batch": get_cache_mode_batch if use_cache else get_batch, + "read_calibration_cache": read_calibration_cache, + "write_calibration_cache": write_calibration_cache, + "__reduce__": __reduce__, # used when you deepcopy the DataLoaderCalibrator object + } + elif isinstance(calibrator, CacheCalibrator): + if os.path.isfile(cache_file): + log( + Level.Debug, + "Using existing cache_file {} for calibration".format(cache_file), + ) + else: + log(Level.Error, "Invalid calibration cache file.") + attribute_mapping = { + "use_cache": True, + "cache_file": cache_file, + "get_batch_size": get_batch_size, + "get_batch": get_cache_mode_batch, + "read_calibration_cache": read_calibration_cache, + "write_calibration_cache": write_calibration_cache, + } + + # Using type metaclass to construct calibrator class based on algorithm type + if algo_type == CalibrationAlgo.ENTROPY_CALIBRATION: + calib_ec = type( + "Int8EntropyCalibrator", (_C.IInt8EntropyCalibrator,), attribute_mapping + )() + return calib_ec + elif algo_type == CalibrationAlgo.ENTROPY_CALIBRATION_2: + calib_ec2 = type( + "Int8EntropyCalibrator2", + (_C.IInt8EntropyCalibrator2,), + attribute_mapping, + )() + return calib_ec2 + elif algo_type == CalibrationAlgo.LEGACY_CALIBRATION: + calib_lc = type( + "Int8LegacyCalibrator", (_C.IInt8LegacyCalibrator,), attribute_mapping + )() + return calib_lc + elif algo_type == CalibrationAlgo.MINMAX_CALIBRATION: + calib_mmc = type( + "Int8MinMaxCalibrator", (_C.IInt8MinMaxCalibrator,), attribute_mapping + )() + return calib_mmc + else: + raise ValueError( + "Invalid calibration algorithm type. Please select among ENTROPY_CALIBRATION, ENTROPY_CALIBRATION, LEGACY_CALIBRATION or MINMAX_CALIBRATION" + ) + + def _parse_device_type(device: Any) -> _enums.DeviceType: if isinstance(device, torch.device): if device.type == "cuda": @@ -399,7 +506,7 @@ def TensorRTCompileSpec( Returns: torch.classes.tensorrt.CompileSpec: List of methods and formated spec objects to be provided to ``torch._C._jit_to_tensorrt`` """ - + ptq_calibrator = _build_calibrator(calibrator) compile_spec = { "inputs": inputs if inputs is not None else [], # "input_signature": input_signature, @@ -417,7 +524,7 @@ def TensorRTCompileSpec( "dla_sram_size": dla_sram_size, # Fast software managed RAM used by DLA to communicate within a layer. "dla_local_dram_size": dla_local_dram_size, # Host RAM used by DLA to share intermediate tensor data across operations "dla_global_dram_size": dla_global_dram_size, # Host RAM used by DLA to store weights and metadata for execution - "calibrator": calibrator, + "calibrator": ptq_calibrator, "truncate_long_and_double": truncate_long_and_double, "allow_shape_tensors": allow_shape_tensors, } diff --git a/py/torch_tensorrt/ts/_compiler.py b/py/torch_tensorrt/ts/_compiler.py index 4a9bb53dc0..0deff815b9 100644 --- a/py/torch_tensorrt/ts/_compiler.py +++ b/py/torch_tensorrt/ts/_compiler.py @@ -7,7 +7,11 @@ from torch_tensorrt import _enums from torch_tensorrt._Device import Device from torch_tensorrt._Input import Input -from torch_tensorrt.ts._compile_spec import _parse_compile_spec, _parse_device +from torch_tensorrt.ts._compile_spec import ( + _build_calibrator, + _parse_compile_spec, + _parse_device, +) def compile( @@ -125,6 +129,7 @@ def compile( f"require_full_compilation is enabled however the list of modules and ops to run in torch is not empty. Found: torch_executed_ops: {torch_executed_ops}, torch_executed_modules: {torch_executed_modules}" ) + ptq_calibrator = _build_calibrator(calibrator) spec = { "inputs": input_list, "input_signature": input_signature, @@ -137,7 +142,7 @@ def compile( "capability": capability, # Restrict kernel selection to safe gpu kernels or safe dla kernels "num_avg_timing_iters": num_avg_timing_iters, # Number of averaging timing iterations used to select kernels "workspace_size": workspace_size, # Maximum size of workspace given to TensorRT - "calibrator": calibrator, + "calibrator": ptq_calibrator, "truncate_long_and_double": truncate_long_and_double, "torch_fallback": { "enabled": not require_full_compilation, diff --git a/tests/py/dynamo/ptq/test_compile_ptq.py b/tests/py/dynamo/ptq/test_compile_ptq.py new file mode 100644 index 0000000000..a8cb0b64c9 --- /dev/null +++ b/tests/py/dynamo/ptq/test_compile_ptq.py @@ -0,0 +1,92 @@ +import os +import unittest + +import torch +import torch.nn as nn +import torch_tensorrt as torchtrt +import torchvision +import torchvision.models as models +import torchvision.transforms as transforms +from torch.nn import functional as F +from torch_tensorrt.logging import * +from torch_tensorrt.ptq import CalibrationAlgo, DataLoaderCalibrator +from vgg16 import vgg16 + + +def compute_accuracy(testing_dataloader, model): + total = 0 + correct = 0 + loss = 0.0 + class_probs = [] + class_preds = [] + device = torch.device("cuda:0") + with torch.no_grad(): + idx = 0 + for data, labels in testing_dataloader: + data, labels = data.to(device), labels.to(device) + out = model(data) + preds = torch.max(out, 1)[1] + class_probs.append([F.softmax(i, dim=0) for i in out]) + class_preds.append(preds) + total += labels.size(0) + correct += (preds == labels).sum().item() + idx += 1 + + test_probs = torch.cat([torch.stack(batch) for batch in class_probs]) + test_preds = torch.cat(class_preds) + return correct / total + + +class TestAccuracy(unittest.TestCase): + def test_compile_script(self): + self.model = vgg16(num_classes=10, init_weights=False).eval().cuda() + self.testing_dataset = torchvision.datasets.CIFAR10( + root="./data", + train=False, + download=True, + transform=transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize( + (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) + ), + ] + ), + ) + + self.testing_dataloader = torch.utils.data.DataLoader( + self.testing_dataset, batch_size=100, shuffle=False, num_workers=1 + ) + self.calibrator = DataLoaderCalibrator( + self.testing_dataloader, + cache_file="./calibration.cache", + use_cache=False, + algo_type=CalibrationAlgo.ENTROPY_CALIBRATION_2, + device=torch.device("cuda:0"), + ) + + compile_spec = { + "inputs": [torchtrt.Input([100, 3, 32, 32])], + "enabled_precisions": {torch.int8}, + "calibrator": self.calibrator, + "truncate_long_and_double": True, + "debug": True, + "require_full_compilation": True, + "enable_experimental_decompositions": True, + "min_block_size": 1, + } + trt_mod = torch.compile( + self.model, backend="torch_tensorrt", dynamic=False, options=compile_spec + ) + + fp32_test_acc = compute_accuracy(self.testing_dataloader, self.model) + log(Level.Info, "[Pyt FP32] Test Acc: {:.2f}%".format(100 * fp32_test_acc)) + + int8_test_acc = compute_accuracy(self.testing_dataloader, trt_mod) + log(Level.Info, "[TRT INT8] Test Acc: {:.2f}%".format(100 * int8_test_acc)) + acc_diff = fp32_test_acc - int8_test_acc + self.assertTrue(abs(acc_diff) < 3) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/py/dynamo/ptq/test_export_ptq.py b/tests/py/dynamo/ptq/test_export_ptq.py new file mode 100644 index 0000000000..24691e48e4 --- /dev/null +++ b/tests/py/dynamo/ptq/test_export_ptq.py @@ -0,0 +1,90 @@ +import os +import unittest + +import torch +import torch.nn as nn +import torch_tensorrt as torchtrt +import torchvision +import torchvision.models as models +import torchvision.transforms as transforms +from torch.nn import functional as F +from torch_tensorrt.logging import * +from torch_tensorrt.ptq import CalibrationAlgo, DataLoaderCalibrator +from vgg16 import vgg16 + + +def compute_accuracy(testing_dataloader, model): + total = 0 + correct = 0 + loss = 0.0 + class_probs = [] + class_preds = [] + device = torch.device("cuda:0") + with torch.no_grad(): + idx = 0 + for data, labels in testing_dataloader: + data, labels = data.to(device), labels.to(device) + out = model(data) + out = out[0] if isinstance(out, tuple) else out + preds = torch.max(out, 1)[1] + class_probs.append([F.softmax(i, dim=0) for i in out]) + class_preds.append(preds) + total += labels.size(0) + correct += (preds == labels).sum().item() + idx += 1 + + test_probs = torch.cat([torch.stack(batch) for batch in class_probs]) + test_preds = torch.cat(class_preds) + return correct / total + + +class TestAccuracy(unittest.TestCase): + def test_compile_script(self): + self.model = vgg16(num_classes=10, init_weights=False).eval().cuda() + self.testing_dataset = torchvision.datasets.CIFAR10( + root="./data", + train=False, + download=True, + transform=transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize( + (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) + ), + ] + ), + ) + + self.testing_dataloader = torch.utils.data.DataLoader( + self.testing_dataset, batch_size=100, shuffle=False, num_workers=1 + ) + self.calibrator = DataLoaderCalibrator( + self.testing_dataloader, + cache_file="./calibration.cache", + use_cache=False, + algo_type=CalibrationAlgo.ENTROPY_CALIBRATION_2, + device=torch.device("cuda:0"), + ) + + compile_spec = { + "inputs": [torchtrt.Input([100, 3, 32, 32])], + "enabled_precisions": {torch.int8}, + "calibrator": self.calibrator, + "truncate_long_and_double": True, + "debug": True, + "min_block_size": 1, + "enable_experimental_decompositions": True, + } + trt_mod = torchtrt.compile(self.model, **compile_spec) + + fp32_test_acc = compute_accuracy(self.testing_dataloader, self.model) + log(Level.Info, "[Pyt FP32] Test Acc: {:.2f}%".format(100 * fp32_test_acc)) + + int8_test_acc = compute_accuracy(self.testing_dataloader, trt_mod) + log(Level.Info, "[TRT INT8] Test Acc: {:.2f}%".format(100 * int8_test_acc)) + acc_diff = fp32_test_acc - int8_test_acc + self.assertTrue(abs(acc_diff) < 3) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/py/ts/ptq/test_ptq_dataloader_calibrator.py b/tests/py/ts/ptq/test_ptq_dataloader_calibrator.py index c5a84f301d..589762087e 100644 --- a/tests/py/ts/ptq/test_ptq_dataloader_calibrator.py +++ b/tests/py/ts/ptq/test_ptq_dataloader_calibrator.py @@ -1,13 +1,13 @@ +import os import unittest -import torch_tensorrt as torchtrt -from torch_tensorrt.logging import * + import torch import torch.nn as nn -from torch.nn import functional as F +import torch_tensorrt as torchtrt import torchvision import torchvision.transforms as transforms - -import os +from torch.nn import functional as F +from torch_tensorrt.logging import * def find_repo_root(max_depth=10): @@ -70,7 +70,7 @@ def test_compile_script(self): ) self.testing_dataloader = torch.utils.data.DataLoader( - self.testing_dataset, batch_size=1, shuffle=False, num_workers=1 + self.testing_dataset, batch_size=100, shuffle=False, num_workers=0 ) self.calibrator = torchtrt.ptq.DataLoaderCalibrator( self.testing_dataloader, @@ -81,7 +81,7 @@ def test_compile_script(self): ) compile_spec = { - "inputs": [torchtrt.Input([1, 3, 32, 32])], + "inputs": [torchtrt.Input([100, 3, 32, 32])], "enabled_precisions": {torch.float, torch.int8}, "calibrator": self.calibrator, "truncate_long_and_double": True, @@ -92,7 +92,8 @@ def test_compile_script(self): "allow_gpu_fallback": False, }, } - trt_mod = torchtrt.ts.compile(self.model, **compile_spec) + with torchtrt.logging.debug(): + trt_mod = torchtrt.ts.compile(self.model, **compile_spec) fp32_test_acc = compute_accuracy(self.testing_dataloader, self.model) log(Level.Info, "[Pyt FP32] Test Acc: {:.2f}%".format(100 * fp32_test_acc)) diff --git a/tests/py/ts/ptq/test_ptq_to_backend.py b/tests/py/ts/ptq/test_ptq_to_backend.py index 3a0a5bf336..ff58708772 100644 --- a/tests/py/ts/ptq/test_ptq_to_backend.py +++ b/tests/py/ts/ptq/test_ptq_to_backend.py @@ -1,12 +1,13 @@ +import os import unittest -import torch_tensorrt as torchtrt -from torch_tensorrt.logging import * + import torch import torch.nn as nn -from torch.nn import functional as F +import torch_tensorrt as torchtrt import torchvision import torchvision.transforms as transforms -import os +from torch.nn import functional as F +from torch_tensorrt.logging import * def find_repo_root(max_depth=10): @@ -69,7 +70,7 @@ def test_compile_script(self): ) self.testing_dataloader = torch.utils.data.DataLoader( - self.testing_dataset, batch_size=1, shuffle=False, num_workers=1 + self.testing_dataset, batch_size=100, shuffle=False, num_workers=1 ) self.calibrator = torchtrt.ptq.DataLoaderCalibrator( self.testing_dataloader, @@ -82,8 +83,8 @@ def test_compile_script(self): self.spec = { "forward": torchtrt.ts.TensorRTCompileSpec( **{ - "inputs": [torchtrt.Input([1, 3, 32, 32])], - "enabled_precisions": {torch.float, torch.half, torch.int8}, + "inputs": [torchtrt.Input([100, 3, 32, 32])], + "enabled_precisions": {torch.float, torch.int8}, "calibrator": self.calibrator, "truncate_long_and_double": True, "device": {