From be122e1f00f19025e1174d1d6f8e40047ebd559f Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 21 May 2019 05:17:27 -0700 Subject: [PATCH] Add pretrained arg to reference scripts Allows for easily evaluating the pre-trained models in the modelzoo --- references/classification/train.py | 8 +++++++- references/detection/train.py | 9 ++++++++- references/segmentation/train.py | 10 +++++++++- 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index ddc637dd6ef..8c83dab2bcf 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -144,7 +144,7 @@ def main(args): sampler=test_sampler, num_workers=args.workers, pin_memory=True) print("Creating model") - model = torchvision.models.__dict__[args.model]() + model = torchvision.models.__dict__[args.model](pretrained=args.pretrained) model.to(device) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) @@ -242,6 +242,12 @@ def parse_args(): help="Only test the model", action="store_true", ) + parser.add_argument( + "--pretrained", + dest="pretrained", + help="Use pre-trained models from the modelzoo", + action="store_true", + ) # distributed training parameters parser.add_argument('--world-size', default=1, type=int, diff --git a/references/detection/train.py b/references/detection/train.py index c9b6ffa83ed..0a187686dd8 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -76,7 +76,8 @@ def main(args): collate_fn=utils.collate_fn) print("Creating model") - model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes) + model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes, + pretrained=args.pretrained) model.to(device) model_without_ddp = model @@ -156,6 +157,12 @@ def main(args): help="Only test the model", action="store_true", ) + parser.add_argument( + "--pretrained", + dest="pretrained", + help="Use pre-trained models from the modelzoo", + action="store_true", + ) # distributed training parameters parser.add_argument('--world-size', default=1, type=int, diff --git a/references/segmentation/train.py b/references/segmentation/train.py index d16fcdf997e..b1173d5323a 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -121,7 +121,9 @@ def main(args): sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn) - model = torchvision.models.segmentation.__dict__[args.model](num_classes=num_classes, aux_loss=args.aux_loss) + model = torchvision.models.segmentation.__dict__[args.model](num_classes=num_classes, + aux_loss=args.aux_loss, + pretrained=args.pretrained) model.to(device) if args.distributed: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) @@ -205,6 +207,12 @@ def parse_args(): help="Only test the model", action="store_true", ) + parser.add_argument( + "--pretrained", + dest="pretrained", + help="Use pre-trained models from the modelzoo", + action="store_true", + ) # distributed training parameters parser.add_argument('--world-size', default=1, type=int, help='number of distributed processes')