Skip to content

Commit

Permalink
clean-ups
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Dec 28, 2015
1 parent 585f083 commit 745ad4f
Show file tree
Hide file tree
Showing 10 changed files with 56 additions and 47 deletions.
2 changes: 1 addition & 1 deletion LICENSE
Expand Up @@ -186,7 +186,7 @@
same "printed page" as the copyright notice for easier
identification within third-party archives.

Copyright {yyyy} {name of copyright owner}
Copyright Yuxin Wu

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down
7 changes: 4 additions & 3 deletions README.md
Expand Up @@ -3,7 +3,8 @@ Neural Network Toolbox based on TensorFlow


## Features:
+ Scoped Abstraction of common models.
+ Scoped abstraction of common models.
+ Provide callbacks to control training behavior (as in [Keras](http://keras.io)).
+ Use `Dataflow` to fine-grained control data preprocessing.
+ Write a config file, tensorpack will do the rest.
+ Use `Dataflow` to own fine-grained control on data preprocessing.
+ Automatically use the Queue operator in tensorflow to speed up input.
+ Training and testing graph are modeled together, automatically.
1 change: 0 additions & 1 deletion dataflow/base.py
Expand Up @@ -8,7 +8,6 @@
__all__ = ['DataFlow']

class DataFlow(object):
# TODO private impl
@abstractmethod
def get_data(self):
"""
Expand Down
7 changes: 4 additions & 3 deletions dataflow/batch.py → dataflow/common.py
Expand Up @@ -11,9 +11,10 @@
class BatchData(DataFlow):
def __init__(self, ds, batch_size, remainder=False):
"""
Args:
ds: a dataflow
remainder: whether to return the remaining data smaller than a batch_size
Group data in ds into batches
ds: a DataFlow instance
remainder: whether to return the remaining data smaller than a batch_size.
if set, might return a data point of a different shape
"""
self.ds = ds
self.batch_size = batch_size
Expand Down
4 changes: 2 additions & 2 deletions example_mnist.py 100644 → 100755
Expand Up @@ -90,8 +90,8 @@ def get_config():

dataset_train = BatchData(Mnist('train'), BATCH_SIZE)
dataset_test = BatchData(Mnist('test'), 256, remainder=True)
dataset_train = FixedSizeData(dataset_train, 20)
dataset_test = FixedSizeData(dataset_test, 20)
#dataset_train = FixedSizeData(dataset_train, 20)
#dataset_test = FixedSizeData(dataset_test, 20)

sess_config = tf.ConfigProto()
sess_config.device_count['GPU'] = 1
Expand Down
29 changes: 13 additions & 16 deletions models/_common.py
Expand Up @@ -12,7 +12,12 @@

def layer_register(summary_activation=False):
"""
summary_activation: default behavior of whether to summary the output of this layer
Register a layer.
Args:
summary_activation:
Define the default behavior of whether to
summary the output(activation) of this layer.
Can be overriden when creating the layer.
"""
def wrapper(func):
def inner(*args, **kwargs):
Expand All @@ -26,24 +31,17 @@ def inner(*args, **kwargs):
outputs = func(*args, **kwargs)
if name not in _layer_logged:
# log shape info and add activation
if isinstance(inputs, list):
shape_str = ",".join(
map(str(x.get_shape().as_list()), inputs))
else:
shape_str = str(inputs.get_shape().as_list())
logger.info("{} input: {}".format(name, shape_str))
logger.info("{} input: {}".format(
name, get_shape_str(inputs)))
logger.info("{} output: {}".format(
name, get_shape_str(outputs)))

if isinstance(outputs, list):
shape_str = ",".join(
map(str(x.get_shape().as_list()), outputs))
if do_summary:
if do_summary:
if isinstance(outputs, list):
for x in outputs:
add_activation_summary(x, scope.name)
else:
shape_str = str(outputs.get_shape().as_list())
if do_summary:
else:
add_activation_summary(outputs, scope.name)
logger.info("{} output: {}".format(name, shape_str))
_layer_logged.add(name)
return outputs
return inner
Expand All @@ -63,4 +61,3 @@ def shape2d(a):
def shape4d(a):
# for use with tensorflow
return [1] + shape2d(a) + [1]

6 changes: 2 additions & 4 deletions train.py
Expand Up @@ -5,9 +5,8 @@

import tensorflow as tf
from utils import *
from utils.concurrency import *
from utils.callback import *
from utils.summary import *
from utils.concurrency import EnqueueThread,coordinator_guard
from utils.summary import summary_moving_average, describe_model
from dataflow import DataFlow
from itertools import count
import argparse
Expand Down Expand Up @@ -97,7 +96,6 @@ def start_train(config):

# note that summary_op will take a data from the queue.
callbacks.trigger_epoch()
sess.close()

def main(get_config_func):
parser = argparse.ArgumentParser()
Expand Down
14 changes: 0 additions & 14 deletions utils/__init__.py
Expand Up @@ -27,20 +27,6 @@ def timed_operation(msg, log_start=False):
logger.info('finished {}, time={:.2f}sec.'.format(
msg, time.time() - start))

def describe_model():
train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
msg = [""]
total = 0
for v in train_vars:
shape = v.get_shape()
ele = shape.num_elements()
total += ele
msg.append("{}: shape={}, dim={}".format(
v.name, shape.as_list(), ele))
msg.append("Total dim={}".format(total))
logger.info("Model Params: {}".format('\n'.join(msg)))

# TODO disable shape output in get_model
@contextmanager
def create_test_graph():
G = tf.get_default_graph()
Expand Down
3 changes: 0 additions & 3 deletions utils/naming.py
Expand Up @@ -3,9 +3,6 @@
# File: naming.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>

IS_TRAINING_OP_NAME = 'is_training'
IS_TRAINING_VAR_NAME = 'is_training:0'

GLOBAL_STEP_OP_NAME = 'global_step'
GLOBAL_STEP_VAR_NAME = 'global_step:0'

Expand Down
30 changes: 30 additions & 0 deletions utils/summary.py
Expand Up @@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>

import tensorflow as tf
import logger
from .naming import *

def create_summary(name, v):
Expand Down Expand Up @@ -44,6 +45,10 @@ def add_histogram_summary(regex):
tf.histogram_summary(name, p)

def summary_moving_average(cost_var):
""" Create a MovingAverage op and summary for all variables in
COST_VARS_KEY, SUMMARY_VARS_KEY, as well as the argument
Return a op to maintain these average
"""
global_step_var = tf.get_default_graph().get_tensor_by_name(GLOBAL_STEP_VAR_NAME)
averager = tf.train.ExponentialMovingAverage(
0.9, num_updates=global_step_var, name='avg')
Expand All @@ -54,3 +59,28 @@ def summary_moving_average(cost_var):
for c in vars_to_summary:
tf.scalar_summary(c.op.name, averager.average(c))
return avg_maintain_op

def describe_model():
""" describe the current model parameters"""
train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
msg = [""]
total = 0
for v in train_vars:
shape = v.get_shape()
ele = shape.num_elements()
total += ele
msg.append("{}: shape={}, dim={}".format(
v.name, shape.as_list(), ele))
msg.append("Total dim={}".format(total))
logger.info("Model Params: {}".format('\n'.join(msg)))


def get_shape_str(tensors):
""" return the shape string for a tensor or a list of tensors"""
if isinstance(tensors, list):
shape_str = ",".join(
map(str(x.get_shape().as_list()), tensors))
else:
shape_str = str(tensors.get_shape().as_list())
return shape_str

0 comments on commit 745ad4f

Please sign in to comment.