Skip to content

Commit

Permalink
bnrelu, classificationerror
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Apr 11, 2016
1 parent fec3a4a commit dcbd469
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 17 deletions.
16 changes: 8 additions & 8 deletions examples/cifar10_convnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@

"""
CIFAR10 90% validation accuracy after 70k step.
91% validation accuracy after 36k step with 3 GPU.
"""

BATCH_SIZE = 128
Expand All @@ -46,15 +44,15 @@ def _get_cost(self, input_vars, is_training):

image = image / 4.0 # just to make range smaller
l = Conv2D('conv1.1', image, out_channel=64, kernel_shape=3)
l = Conv2D('conv1.2', l, out_channel=64, kernel_shape=3, nl=BNReLU(is_training))
l = Conv2D('conv1.2', l, out_channel=64, kernel_shape=3, nl=BNReLU(is_training), use_bias=False)
l = MaxPooling('pool1', l, 3, stride=2, padding='SAME')

l = Conv2D('conv2.1', l, out_channel=128, kernel_shape=3)
l = Conv2D('conv2.2', l, out_channel=128, kernel_shape=3, nl=BNReLU(is_training))
l = Conv2D('conv2.2', l, out_channel=128, kernel_shape=3, nl=BNReLU(is_training), use_bias=False)
l = MaxPooling('pool2', l, 3, stride=2, padding='SAME')

l = Conv2D('conv3.1', l, out_channel=128, kernel_shape=3, padding='VALID')
l = Conv2D('conv3.2', l, out_channel=128, kernel_shape=3, padding='VALID', nl=BNReLU(is_training))
l = Conv2D('conv3.2', l, out_channel=128, kernel_shape=3, padding='VALID', nl=BNReLU(is_training), use_bias=False)
l = FullyConnected('fc0', l, 1024 + 512,
b_init=tf.constant_initializer(0.1))
l = tf.nn.dropout(l, keep_prob)
Expand All @@ -69,7 +67,7 @@ def _get_cost(self, input_vars, is_training):
cost = tf.reduce_mean(cost, name='cross_entropy_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)

# compute the number of failed samples, for ValidationError to use at test time
# compute the number of failed samples, for ClassificationError to use at test time
wrong = prediction_incorrect(logits, label)
nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error
Expand Down Expand Up @@ -125,7 +123,7 @@ def get_config():
lr = tf.train.exponential_decay(
learning_rate=1e-2,
global_step=get_global_step_var(),
decay_steps=dataset_train.size() * 30 if nr_gpu == 1 else 15,
decay_steps=dataset_train.size() * 30 if nr_gpu == 1 else 20,
decay_rate=0.5, staircase=True, name='learning_rate')
tf.scalar_summary('learning_rate', lr)

Expand All @@ -135,7 +133,7 @@ def get_config():
callbacks=Callbacks([
StatPrinter(),
PeriodicSaver(),
ValidationError(dataset_test, prefix='test'),
ClassificationError(dataset_test, prefix='test'),
]),
session_config=sess_config,
model=Model(),
Expand All @@ -155,6 +153,8 @@ def get_config():

if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
else:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

with tf.Graph().as_default():
config = get_config()
Expand Down
2 changes: 1 addition & 1 deletion examples/cifar10_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def get_config():
callbacks=Callbacks([
StatPrinter(),
PeriodicSaver(),
ValidationError(dataset_test, prefix='test'),
ClassificationError(dataset_test, prefix='test'),
ScheduledHyperParamSetter('learning_rate',
[(1, 0.1), (82, 0.01), (123, 0.001), (300, 0.0002)])
]),
Expand Down
4 changes: 2 additions & 2 deletions examples/mnist_convnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _get_cost(self, input_vars, is_training):
cost = tf.reduce_mean(cost, name='cross_entropy_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)

# compute the number of failed samples, for ValidationError to use at test time
# compute the number of failed samples, for ClassificationError to use at test time
wrong = prediction_incorrect(logits, label)
nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error
Expand Down Expand Up @@ -106,7 +106,7 @@ def get_config():
StatPrinter(),
PeriodicSaver(),
ValidationStatPrinter(dataset_test, ['cost:0']),
ValidationError(dataset_test, prefix='validation'),
ClassificationError(dataset_test, prefix='validation'),
]),
session_config=sess_config,
model=Model(),
Expand Down
4 changes: 2 additions & 2 deletions examples/svhn_digit_convnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _get_cost(self, input_vars, is_training):
cost = tf.reduce_mean(cost, name='cross_entropy_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)

# compute the number of failed samples, for ValidationError to use at test time
# compute the number of failed samples, for ClassificationError to use at test time
wrong = prediction_incorrect(logits, label)
nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error
Expand Down Expand Up @@ -110,7 +110,7 @@ def get_config():
callbacks=Callbacks([
StatPrinter(),
PeriodicSaver(),
ValidationError(test, prefix='test'),
ClassificationError(test, prefix='test'),
]),
session_config=sess_config,
model=Model(),
Expand Down
7 changes: 3 additions & 4 deletions tensorpack/callbacks/validation_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ..tfutils.summary import *
from .base import PeriodicCallback, Callback, TestCallbackType

__all__ = ['ValidationError', 'ValidationCallback', 'ValidationStatPrinter']
__all__ = ['ClassificationError', 'ValidationCallback', 'ValidationStatPrinter']

class ValidationCallback(PeriodicCallback):
"""
Expand Down Expand Up @@ -100,8 +100,7 @@ def _trigger_periodic(self):
'{}_{}'.format(self.prefix, name), stat), self.global_step)
self.trainer.stat_holder.add_stat("{}_{}".format(self.prefix, name), stat)


class ValidationError(ValidationCallback):
class ClassificationError(ValidationCallback):
"""
Validate the accuracy from a `wrong` variable
Expand All @@ -119,7 +118,7 @@ def __init__(self, ds, prefix='validation',
:param ds: a batched `DataFlow` instance
:param wrong_var_name: name of the `wrong` variable
"""
super(ValidationError, self).__init__(ds, prefix, period)
super(ClassificationError, self).__init__(ds, prefix, period)
self.wrong_var_name = wrong_var_name

def _find_output_vars(self):
Expand Down

0 comments on commit dcbd469

Please sign in to comment.