Skip to content

Commit

Permalink
Move more core code to tf.compat.v1
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Mar 19, 2019
1 parent 4057c53 commit 3d1a30f
Show file tree
Hide file tree
Showing 15 changed files with 41 additions and 28 deletions.
4 changes: 2 additions & 2 deletions examples/ImageNetModels/imagenet_utils.py
Expand Up @@ -13,7 +13,7 @@
from tensorpack import ModelDesc
from tensorpack.dataflow import AugmentImageComponent, BatchData, MultiThreadMapData, PrefetchDataZMQ, dataset, imgaug
from tensorpack.input_source import QueueInput, StagingInput
from tensorpack.models import regularize_cost
from tensorpack.models import regularize_cost, l2_regularizer
from tensorpack.predict import FeedfreePredictor, PredictConfig
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils import logger
Expand Down Expand Up @@ -339,7 +339,7 @@ def build_graph(self, image, label):

if self.weight_decay > 0:
wd_loss = regularize_cost(self.weight_decay_pattern,
tf.contrib.layers.l2_regularizer(self.weight_decay),
l2_regularizer(self.weight_decay),
name='l2_regularize_loss')
add_moving_summary(loss, wd_loss)
total_cost = tf.add_n([loss, wd_loss], name='cost')
Expand Down
2 changes: 1 addition & 1 deletion examples/basics/mnist-visualizations.py
Expand Up @@ -98,7 +98,7 @@ def build_graph(self, image, label):
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss')

tf.reduce_mean(tf.cast(tf.nn.in_top_k(logits, label, 1)), tf.float32, name='accuracy')
tf.reduce_mean(tf.cast(tf.nn.in_top_k(logits, label, 1), tf.float32), name='accuracy')

wd_cost = tf.multiply(1e-5,
regularize_cost('fc.*/W', tf.nn.l2_loss),
Expand Down
3 changes: 2 additions & 1 deletion tensorpack/graph_builder/training.py
Expand Up @@ -10,6 +10,7 @@
import tensorflow as tf
from six.moves import range, zip

from ..compat import tfv1
from ..tfutils.common import get_tf_version_tuple
from ..tfutils.gradproc import ScaleGradient
from ..tfutils.tower import TrainTowerContext
Expand Down Expand Up @@ -101,7 +102,7 @@ def call_for_each_tower(
device = devices[idx] if devices is not None else '/gpu:{}'.format(t)
usevs = use_vs[idx] if use_vs is not None else False
reuse = not usevs and idx > 0
with tf.device(device), _maybe_reuse_vs(reuse), TrainTowerContext(
with tfv1.device(device), _maybe_reuse_vs(reuse), TrainTowerContext(
tower_names[idx],
vs_name=tower_names[idx] if usevs else '',
index=idx, total=len(towers)):
Expand Down
3 changes: 2 additions & 1 deletion tensorpack/graph_builder/utils.py
Expand Up @@ -6,6 +6,7 @@
from contextlib import contextmanager
import tensorflow as tf

from ..compat import tfv1
from ..tfutils.common import get_tf_version_tuple
from ..tfutils.scope_utils import cached_name_scope, under_name_scope
from ..tfutils.varreplace import custom_getter_scope
Expand Down Expand Up @@ -82,7 +83,7 @@ def __call__(self, op):
# from tensorflow.python.training.device_util import canonicalize
# from tensorflow.python.distribute.device_util import canonicalize
def canonicalize(name): # tensorflow/tensorflow#11484
return tf.DeviceSpec.from_string(name).to_string()
return tfv1.DeviceSpec.from_string(name).to_string()

if op.device:
return op.device
Expand Down
5 changes: 3 additions & 2 deletions tensorpack/models/nonlin.py
Expand Up @@ -4,6 +4,7 @@

import tensorflow as tf

from ..compat import tfv1
from .batch_norm import BatchNorm
from .common import VariableHolder, layer_register

Expand Down Expand Up @@ -50,8 +51,8 @@ def PReLU(x, init=0.001, name='output'):
* ``alpha``: learnable slope.
"""
init = tf.constant_initializer(init)
alpha = tf.get_variable('alpha', [], initializer=init)
init = tfv1.constant_initializer(init)
alpha = tfv1.get_variable('alpha', [], initializer=init)
x = ((1 + alpha) * x + (1 - alpha) * tf.abs(x))
ret = tf.multiply(x, 0.5, name=name)

Expand Down
3 changes: 2 additions & 1 deletion tensorpack/models/regularize.py
Expand Up @@ -25,7 +25,8 @@ def _log_once(msg):
l2_regularizer = tf.contrib.layers.l2_regularizer
l1_regularizer = tf.contrib.layers.l1_regularizer
else:
l2_regularizer = tf.keras.regularizers.l2
# oh these little dirty details

This comment has been minimized.

Copy link
@ppwwyyxx

ppwwyyxx Mar 21, 2019

Author Collaborator
l2_regularizer = lambda x: tf.keras.regularizers.l2(x * 0.5) # noqa
l1_regularizer = tf.keras.regularizers.l1


Expand Down
3 changes: 2 additions & 1 deletion tensorpack/predict/concurrency.py
Expand Up @@ -8,6 +8,7 @@
import tensorflow as tf
from six.moves import queue, range

from ..compat import tfv1
from ..tfutils.model_utils import describe_trainable_vars
from ..utils import logger
from ..utils.concurrency import DIE, ShareSessionThread, StoppableThread
Expand Down Expand Up @@ -162,7 +163,7 @@ def __init__(self, predictors, batch_size=5):

def start(self):
if self._need_default_sess:
assert tf.get_default_session() is not None, \
assert tfv1.get_default_session() is not None, \
"Not session is bind to predictors, " \
"MultiThreadAsyncPredictor.start() has to be called under a default session!"
for t in self.threads:
Expand Down
2 changes: 1 addition & 1 deletion tensorpack/predict/config.py
Expand Up @@ -3,7 +3,7 @@


import six
import tensorflow as tf
from ..compat import tfv1 as tf

from ..graph_builder import ModelDescBase
from ..tfutils import get_default_sess_config
Expand Down
4 changes: 4 additions & 0 deletions tensorpack/tfutils/argscope.py
Expand Up @@ -6,7 +6,9 @@
from contextlib import contextmanager
from functools import wraps
from inspect import getmembers, isfunction
import tensorflow as tf

from ..compat import is_tfv2
from ..utils import logger
from .tower import get_current_tower_context

Expand Down Expand Up @@ -138,6 +140,8 @@ def enable_argscope_for_module(module, log_shape=True):
Args:
log_shape (bool): print input/output shapes of each function.
"""
if is_tfv2() and module == tf.layers:
module = tf.compat.v1.layers
for name, obj in getmembers(module):
if isfunction(obj):
setattr(module, name, enable_argscope_for_function(obj,
Expand Down
18 changes: 10 additions & 8 deletions tensorpack/tfutils/export.py
Expand Up @@ -12,6 +12,7 @@
from tensorflow.python.platform import gfile
from tensorflow.python.tools import optimize_for_inference_lib

from ..compat import is_tfv2, tfv1
from ..input_source import PlaceholderInput
from ..tfutils.common import get_tensors_by_names, get_tf_version_tuple
from ..tfutils.tower import PredictTowerContext
Expand Down Expand Up @@ -60,7 +61,7 @@ def export_compact(self, filename, optimize=True, toco_compatible=False):

self.config.session_init._setup_graph()
# we cannot use "self.config.session_creator.create_session()" here since it finalizes the graph
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
sess = tfv1.Session(config=tfv1.ConfigProto(allow_soft_placement=True))
self.config.session_init._run_init(sess)

dtypes = [n.dtype for n in input_tensors]
Expand Down Expand Up @@ -88,7 +89,7 @@ def export_compact(self, filename, optimize=True, toco_compatible=False):
logger.info("Output graph written to {}.".format(filename))

def export_serving(self, filename,
tags=[tf.saved_model.tag_constants.SERVING],
tags=[tf.saved_model.SERVING if is_tfv2() else tf.saved_model.tag_constants.SERVING],
signature_name='prediction_pipeline'):
"""
Converts a checkpoint and graph to a servable for TensorFlow Serving.
Expand Down Expand Up @@ -121,21 +122,22 @@ def export_serving(self, filename,
self.config.tower_func(*input.get_input_tensors())

input_tensors = get_tensors_by_names(self.config.input_names)
inputs_signatures = {t.name: tf.saved_model.utils.build_tensor_info(t) for t in input_tensors}
saved_model = tfv1.saved_model.utils
inputs_signatures = {t.name: saved_model.build_tensor_info(t) for t in input_tensors}
output_tensors = get_tensors_by_names(self.config.output_names)
outputs_signatures = {t.name: tf.saved_model.utils.build_tensor_info(t) for t in output_tensors}
outputs_signatures = {t.name: saved_model.build_tensor_info(t) for t in output_tensors}

self.config.session_init._setup_graph()
# we cannot use "self.config.session_creator.create_session()" here since it finalizes the graph
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
sess = tfv1.Session(config=tfv1.ConfigProto(allow_soft_placement=True))
self.config.session_init._run_init(sess)

builder = tf.saved_model.builder.SavedModelBuilder(filename)
builder = tfv1.saved_model.builder.SavedModelBuilder(filename)

prediction_signature = tf.saved_model.signature_def_utils.build_signature_def(
prediction_signature = tfv1.saved_model.signature_def_utils.build_signature_def(
inputs=inputs_signatures,
outputs=outputs_signatures,
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
method_name=tfv1.saved_model.signature_constants.PREDICT_METHOD_NAME)

builder.add_meta_graph_and_variables(
sess, tags,
Expand Down
7 changes: 4 additions & 3 deletions tensorpack/tfutils/gradproc.py
Expand Up @@ -8,6 +8,7 @@
import six
import tensorflow as tf

from ..compat import tfv1
from ..utils import logger
from .summary import add_moving_summary
from .symbolic_functions import print_stat, rms
Expand Down Expand Up @@ -40,11 +41,11 @@ def process(self, grads):

# reuse the old name_scope, if process() is called multiple times
if self._name_scope is None:
with tf.name_scope(type(self).__name__) as scope:
with tfv1.name_scope(type(self).__name__) as scope:
self._name_scope = scope
return self._process(grads)
else:
with tf.name_scope(self._name_scope):
with tfv1.name_scope(self._name_scope):
return self._process(grads)

@abstractmethod
Expand Down Expand Up @@ -175,7 +176,7 @@ def _mapper(self, grad, var):
return grad
if name not in SummaryGradient._summaried_gradient:
SummaryGradient._summaried_gradient.add(name)
tf.summary.histogram(name + '-grad', grad, collections=self._coll)
tfv1.summary.histogram(name + '-grad', grad, collections=self._coll)
add_moving_summary(rms(grad, name=name + '/rms'))
return grad

Expand Down
2 changes: 1 addition & 1 deletion tensorpack/tfutils/optimizer.py
Expand Up @@ -20,7 +20,7 @@ class ProxyOptimizer(tfv1.train.Optimizer):
A transparent proxy which delegates all methods of :class:`tf.train.Optimizer`
"""
def __init__(self, opt, name='ProxyOptimizer'):
assert isinstance(opt, tf.train.Optimizer), opt
assert isinstance(opt, tfv1.train.Optimizer), opt
super(ProxyOptimizer, self).__init__(False, name)
self._opt = opt

Expand Down
2 changes: 1 addition & 1 deletion tensorpack/tfutils/sessinit.py
Expand Up @@ -4,8 +4,8 @@
import os
import numpy as np
import six
import tensorflow as tf

from ..compat import tfv1 as tf
from ..utils import logger
from .common import get_op_tensor_name
from .varmanip import SessionUpdate, get_checkpoint_path, get_savename_from_varname, is_training_name
Expand Down
9 changes: 5 additions & 4 deletions tensorpack/tfutils/varmanip.py
Expand Up @@ -7,6 +7,7 @@
import six
import tensorflow as tf

from ..compat import tfv1
from ..utils import logger
from .common import get_op_tensor_name

Expand Down Expand Up @@ -84,7 +85,7 @@ def upcast(vartype, valtype):
return None

if hasattr(value, 'dtype'):
vartype = var.value().dtype
vartype = var.dtype
if vartype != value.dtype:
msg = "Variable {} has dtype {} but was given a value of dtype {}.".format(name, vartype, value.dtype)
newtype = upcast(var.dtype.base_dtype, value.dtype)
Expand Down Expand Up @@ -172,7 +173,7 @@ def get_checkpoint_path(model_path):
if os.path.basename(model_path) == model_path:
model_path = os.path.join('.', model_path) # avoid #4921 and #6142
if os.path.basename(model_path) == 'checkpoint':
assert tf.gfile.Exists(model_path), model_path
assert tfv1.gfile.Exists(model_path), model_path
model_path = tf.train.latest_checkpoint(os.path.dirname(model_path))
# to be consistent with either v1 or v2

Expand All @@ -186,7 +187,7 @@ def get_checkpoint_path(model_path):
logger.info(
"Checkpoint path {} is auto-corrected to {}.".format(model_path, new_path))
model_path = new_path
assert tf.gfile.Exists(model_path) or tf.gfile.Exists(model_path + '.index'), model_path
assert tfv1.gfile.Exists(model_path) or tfv1.gfile.Exists(model_path + '.index'), model_path
return model_path


Expand All @@ -200,7 +201,7 @@ def load_chkpt_vars(model_path):
dict: a name:value dict
"""
model_path = get_checkpoint_path(model_path)
reader = tf.train.NewCheckpointReader(model_path)
reader = tfv1.train.NewCheckpointReader(model_path)
var_names = reader.get_variable_to_shape_map().keys()
result = {}
for n in var_names:
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Expand Up @@ -10,7 +10,7 @@ exclude = .git,
examples,
docs/conf.py
snippet,
examples-old,
examples_v2,
_test.py,

[isort]
Expand Down

0 comments on commit 3d1a30f

Please sign in to comment.