Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TensorFlow 2.0 Preview - TypeError: 'Attribute' object is not iterable when using tf.function #25281

Closed
mr-ubik opened this issue Jan 29, 2019 · 9 comments
Assignees
Labels
comp:data tf.data related issues TF 2.0 Issues relating to TensorFlow 2.0

Comments

@mr-ubik
Copy link

mr-ubik commented Jan 29, 2019

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes, the code is attached below
  • OS Platform and Distribution: Arch Linux
  • TensorFlow installed from (source or binary): PyPI
  • TensorFlow version (use command below): tf-nightly-gpu-2.0-preview==2.0.0.dev20190129 AKA the latest nightly
  • Python version: 3.6.8
  • CUDA/cuDNN version: 10.0
  • GPU model and memory: GTX 1080Ti

Describe the current behavior
Running the train() procedure provided below breaks while using the @tf.function decorator.

Describe the expected behavior
Not encountering any errors as per the "Effective TensorFlow 2.0 Guide"

Code to reproduce the issue

"""
Implement DCGAN using the new TF 2.0 API.

Also test tensorflow-datasets.

Celeb-A dataset.
"""

from typing import Dict
import tensorflow_datasets as tfds
import tensorflow as tf
from tensorflow import keras as k


def bce(x: tf.Tensor, label: tf.Tensor, label_smoothing: float = 0.0) -> tf.Tensor:
    """Returns the discrete binary cross entropy between x and the discrete label
    Args:
        x: a 2D tensor
        label: the discrite label, aka, the distribution to match
        label_smoothing: if greater than zero, smooth the labels

    Returns:
        The binary cros entropy
    """
    # FIXME: Fix the warning
    # assert len(x.shape) == 2 and len(label.shape) == 0

    return k.losses.BinaryCrossentropy()(tf.ones_like(x) * label, x)


def min_max(
    positive: tf.Tensor, negative: tf.Tensor, label_smoothing: float = 0.0
) -> tf.Tensor:
    """Returns the discriminator (min max) loss
    Args:
        positive: the discriminator output for the positive class: 2D tensor
        negative: the discriminator output for the negative class: 2D tensor
        smooth: if greater than zero, appiles one-sided label smoothing
    Returns:
        The sum of 2 BCE
    """

    one = tf.constant(1.0)
    zero = tf.constant(0.0)
    d_loss = bce(positive, one, label_smoothing) + bce(negative, zero)
    return d_loss


class Generator(k.Model):
    def __init__(self) -> None:
        super(Generator, self).__init__()
        self.fc1 = k.layers.Dense(4 * 4 * 1024)
        self.batchnorm1 = k.layers.BatchNormalization()

        self.conv2 = k.layers.Conv2DTranspose(
            filters=512,
            kernel_size=(5, 5),
            strides=(2, 2),
            padding="same",
            use_bias=False,
        )
        self.batchnorm2 = k.layers.BatchNormalization()

        self.conv3 = k.layers.Conv2DTranspose(
            filters=256,
            kernel_size=(5, 5),
            strides=(2, 2),
            padding="same",
            use_bias=False,
        )
        self.batchnorm3 = k.layers.BatchNormalization()

        self.conv4 = k.layers.Conv2DTranspose(
            filters=128,
            kernel_size=(5, 5),
            strides=(2, 2),
            padding="same",
            use_bias=False,
        )
        self.batchnorm4 = k.layers.BatchNormalization()

        self.conv5 = k.layers.Conv2DTranspose(
            filters=3,
            kernel_size=(5, 5),
            strides=(2, 2),
            padding="same",
            use_bias=False,
        )
        self.batchnorm5 = k.layers.BatchNormalization()

    def call(self, x: tf.Tensor, training: bool = True) -> tf.Tensor:
        x = self.fc1(x)
        x = self.batchnorm1(x, training=training)
        x = tf.nn.relu(x)
        x = tf.reshape(x, shape=(-1, 4, 4, 1024))

        x = self.conv2(x)
        x = self.batchnorm2(x, training=training)
        x = tf.nn.relu(x)

        x = self.conv3(x)
        x = self.batchnorm3(x, training=training)
        x = tf.nn.relu(x)

        x = self.conv4(x)
        x = self.batchnorm4(x, training=training)
        x = tf.nn.relu(x)

        x = self.conv5(x)
        x = self.batchnorm5(x, training=training)

        x = tf.nn.tanh(x)
        return x


class Discriminator(k.Model):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = k.layers.Conv2D(128, (5, 5), strides=(2, 2), padding="same")
        self.conv2 = k.layers.Conv2D(256, (5, 5), strides=(2, 2), padding="same")
        self.batchnorm2 = k.layers.BatchNormalization()
        self.conv3 = k.layers.Conv2D(512, (5, 5), strides=(2, 2), padding="same")
        self.batchnorm3 = k.layers.BatchNormalization()
        self.conv4 = k.layers.Conv2D(1024, (5, 5), strides=(2, 2), padding="same")
        self.batchnorm4 = k.layers.BatchNormalization()
        self.flatten = k.layers.Flatten()
        self.fc5 = k.layers.Dense(1)

    def call(self, x, training=True):
        x = self.conv1(x)
        x = tf.nn.leaky_relu(x)

        x = self.conv2(x)
        x = self.batchnorm2(x)
        x = tf.nn.leaky_relu(x)

        x = self.conv3(x)
        x = self.batchnorm3(x)
        x = tf.nn.leaky_relu(x)

        x = self.conv4(x)
        x = self.batchnorm4(x)
        x = tf.nn.leaky_relu(x)

        x = self.flatten(x)
        x = self.fc5(x)
        return x


class GAN:
    def __init__(self, generator, discriminator, encoder=None):
        """
        GAN initializer.

        Args:
            generator: A ``tensorflow.keras.Model`` to use as Generator.
            discriminator: A ``tensorflow.keras.Model`` to use as Discriminator.
            encoder: A ``tensorflow.keras.Model`` to use as Encoder.

        Returns:
            Trained GAN model (?).

        """
        self.G = generator()
        self.D = discriminator()
        self.E = encoder() if encoder is not None else None
        self.latent_vector_dims = 100

        self.G_opt = k.optimizers.Adam(learning_rate=1e-5, beta_1=0.5)
        self.D_opt = k.optimizers.Adam(learning_rate=1e-5, beta_1=0.5)

    @tf.function()
    def train(self, dataset: tf.data.Dataset):
        """
        Train.
        """
        for step, features in enumerate(dataset, start=1):
            x = features["image"]
            z = tf.random.normal((x.shape[0], self.latent_vector_dims))

            # We record all the operations in the tape
            with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
                G_z = self.G(z, training=True)

                D_x = self.D(x, training=True)
                D_Gz = self.D(G_z, training=True)

                g_loss = bce(D_Gz, tf.constant(1.0))
                d_loss = min_max(D_x, D_Gz, label_smoothing=0.0)

            # We retrieve the gradients from our records
            G_grads = gen_tape.gradient(g_loss, self.G.trainable_variables)
            D_grads = disc_tape.gradient(d_loss, self.D.trainable_variables)

            # Optimize and apply the gradients
            self.G_opt.apply_gradients(zip(G_grads, self.G.trainable_variables))
            self.D_opt.apply_gradients(zip(D_grads, self.D.trainable_variables))

            if step % 10 == 0:
                print(f"--------------------------")
                print(f"STEP: {step}")
                print(f"D_LOSS: {d_loss}")
                print(f"G_LOSS: {g_loss}")


class InputPipeline:
    def __init__(
        self, dataset, batch_size, epochs, shuffle_buffer, prefetched_items, size
    ):
        self.batch_size = batch_size
        self.dataset_name = dataset
        self.epochs = epochs
        self.prefetched_items = prefetched_items
        self.shuffle_buffer = shuffle_buffer
        self.size = size

    def get_input_fn(self) -> tf.data.Dataset:
        """Input fn."""
        return self.input_fn

    def load_public_dataset(self):
        """
        Load one of the publicly available datasets, will merge together all the splits.

        Args:
            chosen_dataset: dataset to use.

        Return:
            The chosen dataset as a ``tf.data.Dataset``

        """
        # Construct a tf.data.Dataset
        datasets = tfds.load(name=self.dataset_name, split=tfds.Split.ALL)
        return datasets

    def resize_images(self, features: Dict) -> Dict:
        """
        Overwrite the \"image\" feature in order to resize them.

        Args:
            features: features dictionary.
            size: desired target size.

        Returns:
            Features with \"image\" resized to the correct shape.

        """
        features["image"] = tf.image.resize(features["image"], self.size)
        return features

    def input_fn(self):
        dataset = self.load_public_dataset()
        dataset = (
            dataset.map(self.resize_images)
            .shuffle(self.shuffle_buffer)
            .batch(self.batch_size)
            .prefetch(self.prefetched_items)
            .repeat(self.epochs)
        )
        return dataset


def main():

    # TODO: replace with CLI
    CHOICE = "celeb_a"
    EPOCHS = 10
    BATCH_SIZE = 64
    PREFETCH = 10
    SHUFFLE_BUFFER = 10000

    # See available datasets
    public_datasets = tfds.list_builders()

    gan = GAN(Generator, Discriminator)
    input_pipeline = InputPipeline(
        dataset=CHOICE,
        batch_size=BATCH_SIZE,
        epochs=EPOCHS,
        prefetched_items=PREFETCH,
        shuffle_buffer=SHUFFLE_BUFFER,
        size=(64, 64),
    )
    dataset = input_pipeline.input_fn()
    gan.train(dataset=dataset)


if __name__ == "__main__":
    main()

Other info / logs

Full Traceback

Traceback (most recent call last):
  File "dcgan-tf2.py", line 289, in <module>
    main()
  File "dcgan-tf2.py", line 285, in main
    gan.train(dataset=dataset)
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 383, in __call__
    self._initialize(args, kwds)
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 355, in _initialize
    *args, **kwds))
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1097, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1322, in _maybe_define_function
    arg_names=arg_names), self._function_attributes)
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py", line 540, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 298, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1803, in bound_method_wrapper
    return wrapped_fn(weak_instance(), *args, **kwargs)
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py", line 533, in wrapper
    ), *args, **kwargs)
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/autograph/impl/api.py", line 293, in converted_call
    experimental_partial_types=partial_types)
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/autograph/impl/api.py", line 415, in to_graph
    arg_values, arg_types)
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/autograph/impl/conversion.py", line 222, in entity_to_graph
    entity_to_graph(candidate, program_ctx, {}, {})
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/autograph/impl/conversion.py", line 175, in entity_to_graph
    node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types)
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/autograph/impl/conversion.py", line 376, in function_to_graph
    node = node_to_graph(node, context)
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/autograph/impl/conversion.py", line 435, in node_to_graph
    node = converter.apply_(node, context, call_trees)
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/autograph/core/converter.py", line 507, in apply_
    node = converter_module.transform(node, context)
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/autograph/converters/call_trees.py", line 350, in transform
    return CallTreeTransformer(ctx).visit(node)
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/autograph/core/converter.py", line 440, in visit
    return super(Base, self).visit(node)
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/autograph/pyct/transformer.py", line 484, in visit
    result = super(Base, self).visit(node)
  File "/usr/lib64/python3.6/ast.py", line 253, in visit
    return visitor(node)
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/autograph/converters/call_trees.py", line 282, in visit_FunctionDef
    node.returns = self.visit_block(node.returns)
  File "/home/ubik/.python_envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/autograph/pyct/transformer.py", line 368, in visit_block
    for node in nodes:
TypeError: 'Attribute' object is not iterable

I had previously opened a question on SO.

EDIT: even writing a simpler input pipeline, dropping tensorflow-datasets and using the builtin Keras datasets the error persists.

CC @galeone

@mr-ubik
Copy link
Author

mr-ubik commented Jan 29, 2019

Actually we have hunted down the issue to the type annotations in the two losses.

If we remove the return type we get tf is undefined, if we remove all the annotations the train under tf.function works properly.

@jvishnuvardhan jvishnuvardhan self-assigned this Jan 29, 2019
@jvishnuvardhan jvishnuvardhan added TF 2.0 Issues relating to TensorFlow 2.0 comp:data tf.data related issues labels Jan 29, 2019
@jvishnuvardhan
Copy link
Contributor

@mr-ubik Is there any issue still or is it okay to be closed? Thanks!

@jvishnuvardhan jvishnuvardhan added the stat:awaiting response Status - Awaiting response from author label Jan 29, 2019
@mr-ubik
Copy link
Author

mr-ubik commented Feb 1, 2019

@jvishnuvardhan I guess you could close it, though the issue with type hints remains.

@mr-ubik mr-ubik closed this as completed Feb 1, 2019
@jvishnuvardhan jvishnuvardhan added stat:awaiting tensorflower Status - Awaiting response from tensorflower and removed stat:awaiting response Status - Awaiting response from author labels Feb 1, 2019
@martinwicke martinwicke reopened this Feb 1, 2019
@martinwicke
Copy link
Member

This should be open.

@alexbw we should support type annotations. With Py3 these will be increasingly pervasive and not supporting them will cause trouble.

@martinwicke
Copy link
Member

@mr-ubik we should at least say clearly that type annotations are not supported in autograph. Better, throw a meaningful error.

A real fix may take longer.

@alexbw
Copy link
Contributor

alexbw commented Feb 1, 2019

I agree that type annotations should not blow us up. Noted.

Also, actually using type information will be important in the future, as it will help us catch "easy" cases when type information has been user-provided. We have avoided that strategy so far because we can't unambiguously determine type in all cases for Python, so focused on the purely runtime-type-discovery strategy. Adding even a little bit of ahead-of-time type information will help, and we'll get there.

@mdanatg @aaandrewww @brilee

@mdanatg
Copy link

mdanatg commented Feb 1, 2019

Thank you for the detailed investigation!

Thankfully, this is a simpler bug in autograph and only incidentally caused by annotations (could have been any other field). We have a change in progress that will land over the next few days, and that should fix this particular error.

We didn't run extensive tests with type-annotated code, which is why this bug slipped though (so there may be more), but I'm not aware of any reasons why annotations should not be supported, in the sense that autograph should just let them pass through. Anything that doesn't do that should be a bug, please let us know about it!

Will update this thread as the fix lands.

Cheers,
Dan

@mdanatg mdanatg assigned mdanatg and unassigned martinwicke Feb 1, 2019
@tensorflowbutler tensorflowbutler removed the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Feb 2, 2019
@mdanatg
Copy link

mdanatg commented Feb 14, 2019

Quick update - The TypeError has now been fixed, but a second fix is needed to make sure the symbols that the type annotations refer to resolve during the conversion process. The fix is largely a refactoring, and should land over the coming weeks - will post an update here.

@mdanatg
Copy link

mdanatg commented Apr 25, 2019

Just double checked this against tf-nightly and it runs. BTW, I had to make a few changes to get it to run well with tf.function, mainly in the train function. See below:

from typing import Dict
import tensorflow_datasets as tfds
import tensorflow as tf
from tensorflow import keras as k


def bce(x: tf.Tensor, label: tf.Tensor, label_smoothing: float = 0.0) -> tf.Tensor:
    """Returns the discrete binary cross entropy between x and the discrete label
    Args:
        x: a 2D tensor
        label: the discrite label, aka, the distribution to match
        label_smoothing: if greater than zero, smooth the labels

    Returns:
        The binary cros entropy
    """
    # FIXME: Fix the warning
    # assert len(x.shape) == 2 and len(label.shape) == 0

    return k.losses.BinaryCrossentropy()(tf.ones_like(x) * label, x)


def min_max(
    positive: tf.Tensor, negative: tf.Tensor, label_smoothing: float = 0.0
) -> tf.Tensor:
    """Returns the discriminator (min max) loss
    Args:
        positive: the discriminator output for the positive class: 2D tensor
        negative: the discriminator output for the negative class: 2D tensor
        smooth: if greater than zero, appiles one-sided label smoothing
    Returns:
        The sum of 2 BCE
    """

    one = tf.constant(1.0)
    zero = tf.constant(0.0)
    d_loss = bce(positive, one, label_smoothing) + bce(negative, zero)
    return d_loss


class Generator(k.Model):
    def __init__(self) -> None:
        super(Generator, self).__init__()
        self.fc1 = k.layers.Dense(4 * 4 * 1024)
        self.batchnorm1 = k.layers.BatchNormalization()

        self.conv2 = k.layers.Conv2DTranspose(
            filters=512,
            kernel_size=(5, 5),
            strides=(2, 2),
            padding="same",
            use_bias=False,
        )
        self.batchnorm2 = k.layers.BatchNormalization()

        self.conv3 = k.layers.Conv2DTranspose(
            filters=256,
            kernel_size=(5, 5),
            strides=(2, 2),
            padding="same",
            use_bias=False,
        )
        self.batchnorm3 = k.layers.BatchNormalization()

        self.conv4 = k.layers.Conv2DTranspose(
            filters=128,
            kernel_size=(5, 5),
            strides=(2, 2),
            padding="same",
            use_bias=False,
        )
        self.batchnorm4 = k.layers.BatchNormalization()

        self.conv5 = k.layers.Conv2DTranspose(
            filters=3,
            kernel_size=(5, 5),
            strides=(2, 2),
            padding="same",
            use_bias=False,
        )
        self.batchnorm5 = k.layers.BatchNormalization()

    def call(self, x: tf.Tensor, training: bool = True) -> tf.Tensor:
        x = self.fc1(x)
        x = self.batchnorm1(x, training=training)
        x = tf.nn.relu(x)
        x = tf.reshape(x, shape=(-1, 4, 4, 1024))

        x = self.conv2(x)
        x = self.batchnorm2(x, training=training)
        x = tf.nn.relu(x)

        x = self.conv3(x)
        x = self.batchnorm3(x, training=training)
        x = tf.nn.relu(x)

        x = self.conv4(x)
        x = self.batchnorm4(x, training=training)
        x = tf.nn.relu(x)

        x = self.conv5(x)
        x = self.batchnorm5(x, training=training)

        x = tf.nn.tanh(x)
        return x


class Discriminator(k.Model):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = k.layers.Conv2D(128, (5, 5), strides=(2, 2), padding="same")
        self.conv2 = k.layers.Conv2D(256, (5, 5), strides=(2, 2), padding="same")
        self.batchnorm2 = k.layers.BatchNormalization()
        self.conv3 = k.layers.Conv2D(512, (5, 5), strides=(2, 2), padding="same")
        self.batchnorm3 = k.layers.BatchNormalization()
        self.conv4 = k.layers.Conv2D(1024, (5, 5), strides=(2, 2), padding="same")
        self.batchnorm4 = k.layers.BatchNormalization()
        self.flatten = k.layers.Flatten()
        self.fc5 = k.layers.Dense(1)

    def call(self, x, training=True):
        x = self.conv1(x)
        x = tf.nn.leaky_relu(x)

        x = self.conv2(x)
        x = self.batchnorm2(x)
        x = tf.nn.leaky_relu(x)

        x = self.conv3(x)
        x = self.batchnorm3(x)
        x = tf.nn.leaky_relu(x)

        x = self.conv4(x)
        x = self.batchnorm4(x)
        x = tf.nn.leaky_relu(x)

        x = self.flatten(x)
        x = self.fc5(x)
        return x


class GAN:
    def __init__(self, generator, discriminator, encoder=None):
        """
        GAN initializer.

        Args:
            generator: A ``tensorflow.keras.Model`` to use as Generator.
            discriminator: A ``tensorflow.keras.Model`` to use as Discriminator.
            encoder: A ``tensorflow.keras.Model`` to use as Encoder.

        Returns:
            Trained GAN model (?).

        """
        self.G = generator()
        self.D = discriminator()
        self.E = encoder() if encoder is not None else None
        self.latent_vector_dims = 100

        self.G_opt = k.optimizers.Adam(learning_rate=1e-5, beta_1=0.5)
        self.D_opt = k.optimizers.Adam(learning_rate=1e-5, beta_1=0.5)

    @tf.function()
    def train(self, dataset: tf.data.Dataset):
        """
        Train.
        """
        step = tf.constant(0)
        for features in dataset:
            step += 1

            x = features["image"]
            z = tf.random.normal((tf.shape(x)[0], self.latent_vector_dims))

            # We record all the operations in the tape
            with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
                G_z = self.G(z, training=True)

                D_x = self.D(x, training=True)
                D_Gz = self.D(G_z, training=True)

                g_loss = bce(D_Gz, tf.constant(1.0))
                d_loss = min_max(D_x, D_Gz, label_smoothing=0.0)

            # We retrieve the gradients from our records
            G_grads = gen_tape.gradient(g_loss, self.G.trainable_variables)
            D_grads = disc_tape.gradient(d_loss, self.D.trainable_variables)

            # Optimize and apply the gradients
            self.G_opt.apply_gradients(zip(G_grads, self.G.trainable_variables))
            self.D_opt.apply_gradients(zip(D_grads, self.D.trainable_variables))

            if tf.equal(step % 10, 0):
                tf.print("--------------------------")
                tf.print("STEP:", step)
                tf.print("D_LOSS:", d_loss)
                tf.print("G_LOSS:", g_loss)


class InputPipeline:
    def __init__(
        self, dataset, batch_size, epochs, shuffle_buffer, prefetched_items, size
    ):
        self.batch_size = batch_size
        self.dataset_name = dataset
        self.epochs = epochs
        self.prefetched_items = prefetched_items
        self.shuffle_buffer = shuffle_buffer
        self.size = size

    def get_input_fn(self) -> tf.data.Dataset:
        """Input fn."""
        return self.input_fn

    def load_public_dataset(self):
        """
        Load one of the publicly available datasets, will merge together all the splits.

        Args:
            chosen_dataset: dataset to use.

        Return:
            The chosen dataset as a ``tf.data.Dataset``

        """
        # Construct a tf.data.Dataset
        datasets = tfds.load(name=self.dataset_name, split=tfds.Split.ALL)
        return datasets

    def resize_images(self, features: Dict) -> Dict:
        """
        Overwrite the \"image\" feature in order to resize them.

        Args:
            features: features dictionary.
            size: desired target size.

        Returns:
            Features with \"image\" resized to the correct shape.

        """
        features["image"] = tf.image.resize(features["image"], self.size)
        return features

    def input_fn(self):
        dataset = self.load_public_dataset()
        dataset = (
            dataset.map(self.resize_images)
            .take(10 * self.batch_size)
            .shuffle(self.shuffle_buffer)
            .batch(self.batch_size)
            .prefetch(self.prefetched_items)
            .repeat(self.epochs)
        )
        return dataset


def main():

    tf.autograph.set_verbosity(0, True)

    # TODO: replace with CLI
    CHOICE = "celeb_a"
    EPOCHS = 10
    BATCH_SIZE = 64
    PREFETCH = 10
    SHUFFLE_BUFFER = 10000

    # See available datasets
    public_datasets = tfds.list_builders()

    gan = GAN(Generator, Discriminator)
    input_pipeline = InputPipeline(
        dataset=CHOICE,
        batch_size=BATCH_SIZE,
        epochs=EPOCHS,
        prefetched_items=PREFETCH,
        shuffle_buffer=SHUFFLE_BUFFER,
        size=(64, 64),
    )
    dataset = input_pipeline.input_fn()
    gan.train(dataset=dataset)


if __name__ == "__main__":
    main()

@mdanatg mdanatg closed this as completed Apr 25, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:data tf.data related issues TF 2.0 Issues relating to TensorFlow 2.0
Projects
None yet
Development

No branches or pull requests

6 participants