Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions references/classification/train_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@


try:
from torchvision.prototype import models as PM
from torchvision import prototype
except ImportError:
PM = None
prototype = None


def main(args):
if args.weights and PM is None:
if args.prototype and prototype is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
if not args.prototype and args.weights:
raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
if args.output_dir:
utils.mkdir(args.output_dir)

Expand Down Expand Up @@ -54,10 +56,10 @@ def main(args):

print("Creating model", args.model)
# when training quantized models, we always start from a pre-trained fp32 reference model
if not args.weights:
if not args.prototype:
model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only)
else:
model = PM.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only)
model = prototype.models.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only)
model.to(device)

if not (args.test_only or args.post_training_quantize):
Expand Down Expand Up @@ -264,6 +266,12 @@ def get_args_parser(add_help=True):
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")

# Prototype models only
parser.add_argument(
"--prototype",
dest="prototype",
help="Use prototype model builders instead those from main area",
action="store_true",
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")

return parser
Expand Down