Skip to content

Commit

Permalink
single-pass inference
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Apr 19, 2016
1 parent 76fe1b6 commit 174c3fc
Show file tree
Hide file tree
Showing 8 changed files with 206 additions and 142 deletions.
3 changes: 2 additions & 1 deletion examples/ResNet/cifar10_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ def get_config():
callbacks=Callbacks([
StatPrinter(),
ModelSaver(),
ClassificationError(dataset_test, prefix='validation'),
InferenceRunner(dataset_test,
[ScalarStats('cost'), ClassificationError()]),
ScheduledHyperParamSetter('learning_rate',
[(1, 0.1), (82, 0.01), (123, 0.001), (300, 0.0002)])
]),
Expand Down
3 changes: 2 additions & 1 deletion examples/ResNet/svhn_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ def get_config():
callbacks=Callbacks([
StatPrinter(),
ModelSaver(),
ClassificationError(dataset_test, prefix='validation'),
InferenceRunner(dataset_test,
[ScalarStats('cost'), ClassificationError() ]),
ScheduledHyperParamSetter('learning_rate',
[(1, 0.1), (20, 0.01), (33, 0.001), (60, 0.0001)])
]),
Expand Down
4 changes: 2 additions & 2 deletions examples/cifar10_convnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,12 @@ def get_config():
callbacks=Callbacks([
StatPrinter(),
ModelSaver(),
ClassificationError(dataset_test, prefix='test'),
InferenceRunner(dataset_test, ClassificationError())
]),
session_config=sess_config,
model=Model(),
step_per_epoch=step_per_epoch,
max_epoch=20,
max_epoch=300,
)

if __name__ == '__main__':
Expand Down
6 changes: 3 additions & 3 deletions examples/mnist_convnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _get_cost(self, input_vars, is_training):
name='regularize_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost)

add_param_summary([('.*/W', ['histogram', 'sparsity'])]) # monitor histogram of all W
add_param_summary([('.*/W', ['histogram'])]) # monitor histogram of all W
return tf.add_n([wd_cost, cost], name='cost')

def get_config():
Expand Down Expand Up @@ -102,8 +102,8 @@ def get_config():
callbacks=Callbacks([
StatPrinter(),
ModelSaver(),
ValidationStatPrinter(dataset_test, ['cost:0']),
ClassificationError(dataset_test, prefix='validation'),
InferenceRunner(dataset_test,
[ScalarStats('cost'), ClassificationError() ])
]),
session_config=sess_config,
model=Model(),
Expand Down
3 changes: 2 additions & 1 deletion examples/svhn_digit_convnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ def get_config():
callbacks=Callbacks([
StatPrinter(),
ModelSaver(),
ClassificationError(test, prefix='test'),
InferenceRunner(dataset_test,
[ScalarStats('cost'), ClassificationError()])
]),
session_config=sess_config,
model=Model(),
Expand Down
195 changes: 195 additions & 0 deletions tensorpack/callbacks/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# -*- coding: UTF-8 -*-
# File: inference.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>

import tensorflow as tf
from tqdm import tqdm
from abc import ABCMeta, abstractmethod
from six.moves import zip

from ..dataflow import DataFlow
from ..utils import *
from ..utils.stat import *
from ..tfutils import *
from ..tfutils.summary import *
from .base import Callback, TestCallbackType

__all__ = ['InferenceRunner', 'ClassificationError',
'ScalarStats', 'Inferencer']

class Inferencer(object):
__metaclass__ = ABCMeta

def before_inference(self):
"""
Called before a new round of inference starts.
"""
self._before_inference()

def _before_inference(self):
pass

def datapoint(self, dp, output):
"""
Called after complete running every data point
"""
self._datapoint(dp, output)

@abstractmethod
def _datapoint(self, dp, output):
pass

def after_inference(self):
"""
Called after a round of inference ends.
"""
self._after_inference()

def _after_inference(self):
pass

def get_output_tensors(self):
"""
Return a list of tensor names needed for this inference
"""
return self._get_output_vars()

@abstractmethod
def _get_output_tensors(self):
pass

class InferenceRunner(Callback):
"""
A callback that runs different kinds of inferencer.
"""
type = TestCallbackType()

def __init__(self, ds, vcs):
"""
:param ds: inference dataset. a `DataFlow` instance.
:param vcs: a list of `Inferencer` instance.
"""
assert isinstance(ds, DataFlow), type(ds)
self.ds = ds
if not isinstance(vcs, list):
self.vcs = [vcs]
else:
self.vcs = vcs
for v in self.vcs:
assert isinstance(v, Inferencer), str(v)

def _before_train(self):
self.input_vars = self.trainer.model.reuse_input_vars()
self._find_output_tensors()
for v in self.vcs:
v.trainer = self.trainer

def _find_output_tensors(self):
self.output_tensors = []
self.vc_to_vars = []
for vc in self.vcs:
vc_vars = vc._get_output_tensors()
def find_oid(var):
if var in self.output_tensors:
return self.output_tensors.index(var)
else:
self.output_tensors.append(var)
return len(self.output_tensors) - 1
vc_vars = [(var, find_oid(var)) for var in vc_vars]
self.vc_to_vars.append(vc_vars)

# convert name to tensors
def get_tensor(name):
_, varname = get_op_var_name(name)
return self.graph.get_tensor_by_name(varname)
self.output_tensors = map(get_tensor, self.output_tensors)

def _trigger_epoch(self):
for vc in self.vcs:
vc.before_inference()

sess = tf.get_default_session()
with tqdm(total=self.ds.size(), ascii=True) as pbar:
for dp in self.ds.get_data():
feed = dict(zip(self.input_vars, dp)) # TODO custom dp mapping?
outputs = sess.run(self.output_tensors, feed_dict=feed)
for vc, varsmap in zip(self.vcs, self.vc_to_vars):
vc_output = [outputs[k[1]] for k in varsmap]
vc.datapoint(dp, vc_output)
pbar.update()

for vc in self.vcs:
vc.after_inference()

class ScalarStats(Inferencer):
"""
Write stat and summary of some scalar tensor.
The output of the given Ops must be a scalar.
The value will be averaged over all data points in the dataset.
"""
def __init__(self, names_to_print, prefix='validation'):
"""
:param names_to_print: list of names of tensors, or just a name
:param prefix: an optional prefix for logging
"""
if not isinstance(names_to_print, list):
self.names = [names_to_print]
else:
self.names = names_to_print
self.prefix = prefix

def _get_output_tensors(self):
return self.names

def _before_inference(self):
self.stats = []

def _datapoint(self, dp, output):
self.stats.append(output)

def _after_inference(self):
self.stats = np.mean(self.stats, axis=0)
assert len(self.stats) == len(self.names)

for stat, name in zip(self.stats, self.names):
opname, _ = get_op_var_name(name)
name = '{}_{}'.format(self.prefix, opname) if self.prefix else opname
self.trainer.summary_writer.add_summary(
create_summary(name, stat), get_global_step())
self.trainer.stat_holder.add_stat(name, stat)

class ClassificationError(Inferencer):
"""
Validate the accuracy from a `wrong` variable
The `wrong` variable is supposed to be an integer equal to the number of failed samples in this batch
This callback produce the "true" error,
taking account of the fact that batches might not have the same size in
testing (because the size of test set might not be a multiple of batch size).
In theory, the result could be different from what produced by ValidationStatPrinter.
"""
def __init__(self, wrong_var_name='wrong:0', prefix='validation'):
"""
:param wrong_var_name: name of the `wrong` variable
:param prefix: an optional prefix for logging
"""
self.wrong_var_name = wrong_var_name
self.prefix = prefix

def _get_output_tensors(self):
return [self.wrong_var_name]

def _before_inference(self):
self.err_stat = Accuracy()

def _datapoint(self, dp, outputs):
batch_size = dp[0].shape[0] # assume batched input
wrong = int(outputs[0])
self.err_stat.feed(wrong, batch_size)

def _after_inference(self):
self.trainer.summary_writer.add_summary(
create_summary('{}_error'.format(self.prefix), self.err_stat.accuracy),
get_global_step())
self.trainer.stat_holder.add_stat("{}_error".format(self.prefix), self.err_stat.accuracy)
133 changes: 0 additions & 133 deletions tensorpack/callbacks/validation_callback.py

This file was deleted.

0 comments on commit 174c3fc

Please sign in to comment.