From 0d30a2fe74e77f82ab01dcc2748a6aa5d4086e01 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 21 Jan 2022 17:04:20 +0000 Subject: [PATCH 1/4] Adding prototype flag on reference scripts. --- references/classification/train.py | 27 ++++++++++++++------ references/detection/train.py | 25 ++++++++++++------ references/optical_flow/train.py | 32 +++++++++++++++++------- references/segmentation/train.py | 25 ++++++++++++------ references/video_classification/train.py | 28 ++++++++++++++------- 5 files changed, 98 insertions(+), 39 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index 507356c8048..439c6dce82a 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -16,9 +16,9 @@ try: - from torchvision.prototype import models as PM + from torchvision.prototype import models as PM, transforms as PT except ImportError: - PM = None + PM = PT = None def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None): @@ -154,13 +154,18 @@ def load_data(traindir, valdir, args): print(f"Loading dataset_test from {cache_path}") dataset_test, _ = torch.load(cache_path) else: - if not args.weights: + if not args.prototype: preprocessing = presets.ClassificationPresetEval( crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation ) else: - weights = PM.get_weight(args.weights) - preprocessing = weights.transforms() + if args.weights: + weights = PM.get_weight(args.weights) + preprocessing = weights.transforms() + else: + preprocessing = PT.ImageNetEval( + crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation + ) dataset_test = torchvision.datasets.ImageFolder( valdir, @@ -186,8 +191,10 @@ def load_data(traindir, valdir, args): def main(args): - if args.weights and PM is None: + if args.prototype and PM is None: raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") + if not args.prototype and args.weights: + raise ImportError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") if args.output_dir: utils.mkdir(args.output_dir) @@ -229,7 +236,7 @@ def main(args): ) print("Creating model") - if not args.weights: + if not args.prototype: model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes) else: model = PM.__dict__[args.model](weights=args.weights, num_classes=num_classes) @@ -491,6 +498,12 @@ def get_args_parser(add_help=True): ) # Prototype models only + parser.add_argument( + "--prototype", + dest="prototype", + help="Use prototype model builders instead those from main area", + action="store_true", + ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") return parser diff --git a/references/detection/train.py b/references/detection/train.py index 0788895af20..5a2712e9d14 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -34,9 +34,9 @@ try: - from torchvision.prototype import models as PM + from torchvision.prototype import models as PM, transforms as PT except ImportError: - PM = None + PM = PT = None def get_dataset(name, image_set, transform, data_path): @@ -50,11 +50,14 @@ def get_dataset(name, image_set, transform, data_path): def get_transform(train, args): if train: return presets.DetectionPresetTrain(args.data_augmentation) - elif not args.weights: + elif not args.prototype: return presets.DetectionPresetEval() else: - weights = PM.get_weight(args.weights) - return weights.transforms() + if args.weights: + weights = PM.get_weight(args.weights) + return weights.transforms() + else: + return PT.CocoEval() def get_args_parser(add_help=True): @@ -141,6 +144,12 @@ def get_args_parser(add_help=True): parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") # Prototype models only + parser.add_argument( + "--prototype", + dest="prototype", + help="Use prototype model builders instead those from main area", + action="store_true", + ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") # Mixed precision training parameters @@ -150,8 +159,10 @@ def get_args_parser(add_help=True): def main(args): - if args.weights and PM is None: + if args.prototype and PM is None: raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") + if not args.prototype and args.weights: + raise ImportError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") if args.output_dir: utils.mkdir(args.output_dir) @@ -193,7 +204,7 @@ def main(args): if "rcnn" in args.model: if args.rpn_score_thresh is not None: kwargs["rpn_score_thresh"] = args.rpn_score_thresh - if not args.weights: + if not args.prototype: model = torchvision.models.detection.__dict__[args.model]( pretrained=args.pretrained, num_classes=num_classes, **kwargs ) diff --git a/references/optical_flow/train.py b/references/optical_flow/train.py index 4f63483a688..c3449cab713 100644 --- a/references/optical_flow/train.py +++ b/references/optical_flow/train.py @@ -10,10 +10,9 @@ from torchvision.datasets import KittiFlow, FlyingChairs, FlyingThings3D, Sintel, HD1K try: - from torchvision.prototype import models as PM - from torchvision.prototype.models import optical_flow as PMOF + from torchvision.prototype import models as PM, transforms as PT except ImportError: - PM = PMOF = None + PM = PT = None def get_train_dataset(stage, dataset_root): @@ -133,9 +132,12 @@ def inner_loop(blob): def validate(model, args): val_datasets = args.val_dataset or [] - if args.weights: - weights = PM.get_weight(args.weights) - preprocessing = weights.transforms() + if args.prototype: + if args.weights: + weights = PM.get_weight(args.weights) + preprocessing = weights.transforms() + else: + preprocessing = PT.RaftEval() else: preprocessing = OpticalFlowPresetEval() @@ -192,10 +194,14 @@ def train_one_epoch(model, optimizer, scheduler, train_loader, logger, args): def main(args): + if args.prototype and PM is None: + raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") + if not args.prototype and args.weights: + raise ImportError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") utils.setup_ddp(args) - if args.weights: - model = PMOF.__dict__[args.model](weights=args.weights) + if args.prototype: + model = PM.optical_flow.__dict__[args.model](weights=args.weights) else: model = torchvision.models.optical_flow.__dict__[args.model](pretrained=args.pretrained) @@ -317,7 +323,6 @@ def get_args_parser(add_help=True): ) # TODO: resume, pretrained, and weights should be in an exclusive arg group parser.add_argument("--pretrained", action="store_true", help="Whether to use pretrained weights") - parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load.") parser.add_argument( "--num_flow_updates", @@ -336,6 +341,15 @@ def get_args_parser(add_help=True): required=True, ) + # Prototype models only + parser.add_argument( + "--prototype", + dest="prototype", + help="Use prototype model builders instead those from main area", + action="store_true", + ) + parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load.") + return parser diff --git a/references/segmentation/train.py b/references/segmentation/train.py index 72a9bdb01f5..cf33100d610 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -12,9 +12,9 @@ try: - from torchvision.prototype import models as PM + from torchvision.prototype import models as PM, transforms as PT except ImportError: - PM = None + PM = PT = None def get_dataset(dir_path, name, image_set, transform): @@ -35,11 +35,14 @@ def sbd(*args, **kwargs): def get_transform(train, args): if train: return presets.SegmentationPresetTrain(base_size=520, crop_size=480) - elif not args.weights: + elif not args.prototype: return presets.SegmentationPresetEval(base_size=520) else: - weights = PM.get_weight(args.weights) - return weights.transforms() + if args.weights: + weights = PM.get_weight(args.weights) + return weights.transforms() + else: + return PT.VocEval(resize_size=520) def criterion(inputs, target): @@ -97,8 +100,10 @@ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, devi def main(args): - if args.weights and PM is None: + if args.prototype and PM is None: raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") + if not args.prototype and args.weights: + raise ImportError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") if args.output_dir: utils.mkdir(args.output_dir) @@ -130,7 +135,7 @@ def main(args): dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn ) - if not args.weights: + if not args.prototype: model = torchvision.models.segmentation.__dict__[args.model]( pretrained=args.pretrained, num_classes=num_classes, @@ -278,6 +283,12 @@ def get_args_parser(add_help=True): parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") # Prototype models only + parser.add_argument( + "--prototype", + dest="prototype", + help="Use prototype model builders instead those from main area", + action="store_true", + ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") # Mixed precision training parameters diff --git a/references/video_classification/train.py b/references/video_classification/train.py index 0cd88e8022f..38af8e95dd9 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -13,9 +13,9 @@ from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler try: - from torchvision.prototype import models as PM + from torchvision.prototype import models as PM, transforms as PT except ImportError: - PM = None + PM = PT = None def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, scaler=None): @@ -96,9 +96,10 @@ def collate_fn(batch): def main(args): - if args.weights and PM is None: + if args.prototype and PM is None: raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") - + if not args.prototype and args.weights: + raise ImportError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") if args.output_dir: utils.mkdir(args.output_dir) @@ -149,11 +150,14 @@ def main(args): print("Loading validation data") cache_path = _get_cache_path(valdir) - if not args.weights: - transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112)) + if not args.prototype: + transform_test = presets.VideoClassificationPresetEval(resize_size=(128, 171), crop_size=(112, 112)) else: - weights = PM.get_weight(args.weights) - transform_test = weights.transforms() + if args.weights: + weights = PM.get_weight(args.weights) + transform_test = weights.transforms() + else: + transform_test = PT.Kinect400Eval(crop_size=(112, 112), resize_size=(128, 171)) if args.cache_dataset and os.path.exists(cache_path): print(f"Loading dataset_test from {cache_path}") @@ -204,7 +208,7 @@ def main(args): ) print("Creating model") - if not args.weights: + if not args.prototype: model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained) else: model = PM.video.__dict__[args.model](weights=args.weights) @@ -360,6 +364,12 @@ def parse_args(): parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") # Prototype models only + parser.add_argument( + "--prototype", + dest="prototype", + help="Use prototype model builders instead those from main area", + action="store_true", + ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") # Mixed precision training parameters From b0fc1d3128db7d58dd2064f7111eac655d6c17ae Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 21 Jan 2022 17:39:01 +0000 Subject: [PATCH 2/4] Import prototype instead of models/transforms. --- references/classification/train.py | 12 ++++++------ references/detection/train.py | 12 ++++++------ references/optical_flow/train.py | 12 ++++++------ references/segmentation/train.py | 12 ++++++------ references/video_classification/train.py | 12 ++++++------ 5 files changed, 30 insertions(+), 30 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index 439c6dce82a..9d5e43a992a 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -16,9 +16,9 @@ try: - from torchvision.prototype import models as PM, transforms as PT + from torchvision import prototype except ImportError: - PM = PT = None + prototype = None def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None): @@ -160,10 +160,10 @@ def load_data(traindir, valdir, args): ) else: if args.weights: - weights = PM.get_weight(args.weights) + weights = prototype.models.get_weight(args.weights) preprocessing = weights.transforms() else: - preprocessing = PT.ImageNetEval( + preprocessing = prototype.transforms.ImageNetEval( crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation ) @@ -191,7 +191,7 @@ def load_data(traindir, valdir, args): def main(args): - if args.prototype and PM is None: + if args.prototype and prototype.models is None: raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") if not args.prototype and args.weights: raise ImportError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") @@ -239,7 +239,7 @@ def main(args): if not args.prototype: model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes) else: - model = PM.__dict__[args.model](weights=args.weights, num_classes=num_classes) + model = prototype.models.__dict__[args.model](weights=args.weights, num_classes=num_classes) model.to(device) if args.distributed and args.sync_bn: diff --git a/references/detection/train.py b/references/detection/train.py index 5a2712e9d14..a9095cb1877 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -34,9 +34,9 @@ try: - from torchvision.prototype import models as PM, transforms as PT + from torchvision import prototype except ImportError: - PM = PT = None + prototype = None def get_dataset(name, image_set, transform, data_path): @@ -54,10 +54,10 @@ def get_transform(train, args): return presets.DetectionPresetEval() else: if args.weights: - weights = PM.get_weight(args.weights) + weights = prototype.models.get_weight(args.weights) return weights.transforms() else: - return PT.CocoEval() + return prototype.transforms.CocoEval() def get_args_parser(add_help=True): @@ -159,7 +159,7 @@ def get_args_parser(add_help=True): def main(args): - if args.prototype and PM is None: + if args.prototype and prototype.models is None: raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") if not args.prototype and args.weights: raise ImportError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") @@ -209,7 +209,7 @@ def main(args): pretrained=args.pretrained, num_classes=num_classes, **kwargs ) else: - model = PM.detection.__dict__[args.model](weights=args.weights, num_classes=num_classes, **kwargs) + model = prototype.models.detection.__dict__[args.model](weights=args.weights, num_classes=num_classes, **kwargs) model.to(device) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) diff --git a/references/optical_flow/train.py b/references/optical_flow/train.py index c3449cab713..1a1d7b10ff9 100644 --- a/references/optical_flow/train.py +++ b/references/optical_flow/train.py @@ -10,9 +10,9 @@ from torchvision.datasets import KittiFlow, FlyingChairs, FlyingThings3D, Sintel, HD1K try: - from torchvision.prototype import models as PM, transforms as PT + from torchvision import prototype except ImportError: - PM = PT = None + prototype = None def get_train_dataset(stage, dataset_root): @@ -134,10 +134,10 @@ def validate(model, args): if args.prototype: if args.weights: - weights = PM.get_weight(args.weights) + weights = prototype.models.get_weight(args.weights) preprocessing = weights.transforms() else: - preprocessing = PT.RaftEval() + preprocessing = prototype.transforms.RaftEval() else: preprocessing = OpticalFlowPresetEval() @@ -194,14 +194,14 @@ def train_one_epoch(model, optimizer, scheduler, train_loader, logger, args): def main(args): - if args.prototype and PM is None: + if args.prototype and prototype.models is None: raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") if not args.prototype and args.weights: raise ImportError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") utils.setup_ddp(args) if args.prototype: - model = PM.optical_flow.__dict__[args.model](weights=args.weights) + model = prototype.models.optical_flow.__dict__[args.model](weights=args.weights) else: model = torchvision.models.optical_flow.__dict__[args.model](pretrained=args.pretrained) diff --git a/references/segmentation/train.py b/references/segmentation/train.py index cf33100d610..fb3ddf98de3 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -12,9 +12,9 @@ try: - from torchvision.prototype import models as PM, transforms as PT + from torchvision import prototype except ImportError: - PM = PT = None + prototype = None def get_dataset(dir_path, name, image_set, transform): @@ -39,10 +39,10 @@ def get_transform(train, args): return presets.SegmentationPresetEval(base_size=520) else: if args.weights: - weights = PM.get_weight(args.weights) + weights = prototype.models.get_weight(args.weights) return weights.transforms() else: - return PT.VocEval(resize_size=520) + return prototype.transforms.VocEval(resize_size=520) def criterion(inputs, target): @@ -100,7 +100,7 @@ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, devi def main(args): - if args.prototype and PM is None: + if args.prototype and prototype.models is None: raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") if not args.prototype and args.weights: raise ImportError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") @@ -142,7 +142,7 @@ def main(args): aux_loss=args.aux_loss, ) else: - model = PM.segmentation.__dict__[args.model]( + model = prototype.models.segmentation.__dict__[args.model]( weights=args.weights, num_classes=num_classes, aux_loss=args.aux_loss ) model.to(device) diff --git a/references/video_classification/train.py b/references/video_classification/train.py index 38af8e95dd9..879e16db502 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -13,9 +13,9 @@ from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler try: - from torchvision.prototype import models as PM, transforms as PT + from torchvision import prototype except ImportError: - PM = PT = None + prototype = None def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, scaler=None): @@ -96,7 +96,7 @@ def collate_fn(batch): def main(args): - if args.prototype and PM is None: + if args.prototype and prototype.models is None: raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") if not args.prototype and args.weights: raise ImportError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") @@ -154,10 +154,10 @@ def main(args): transform_test = presets.VideoClassificationPresetEval(resize_size=(128, 171), crop_size=(112, 112)) else: if args.weights: - weights = PM.get_weight(args.weights) + weights = prototype.models.get_weight(args.weights) transform_test = weights.transforms() else: - transform_test = PT.Kinect400Eval(crop_size=(112, 112), resize_size=(128, 171)) + transform_test = prototype.transforms.Kinect400Eval(crop_size=(112, 112), resize_size=(128, 171)) if args.cache_dataset and os.path.exists(cache_path): print(f"Loading dataset_test from {cache_path}") @@ -211,7 +211,7 @@ def main(args): if not args.prototype: model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained) else: - model = PM.video.__dict__[args.model](weights=args.weights) + model = prototype.models.video.__dict__[args.model](weights=args.weights) model.to(device) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) From d2b78858a79011309fd2ccc4a18007da2bcb1301 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 21 Jan 2022 17:58:54 +0000 Subject: [PATCH 3/4] Correcting exception type. --- references/classification/train.py | 2 +- references/detection/train.py | 2 +- references/optical_flow/train.py | 2 +- references/segmentation/train.py | 2 +- references/video_classification/train.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index 9d5e43a992a..73fb4ea1f1d 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -194,7 +194,7 @@ def main(args): if args.prototype and prototype.models is None: raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") if not args.prototype and args.weights: - raise ImportError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") + raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") if args.output_dir: utils.mkdir(args.output_dir) diff --git a/references/detection/train.py b/references/detection/train.py index a9095cb1877..f466118ad19 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -162,7 +162,7 @@ def main(args): if args.prototype and prototype.models is None: raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") if not args.prototype and args.weights: - raise ImportError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") + raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") if args.output_dir: utils.mkdir(args.output_dir) diff --git a/references/optical_flow/train.py b/references/optical_flow/train.py index 1a1d7b10ff9..d21fe9c2504 100644 --- a/references/optical_flow/train.py +++ b/references/optical_flow/train.py @@ -197,7 +197,7 @@ def main(args): if args.prototype and prototype.models is None: raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") if not args.prototype and args.weights: - raise ImportError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") + raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") utils.setup_ddp(args) if args.prototype: diff --git a/references/segmentation/train.py b/references/segmentation/train.py index fb3ddf98de3..83b861b5664 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -103,7 +103,7 @@ def main(args): if args.prototype and prototype.models is None: raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") if not args.prototype and args.weights: - raise ImportError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") + raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") if args.output_dir: utils.mkdir(args.output_dir) diff --git a/references/video_classification/train.py b/references/video_classification/train.py index 879e16db502..db93a7860bb 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -99,7 +99,7 @@ def main(args): if args.prototype and prototype.models is None: raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") if not args.prototype and args.weights: - raise ImportError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") + raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") if args.output_dir: utils.mkdir(args.output_dir) From e8d7a827b706a6eabcae1660d14245f53d4a5b23 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 21 Jan 2022 18:05:31 +0000 Subject: [PATCH 4/4] fixing none referencing --- references/classification/train.py | 2 +- references/detection/train.py | 2 +- references/optical_flow/train.py | 2 +- references/segmentation/train.py | 2 +- references/video_classification/train.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index 73fb4ea1f1d..c87471ba657 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -191,7 +191,7 @@ def load_data(traindir, valdir, args): def main(args): - if args.prototype and prototype.models is None: + if args.prototype and prototype is None: raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") if not args.prototype and args.weights: raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") diff --git a/references/detection/train.py b/references/detection/train.py index f466118ad19..765f8144364 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -159,7 +159,7 @@ def get_args_parser(add_help=True): def main(args): - if args.prototype and prototype.models is None: + if args.prototype and prototype is None: raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") if not args.prototype and args.weights: raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") diff --git a/references/optical_flow/train.py b/references/optical_flow/train.py index d21fe9c2504..fdb9e4d8d7a 100644 --- a/references/optical_flow/train.py +++ b/references/optical_flow/train.py @@ -194,7 +194,7 @@ def train_one_epoch(model, optimizer, scheduler, train_loader, logger, args): def main(args): - if args.prototype and prototype.models is None: + if args.prototype and prototype is None: raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") if not args.prototype and args.weights: raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") diff --git a/references/segmentation/train.py b/references/segmentation/train.py index 83b861b5664..436dd491dca 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -100,7 +100,7 @@ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, devi def main(args): - if args.prototype and prototype.models is None: + if args.prototype and prototype is None: raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") if not args.prototype and args.weights: raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") diff --git a/references/video_classification/train.py b/references/video_classification/train.py index db93a7860bb..df8687ff6c2 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -96,7 +96,7 @@ def collate_fn(batch): def main(args): - if args.prototype and prototype.models is None: + if args.prototype and prototype is None: raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") if not args.prototype and args.weights: raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")