Skip to content

Commit

Permalink
bring some sense to import
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Apr 20, 2016
1 parent bd0ca73 commit 3e87659
Show file tree
Hide file tree
Showing 11 changed files with 42 additions and 19 deletions.
2 changes: 1 addition & 1 deletion examples/ResNet/svhn_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

"""
ResNet-110 for SVHN Digit Classification.
Reach 1.9% validation error after 90 epochs, with 2 TitanX xxhr, 2it/s.
Reach 1.8% validation error after 70 epochs, with 2 TitanX. 2it/s.
You might need to adjust the learning rate schedule when running with 1 GPU.
"""

Expand Down
26 changes: 11 additions & 15 deletions examples/mnist_convnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,10 @@
import os, sys
import argparse

from tensorpack.train import *
from tensorpack.models import *
from tensorpack.utils import *
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from tensorpack.tfutils import *
import tensorpack as tp
from tensorpack.models import *
from tensorpack.utils import *
from tensorpack.callbacks import *
from tensorpack.dataflow import *

"""
MNIST ConvNet example.
Expand Down Expand Up @@ -60,7 +56,7 @@ def _get_cost(self, input_vars, is_training):
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)

# compute the number of failed samples, for ClassificationError to use at test time
wrong = prediction_incorrect(logits, label)
wrong = tp.symbolic_functions.prediction_incorrect(logits, label)
nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error
tf.add_to_collection(
Expand All @@ -72,7 +68,7 @@ def _get_cost(self, input_vars, is_training):
name='regularize_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost)

add_param_summary([('.*/W', ['histogram'])]) # monitor histogram of all W
tp.summary.add_param_summary([('.*/W', ['histogram'])]) # monitor histogram of all W
return tf.add_n([wd_cost, cost], name='cost')

def get_config():
Expand All @@ -81,22 +77,22 @@ def get_config():
os.path.join('train_log', basename[:basename.rfind('.')]))

# prepare dataset
dataset_train = BatchData(dataset.Mnist('train'), 128)
dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True)
dataset_train = tp.BatchData(tp.dataset.Mnist('train'), 128)
dataset_test = tp.BatchData(tp.dataset.Mnist('test'), 256, remainder=True)
step_per_epoch = dataset_train.size()

# prepare session
sess_config = get_default_sess_config()
sess_config = tp.get_default_sess_config()
sess_config.gpu_options.per_process_gpu_memory_fraction = 0.5

lr = tf.train.exponential_decay(
learning_rate=1e-3,
global_step=get_global_step_var(),
global_step=tp.get_global_step_var(),
decay_steps=dataset_train.size() * 10,
decay_rate=0.3, staircase=True, name='learning_rate')
tf.scalar_summary('learning_rate', lr)

return TrainConfig(
return tp.TrainConfig(
dataset=dataset_train,
optimizer=tf.train.AdamOptimizer(lr),
callbacks=Callbacks([
Expand Down Expand Up @@ -125,5 +121,5 @@ def get_config():
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
SimpleTrainer(config).train()
tp.SimpleTrainer(config).train()

17 changes: 17 additions & 0 deletions tensorpack/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# -*- coding: utf-8 -*-
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>

import models
import train
import utils
import tfutils
import callbacks
import dataflow

from .train import *
from .models import *
from .utils import *
from .tfutils import *
from .callbacks import *
from .dataflow import *
1 change: 1 addition & 0 deletions tensorpack/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
def _global_import(name):
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
del globals()[name]
for k in lst:
globals()[k] = p.__dict__[k]

Expand Down
4 changes: 3 additions & 1 deletion tensorpack/dataflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>

from pkgutil import walk_packages
import importlib
import os
import os.path

Expand All @@ -12,10 +13,11 @@
def _global_import(name):
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
del globals()[name]
for k in lst:
globals()[k] = p.__dict__[k]

__SKIP = ['dftools', 'dataset']
__SKIP = ['dftools', 'dataset', 'imgaug']
for _, module_name, _ in walk_packages(
[os.path.dirname(__file__)]):
if not module_name.startswith('_') and \
Expand Down
1 change: 1 addition & 0 deletions tensorpack/dataflow/imgaug/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
def global_import(name):
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
del globals()[name]
for k in lst:
globals()[k] = p.__dict__[k]

Expand Down
4 changes: 2 additions & 2 deletions tensorpack/dataflow/imgaug/paste.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def _augment(self, img):

background = self.background_filler.fill(
self.background_shape, img.arr)
h0 = (self.background_shape[0] - img_shape[0]) * 0.5
w0 = (self.background_shape[1] - img_shape[1]) * 0.5
h0 = int((self.background_shape[0] - img_shape[0]) * 0.5)
w0 = int((self.background_shape[1] - img_shape[1]) * 0.5)
background[h0:h0+img_shape[0], w0:w0+img_shape[1]] = img.arr
img.arr = background
if img.coords:
Expand Down
1 change: 1 addition & 0 deletions tensorpack/tfutils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
def _global_import(name):
p = __import__(name, globals(), None, level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
del globals()[name]
for k in lst:
globals()[k] = p.__dict__[k]

Expand Down
3 changes: 3 additions & 0 deletions tensorpack/tfutils/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from ..utils import *
from . import get_global_step_var

__all__ = ['create_summary', 'add_param_summary', 'add_activation_summary',
'summary_moving_average']

def create_summary(name, v):
"""
Return a tf.Summary object with name and simple scalar value v
Expand Down
1 change: 1 addition & 0 deletions tensorpack/train/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def global_import(name):
lst = p.__all__ if '__all__' in dir(p) else dir(p)
for k in lst:
globals()[k] = p.__dict__[k]
del globals()[name]

for _, module_name, _ in walk_packages(
[os.path.dirname(__file__)]):
Expand Down
1 change: 1 addition & 0 deletions tensorpack/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
def _global_import(name):
p = __import__(name, globals(), None, level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
del globals()[name]
for k in lst:
globals()[k] = p.__dict__[k]
_global_import('naming')
Expand Down

0 comments on commit 3e87659

Please sign in to comment.