Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@


try:
from torchvision.prototype import models as PM
from torchvision import prototype
except ImportError:
PM = None
prototype = None


def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
Expand Down Expand Up @@ -154,13 +154,18 @@ def load_data(traindir, valdir, args):
print(f"Loading dataset_test from {cache_path}")
dataset_test, _ = torch.load(cache_path)
else:
if not args.weights:
if not args.prototype:
preprocessing = presets.ClassificationPresetEval(
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
)
else:
weights = PM.get_weight(args.weights)
preprocessing = weights.transforms()
if args.weights:
weights = prototype.models.get_weight(args.weights)
preprocessing = weights.transforms()
else:
preprocessing = prototype.transforms.ImageNetEval(
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
)

dataset_test = torchvision.datasets.ImageFolder(
valdir,
Expand All @@ -186,8 +191,10 @@ def load_data(traindir, valdir, args):


def main(args):
if args.weights and PM is None:
if args.prototype and prototype is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
if not args.prototype and args.weights:
raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
if args.output_dir:
utils.mkdir(args.output_dir)

Expand Down Expand Up @@ -229,10 +236,10 @@ def main(args):
)

print("Creating model")
if not args.weights:
if not args.prototype:
model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes)
else:
model = PM.__dict__[args.model](weights=args.weights, num_classes=num_classes)
model = prototype.models.__dict__[args.model](weights=args.weights, num_classes=num_classes)
model.to(device)

if args.distributed and args.sync_bn:
Expand Down Expand Up @@ -491,6 +498,12 @@ def get_args_parser(add_help=True):
)

# Prototype models only
parser.add_argument(
"--prototype",
dest="prototype",
help="Use prototype model builders instead those from main area",
action="store_true",
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")

return parser
Expand Down
27 changes: 19 additions & 8 deletions references/detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@


try:
from torchvision.prototype import models as PM
from torchvision import prototype
except ImportError:
PM = None
prototype = None


def get_dataset(name, image_set, transform, data_path):
Expand All @@ -50,11 +50,14 @@ def get_dataset(name, image_set, transform, data_path):
def get_transform(train, args):
if train:
return presets.DetectionPresetTrain(args.data_augmentation)
elif not args.weights:
elif not args.prototype:
return presets.DetectionPresetEval()
else:
weights = PM.get_weight(args.weights)
return weights.transforms()
if args.weights:
weights = prototype.models.get_weight(args.weights)
return weights.transforms()
else:
return prototype.transforms.CocoEval()


def get_args_parser(add_help=True):
Expand Down Expand Up @@ -141,6 +144,12 @@ def get_args_parser(add_help=True):
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")

# Prototype models only
parser.add_argument(
"--prototype",
dest="prototype",
help="Use prototype model builders instead those from main area",
action="store_true",
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")

# Mixed precision training parameters
Expand All @@ -150,8 +159,10 @@ def get_args_parser(add_help=True):


def main(args):
if args.weights and PM is None:
if args.prototype and prototype is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
if not args.prototype and args.weights:
raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
if args.output_dir:
utils.mkdir(args.output_dir)

Expand Down Expand Up @@ -193,12 +204,12 @@ def main(args):
if "rcnn" in args.model:
if args.rpn_score_thresh is not None:
kwargs["rpn_score_thresh"] = args.rpn_score_thresh
if not args.weights:
if not args.prototype:
model = torchvision.models.detection.__dict__[args.model](
pretrained=args.pretrained, num_classes=num_classes, **kwargs
)
else:
model = PM.detection.__dict__[args.model](weights=args.weights, num_classes=num_classes, **kwargs)
model = prototype.models.detection.__dict__[args.model](weights=args.weights, num_classes=num_classes, **kwargs)
model.to(device)
if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
Expand Down
32 changes: 23 additions & 9 deletions references/optical_flow/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@
from torchvision.datasets import KittiFlow, FlyingChairs, FlyingThings3D, Sintel, HD1K

try:
from torchvision.prototype import models as PM
from torchvision.prototype.models import optical_flow as PMOF
from torchvision import prototype
except ImportError:
PM = PMOF = None
prototype = None


def get_train_dataset(stage, dataset_root):
Expand Down Expand Up @@ -133,9 +132,12 @@ def inner_loop(blob):
def validate(model, args):
val_datasets = args.val_dataset or []

if args.weights:
weights = PM.get_weight(args.weights)
preprocessing = weights.transforms()
if args.prototype:
if args.weights:
weights = prototype.models.get_weight(args.weights)
preprocessing = weights.transforms()
else:
preprocessing = prototype.transforms.RaftEval()
else:
preprocessing = OpticalFlowPresetEval()

Expand Down Expand Up @@ -192,10 +194,14 @@ def train_one_epoch(model, optimizer, scheduler, train_loader, logger, args):


def main(args):
if args.prototype and prototype is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
if not args.prototype and args.weights:
raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
utils.setup_ddp(args)

if args.weights:
model = PMOF.__dict__[args.model](weights=args.weights)
if args.prototype:
model = prototype.models.optical_flow.__dict__[args.model](weights=args.weights)
else:
model = torchvision.models.optical_flow.__dict__[args.model](pretrained=args.pretrained)

Expand Down Expand Up @@ -317,7 +323,6 @@ def get_args_parser(add_help=True):
)
# TODO: resume, pretrained, and weights should be in an exclusive arg group
parser.add_argument("--pretrained", action="store_true", help="Whether to use pretrained weights")
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load.")

parser.add_argument(
"--num_flow_updates",
Expand All @@ -336,6 +341,15 @@ def get_args_parser(add_help=True):
required=True,
)

# Prototype models only
parser.add_argument(
"--prototype",
dest="prototype",
help="Use prototype model builders instead those from main area",
action="store_true",
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load.")

return parser


Expand Down
27 changes: 19 additions & 8 deletions references/segmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@


try:
from torchvision.prototype import models as PM
from torchvision import prototype
except ImportError:
PM = None
prototype = None


def get_dataset(dir_path, name, image_set, transform):
Expand All @@ -35,11 +35,14 @@ def sbd(*args, **kwargs):
def get_transform(train, args):
if train:
return presets.SegmentationPresetTrain(base_size=520, crop_size=480)
elif not args.weights:
elif not args.prototype:
return presets.SegmentationPresetEval(base_size=520)
else:
weights = PM.get_weight(args.weights)
return weights.transforms()
if args.weights:
weights = prototype.models.get_weight(args.weights)
return weights.transforms()
else:
return prototype.transforms.VocEval(resize_size=520)


def criterion(inputs, target):
Expand Down Expand Up @@ -97,8 +100,10 @@ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, devi


def main(args):
if args.weights and PM is None:
if args.prototype and prototype is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
if not args.prototype and args.weights:
raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
if args.output_dir:
utils.mkdir(args.output_dir)

Expand Down Expand Up @@ -130,14 +135,14 @@ def main(args):
dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
)

if not args.weights:
if not args.prototype:
model = torchvision.models.segmentation.__dict__[args.model](
pretrained=args.pretrained,
num_classes=num_classes,
aux_loss=args.aux_loss,
)
else:
model = PM.segmentation.__dict__[args.model](
model = prototype.models.segmentation.__dict__[args.model](
weights=args.weights, num_classes=num_classes, aux_loss=args.aux_loss
)
model.to(device)
Expand Down Expand Up @@ -278,6 +283,12 @@ def get_args_parser(add_help=True):
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")

# Prototype models only
parser.add_argument(
"--prototype",
dest="prototype",
help="Use prototype model builders instead those from main area",
action="store_true",
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")

# Mixed precision training parameters
Expand Down
30 changes: 20 additions & 10 deletions references/video_classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler

try:
from torchvision.prototype import models as PM
from torchvision import prototype
except ImportError:
PM = None
prototype = None


def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, scaler=None):
Expand Down Expand Up @@ -96,9 +96,10 @@ def collate_fn(batch):


def main(args):
if args.weights and PM is None:
if args.prototype and prototype is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")

if not args.prototype and args.weights:
raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
if args.output_dir:
utils.mkdir(args.output_dir)

Expand Down Expand Up @@ -149,11 +150,14 @@ def main(args):
print("Loading validation data")
cache_path = _get_cache_path(valdir)

if not args.weights:
transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112))
if not args.prototype:
transform_test = presets.VideoClassificationPresetEval(resize_size=(128, 171), crop_size=(112, 112))
else:
weights = PM.get_weight(args.weights)
transform_test = weights.transforms()
if args.weights:
weights = prototype.models.get_weight(args.weights)
transform_test = weights.transforms()
else:
transform_test = prototype.transforms.Kinect400Eval(crop_size=(112, 112), resize_size=(128, 171))

if args.cache_dataset and os.path.exists(cache_path):
print(f"Loading dataset_test from {cache_path}")
Expand Down Expand Up @@ -204,10 +208,10 @@ def main(args):
)

print("Creating model")
if not args.weights:
if not args.prototype:
model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained)
else:
model = PM.video.__dict__[args.model](weights=args.weights)
model = prototype.models.video.__dict__[args.model](weights=args.weights)
model.to(device)
if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
Expand Down Expand Up @@ -360,6 +364,12 @@ def parse_args():
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")

# Prototype models only
parser.add_argument(
"--prototype",
dest="prototype",
help="Use prototype model builders instead those from main area",
action="store_true",
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")

# Mixed precision training parameters
Expand Down