Permalink
Browse files

Move ImageNet models together

  • Loading branch information...
ppwwyyxx committed Mar 8, 2018
1 parent e9a4df1 commit 0ba89131d6a21da62f8535d760a9fb1b81eadf98
@@ -0,0 +1,35 @@
ImageNet training code of ResNet, Inception, VGG, ShuffleNet, DoReFa-Net with tensorpack.
To train any of the models, just do `./{model}.py --data /path/to/ilsvrc`.
Expected format of data directory is described in [docs](http://tensorpack.readthedocs.io/en/latest/modules/dataflow.dataset.html#tensorpack.dataflow.dataset.ILSVRC12).
Pretrained models can be downloaded at [tensorpack model zoo](http://models.tensorpack.com/).
### ShuffleNet
Reproduce [ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices](https://arxiv.org/abs/1707.01083)
on ImageNet.
This is a 38Mflops ShuffleNet, corresponding to `ShuffleNet 0.5x g=3` in [version 2](https://arxiv.org/pdf/1707.01083v2) of the paper.
After 240 epochs (36 hours on 8 P100s) it reaches top-1 error of 42.32%, better than the paper's number.
To print flops:
```bash
./shufflenet.py --flops
```
It will print about 75Mflops, because the paper counts multiply+add as 1 flop.
Evaluate the [pretrained model](http://models.tensorpack.com/ShuffleNet/):
```
./shufflenet.py --eval --data /path/to/ilsvrc --load /path/to/model
```
### Inception-BN, VGG16
This Inception-BN script reaches 27% single-crop error after 300k steps with 6 GPUs.
This VGG16 script reaches 28.8% single-crop error after 100 epochs.
### ResNet, DoReFa-Net
See [ResNet examples](../ResNet) and [DoReFa-Net examples](../DoReFa-Net).
@@ -0,0 +1,239 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: imagenet_utils.py
import cv2
import numpy as np
import multiprocessing
import tensorflow as tf
from abc import abstractmethod
from tensorpack import imgaug, dataset, ModelDesc, InputDesc
from tensorpack.dataflow import (
AugmentImageComponent, PrefetchDataZMQ,
BatchData, MultiThreadMapData)
from tensorpack.predict import PredictConfig, SimpleDatasetPredictor
from tensorpack.utils.stats import RatioCounter
from tensorpack.models import regularize_cost
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils import logger
class GoogleNetResize(imgaug.ImageAugmentor):
"""
crop 8%~100% of the original image
See `Going Deeper with Convolutions` by Google.
"""
def __init__(self, crop_area_fraction=0.08,
aspect_ratio_low=0.75, aspect_ratio_high=1.333,
target_shape=224):
self._init(locals())
def _augment(self, img, _):
h, w = img.shape[:2]
area = h * w
for _ in range(10):
targetArea = self.rng.uniform(self.crop_area_fraction, 1.0) * area
aspectR = self.rng.uniform(self.aspect_ratio_low, self.aspect_ratio_high)
ww = int(np.sqrt(targetArea * aspectR) + 0.5)
hh = int(np.sqrt(targetArea / aspectR) + 0.5)
if self.rng.uniform() < 0.5:
ww, hh = hh, ww
if hh <= h and ww <= w:
x1 = 0 if w == ww else self.rng.randint(0, w - ww)
y1 = 0 if h == hh else self.rng.randint(0, h - hh)
out = img[y1:y1 + hh, x1:x1 + ww]
out = cv2.resize(out, (self.target_shape, self.target_shape), interpolation=cv2.INTER_CUBIC)
return out
out = imgaug.ResizeShortestEdge(self.target_shape, interp=cv2.INTER_CUBIC).augment(img)
out = imgaug.CenterCrop(self.target_shape).augment(out)
return out
def fbresnet_augmentor(isTrain):
"""
Augmentor used in fb.resnet.torch, for BGR images in range [0,255].
"""
if isTrain:
augmentors = [
GoogleNetResize(),
imgaug.RandomOrderAug( # Remove these augs if your CPU is not fast enough
[imgaug.BrightnessScale((0.6, 1.4), clip=False),
imgaug.Contrast((0.6, 1.4), clip=False),
imgaug.Saturation(0.4, rgb=False),
# rgb-bgr conversion for the constants copied from fb.resnet.torch
imgaug.Lighting(0.1,
eigval=np.asarray(
[0.2175, 0.0188, 0.0045][::-1]) * 255.0,
eigvec=np.array(
[[-0.5675, 0.7192, 0.4009],
[-0.5808, -0.0045, -0.8140],
[-0.5836, -0.6948, 0.4203]],
dtype='float32')[::-1, ::-1]
)]),
imgaug.Flip(horiz=True),
]
else:
augmentors = [
imgaug.ResizeShortestEdge(256, cv2.INTER_CUBIC),
imgaug.CenterCrop((224, 224)),
]
return augmentors
def get_imagenet_dataflow(
datadir, name, batch_size,
augmentors, parallel=None):
"""
See explanations in the tutorial:
http://tensorpack.readthedocs.io/en/latest/tutorial/efficient-dataflow.html
"""
assert name in ['train', 'val', 'test']
assert datadir is not None
assert isinstance(augmentors, list)
isTrain = name == 'train'
if parallel is None:
parallel = min(40, multiprocessing.cpu_count() // 2) # assuming hyperthreading
if isTrain:
ds = dataset.ILSVRC12(datadir, name, shuffle=True)
ds = AugmentImageComponent(ds, augmentors, copy=False)
if parallel < 16:
logger.warn("DataFlow may become the bottleneck when too few processes are used.")
ds = PrefetchDataZMQ(ds, parallel)
ds = BatchData(ds, batch_size, remainder=False)
else:
ds = dataset.ILSVRC12Files(datadir, name, shuffle=False)
aug = imgaug.AugmentorList(augmentors)
def mapf(dp):
fname, cls = dp
im = cv2.imread(fname, cv2.IMREAD_COLOR)
im = aug.augment(im)
return im, cls
ds = MultiThreadMapData(ds, parallel, mapf, buffer_size=2000, strict=True)
ds = BatchData(ds, batch_size, remainder=True)
ds = PrefetchDataZMQ(ds, 1)
return ds
def eval_on_ILSVRC12(model, sessinit, dataflow):
pred_config = PredictConfig(
model=model,
session_init=sessinit,
input_names=['input', 'label'],
output_names=['wrong-top1', 'wrong-top5']
)
pred = SimpleDatasetPredictor(pred_config, dataflow)
acc1, acc5 = RatioCounter(), RatioCounter()
for top1, top5 in pred.get_result():
batch_size = top1.shape[0]
acc1.feed(top1.sum(), batch_size)
acc5.feed(top5.sum(), batch_size)
print("Top1 Error: {}".format(acc1.ratio))
print("Top5 Error: {}".format(acc5.ratio))
class ImageNetModel(ModelDesc):
weight_decay = 1e-4
image_shape = 224
"""
uint8 instead of float32 is used as input type to reduce copy overhead.
It might hurt the performance a liiiitle bit.
The pretrained models were trained with float32.
"""
image_dtype = tf.uint8
def __init__(self, data_format='NCHW'):
self.data_format = data_format
def _get_inputs(self):
return [InputDesc(self.image_dtype, [None, self.image_shape, self.image_shape, 3], 'input'),
InputDesc(tf.int32, [None], 'label')]
def _build_graph(self, inputs):
image, label = inputs
image = ImageNetModel.image_preprocess(image, bgr=True)
if self.data_format == 'NCHW':
image = tf.transpose(image, [0, 3, 1, 2])
logits = self.get_logits(image)
loss = ImageNetModel.compute_loss_and_error(logits, label)
if self.weight_decay > 0:
wd_loss = regularize_cost('.*/W', tf.contrib.layers.l2_regularizer(self.weight_decay),
name='l2_regularize_loss')
add_moving_summary(loss, wd_loss)
self.cost = tf.add_n([loss, wd_loss], name='cost')
else:
self.cost = tf.identity(loss, name='cost')
add_moving_summary(self.cost)
@abstractmethod
def get_logits(self, image):
"""
Args:
image: 4D tensor of 224x224 in ``self.data_format``
Returns:
Nx1000 logits
"""
def _get_optimizer(self):
lr = tf.get_variable('learning_rate', initializer=0.1, trainable=False)
tf.summary.scalar('learning_rate-summary', lr)
return tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True)
@staticmethod
def image_preprocess(image, bgr=True):
with tf.name_scope('image_preprocess'):
if image.dtype.base_dtype != tf.float32:
image = tf.cast(image, tf.float32)
image = image * (1.0 / 255)
mean = [0.485, 0.456, 0.406] # rgb
std = [0.229, 0.224, 0.225]
if bgr:
mean = mean[::-1]
std = std[::-1]
image_mean = tf.constant(mean, dtype=tf.float32)
image_std = tf.constant(std, dtype=tf.float32)
image = (image - image_mean) / image_std
return image
@staticmethod
def compute_loss_and_error(logits, label):
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
loss = tf.reduce_mean(loss, name='xentropy-loss')
def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'):
with tf.name_scope('prediction_incorrect'):
x = tf.logical_not(tf.nn.in_top_k(logits, label, topk))
return tf.cast(x, tf.float32, name=name)
wrong = prediction_incorrect(logits, label, 1, name='wrong-top1')
add_moving_summary(tf.reduce_mean(wrong, name='train-error-top1'))
wrong = prediction_incorrect(logits, label, 5, name='wrong-top5')
add_moving_summary(tf.reduce_mean(wrong, name='train-error-top5'))
return loss
if __name__ == '__main__':
import argparse
from tensorpack.dataflow import TestDataSpeed
parser = argparse.ArgumentParser()
parser.add_argument('--data', required=True)
parser.add_argument('--batch', type=int, default=32)
parser.add_argument('--aug', choices=['train', 'val'], default='val')
args = parser.parse_args()
if args.aug == 'val':
augs = fbresnet_augmentor(False)
elif args.aug == 'train':
augs = fbresnet_augmentor(True)
df = get_imagenet_dataflow(
args.data, 'train', args.batch, augs)
# For val augmentor, Should get >100 it/s (i.e. 3k im/s) here on a decent E5 server.
TestDataSpeed(df).start()
@@ -22,14 +22,6 @@
BATCH_SIZE = TOTAL_BATCH_SIZE // NR_GPU
INPUT_SHAPE = 224
"""
Inception-BN model on ILSVRC12.
See "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift", arxiv:1502.03167
This config reaches 73% single-crop validation accuracy after 300k steps with 6 GPUs.
"""
class Model(ModelDesc):
def _get_inputs(self):
return [InputDesc(tf.float32, [None, INPUT_SHAPE, INPUT_SHAPE, 3], 'input'),
Oops, something went wrong.

0 comments on commit 0ba8913

Please sign in to comment.