Skip to content

Commit

Permalink
Merge pull request #686 from ufal/tensorval
Browse files Browse the repository at this point in the history
Super-duper tensor logging functionality
  • Loading branch information
jindrahelcl committed Apr 4, 2018
2 parents fa63027 + fc7e727 commit 52d4c29
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 15 deletions.
7 changes: 7 additions & 0 deletions neuralmonkey/decorators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from functools import wraps

import tensorflow as tf

from neuralmonkey.model.model_part import ModelPart
from neuralmonkey.tf_utils import tf_print


def tensor(func):
Expand All @@ -12,6 +15,10 @@ def decorate(self, *args, **kwargs):
# jump out of the caller's scope and into the ModelPart's scope
with self.use_scope():
value = func(self, *args, **kwargs)
if isinstance(value, tf.Tensor):
value = tf_print(
value, "<{}.{}>".format(self.name, func.__name__),
"tensorval")
else:
value = func(self, *args, **kwargs)
setattr(self, attribute_name, value)
Expand Down
39 changes: 24 additions & 15 deletions neuralmonkey/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,21 @@
import os

# pylint: disable=unused-import
from typing import Any, Optional, List
from typing import Any, List
# pylint: enable=unused-import

from termcolor import colored


class Logging(object):

log_file = None # type: Optional[Any]
log_file = None # type: Any

# 'all' and 'none' are special symbols,
# others are filtered according the labels
debug_enabled = [
debug_enabled_for = [
os.environ.get("NEURALMONKEY_DEBUG_ENABLE", "none")] # type: List[str]
debug_disabled = [
debug_disabled_for = [
os.environ.get("NEURALMONKEY_DEBUG_DISABLE", "")] # type: List[str]
strict_mode = os.environ.get("NEURALMONKEY_STRICT") # type: str

Expand Down Expand Up @@ -76,24 +76,32 @@ def print_header(title: str, path: str) -> None:
log_print("")

@staticmethod
def debug(message: str, label: Optional[str] = None):
if "none" in Logging.debug_enabled:
return

if (label not in Logging.debug_enabled and
"all" not in Logging.debug_enabled):
return

if label in Logging.debug_disabled:
def debug(message: str, label: str = None):
if not debug_enabled(label):
return

if label:
prefix = "DEBUG ({}):".format(label)
prefix = "{}: DEBUG ({}): ".format(Logging._get_time(), label)
else:
prefix = "DEBUG:"
prefix = "{}: DEBUG: ".format(Logging._get_time())

log_print("{}{}".format(colored(prefix, color="cyan"), message))

@staticmethod
def debug_enabled(label: str = None):
if "none" in Logging.debug_enabled_for:
return False

if label is None:
return True

if (label in Logging.debug_disabled_for
or ("all" not in Logging.debug_enabled_for
and label not in Logging.debug_enabled_for)):
return False

return True


# pylint: disable=invalid-name
# we want these helper functions to have this exact name
Expand All @@ -102,3 +110,4 @@ def debug(message: str, label: Optional[str] = None):
debug = Logging.debug
warn = Logging.warn
notice = Logging.notice
debug_enabled = Logging.debug_enabled
40 changes: 40 additions & 0 deletions neuralmonkey/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
# pylint: disable=unused-import
from typing import Dict, Set
# pylint: enable=unused-import
import numpy as np
import tensorflow as tf

from neuralmonkey.logging import debug, debug_enabled


def _get_current_experiment():
# This is needed to avoid circular imports.
Expand Down Expand Up @@ -45,3 +48,40 @@ def get_variable(name: str,
name=name, shape=shape, dtype=dtype,
initializer=get_initializer(name, initializer),
**kwargs)


def tf_print(tensor: tf.Tensor,
message: str = None,
debug_label: str = None) -> tf.Tensor:
"""Print the value of a tensor to the debug log.
Better than tf.Print, logs to console only when the "tensorval" debug
subject is turned on.
Idea found at: https://stackoverflow.com/a/39649614
Args:
tensor: The tensor whose value to print
Returns:
As tf.Print, this function returns a tensor identical to the input
tensor, with the printing side-effect added.
"""
def print_tensor(x: np.ndarray) -> tf.Tensor:
if message is not None:
debug(
"{}, shape: {}:\n{}".format(message, x.shape, x), debug_label)
else:
debug("Shape: {}\n{}".format(x.shape, x), debug_label)
return x

# To save time, check if debug will print something
if not debug_enabled(debug_label):
return tensor

log_op = tf.py_func(print_tensor, [tensor], [tensor.dtype])[0]

with tf.control_dependencies([log_op]):
res = tf.identity(tensor)

return res

0 comments on commit 52d4c29

Please sign in to comment.