Skip to content

Commit

Permalink
fix periodic bug
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Apr 16, 2016
1 parent 9a4e6d9 commit 32feff4
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 19 deletions.
3 changes: 0 additions & 3 deletions examples/mnist_convnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>

import tensorflow as tf
from tensorflow.python.ops import control_flow_ops

import numpy as np
import os, sys
import argparse
Expand All @@ -18,7 +16,6 @@
from tensorpack.tfutils import *
from tensorpack.callbacks import *
from tensorpack.dataflow import *
from IPython import embed; embed()

"""
MNIST ConvNet example.
Expand Down
24 changes: 8 additions & 16 deletions tensorpack/callbacks/validation_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,21 @@
from ..utils.stat import *
from ..tfutils import *
from ..tfutils.summary import *
from .base import PeriodicCallback, Callback, TestCallbackType
from .base import Callback, TestCallbackType

__all__ = ['ClassificationError', 'ValidationCallback', 'ValidationStatPrinter']

class ValidationCallback(PeriodicCallback):
class ValidationCallback(Callback):
"""
Base class for validation callbacks.
"""
type = TestCallbackType()

def __init__(self, ds, prefix, period=1):
def __init__(self, ds, prefix):
"""
:param ds: validation dataset. must be a `DataFlow` instance.
:param prefix: name to use for this validation.
:param period: period to perform validation.
"""
super(ValidationCallback, self).__init__(period)
self.ds = ds
self.prefix = prefix

Expand Down Expand Up @@ -63,23 +61,18 @@ def _run_validation(self):
yield (dp, outputs)
pbar.update()

@abstractmethod
def _trigger_periodic(self):
""" Implement the actual callback"""

class ValidationStatPrinter(ValidationCallback):
"""
Write stat and summary of some Op for a validation dataset.
The result of the given Op must be a scalar, and will be averaged for all batches in the validaion set.
"""
def __init__(self, ds, names_to_print, prefix='validation', period=1):
def __init__(self, ds, names_to_print, prefix='validation'):
"""
:param ds: validation dataset. must be a `DataFlow` instance.
:param names_to_print: names of variables to print
:param prefix: name to use for this validation.
:param period: period to perform validation.
"""
super(ValidationStatPrinter, self).__init__(ds, prefix, period)
super(ValidationStatPrinter, self).__init__(ds, prefix)
self.names = names_to_print

def _find_output_vars(self):
Expand All @@ -89,7 +82,7 @@ def _find_output_vars(self):
def _get_output_vars(self):
return self.vars_to_print

def _trigger_periodic(self):
def _trigger_epoch(self):
stats = []
for dp, outputs in self._run_validation():
stats.append(outputs)
Expand All @@ -114,13 +107,12 @@ class ClassificationError(ValidationCallback):
In theory, the result could be different from what produced by ValidationStatPrinter.
"""
def __init__(self, ds, prefix='validation',
period=1,
wrong_var_name='wrong:0'):
"""
:param ds: a batched `DataFlow` instance
:param wrong_var_name: name of the `wrong` variable
"""
super(ClassificationError, self).__init__(ds, prefix, period)
super(ClassificationError, self).__init__(ds, prefix)
self.wrong_var_name = wrong_var_name

def _find_output_vars(self):
Expand All @@ -129,7 +121,7 @@ def _find_output_vars(self):
def _get_output_vars(self):
return [self.wrong_var]

def _trigger_periodic(self):
def _trigger_epoch(self):
err_stat = Accuracy()
for dp, outputs in self._run_validation():
batch_size = dp[0].shape[0] # assume batched input
Expand Down

0 comments on commit 32feff4

Please sign in to comment.