Skip to content
Permalink
Browse files

Rename eval_on_ILSVRC12->eval_classification; fix #1194

  • Loading branch information...
ppwwyyxx committed May 17, 2019
1 parent 1e9342a commit 0831fe9df2a44a9da9f4c120487fd05495aacc67
@@ -18,7 +18,7 @@
from tensorpack.utils.gpu import get_num_gpu

from dorefa import get_dorefa, ternarize
from imagenet_utils import ImageNetModel, eval_on_ILSVRC12, fbresnet_augmentor, get_imagenet_dataflow
from imagenet_utils import ImageNetModel, eval_classification, fbresnet_augmentor, get_imagenet_dataflow

"""
This is a tensorpack script for the ImageNet results in paper:
@@ -219,7 +219,7 @@ def run_image(model, sess_init, inputs):
if args.eval:
BATCH_SIZE = 128
ds = get_data('val')
eval_on_ILSVRC12(Model(), get_model_loader(args.load), ds)
eval_classification(Model(), get_model_loader(args.load), ds)
sys.exit()

nr_tower = max(get_num_gpu(), 1)
@@ -13,7 +13,7 @@
from tensorpack.tfutils.varreplace import remap_variables

from dorefa import get_dorefa
from imagenet_utils import ImageNetModel, eval_on_ILSVRC12, fbresnet_augmentor
from imagenet_utils import ImageNetModel, eval_classification, fbresnet_augmentor

"""
This script loads the pre-trained ResNet-18 model with (W,A,G) = (1,4,32)
@@ -163,7 +163,7 @@ def run_image(model, sess_init, inputs):
ds = dataset.ILSVRC12(args.data, 'val', shuffle=False)
ds = AugmentImageComponent(ds, get_inference_augmentor())
ds = BatchData(ds, 192, remainder=True)
eval_on_ILSVRC12(Model(), get_model_loader(args.load), ds)
eval_classification(Model(), get_model_loader(args.load), ds)
elif args.run:
assert args.load.endswith('.npz')
run_image(Model(), DictRestore(dict(np.load(args.load))), args.run)
@@ -264,7 +264,11 @@ def bad():
"""


def eval_on_ILSVRC12(model, sessinit, dataflow):
def eval_classification(model, sessinit, dataflow):
"""
Eval a classification model on the dataset. It assumes the model inputs are
named "input" and "label", and contains "wrong-top1" and "wrong-top5" in the graph.
"""
pred_config = PredictConfig(
model=model,
session_init=sessinit,
@@ -16,7 +16,7 @@
from tensorpack.utils import logger
from tensorpack.utils.gpu import get_num_gpu

from imagenet_utils import ImageNetModel, eval_on_ILSVRC12, get_imagenet_dataflow
from imagenet_utils import ImageNetModel, eval_classification, get_imagenet_dataflow


@layer_register(log_shape=True)
@@ -251,7 +251,7 @@ def get_config(model, nr_tower):
if args.eval:
batch = 128 # something that can run on one gpu
ds = get_data('val', batch)
eval_on_ILSVRC12(model, get_model_loader(args.load), ds)
eval_classification(model, get_model_loader(args.load), ds)
elif args.flops:
# manually build the graph with batch=1
with TowerContext('', is_training=False):
@@ -13,7 +13,7 @@
from tensorpack.train import SyncMultiGPUTrainerReplicated, TrainConfig, launch_train_with_config
from tensorpack.utils.gpu import get_num_gpu

from imagenet_utils import ImageNetModel, eval_on_ILSVRC12, get_imagenet_dataflow, get_imagenet_tfdata
from imagenet_utils import ImageNetModel, eval_classification, get_imagenet_dataflow, get_imagenet_tfdata
from resnet_model import (
preresnet_basicblock, preresnet_bottleneck, preresnet_group,
resnet_backbone, resnet_group,
@@ -143,7 +143,7 @@ def get_config(model):
if args.eval:
batch = 128 # something that can run on one gpu
ds = get_imagenet_dataflow(args.data, 'val', batch)
eval_on_ILSVRC12(model, get_model_loader(args.load), ds)
eval_classification(model, get_model_loader(args.load), ds)
else:
if args.fake:
logger.set_logger_dir(os.path.join('train_log', 'tmp'), 'd')
@@ -16,7 +16,7 @@
from tensorpack.dataflow.dataset import ILSVRCMeta
from tensorpack.utils import logger

from imagenet_utils import ImageNetModel, eval_on_ILSVRC12, get_imagenet_dataflow
from imagenet_utils import ImageNetModel, eval_classification, get_imagenet_dataflow
from resnet_model import resnet_bottleneck, resnet_group

DEPTH = None
@@ -172,6 +172,6 @@ def convert_param_name(param):

if args.eval:
ds = get_imagenet_dataflow(args.eval, 'val', 128, get_inference_augmentor())
eval_on_ILSVRC12(Model(), DictRestore(param), ds)
eval_classification(Model(), DictRestore(param), ds)
elif args.input:
run_test(param, args.input)
@@ -98,7 +98,7 @@ def clear_tower0_name_scope():
# NOTE: ctx.is_training won't be useful inside model,
# because inference will always use the cached Keras model
model = self.cached_model
outputs = model.call(input_tensors)
outputs = model.call(*input_tensors)
else:
# create new Keras model if not reuse
model = self.get_model(*input_tensors)
@@ -45,7 +45,7 @@ def catch(self):


class ImageFromFile(RNGDataFlow):
""" Produce images read from a list of files. """
""" Produce images read from a list of files as (h, w, c) arrays. """
def __init__(self, files, channel=3, resize=None, shuffle=False):
"""
Args:

0 comments on commit 0831fe9

Please sign in to comment.
You can’t perform that action at this time.