diff --git a/neuralmonkey/decorators.py b/neuralmonkey/decorators.py index 62f968009..146ac986c 100644 --- a/neuralmonkey/decorators.py +++ b/neuralmonkey/decorators.py @@ -1,6 +1,9 @@ from functools import wraps +import tensorflow as tf + from neuralmonkey.model.model_part import ModelPart +from neuralmonkey.tf_utils import tf_print def tensor(func): @@ -12,6 +15,10 @@ def decorate(self, *args, **kwargs): # jump out of the caller's scope and into the ModelPart's scope with self.use_scope(): value = func(self, *args, **kwargs) + if isinstance(value, tf.Tensor): + value = tf_print( + value, "<{}.{}>".format(self.name, func.__name__), + "tensorval") else: value = func(self, *args, **kwargs) setattr(self, attribute_name, value) diff --git a/neuralmonkey/logging.py b/neuralmonkey/logging.py index fd6670916..8b92861a2 100644 --- a/neuralmonkey/logging.py +++ b/neuralmonkey/logging.py @@ -3,7 +3,7 @@ import os # pylint: disable=unused-import -from typing import Any, Optional, List +from typing import Any, List # pylint: enable=unused-import from termcolor import colored @@ -11,13 +11,13 @@ class Logging(object): - log_file = None # type: Optional[Any] + log_file = None # type: Any # 'all' and 'none' are special symbols, # others are filtered according the labels - debug_enabled = [ + debug_enabled_for = [ os.environ.get("NEURALMONKEY_DEBUG_ENABLE", "none")] # type: List[str] - debug_disabled = [ + debug_disabled_for = [ os.environ.get("NEURALMONKEY_DEBUG_DISABLE", "")] # type: List[str] strict_mode = os.environ.get("NEURALMONKEY_STRICT") # type: str @@ -76,24 +76,32 @@ def print_header(title: str, path: str) -> None: log_print("") @staticmethod - def debug(message: str, label: Optional[str] = None): - if "none" in Logging.debug_enabled: - return - - if (label not in Logging.debug_enabled and - "all" not in Logging.debug_enabled): - return - - if label in Logging.debug_disabled: + def debug(message: str, label: str = None): + if not debug_enabled(label): return if label: - prefix = "DEBUG ({}):".format(label) + prefix = "{}: DEBUG ({}): ".format(Logging._get_time(), label) else: - prefix = "DEBUG:" + prefix = "{}: DEBUG: ".format(Logging._get_time()) log_print("{}{}".format(colored(prefix, color="cyan"), message)) + @staticmethod + def debug_enabled(label: str = None): + if "none" in Logging.debug_enabled_for: + return False + + if label is None: + return True + + if (label in Logging.debug_disabled_for + or ("all" not in Logging.debug_enabled_for + and label not in Logging.debug_enabled_for)): + return False + + return True + # pylint: disable=invalid-name # we want these helper functions to have this exact name @@ -102,3 +110,4 @@ def debug(message: str, label: Optional[str] = None): debug = Logging.debug warn = Logging.warn notice = Logging.notice +debug_enabled = Logging.debug_enabled diff --git a/neuralmonkey/tf_utils.py b/neuralmonkey/tf_utils.py index c30b9b4a9..3af3da72e 100644 --- a/neuralmonkey/tf_utils.py +++ b/neuralmonkey/tf_utils.py @@ -3,8 +3,11 @@ # pylint: disable=unused-import from typing import Dict, Set # pylint: enable=unused-import +import numpy as np import tensorflow as tf +from neuralmonkey.logging import debug, debug_enabled + def _get_current_experiment(): # This is needed to avoid circular imports. @@ -45,3 +48,40 @@ def get_variable(name: str, name=name, shape=shape, dtype=dtype, initializer=get_initializer(name, initializer), **kwargs) + + +def tf_print(tensor: tf.Tensor, + message: str = None, + debug_label: str = None) -> tf.Tensor: + """Print the value of a tensor to the debug log. + + Better than tf.Print, logs to console only when the "tensorval" debug + subject is turned on. + + Idea found at: https://stackoverflow.com/a/39649614 + + Args: + tensor: The tensor whose value to print + + Returns: + As tf.Print, this function returns a tensor identical to the input + tensor, with the printing side-effect added. + """ + def print_tensor(x: np.ndarray) -> tf.Tensor: + if message is not None: + debug( + "{}, shape: {}:\n{}".format(message, x.shape, x), debug_label) + else: + debug("Shape: {}\n{}".format(x.shape, x), debug_label) + return x + + # To save time, check if debug will print something + if not debug_enabled(debug_label): + return tensor + + log_op = tf.py_func(print_tensor, [tensor], [tensor.dtype])[0] + + with tf.control_dependencies([log_op]): + res = tf.identity(tensor) + + return res