Skip to content

Commit

Permalink
validation callback printer
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Feb 26, 2016
1 parent 9fe18ff commit 80622ae
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 39 deletions.
3 changes: 2 additions & 1 deletion example_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

"""
MNIST ConvNet example.
99.3% validation accuracy after 50 epochs.
99.25% validation accuracy after 50 epochs.
"""

BATCH_SIZE = 128
Expand Down Expand Up @@ -107,6 +107,7 @@ def get_config():
callbacks=Callbacks([
StatPrinter(),
PeriodicSaver(),
ValidationStatPrinter(dataset_test, ['cost:0']),
ValidationError(dataset_test, prefix='validation'),
]),
session_config=sess_config,
Expand Down
83 changes: 49 additions & 34 deletions tensorpack/callbacks/validation_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,87 +6,102 @@
import tensorflow as tf
import itertools
from tqdm import tqdm
from abc import ABCMeta

from ..utils import *
from ..utils.stat import *
from ..utils.summary import *
from .base import PeriodicCallback, Callback, TestCallback

__all__ = ['ValidationError', 'ValidationCallback']
__all__ = ['ValidationError', 'ValidationCallback', 'ValidationStatPrinter']

class ValidationCallback(PeriodicCallback):
type = TestCallback()
"""
Basic routine for validation callbacks.
Base class for validation callbacks.
"""
def __init__(self, ds, prefix, period=1, cost_var_name='cost:0'):
def __init__(self, ds, prefix, period=1):
super(ValidationCallback, self).__init__(period)
self.ds = ds
self.prefix = prefix
self.cost_var_name = cost_var_name

def _before_train(self):
self.input_vars = self.trainer.model.reuse_input_vars()
self.cost_var = self.get_tensor(self.cost_var_name)
self._find_output_vars()

def get_tensor(self, name):
return self.graph.get_tensor_by_name(name)

@abstractmethod
def _find_output_vars(self):
pass
""" prepare output variables. Will be called in before_train"""

@abstractmethod
def _get_output_vars(self):
return []
""" return a list of output vars to eval"""

def _run_validation(self):
"""
Generator to return inputs and outputs
Eval the vars, generate inputs and outputs
"""
cnt = 0
cost_sum = 0

output_vars = self._get_output_vars()
output_vars.append(self.cost_var)
sess = tf.get_default_session()
with tqdm(total=self.ds.size(), ascii=True) as pbar:
for dp in self.ds.get_data():
feed = dict(itertools.izip(self.input_vars, dp))

batch_size = dp[0].shape[0] # assume batched input

cnt += batch_size
outputs = sess.run(output_vars, feed_dict=feed)
cost = outputs[-1]
# each batch might not have the same size in validation
cost_sum += cost * batch_size
yield (dp, outputs[:-1])
yield (dp, outputs)
pbar.update()

cost_avg = cost_sum / cnt
self.trainer.summary_writer.add_summary(create_summary(
'{}_cost'.format(self.prefix), cost_avg), self.global_step)
self.trainer.stat_holder.add_stat("{}_cost".format(self.prefix), cost_avg)
@abstractmethod
def _trigger_periodic(self):
""" Implement the actual callback"""

class ValidationStatPrinter(ValidationCallback):
"""
Write stat and summary of some Op for a validation dataset.
The result of the given Op must be a scalar, and will be averaged for all batches in the validaion set.
"""
def __init__(self, ds, names_to_print, prefix='validation', period=1):
super(ValidationStatPrinter, self).__init__(ds, prefix, period)
self.names = names_to_print

def _find_output_vars(self):
self.vars_to_print = [self.get_tensor(n) for n in self.names]

def _get_output_vars(self):
return self.vars_to_print

def _trigger_periodic(self):
stats = []
for dp, outputs in self._run_validation():
pass
stats.append(outputs)
stats = np.mean(stats, axis=0)
assert len(stats) == len(self.vars_to_print)

for stat, var in itertools.izip(stats, self.vars_to_print):
name = var.name.replace(':0', '')
self.trainer.summary_writer.add_summary(create_summary(
'{}_{}'.format(self.prefix, name), stat), self.global_step)
self.trainer.stat_holder.add_stat("{}_{}".format(self.prefix, name), stat)


class ValidationError(ValidationCallback):
running_graph = 'test'
"""
Validate the accuracy for the given wrong and cost variable
Use under the following setup:
wrong_var: integer, number of failed samples in this batch
ds: batched dataset
Validate the accuracy from a 'wrong' variable
wrong_var: integer, number of failed samples in this batch
ds: batched dataset
This callback produce the "true" error,
taking account of the fact that batches might not have the same size in
testing (because the size of test set might not be a multiple of batch size).
In theory, the result could be different from what produced by ValidationStatPrinter.
"""
def __init__(self, ds, prefix,
def __init__(self, ds, prefix='validation',
period=1,
wrong_var_name='wrong:0',
cost_var_name='cost:0'):
super(ValidationError, self).__init__(
ds, prefix, period, cost_var_name)
wrong_var_name='wrong:0'):
super(ValidationError, self).__init__(ds, prefix, period)
self.wrong_var_name = wrong_var_name

def _find_output_vars(self):
Expand Down
9 changes: 5 additions & 4 deletions tensorpack/utils/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,20 @@ def add_param_summary(summary_lists):
"""
def perform(var, action):
ndim = var.get_shape().ndims
name = var.name.replace(':0', '')
if action == 'scalar':
assert ndim == 0, "Scalar summary on high-dimension data. Maybe you want 'mean'?"
tf.scalar_summary(var.name, var)
tf.scalar_summary(name, var)
return
assert ndim > 0, "Cannot perform {} summary on scalar data".format(action)
if action == 'histogram':
tf.histogram_summary(var.name, var)
tf.histogram_summary(name, var)
return
if action == 'sparsity':
tf.scalar_summary(var.name + '/sparsity', tf.nn.zero_fraction(var))
tf.scalar_summary(name + '/sparsity', tf.nn.zero_fraction(var))
return
if action == 'mean':
tf.scalar_summary(var.name + '/mean', tf.reduce_mean(var))
tf.scalar_summary(name + '/mean', tf.reduce_mean(var))
return
raise RuntimeError("Unknown action {}".format(action))

Expand Down

0 comments on commit 80622ae

Please sign in to comment.