diff --git a/.gitignore b/.gitignore index c468f8a4f..eaa3ba9a8 100644 --- a/.gitignore +++ b/.gitignore @@ -2,15 +2,19 @@ *.npz *.pyc *~ + .DS_Store .idea .spyproject/ +.pytest_cache/ + +venv/ build/ -dist +data/ +dist/ docs/_build tensorlayer.egg-info -tensorlayer/__pacache__ -venv/ -.pytest_cache/ +tensorlayer/__pycache__ + update_tl.bat update_tl.py diff --git a/tensorlayer/layers/core.py b/tensorlayer/layers/core.py index 80ac83a22..8fdcd9945 100644 --- a/tensorlayer/layers/core.py +++ b/tensorlayer/layers/core.py @@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- +import six import time +from abc import ABCMeta, abstractmethod import numpy as np import tensorflow as tf @@ -38,10 +40,16 @@ ] -class LayersConfig: +@six.add_metaclass(ABCMeta) +class LayersConfig(object): + tf_dtype = tf.float32 # TensorFlow DType set_keep = {} # A dictionary for holding tf.placeholders + @abstractmethod + def __init__(self): + pass + try: # For TF12 and later TF_GRAPHKEYS_VARIABLES = tf.GraphKeys.GLOBAL_VARIABLES diff --git a/tests/test_layers_core.py b/tests/test_layers_core.py index fb0cfb956..806c43954 100644 --- a/tests/test_layers_core.py +++ b/tests/test_layers_core.py @@ -6,6 +6,16 @@ import tensorlayer as tl +class Core_Helpers_Test(unittest.TestCase): + + def test_LayersConfig(self): + with self.assertRaises(TypeError): + tl.layers.LayersConfig() + + self.assertIsInstance(tl.layers.LayersConfig.tf_dtype, tf.DType) + self.assertIsInstance(tl.layers.LayersConfig.set_keep, dict) + + class Layer_Core_Test(unittest.TestCase): @classmethod