Skip to content

Commit

Permalink
Add per model crop pct, interpolation defaults, tie it all together
Browse files Browse the repository at this point in the history
* create one resolve fn to pull together model defaults + cmd line args
* update attribution comments in some models
* test update train/validation/inference scripts
  • Loading branch information
rwightman committed Apr 13, 2019
1 parent c328b15 commit 0562b91
Show file tree
Hide file tree
Showing 15 changed files with 173 additions and 69 deletions.
17 changes: 12 additions & 5 deletions data/loader.py
Expand Up @@ -23,20 +23,20 @@ def __init__(self,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD):
self.loader = loader
self.stream = torch.cuda.Stream()
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
if rand_erase_prob:
if rand_erase_prob > 0.:
self.random_erasing = RandomErasingTorch(
probability=rand_erase_prob, per_pixel=rand_erase_pp)
else:
self.random_erasing = None

def __iter__(self):
stream = torch.cuda.Stream()
first = True

for next_input, next_target in self.loader:
with torch.cuda.stream(self.stream):
with torch.cuda.stream(stream):
next_input = next_input.cuda(non_blocking=True)
next_target = next_target.cuda(non_blocking=True)
next_input = next_input.float().sub_(self.mean).div_(self.std)
Expand All @@ -48,7 +48,7 @@ def __iter__(self):
else:
first = False

torch.cuda.current_stream().wait_stream(self.stream)
torch.cuda.current_stream().wait_stream(stream)
input = next_input
target = next_target

Expand All @@ -64,28 +64,35 @@ def sampler(self):

def create_loader(
dataset,
img_size,
input_size,
batch_size,
is_training=False,
use_prefetcher=True,
rand_erase_prob=0.,
rand_erase_pp=False,
interpolation='bilinear',
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
num_workers=1,
distributed=False,
crop_pct=None,
):
if isinstance(input_size, tuple):
img_size = input_size[-2:]
else:
img_size = input_size

if is_training:
transform = transforms_imagenet_train(
img_size,
interpolation=interpolation,
use_prefetcher=use_prefetcher,
mean=mean,
std=std)
else:
transform = transforms_imagenet_eval(
img_size,
interpolation=interpolation,
use_prefetcher=use_prefetcher,
mean=mean,
std=std,
Expand Down
96 changes: 73 additions & 23 deletions data/transforms.py
Expand Up @@ -15,28 +15,66 @@
IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)


def get_mean_and_std(model, args, num_chan=3):
if hasattr(model, 'default_cfg'):
mean = model.default_cfg['mean']
std = model.default_cfg['std']
else:
if args.mean is not None:
mean = tuple(args.mean)
if len(mean) == 1:
mean = tuple(list(mean) * num_chan)
else:
assert len(mean) == num_chan
def resolve_data_config(model, args, default_cfg={}, verbose=True):
new_config = {}
default_cfg = default_cfg
if not default_cfg and hasattr(model, 'default_cfg'):
default_cfg = model.default_cfg

# Resolve input/image size
# FIXME grayscale/chans arg to use different # channels?
in_chans = 3
input_size = (in_chans, 224, 224)
if args.img_size is not None:
# FIXME support passing img_size as tuple, non-square
assert isinstance(args.img_size, int)
input_size = (in_chans, args.img_size, args.img_size)
elif 'input_size' in default_cfg:
input_size = default_cfg['input_size']
new_config['input_size'] = input_size

# resolve interpolation method
new_config['interpolation'] = 'bilinear'
if args.interpolation:
new_config['interpolation'] = args.interpolation
elif 'interpolation' in default_cfg:
new_config['interpolation'] = default_cfg['interpolation']

# resolve dataset + model mean for normalization
new_config['mean'] = get_mean_by_model(args.model)
if args.mean is not None:
mean = tuple(args.mean)
if len(mean) == 1:
mean = tuple(list(mean) * in_chans)
else:
mean = get_mean_by_model(args.model)
if args.std is not None:
std = tuple(args.std)
if len(std) == 1:
std = tuple(list(std) * num_chan)
else:
assert len(std) == num_chan
assert len(mean) == in_chans
new_config['mean'] = mean
elif 'mean' in default_cfg:
new_config['mean'] = default_cfg['mean']

# resolve dataset + model std deviation for normalization
new_config['std'] = get_std_by_model(args.model)
if args.std is not None:
std = tuple(args.std)
if len(std) == 1:
std = tuple(list(std) * in_chans)
else:
std = get_std_by_model(args.model)
return mean, std
assert len(std) == in_chans
new_config['std'] = std
else:
new_config['std'] = default_cfg['std']

# resolve default crop percentage
new_config['crop_pct'] = DEFAULT_CROP_PCT
if 'crop_pct' in default_cfg:
new_config['crop_pct'] = default_cfg['crop_pct']

if verbose:
print('Data processing configuration for current model + dataset:')
for n, v in new_config.items():
print('\t%s: %s' % (n, str(v)))

return new_config


def get_mean_by_name(name):
Expand Down Expand Up @@ -104,6 +142,7 @@ def transforms_imagenet_train(
img_size=224,
scale=(0.1, 1.0),
color_jitter=(0.4, 0.4, 0.4),
interpolation='bilinear',
random_erasing=0.4,
use_prefetcher=False,
mean=IMAGENET_DEFAULT_MEAN,
Expand All @@ -112,7 +151,8 @@ def transforms_imagenet_train(

tfl = [
transforms.RandomResizedCrop(
img_size, scale=scale, interpolation=Image.BICUBIC),
img_size, scale=scale,
interpolation=Image.BILINEAR if interpolation == 'bilinear' else Image.BICUBIC),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(*color_jitter),
]
Expand All @@ -135,14 +175,24 @@ def transforms_imagenet_train(
def transforms_imagenet_eval(
img_size=224,
crop_pct=None,
interpolation='bilinear',
use_prefetcher=False,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD):
crop_pct = crop_pct or DEFAULT_CROP_PCT
scale_size = int(math.floor(img_size / crop_pct))

if isinstance(img_size, tuple):
assert len(img_size) == 2
if img_size[0] == img_size[1]:
# fall-back to older behaviour so Resize scales to shortest edge if target is square
scale_size = int(math.floor(img_size[0] / crop_pct))
else:
scale_size = tuple([int(x[0] / crop_pct) for x in img_size])
else:
scale_size = int(math.floor(img_size / crop_pct))

tfl = [
transforms.Resize(scale_size, Image.BICUBIC),
transforms.Resize(scale_size, Image.BILINEAR if interpolation == 'bilinear' else Image.BICUBIC),
transforms.CenterCrop(img_size),
]
if use_prefetcher:
Expand Down
26 changes: 17 additions & 9 deletions inference.py
Expand Up @@ -12,7 +12,7 @@
import torch

from models import create_model, apply_test_time_pool
from data import Dataset, create_loader, get_mean_and_std
from data import Dataset, create_loader, resolve_data_config
from utils import AverageMeter

torch.backends.cudnn.benchmark = True
Expand All @@ -30,6 +30,12 @@
metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--img-size', default=224, type=int,
metavar='N', help='Input image dimension')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
help='Override std deviation of of dataset')
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
parser.add_argument('--num-classes', type=int, default=1000,
help='Number classes in dataset')
parser.add_argument('--print-freq', '-p', default=10, type=int,
Expand All @@ -40,8 +46,8 @@
help='use pre-trained model')
parser.add_argument('--num-gpu', type=int, default=1,
help='Number of GPUS to use')
parser.add_argument('--no-test-pool', dest='test_time_pool', action='store_false',
help='use pre-trained model')
parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true',
help='disable test time pool')


def main():
Expand All @@ -58,8 +64,8 @@ def main():
print('Model %s created, param count: %d' %
(args.model, sum([m.numel() for m in model.parameters()])))

data_mean, data_std = get_mean_and_std(model, args)
model, test_time_pool = apply_test_time_pool(model, args)
config = resolve_data_config(model, args)
model, test_time_pool = apply_test_time_pool(model, config, args)

if args.num_gpu > 1:
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
Expand All @@ -68,12 +74,14 @@ def main():

loader = create_loader(
Dataset(args.data),
img_size=args.img_size,
input_size=config['input_size'],
batch_size=args.batch_size,
use_prefetcher=True,
mean=data_mean,
std=data_std,
num_workers=args.workers)
interpolation=config['interpolation'],
mean=config['mean'],
std=config['std'],
num_workers=args.workers,
crop_pct=1.0 if test_time_pool else config['crop_pct'])

model.eval()

Expand Down
3 changes: 2 additions & 1 deletion models/densenet.py
@@ -1,4 +1,4 @@
"""Pytorch Densenet implementation tweaks
"""Pytorch Densenet implementation w/ tweaks
This file is a copy of https://github.com/pytorch/vision 'densenet.py' (BSD-3-Clause) with
fixed kwargs passthrough and addition of dynamic global avg/max pool.
"""
Expand All @@ -18,6 +18,7 @@
def _cfg(url=''):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 244), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bilinear',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'features.conv0', 'classifier': 'classifier',
}
Expand Down
1 change: 1 addition & 0 deletions models/dpn.py
Expand Up @@ -25,6 +25,7 @@
def _cfg(url=''):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_DPN_MEAN, 'std': IMAGENET_DPN_STD,
'first_conv': 'features.conv1_1.conv', 'classifier': 'classifier',
}
Expand Down
2 changes: 1 addition & 1 deletion models/helpers.py
Expand Up @@ -26,7 +26,6 @@ def load_checkpoint(model, checkpoint_path):


def resume_checkpoint(model, checkpoint_path, start_epoch=None):
start_epoch = 0 if start_epoch is None else start_epoch
optimizer_state = None
if os.path.isfile(checkpoint_path):
print("=> loading checkpoint '{}'".format(checkpoint_path))
Expand All @@ -46,6 +45,7 @@ def resume_checkpoint(model, checkpoint_path, start_epoch=None):
start_epoch = checkpoint['epoch'] if start_epoch is None else start_epoch
else:
model.load_state_dict(checkpoint)
start_epoch = 0 if start_epoch is None else start_epoch
return optimizer_state, start_epoch
else:
print("=> No checkpoint found at '{}'".format(checkpoint_path))
Expand Down
1 change: 1 addition & 0 deletions models/inception_resnet_v2.py
Expand Up @@ -14,6 +14,7 @@
'inception_resnet_v2': {
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth',
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
'crop_pct': 0.8975, 'interpolation': 'bicubic',
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': 'conv2d_1a.conv', 'classifier': 'last_linear',
}
Expand Down
1 change: 1 addition & 0 deletions models/inception_v4.py
Expand Up @@ -14,6 +14,7 @@
'inception_v4': {
'url': 'http://webia.lip6.fr/~cadene/Downloads/inceptionv4-97ef9c30.pth',
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': 'features.0.conv', 'classifier': 'classif',
}
Expand Down
10 changes: 9 additions & 1 deletion models/pnasnet.py
@@ -1,3 +1,10 @@
"""
pnasnet5large implementation grabbed from Cadene's pretrained models
Additional credit to https://github.com/creafz
https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/pnasnet.py
"""
from __future__ import print_function, division, absolute_import
from collections import OrderedDict

Expand All @@ -13,9 +20,10 @@
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/pnasnet5large-bf079911.pth',
'input_size': (3, 331, 331),
'pool_size': (11, 11),
'crop_pct': 0.875,
'interpolation': 'bicubic',
'mean': (0.5, 0.5, 0.5),
'std': (0.5, 0.5, 0.5),
'crop_pct': 0.8975,
'num_classes': 1001,
'first_conv': 'conv_0.conv',
'classifier': 'last_linear',
Expand Down
9 changes: 6 additions & 3 deletions models/resnet.py
@@ -1,6 +1,8 @@
"""Pytorch ResNet implementation tweaks
"""Pytorch ResNet implementation w/ tweaks
This file is a copy of https://github.com/pytorch/vision 'resnet.py' (BSD-3-Clause) with
additional dropout and dynamic global avg/max pool.
ResNext additions added by Ross Wightman
"""
import torch
import torch.nn as nn
Expand All @@ -18,7 +20,8 @@ def _cfg(url=''):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'crop_pct': 0.875,
'crop_pct': 0.875, 'interpolation': 'bilinear',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'conv1', 'classifier': 'fc',
}

Expand Down Expand Up @@ -271,7 +274,7 @@ def resnet152(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
def resnext50_32x4d(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
"""Constructs a ResNeXt50-32x4d model.
"""
default_cfg = default_cfgs['resnext50_32x4d2']
default_cfg = default_cfgs['resnext50_32x4d']
model = ResNet(
Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4,
num_classes=num_classes, in_chans=in_chans, **kwargs)
Expand Down
9 changes: 8 additions & 1 deletion models/senet.py
@@ -1,4 +1,10 @@
"""
SEResNet implementation from Cadene's pretrained models
https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/senet.py
Additional credit to https://github.com/creafz
Original model: https://github.com/hujie-frank/SENet
ResNet code gently borrowed from
https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
"""
Expand All @@ -20,7 +26,8 @@
def _cfg(url=''):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 244), 'pool_size': (7, 7),
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'crop_pct': 0.875,
'crop_pct': 0.875, 'interpolation': 'bilinear',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'layer0.conv1', 'classifier': 'last_linear',
}

Expand Down

0 comments on commit 0562b91

Please sign in to comment.