From 29f5a7a5ac7803099216220fb0458e18455805dd Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 2 Feb 2022 10:05:54 +0000 Subject: [PATCH] Add --prototype flag to quantization scripts. --- .../classification/train_quantization.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/references/classification/train_quantization.py b/references/classification/train_quantization.py index efda98170de..c36dd0ac3b9 100644 --- a/references/classification/train_quantization.py +++ b/references/classification/train_quantization.py @@ -13,14 +13,16 @@ try: - from torchvision.prototype import models as PM + from torchvision import prototype except ImportError: - PM = None + prototype = None def main(args): - if args.weights and PM is None: + if args.prototype and prototype is None: raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") + if not args.prototype and args.weights: + raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") if args.output_dir: utils.mkdir(args.output_dir) @@ -54,10 +56,10 @@ def main(args): print("Creating model", args.model) # when training quantized models, we always start from a pre-trained fp32 reference model - if not args.weights: + if not args.prototype: model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only) else: - model = PM.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only) + model = prototype.models.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only) model.to(device) if not (args.test_only or args.post_training_quantize): @@ -264,6 +266,12 @@ def get_args_parser(add_help=True): parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)") # Prototype models only + parser.add_argument( + "--prototype", + dest="prototype", + help="Use prototype model builders instead those from main area", + action="store_true", + ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") return parser