diff --git a/references/classification/README.md b/references/classification/README.md index 04db3837016..da5cd98867d 100644 --- a/references/classification/README.md +++ b/references/classification/README.md @@ -248,7 +248,7 @@ Note that `--val-resize-size` was optimized in a post-training step, see their ` ### MaxViT ``` torchrun --nproc_per_node=8 --n_nodes=4 train.py\ ---model $MODEL --epochs 400 --batch-size 128 --opt adamw --lr 3e-3 --weight-decay 0.05 --lr-scheduler cosineannealinglr --lr-min 1e-5 --lr-warmup-method linear --lr-warmup-epochs 32 --label-smoothing 0.1 --mixup-alpha 0.8 --clip-grad-norm 1.0 --interpolation bicubic --auto-augment ta_wide --policy-magnitude 15 --train-center-crop --model-ema --val-resize-size 224 +--model $MODEL --epochs 400 --batch-size 128 --opt adamw --lr 3e-3 --weight-decay 0.05 --lr-scheduler cosineannealinglr --lr-min 1e-5 --lr-warmup-method linear --lr-warmup-epochs 32 --label-smoothing 0.1 --mixup-alpha 0.8 --clip-grad-norm 1.0 --interpolation bicubic --auto-augment ta_wide --policy-magnitude 15 --model-ema --val-resize-size 224\ --val-crop-size 224 --train-crop-size 224 --amp --model-ema-steps 32 --transformer-embedding-decay 0 --sync-bn ``` Here `$MODEL` is `maxvit_t`. diff --git a/references/classification/presets.py b/references/classification/presets.py index c6028a3417b..5d1bf1cc714 100644 --- a/references/classification/presets.py +++ b/references/classification/presets.py @@ -16,13 +16,8 @@ def __init__( ra_magnitude=9, augmix_severity=3, random_erase_prob=0.0, - center_crop=False, ): - trans = ( - [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] - if center_crop - else [transforms.CenterCrop(crop_size)] - ) + trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] if hflip_prob > 0: trans.append(transforms.RandomHorizontalFlip(hflip_prob)) if auto_augment_policy is not None: diff --git a/references/classification/train.py b/references/classification/train.py index f359739b113..00af6301831 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -113,11 +113,10 @@ def _get_cache_path(filepath): def load_data(traindir, valdir, args): # Data loading code print("Loading data") - val_resize_size, val_crop_size, train_crop_size, center_crop = ( + val_resize_size, val_crop_size, train_crop_size = ( args.val_resize_size, args.val_crop_size, args.train_crop_size, - args.train_center_crop, ) interpolation = InterpolationMode(args.interpolation) @@ -136,7 +135,6 @@ def load_data(traindir, valdir, args): dataset = torchvision.datasets.ImageFolder( traindir, presets.ClassificationPresetTrain( - center_crop=center_crop, crop_size=train_crop_size, interpolation=interpolation, auto_augment_policy=auto_augment_policy, @@ -501,11 +499,6 @@ def get_args_parser(add_help=True): parser.add_argument( "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)" ) - parser.add_argument( - "--train-center-crop", - action="store_true", - help="use center crop instead of random crop for training (default: False)", - ) parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)") parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training") parser.add_argument(