Skip to content

Commit

Permalink
[Keras] use inputs_desc/targets_desc explicitly, to avoid hacks (#160)
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Jan 2, 2018
1 parent f1ee183 commit ac02c62
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 41 deletions.
27 changes: 15 additions & 12 deletions examples/mnist-keras-v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
KL = keras.layers


from tensorpack.input_source import QueueInput
from tensorpack import InputDesc, QueueInput
from tensorpack.dataflow import dataset, BatchData, MapData
from tensorpack.utils import logger
from tensorpack.contrib.keras import KerasModel
Expand All @@ -22,8 +22,7 @@
def get_data():
def f(dp):
im = dp[0][:, :, None]
onehot = np.zeros(10, dtype='int32')
onehot[dp[1]] = 1
onehot = np.eye(10)[dp[1]]
return [im, onehot]

train = BatchData(MapData(dataset.Mnist('train'), f), 128)
Expand All @@ -34,11 +33,14 @@ def f(dp):
if __name__ == '__main__':
logger.auto_set_dir()

def model_func(input_tensors):
def model_func(inputs):
"""
Keras model has to be created inside this function to be used with tensorpack.
"""
M = keras.models.Sequential()
M.add(KL.InputLayer(
input_shape=[IMAGE_SIZE, IMAGE_SIZE, 1],
input_tensor=input_tensors[0]))
# input_tensor have to be used here for tensorpack trainer to function properly.
# Just use inputs[1], inputs[2] if you have multiple inputs.
M.add(KL.InputLayer(input_tensor=inputs[0]))
M.add(KL.Conv2D(32, 3, activation='relu', padding='same'))
M.add(KL.MaxPooling2D())
M.add(KL.Conv2D(32, 3, activation='relu', padding='same'))
Expand All @@ -51,18 +53,19 @@ def model_func(input_tensors):
M.add(KL.Dropout(0.5))
M.add(KL.Dense(10, activation=None, kernel_regularizer=keras.regularizers.l2(1e-5)))
M.add(KL.Activation('softmax'))

return M

dataset_train, dataset_test = get_data()

# from tensorpack import *
# trainer = SyncMultiGPUTrainerReplicated(2)
M = KerasModel(model_func, QueueInput(dataset_train))
M = KerasModel(
model_func,
inputs_desc=[InputDesc(tf.float32, [None, IMAGE_SIZE, IMAGE_SIZE, 1], 'images')],
targets_desc=[InputDesc(tf.float32, [None, 10], 'labels')],
input=QueueInput(dataset_train))
M.compile(
optimizer=tf.train.AdamOptimizer(1e-3),
loss='categorical_crossentropy',
metrics=['categorical_accuracy']
metrics='categorical_accuracy'
)
M.fit(
validation_data=dataset_test,
Expand Down
51 changes: 22 additions & 29 deletions tensorpack/contrib/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from tensorflow.python.keras import metrics as metrics_module

from ..models.regularize import regularize_cost_from_collection
from ..graph_builder import InputDesc
from ..train import Trainer, SimpleTrainer, SyncMultiGPUTrainerParameterServer, DistributedTrainerBase
from ..train import Trainer, SimpleTrainer, SyncMultiGPUTrainerParameterServer
from ..train.trainers import DistributedTrainerBase
from ..callbacks import (
Callback, InferenceRunner, CallbackToHook,
ScalarStats)
Expand Down Expand Up @@ -45,33 +45,24 @@ def __init__(self, get_model):
self.cached_model = None

def __call__(self, input_tensors):
"""
Returns:
output tensors of this tower, evaluated with the input tensors.
"""
reuse = tf.get_variable_scope().reuse
if self.cached_model is None:
assert not reuse
self.cached_model = self.get_model(input_tensors)
return self.cached_model.outputs

if reuse:
# use the cached Keras model to mimic reuse
return self.cached_model.call(input_tensors)
else:
# create new Keras model if not reuse
M = self.get_model(input_tensors)
return M.outputs

def call_virtual(self):

class NoneTensorProxy(object):
def __getitem__(self, index):
return None

def __len__(self):
raise NotImplementedError(
"Do not call `len(inputs)` because it's only a virtual object "
"for the moment! Use `inputs[index]` directly!")

G_tmp = tf.Graph() # we need a model instance to know metadata about inputs/outputs
with G_tmp.as_default():
return self.get_model(NoneTensorProxy())


# Keras needs an extra input if learning_phase is used by the model
# This cb will be used by
Expand All @@ -97,8 +88,9 @@ def _before_run(self, ctx):


def setup_keras_trainer(
trainer, get_model, input,
optimizer, loss, metrics):
trainer, get_model,
inputs_desc, targets_desc,
input, optimizer, loss, metrics):
"""
Args:
trainer (SingleCostTrainer):
Expand All @@ -113,17 +105,11 @@ def setup_keras_trainer(
assert isinstance(metrics, list), metrics
model_caller = KerasModelCaller(get_model)

M_tmp = model_caller.call_virtual()

inputs_desc = [InputDesc(t.dtype, t.shape.as_list(), 'input{}'.format(i))
for i, t in enumerate(M_tmp.inputs)]
outputs_desc = [InputDesc(t.dtype, t.shape.as_list(), 'output{}'.format(i))
for i, t in enumerate(M_tmp.outputs)]
nr_inputs = len(inputs_desc)

def get_cost(*inputs):
assert len(inputs) == len(inputs_desc) + len(outputs_desc), \
"Input source size {} != {} + {}".format(len(inputs), len(inputs_desc), len(outputs_desc))
assert len(inputs) == len(inputs_desc) + len(targets_desc), \
"Input source size {} != {} + {}".format(len(inputs), len(inputs_desc), len(targets_desc))
ctx = get_current_tower_context()
input_tensors = list(inputs[:nr_inputs])
target_tensors = list(inputs[nr_inputs:])
Expand Down Expand Up @@ -173,7 +159,7 @@ def get_cost(*inputs):
return total_loss

trainer.setup_graph(
inputs_desc + outputs_desc,
inputs_desc + targets_desc,
input,
get_cost,
lambda: optimizer)
Expand All @@ -182,20 +168,26 @@ def get_cost(*inputs):


class KerasModel(object):
def __init__(self, get_model, input, trainer=None):
def __init__(self, get_model, inputs_desc, targets_desc,
input, trainer=None):
"""
Args:
get_model ( -> keras.model.Model):
inputs_desc ([InputDesc]):
targets_desc ([InputDesc]):
input (InputSource):
trainer (Trainer): the default will check the number of available
GPUs and use them all.
"""
self.get_model = get_model
self.inputs_desc = inputs_desc
self.targets_desc = targets_desc
if trainer is None:
nr_gpu = get_nr_gpu()
if nr_gpu <= 1:
trainer = SimpleTrainer()
else:
# the default multigpu trainer
trainer = SyncMultiGPUTrainerParameterServer(nr_gpu)
assert isinstance(trainer, Trainer), trainer
assert not isinstance(trainer, DistributedTrainerBase)
Expand All @@ -219,6 +211,7 @@ def compile(self, optimizer, loss, metrics=None):
self._stats_to_inference = loss + metrics + [TOTAL_LOSS_NAME]
setup_keras_trainer(
self.trainer, get_model=self.get_model,
inputs_desc=self.inputs_desc, targets_desc=self.targets_desc,
input=self.input,
optimizer=optimizer,
loss=loss,
Expand Down

0 comments on commit ac02c62

Please sign in to comment.