New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensorflow eager version fails, while Tensorflow static graph works #23407

Closed
galeone opened this Issue Oct 31, 2018 · 24 comments

Comments

Projects
None yet
6 participants
@galeone

galeone commented Oct 31, 2018

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): no
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Archlinux
  • TensorFlow installed from (source or binary): repository
  • TensorFlow version (use command below): 1.11
  • Python version: 3.7
  • CUDA/cuDNN version: cuda 10, cudnn 7
  • GPU model and memory: nvidia 1080ti

Describe the current behavior

I'm porting a ML model ( https://github.com/samet-akcay/ganomaly ) implemented in pytorch to tensorflow (using the keras layers, knowing that tf 2.0 will come soon).

The first implementation was using the eager version to do the train, but the model collapses, nothing works.

The same model definition has been reused but it has been used to first define a static graph and then train the model: it works perfectly.

Describe the expected behavior

The static graph version and the eager version should have the same behavior.

Code to reproduce the issue

Model description (same for both eager and static)

from typing import Dict
import tensorflow as tf
import tensorflow.keras as k
import numpy as np

conv_initializer = k.initializers.random_normal(0.0, 0.02)
batchnorm_inizializer = k.initializers.random_normal(1.0, 0.02)

eps = 1e-5
momentum = 0.99


class Decoder(k.models.Model):
    """
    Decoder (Generator) Network
    """

    def __init__(self, output_depth: int = 1):
        super(Decoder, self).__init__()

        self.conv1 = k.layers.Conv2DTranspose(
            filters=256,
            kernel_size=(4, 4),
            strides=(1, 1),
            kernel_initializer=conv_initializer,
            input_shape=(-1, 1, 1, 100),
            use_bias=False,
        )
        self.batchnorm1 = k.layers.BatchNormalization(
            epsilon=eps,
            momentum=momentum,
            beta_initializer=batchnorm_inizializer,
            gamma_initializer=batchnorm_inizializer,
        )

        self.conv2 = k.layers.Conv2DTranspose(
            filters=128,
            kernel_size=(4, 4),
            strides=(2, 2),
            padding="same",
            kernel_initializer=conv_initializer,
            use_bias=False,
        )
        self.batchnorm2 = k.layers.BatchNormalization(
            epsilon=eps,
            momentum=momentum,
            beta_initializer=batchnorm_inizializer,
            gamma_initializer=batchnorm_inizializer,
        )

        self.conv3 = k.layers.Conv2DTranspose(
            filters=64,
            kernel_size=(4, 4),
            strides=(2, 2),
            padding="same",
            kernel_initializer=conv_initializer,
            use_bias=False,
        )
        self.batchnorm3 = k.layers.BatchNormalization(
            epsilon=eps,
            momentum=momentum,
            beta_initializer=batchnorm_inizializer,
            gamma_initializer=batchnorm_inizializer,
        )

        self.conv4 = k.layers.Conv2DTranspose(
            filters=output_depth,
            kernel_size=(4, 4),
            strides=(2, 2),
            padding="same",
            kernel_initializer=conv_initializer,
            use_bias=False,
        )

    def call(self, x, training=True):
        # print("X.SHAPE: ", x.shape)

        x = self.conv1(x)
        x = self.batchnorm1(x, training=training)
        x = tf.nn.relu(x)

        x = self.conv2(x)
        x = self.batchnorm2(x, training=training)
        x = tf.nn.relu(x)

        x = self.conv3(x)
        x = self.batchnorm3(x, training=training)
        x = tf.nn.relu(x)

        x = self.conv4(x)
        x = tf.nn.tanh(x)  # image

        # print("Decoder call output size: ", x.shape)

        return x


class Encoder(k.models.Model):

    def __init__(self, latent_dimensions: int = 100):
        super(Encoder, self).__init__()

        self.conv0 = k.layers.Conv2D(
            filters=64,
            kernel_size=(4, 4),
            strides=(2, 2),
            padding="same",
            kernel_initializer=conv_initializer,
            input_shape=(-1, 32, 32, 1),
            use_bias=False,
        )

        self.conv1 = k.layers.Conv2D(
            filters=128,
            kernel_size=(4, 4),
            strides=(2, 2),
            padding="same",
            kernel_initializer=conv_initializer,
            use_bias=False,
        )
        self.batchnorm1 = k.layers.BatchNormalization(
            epsilon=eps,
            momentum=momentum,
            beta_initializer=batchnorm_inizializer,
            gamma_initializer=batchnorm_inizializer,
        )

        self.conv2 = k.layers.Conv2D(
            filters=256,
            kernel_size=(4, 4),
            strides=(2, 2),
            padding="same",
            kernel_initializer=conv_initializer,
            use_bias=False,
        )
        self.batchnorm2 = k.layers.BatchNormalization(
            epsilon=eps,
            momentum=momentum,
            beta_initializer=batchnorm_inizializer,
            gamma_initializer=batchnorm_inizializer,
        )

        self.conv3 = k.layers.Conv2D(
            filters=latent_dimensions,
            kernel_size=(4, 4),
            strides=(1, 1),
            padding="valid",
            kernel_initializer=conv_initializer,
            use_bias=False,
        )

    def call(self, x, training=True):
        # x = self.conv0(x, input_shape=x.shape[1:])

        x = self.conv0(x)
        x = tf.nn.leaky_relu(x)

        x = self.conv1(x)
        x = self.batchnorm1(x, training=training)
        x = tf.nn.leaky_relu(x)

        x = self.conv2(x)
        x = self.batchnorm2(x, training=training)
        x = tf.nn.leaky_relu(x)

        x = self.conv3(x)
        # x = tf.nn.tanh(x)       # latent space unitary sphere [-1,1] TODO: temporary?

        # print("Encoder call output size: ", x.shape)

        return x


class Discriminator(k.models.Model):

    def __init__(self):
        super(Discriminator, self).__init__()

        self.conv0 = k.layers.Conv2D(
            filters=64,
            kernel_size=(4, 4),
            strides=(2, 2),
            padding="same",
            kernel_initializer=conv_initializer,
            use_bias=False,
        )

        self.conv1 = k.layers.Conv2D(
            filters=128,
            kernel_size=(4, 4),
            strides=(2, 2),
            padding="same",
            kernel_initializer=conv_initializer,
            use_bias=False,
        )
        self.batchnorm1 = k.layers.BatchNormalization(
            epsilon=eps,
            momentum=momentum,
            beta_initializer=batchnorm_inizializer,
            gamma_initializer=batchnorm_inizializer,
        )

        self.conv2 = k.layers.Conv2D(
            filters=256,
            kernel_size=(4, 4),
            strides=(2, 2),
            padding="same",
            kernel_initializer=conv_initializer,
            use_bias=False,
        )
        self.batchnorm2 = k.layers.BatchNormalization(
            epsilon=eps,
            momentum=momentum,
            beta_initializer=batchnorm_inizializer,
            gamma_initializer=batchnorm_inizializer,
        )

        self.conv3 = k.layers.Conv2D(
            filters=1,
            kernel_size=(4, 4),
            strides=(1, 1),
            padding="valid",
            kernel_initializer=conv_initializer,
            use_bias=False,
        )

    def call(self, x, training=True):
        x = self.conv0(x)
        x = tf.nn.leaky_relu(x)

        x = self.conv1(x)
        x = self.batchnorm1(x, training=training)
        x = tf.nn.leaky_relu(x)

        x = self.conv2(x)
        x = self.batchnorm2(x, training=training)
        x = tf.nn.leaky_relu(x)

        x = self.conv3(x)

        return x

Also, the following variable definitions are shared in both the eager and static implementations

            global_step = tf.train.get_or_create_global_step()
            generator_optimizer = tf.train.AdamOptimizer(2e-4, 0.5)
            discriminator_optimizer = tf.train.AdamOptimizer(2e-4, 0.5)
            discirminator = Discriminator()
            g_encoder = Encoder()
            g_decoder = Decoder()
            encoder = Encoder()

Eager training

I show just a single training update, that's run in a training loop

            # Discriminator training
            with tf.GradientTape() as tape:

                discriminator.trainable = True
                disc_x = tf.squeeze(discriminator(x, training=False), axis=[1, 2])

                disc_real_loss = tf.losses.sigmoid_cross_entropy(  # discriminator loss on result disc_x
                    multi_class_labels=tf.ones_like(disc_x), logits=disc_x
                )
                g_encoder.trainable = False
                g_decoder.trainable = False
                # recreate the data (=> x_hat), starting from real data x
                z = g_encoder(x, training=True)  # Not training
                x_hat = g_decoder(z, training=True)  # Not training

                disc_x_hat = tf.squeeze(discriminator(x_hat, training=False), axis=[1, 2])
                disc_gen_loss = tf.losses.sigmoid_cross_entropy(  # discriminator loss on result disc_x_hat
                    multi_class_labels=tf.zeros_like(disc_x_hat), logits=disc_x_hat
                )
                disc_loss = disc_real_loss + disc_gen_loss

            discriminator_gradients = tape.gradient(
                disc_loss, discriminator.trainable_variables
            )

            discriminator_optimizer.apply_gradients(
                zip(discriminator_gradients, discriminator.trainable_variables)
            )

            # Generator Training
            with tf.GradientTape() as tape:

                # err_g_bce
                g_encoder.trainable = True
                g_decoder.trainable = True
                encoder.trainable = True
                z = g_encoder(x, training=True)
                x_hat = g_decoder(z, training=True)
                disc_x_hat = tf.squeeze(
                    discriminator(x_hat, training=False), axis=[1, 2]
                ) 
                bce_loss = tf.losses.sigmoid_cross_entropy(
                    multi_class_labels=tf.ones_like(disc_x_hat),
                    logits=disc_x_hat,  # G wants to generate reals so ones_like
                )

                l1_loss = tf.losses.absolute_difference(x, x_hat)
                # err_g_enc
                z_hat = encoder(x_hat, training=True)
                l2_loss = tf.losses.mean_squared_error(z, z_hat)

                gen_loss = 1* bce_loss + 50 * l1_loss + 1 * l2_loss
                
            trainable_variable_list = (
                g_encoder.trainable_variables
                + g_decoder.trainable_variables
                + encoder.trainable_variables
            )

            generator_gradients = tape.gradient(gen_loss, trainable_variable_list)

            generator_optimizer.apply_gradients(
                zip(generator_gradients, trainable_variable_list),
                global_step=global_step,
            )

Static graph training

First I show the graph definition,then how is used in the training loop

Graph def

    # Discriminator on real
    D_x = discriminator(x)
    D_x = tf.squeeze(D_x, axis=[1, 2])

    # Generate fake
    z = g_encoder(x)
    x_hat = g_decoder(z)

    D_x_hat = discriminator(x_hat)
    D_x_hat = tf.squeeze(D_x_hat, axis=[1, 2])
    ## Discriminator
    d_loss = tf.losses.sigmoid_cross_entropy( 
                    multi_class_labels=tf.ones_like(D_x), logits=D_x) + tf.losses.sigmoid_cross_entropy(
                    multi_class_labels=tf.zeros_like(D_x_hat), logits=D_x_hat)

    # encode x_hat to z_hat
    z_hat = encoder(x_hat)

    ## Generator
    bce_loss = tf.losses.sigmoid_cross_entropy(tf.ones_like(D_x_hat), D_x_hat)
    l1_loss = tf.losses.absolute_difference(x, x_hat)
    l2_loss = tf.losses.mean_squared_error(z, z_hat)

    g_loss = 1 * bce_loss + 1 * l1_loss + 1 * l2_loss
    global_step = tf.train.get_or_create_global_step()

    # Define the D train op
    train_d = tf.train.AdamOptimizer(
        lr, beta1=0.5).minimize(
            d_loss, var_list=D.trainable_variables)

    # train G_e G_d E
    train_g = tf.train.AdamOptimizer(
        lr, beta1=0.5).minimize(
            g_loss,
            global_step=global_step,
            var_list=g_encoder.trainable_variables + g_decoder.trainable_variables +
            encoder.trainable_variables)

And this is what's inside the training loop (executed in a MonitoredSession):

            # extract from tf.data.Dataset, x is a placeholder and x_ is the iterator.get_next
            real = sess.run(x_) 
            feed_dict = {x: real}

            # train D
            _, d_loss_value = sess.run([train_d, d_loss], feed_dict)

            # train G+E
            _, g_loss_value, step = sess.run([train_g, g_loss, global_step],
                                             feed_dict)

The model definition is the same, the loss definition is the same, the training steps and the same, the only difference is the eager mode enabled or disabled. The results with eager on are:

D loss: collapses to zero, that's wrong since its aim is to stay around 0.5
bad d

Generated images: wrong, bad reconstructions since D collapsed
bad_gen

While when eager is off, the discriminator loss looks correct and the generated output are the one expected:

D loss:
good d

Generated output:
good_gen

@harshini-gadige

This comment has been minimized.

harshini-gadige commented Oct 31, 2018

@galeone - Please try with Python 3.6.

@galeone

This comment has been minimized.

galeone commented Nov 2, 2018

@harshini-gadige I tried. Same exact results: D collapses in eager version, D works correctly in static graph version.

For this test:

  • Python 3.6.6
  • Tensorflow: 1.11.0
@martinwicke

This comment has been minimized.

Member

martinwicke commented Nov 5, 2018

I wonder whether it's possible to narrow this down a little. Maybe by logging a bunch f tensors to see where the differences first appear?

Also, I recommend Sequential, it'll dramatically simplify the model definition.

@galeone

This comment has been minimized.

galeone commented Nov 5, 2018

@martinwicke I add debug operations in both versions and I'll try to spot where the difference starts.
I'll update the thread as soon as I have some insight

@alextp

This comment has been minimized.

Member

alextp commented Nov 5, 2018

The eager code you shared should also work in graph mode to produce training ops (one from the generator and one from the discriminator, which you can session.run separately). Can you do that and see if you also observe the collapse?

@galeone

This comment has been minimized.

galeone commented Nov 6, 2018

@alextp can you please show me in which way should I change the code to run it in a Session?
Because I'm not used to mixing session evaluation and eager mode thus I'm not sure where to place the session creation, and when to run sess.run (if it was in static-graph mode there would be no problem)

@alextp

This comment has been minimized.

Member

alextp commented Nov 7, 2018

I meant replacing your static graph training code block with

            # Discriminator training
            x = tf.placeholder(....)
            with tf.GradientTape() as tape:

                discriminator.trainable = True
                disc_x = tf.squeeze(discriminator(x, training=False), axis=[1, 2])

                disc_real_loss = tf.losses.sigmoid_cross_entropy(  # discriminator loss on result disc_x
                    multi_class_labels=tf.ones_like(disc_x), logits=disc_x
                )
                g_encoder.trainable = False
                g_decoder.trainable = False
                # recreate the data (=> x_hat), starting from real data x
                z = g_encoder(x, training=True)  # Not training
                x_hat = g_decoder(z, training=True)  # Not training

                disc_x_hat = tf.squeeze(discriminator(x_hat, training=False), axis=[1, 2])
                disc_gen_loss = tf.losses.sigmoid_cross_entropy(  # discriminator loss on result disc_x_hat
                    multi_class_labels=tf.zeros_like(disc_x_hat), logits=disc_x_hat
                )
                disc_loss = disc_real_loss + disc_gen_loss

            discriminator_gradients = tape.gradient(
                disc_loss, discriminator.trainable_variables
            )

            d_loss = disc_loss
            train_d = discriminator_optimizer.apply_gradients(
                zip(discriminator_gradients, discriminator.trainable_variables)
            )

            # Generator Training
            with tf.GradientTape() as tape:

                # err_g_bce
                g_encoder.trainable = True
                g_decoder.trainable = True
                encoder.trainable = True
                z = g_encoder(x, training=True)
                x_hat = g_decoder(z, training=True)
                disc_x_hat = tf.squeeze(
                    discriminator(x_hat, training=False), axis=[1, 2]
                ) 
                bce_loss = tf.losses.sigmoid_cross_entropy(
                    multi_class_labels=tf.ones_like(disc_x_hat),
                    logits=disc_x_hat,  # G wants to generate reals so ones_like
                )

                l1_loss = tf.losses.absolute_difference(x, x_hat)
                # err_g_enc
                z_hat = encoder(x_hat, training=True)
                l2_loss = tf.losses.mean_squared_error(z, z_hat)

                gen_loss = 1* bce_loss + 50 * l1_loss + 1 * l2_loss
                
            trainable_variable_list = (
                g_encoder.trainable_variables
                + g_decoder.trainable_variables
                + encoder.trainable_variables
            )

            generator_gradients = tape.gradient(gen_loss, trainable_variable_list)

           g_loss = gen_loss
            train_g = generator_optimizer.apply_gradients(
                zip(generator_gradients, trainable_variable_list),
                global_step=global_step,
            )

and then running your session.run loop for training unchanged:

            # extract from tf.data.Dataset, x is a placeholder and x_ is the iterator.get_next
            real = sess.run(x_) 
            feed_dict = {x: real}

            # train D
            _, d_loss_value = sess.run([train_d, d_loss], feed_dict)

            # train G+E
            _, g_loss_value, step = sess.run([train_g, g_loss, global_step],
                                             feed_dict)
@galeone

This comment has been minimized.

galeone commented Nov 8, 2018

Here we go:

The discriminator does not collapse

[0] d: 1.4126055240631104 - g: 51.892757415771484
[100] d: 1.4121590852737427 - g: 10.285645484924316
[200] d: 1.136518955230713 - g: 8.836201667785645
[300] d: 1.3435794115066528 - g: 7.9710540771484375
[400] d: 1.2257683277130127 - g: 8.381233215332031
[500] d: 1.3684046268463135 - g: 7.7494797706604
[600] d: 1.274878740310669 - g: 8.12108039855957
[700] d: 1.0551540851593018 - g: 7.315877437591553
[800] d: 1.1707713603973389 - g: 6.710018634796143

and the generator works correctly
works

Therefore there's something different between in the execution in eager mode vs graph mode.

@akshaym

This comment has been minimized.

Member

akshaym commented Nov 8, 2018

Hi @galeone,

I tried to reproduce the problems, but I'm not able to. The only changes I made were to add this loop (and print the losses).

def main(_):
  tf.enable_eager_execution()

  (train_images, train_labels), _ = k.datasets.fashion_mnist.load_data()

  train_images = train_images / 255.0
  train_images = train_images.astype(np.float32)
  train_images = np.expand_dims(train_images, -1)

  train_dataset = tf.data.Dataset.from_tensor_slices(
      (train_images, train_labels)).shuffle(10000).repeat().batch(128).map(lambda x, y: (tf.image.resize_images(x, [32, 32]), y))

  global_step = tf.train.get_or_create_global_step()
  generator_optimizer = tf.train.AdamOptimizer(2e-4, 0.5)
  discriminator_optimizer = tf.train.AdamOptimizer(2e-4, 0.5)
  discriminator = Discriminator()
  g_encoder = Encoder()
  g_decoder = Decoder()
  encoder = Encoder()

  it = iter(train_dataset)
  for i in range(10000):
    x, _ = it.next()
    train(encoder, g_encoder, g_decoder, generator_optimizer, discriminator,
          discriminator_optimizer, global_step, x)

I get after 2000 steps:

step <tf.Variable 'global_step:0' shape=() dtype=int64, numpy=1000>, g: 4.06344842911, d: 1.27861332893                                                                                                                                                                                   
step <tf.Variable 'global_step:0' shape=() dtype=int64, numpy=2000>, g: 3.98709654808, d: 0.57758295536

Note that I tried this with/without a GPU, and on 1.11 as well as nightly and wasn't able to reproduce it.

Perhaps you can paste your full code so that I can reproduce this better?

@galeone

This comment has been minimized.

galeone commented Nov 9, 2018

Yes, no problem.

The models are the one posted earlier and are inside the models/ganomaly.py file.

The dataset definition (common to both version too) is placed inside ops/input_fn/fashion_mnist.py:

"""Load Fashion MNIST Dataset."""

from typing import Callable, Dict, Tuple

import numpy as np
import sklearn.model_selection as sk
from skimage.transform import resize
from matplotlib import pyplot as plt
import tensorflow as tf
import tensorflow.keras as k

Dataset = Dict[str, np.ndarray]
CompleteDataset = Dict[str, Dataset]


def arrays_dataset_to_generator(dataset: Dataset) -> Callable:
    # DOCUMENT
    def _generator():
        for i in range(len(dataset["x"])):
            yield dataset["x"][i], dataset["y"][i]

    return _generator


def extract_label(
    features: np.ndarray, labels: np.ndarray, target_label: int
) -> Tuple[Dataset, Dataset]:
    """
    Extract all features whose label is ``target_labels``.

    Args:
        features: Dataset features
        labels: Dataset labels
        target_label: The label used as filter

    Returns:
        All the features with the ``target_labels`` as label.

    """
    bool_filter = labels == target_label
    normal_features = features[~bool_filter]
    normal_labels = labels[~bool_filter]
    anomalous_features = features[bool_filter]
    anomalous_labels = labels[bool_filter]
    return (
        {"x": normal_features, "y": normal_labels},
        {"x": anomalous_features, "y": anomalous_labels},
    )


def _process_features(feature, image_size):
    feature = tf.squeeze(tf.image.resize_images(tf.expand_dims(tf.expand_dims(feature, axis=2), axis=0), size=(32, 32)),
                         axis=0)  # resize change to float32

    # Normalize features between -1 and 1
    feature = (feature * 2) - 1

    return feature.numpy()


class FashionMNIST:
    """
    DigitsMNIST Dataset.

    Attributes:
        train_datasets (Dict[str, np.ndarray]): Dictionary of features for training..
            The dictionary has keys in the form of ``non_{label}`` meaning that each
            key holds the 80% of the features whose labels is not ``{label}``
        test_datasets (Dict[str, np.ndarray]): Dictionary of features for testing.
            The dictionary has keys in the form of ``only_{label}`` meaning that each
            key holds the all the features whose label is ``{label}`` and 20% of all
            the other features as non-anomalies

    """

    train_datasets: CompleteDataset
    test_datasets: CompleteDataset

    # def __init__(self, image_size: Tuple[int, int, int] = (28, 28, 1)) -> None:
    def __init__(self, image_size: Tuple[int, int, int] = (32, 32, 1)) -> None:
        """
        Fetch the dataset and partition it for training and testing.

        Retrieve the dataset from the Keras builtin, split each class 80-20 and
        using these splits build all the datasets required for training and testing.

        Args:
            image_size: Size of each image

        Returns:
            None.

        """
        self.image_size = image_size
        self.train_datasets = {}
        self.test_datasets = {}

        mnist_train_set, _ = k.datasets.fashion_mnist.load_data()
        mnist_x, mnist_y = mnist_train_set

        mnist_x = np.array([_process_features(x, self.image_size) for x in mnist_x])

        for n in range(10):
            # print(n)
            normal_data, anomalous_data = extract_label(mnist_x, mnist_y, n)
            n_train_x, n_test_x, n_train_y, n_test_y = sk.train_test_split(
                normal_data["x"], normal_data["y"], test_size=0.2, random_state=42
            )

            # print(n_train_x.shape[0], n_test_x.shape[0], anomalous_data["x"].shape[0])

            self.train_datasets[f"anomaly_{n}"] = {"x": n_train_x, "y": n_train_y}
            self.test_datasets[f"anomaly_{n}"] = {
                "x": np.concatenate((n_test_x, anomalous_data["x"])),
                "y": np.concatenate((n_test_y, anomalous_data["y"])),
            }

    def to_tf_dataset(self, arrays_dataset: Dataset, hyper: Dict) -> tf.data.Dataset:
        """
        Convert ``features`` and ``labels`` into a `tf.data.Datasets``.

        Args:
            features: Array of features
            labels: Array of labels
            hyper: Dictionary of hyperparameters

        Returns:
            `tf.data.Dataset` of batch(image, label).

        """
        generator = arrays_dataset_to_generator(arrays_dataset)
        dataset = tf.data.Dataset.from_generator(generator, (tf.float32, tf.uint8))
        dataset = (
            dataset.shuffle(hyper["buffer_size"], seed=42)
            .batch(hyper["batch_size"], drop_remainder=True)
            .prefetch(1)
        )
        dataset = dataset.repeat(hyper["epochs"]) if hyper.get("epochs") else dataset

        return dataset

Eager version

This is the training definition:

from __future__ import annotations

import os
import time
from typing import Callable, Dict, Optional

import tensorflow as tf
import tensorflow.keras as k
from models.ganomaly import Discriminator
import statistics as s
import copy

LOGGED: Dict = {}


class GANomaly:
    """
    GANomaly
    """

    g_encoder: k.models.Model  # bow-tie encoder (encoder)
    g_decoder: k.models.Model  # bow-tie decoder (generator)
    encoder: k.models.Model  # encoder
    discriminator: k.models.Model  # discriminator

    generator_optimizer: tf.train.AdamOptimizer
    discriminator_optimizer: tf.train.AdamOptimizer

    dataset: tf.data.Dataset

    checkpoint: tf.train.Checkpoint
    logging_fn: Optional[Callable]

    def __init__(
        self,
        g_encoder: k.models.Model,
        g_decoder: k.models.Model,
        encoder: k.models.Model,
        discriminator: k.models.Model,
        dataset: tf.data.Dataset,
        model_dir: str,
        hyper: Dict,
        logging_fn: Optional[Callable] = None,
    ) -> None:

        # Model
        self.g_encoder = g_encoder()
        self.g_decoder = g_decoder()
        self.encoder = encoder()
        self.discriminator = discriminator()
        self.hyper = hyper

        # Optimizers

        # this is for g_encode, g_decoder and encoder
        self.global_step = tf.train.get_or_create_global_step()
        self.learning_rate = self.hyper["learning_rate"]

        # with decay, 2 optimizer
        self.generator_optimizer = tf.train.AdamOptimizer(
            self.learning_rate, hyper["beta1"]
        )

        self.discriminator_optimizer = tf.train.AdamOptimizer(
            self.learning_rate, hyper["beta1"]
        )
        # Data
        self.dataset = dataset

        # Checkpoints
        self.checkpoint = tf.train.Checkpoint(
            generator_optimizer=self.generator_optimizer,
            discriminator_optimizer=self.discriminator_optimizer,
            g_encoder=self.g_encoder,
            g_decoder=self.g_decoder,
            encoder=self.encoder,
            discriminator=self.discriminator,
        )

        self.checkpoint_prefix = os.path.join(model_dir, "ckpt")
        self.summary_writer = tf.contrib.summary.create_file_writer(
            model_dir, flush_millis=10000
        )

        self.saved_model_dir = model_dir

        # Logging
        self.logging_fn = logging_fn

    def train(  # TODO: reduce complexity
        self,
        steps_per_epoch: int,
        batch_size: int,
        noise_dims: int,
        epochs: float,
        checkpoint_frequency: int,
        logging_enabled: bool = False,
        discriminator_passes: int = 2,
    ) -> None:
        # global_step = tf.train.get_or_create_global_step()

        # self.g_enc_dec_weights = []
        epoch = 0
        epoch_start = time.time()
        for x, _ in self.dataset:
            tf.Print(self.global_step, [self.global_step], "GLOBAL STEP")
            if int(self.global_step.numpy()) % steps_per_epoch == 0:
                tf.logging.info("---------------[NEW EPOCH]---------------")
                tf.logging.info(f"Current Epoch {epoch + 1} | Total Epochs: {epochs}")
                tf.logging.info(
                    f"Step {self.global_step.numpy()} "
                    f"| Total Steps: {steps_per_epoch * epochs}"
                )
                tf.logging.info(f"Epoch time: {time.time() - epoch_start}")
                epoch_start = time.time()
                epoch += 1
            step_start = time.time()

            # Discriminator training
            with tf.GradientTape() as tape:

                self.discriminator.trainable = True
                disc_x = tf.squeeze(
                    self.discriminator(x, training=False), axis=[1, 2]
                )  # discriminator on real data x. Training on.

                # I save the weights here because here I know they are not zero
                if int(self.global_step.numpy()) == 0:
                    print("Saving weights...")
                    self.discriminator.save_weights()

                disc_real_loss = tf.losses.sigmoid_cross_entropy(  # discriminator loss on result disc_x
                    multi_class_labels=tf.ones_like(disc_x), logits=disc_x
                )

                self.g_encoder.trainable = False
                self.g_decoder.trainable = False
                # recreate the data (=> x_hat), starting from real data x
                z = self.g_encoder(x, training=True)  # Not training
                x_hat = self.g_decoder(z, training=True)  # Not training

                disc_x_hat = tf.squeeze(
                    self.discriminator(x_hat, training=False), axis=[1, 2]
                )  # discriminator on recreated data. Training on.
                disc_gen_loss = tf.losses.sigmoid_cross_entropy(  # discriminator loss on result disc_x_hat
                    multi_class_labels=tf.zeros_like(disc_x_hat), logits=disc_x_hat
                )
                disc_loss = disc_real_loss + disc_gen_loss

            discriminator_gradients = tape.gradient(
                disc_loss, self.discriminator.trainable_variables
            )

            self.discriminator_optimizer.apply_gradients(
                zip(discriminator_gradients, self.discriminator.trainable_variables)
            )

            # Generator Training
            with tf.GradientTape() as tape:

                # err_g_bce
                self.g_encoder.trainable = True
                self.g_decoder.trainable = True
                self.encoder.trainable = True
                z = self.g_encoder(x, training=True)
                x_hat = self.g_decoder(z, training=True)
                disc_x_hat = tf.squeeze(
                    self.discriminator(x_hat, training=False), axis=[1, 2]
                )  # Training false
                bce_loss = tf.losses.sigmoid_cross_entropy(
                    multi_class_labels=tf.ones_like(disc_x_hat),
                    logits=disc_x_hat,  # G wants to generate reals so ones_like
                )

                # print("disc_x_hat.shape::::::::", disc_x_hat.shape)

                # err_g_l1l
                l1_loss = tf.losses.absolute_difference(x, x_hat)

                # err_g_enc
                z_hat = self.encoder(x_hat, training=True)
                l2_loss = tf.losses.mean_squared_error(z, z_hat)

                # final generator loss
                gen_loss = (
                    self.hyper["adversarial_w"] * bce_loss
                    + self.hyper["contextual_w"] * l1_loss
                    + self.hyper["encoder_w"] * l2_loss
                )

            trainable_variable_list = (
                self.g_encoder.trainable_variables
                + self.g_decoder.trainable_variables
                + self.encoder.trainable_variables
            )
            generator_gradients = tape.gradient(gen_loss, trainable_variable_list)

            self.generator_optimizer.apply_gradients(
                zip(generator_gradients, trainable_variable_list),
                global_step=self.global_step,
            )

            if disc_loss < 1e-4:
                print("RESET WEIGHTS")
                self.discriminator.reset_weights()
                # self.discriminator = Discriminator()

            step_time_delta = time.time() - step_start

            LOGGED.update(
                {  # HACK: Find better way
                    "generated_data": x_hat,
                    "real_data": x,
                    "encoded_real": z,
                    "gen_loss": gen_loss,
                    "disc_loss": disc_loss,
                    "gen_bce_loss": bce_loss,
                    "gen_cont_loss": l1_loss,
                    "gen_enc_loss": l2_loss,
                    "step": self.global_step,
                    # "learning_rate": self.learning_rate(),
                    "learning_rate": self.learning_rate,
                }
            )

            # TODO: Move the logging to a separate function
            # TODO: Divide Epoch-wise logging from Step-wise logging
            # TODO: Add support for metrics

            if logging_enabled:
                if self.logging_fn:
                    self.logging_fn(self.summary_writer, LOGGED)
                else:
                    tf.logging.error("Logging enabled but no logging_fn was provided.")

        tf.logging.info("# ############ [Training Complete] ##############")
        self.checkpoint.save(file_prefix=self.checkpoint_prefix)

And this is the file that runs the train:

"""
Run the model.
"""
from typing import Optional

import fire
import tensorflow as tf

from models.anogan import AnoGAN
import GANomaly as ganomaly
import models.ganomaly as m_ganomaly

tf.enable_eager_execution()
tf.logging.set_verbosity(tf.logging.INFO)


def main(gan: str, mnist_type: Optional[str] = None) -> None:
    """Execute the training script."""
    print("main -- checking fire arguments")
    if gan.lower() == "dummy":
        tf.logging.info(f"Starting DummyGAN")
        from models.dummy import Generator, Discriminator, stepwise_train_logging
        from ops.input_fn.dummy import DummyData

        hyper = {
            "mean": 10.0,
            "std": 0.01,
            "points": 10000,
            "batch_size": 1000,
            "epochs": 2200,
            "buffer_size": 10000,
            "noise_dims": 100,
            "learning_rate": 0.0002,
            "beta1": 0.5,
        }

        config = {"checkpoint_frequency": 500}

        dataset = DummyData(hyper["mean"], hyper["std"], hyper["points"])
        train_dataset = dataset.to_tf_dataset(
            dataset.train_features,
            batch_size=hyper["batch_size"],
            buffer_size=hyper["buffer_size"],
            epochs=hyper["epochs"],
        )

        dummy_gan = AnoGAN(
            Generator,
            Discriminator,
            train_dataset,
            "logs/dummy",
            logging_fn=stepwise_train_logging,
            hyper=hyper,
        )

        dummy_gan.train(
            steps_per_epoch=int(hyper["points"] / hyper["batch_size"]),
            batch_size=int(hyper["batch_size"]),
            noise_dims=int(hyper["noise_dims"]),
            checkpoint_frequency=int(config["checkpoint_frequency"]),
            logging_enabled=True,
            epochs=hyper["epochs"],
        )

    elif (gan.lower() == "mnist") and mnist_type is not None:
        print("Category: MNIST")
        tf.logging.info(f"Starting MnistGAN")
        from models.ganomaly import stepwise_train_logging
        from ops.input_fn.fashion_mnist import FashionMNIST
        from ops.input_fn.mnist import DigitsMNIST

        config = {"checkpoint_frequency": 1}

        # ----------------------------------- fashion mnist
        if mnist_type.lower() == "fashion":
            print("Sub-Category: FASHION")

            hyper = {
                "batch_size": 64,
                "epochs": 2000,     # 500
                "buffer_size": 6000,
                "noise_dims": 100,  # 128
                "dataset_length": 6000,
                "learning_rate": 0.0002,
                "beta1": 0.5,
                "adversarial_w": 1,    # 1
                "contextual_w": 50,     # 50
                "encoder_w": 1,        # 1
            }

            dataset = FashionMNIST()
            n = 0
            train_dataset = dataset.to_tf_dataset(
                dataset.train_datasets[f"anomaly_{n}"],
                hyper=hyper,
            )

            dummy_gan = ganomaly.GANomaly(
                m_ganomaly.Encoder,
                m_ganomaly.Decoder,
                m_ganomaly.Encoder,
                m_ganomaly.Discriminator,
                train_dataset,
                "logs/ganomaly",
                logging_fn=stepwise_train_logging,
                hyper=hyper,
            )

            dummy_gan.train(
                steps_per_epoch=int(hyper["dataset_length"] / hyper["batch_size"]),
                batch_size=int(hyper["batch_size"]),
                noise_dims=int(hyper["noise_dims"]),
                checkpoint_frequency=int(config["checkpoint_frequency"]),
                logging_enabled=True,
                epochs=hyper["epochs"],
            )

if __name__ == "__main__":
    fire.Fire(main)

^ This version produces the D collapse.

While inside a session, (with ops generated by the eager as @alextp told me to do or using the static graph definition as I showed in the first post) it does not.

@akshaym

This comment has been minimized.

Member

akshaym commented Nov 9, 2018

Hi @galeone,

I'm still unable to reproduce this (the scale of the loss is different, but it doesn't seem to collapse). I am able to fix the scaling by updating the _process_features function to include as the first line:

feature = feature / 255.

I am also not saving/restoring discriminator weights.

Could you try to replace your dataset using the dataset code I posted above, and see if it still collapses?

That might point to some environment differences which are contributing to this.

Thanks

@galeone

This comment has been minimized.

galeone commented Nov 12, 2018

Update: The eager version with the feature /= 255. does not collapses, it works.

But the question is: why in my setup the missing feature scaling makes the model collapse while in yours it doesn't?

@martinwicke

This comment has been minimized.

Member

martinwicke commented Nov 12, 2018

@galeone

This comment has been minimized.

galeone commented Nov 12, 2018

Python 3.6 and Python 3.7 (with tensorflow compiled and packaged by the archlinux maintainers with cuda 10)

@alextp

This comment has been minimized.

Member

alextp commented Nov 12, 2018

@galeone

This comment has been minimized.

galeone commented Nov 13, 2018

It might be, but I repeated the tests several times and the behavior was always the same

@akshaym

This comment has been minimized.

Member

akshaym commented Nov 14, 2018

Hi @galeone,

Taking a step back, I want to confirm that it is actually broken without the feature /= 255. line?

The images that were generated initially (from your screenshot in #23407 (comment)) seem reasonable enough (at the very least not broken/collapsed).

Can you print the disc_loss out in python for the eager case as well (as I mentioned when I did this, the loss was small, but not 0.0).

Thanks!

@galeone

This comment has been minimized.

galeone commented Nov 14, 2018

Yes, without the feature /= 255. line the eager code is broken.

The disc_loss value in eager mode, with the feature scaling, is:

D loss:  tf.Tensor(2.5863452, shape=(), dtype=float32)
GLOBAL STEP[1]
D loss:  tf.Tensor(3.7927284, shape=(), dtype=float32)
GLOBAL STEP[2]
D loss:  tf.Tensor(1.7489786, shape=(), dtype=float32)
GLOBAL STEP[3]
D loss:  tf.Tensor(2.4115205, shape=(), dtype=float32)
GLOBAL STEP[4]
D loss:  tf.Tensor(1.479794, shape=(), dtype=float32)
GLOBAL STEP[5]
D loss:  tf.Tensor(1.857255, shape=(), dtype=float32)
GLOBAL STEP[6]
D loss:  tf.Tensor(1.3435516, shape=(), dtype=float32)
GLOBAL STEP[7]
D loss:  tf.Tensor(1.4879988, shape=(), dtype=float32)
GLOBAL STEP[8]
D loss:  tf.Tensor(1.4008305, shape=(), dtype=float32)
GLOBAL STEP[9]
D loss:  tf.Tensor(1.3350534, shape=(), dtype=float32)
GLOBAL STEP[10]
D loss:  tf.Tensor(1.3525255, shape=(), dtype=float32)
GLOBAL STEP[11]
D loss:  tf.Tensor(1.3573906, shape=(), dtype=float32)
GLOBAL STEP[12]
D loss:  tf.Tensor(1.3513403, shape=(), dtype=float32)
GLOBAL STEP[13]
D loss:  tf.Tensor(1.3532543, shape=(), dtype=float32)
GLOBAL STEP[14]
D loss:  tf.Tensor(1.3461918, shape=(), dtype=float32)
GLOBAL STEP[15]
D loss:  tf.Tensor(1.351814, shape=(), dtype=float32)
GLOBAL STEP[16]
D loss:  tf.Tensor(1.3492448, shape=(), dtype=float32)
GLOBAL STEP[17]
D loss:  tf.Tensor(1.3633372, shape=(), dtype=float32)
GLOBAL STEP[18]
D loss:  tf.Tensor(1.3482394, shape=(), dtype=float32)
GLOBAL STEP[19]
D loss:  tf.Tensor(1.3703191, shape=(), dtype=float32)
GLOBAL STEP[20]
D loss:  tf.Tensor(1.3946171, shape=(), dtype=float32)
GLOBAL STEP[21]
D loss:  tf.Tensor(1.4548875, shape=(), dtype=float32)
GLOBAL STEP[22]
D loss:  tf.Tensor(1.6209935, shape=(), dtype=float32)
GLOBAL STEP[23]
D loss:  tf.Tensor(1.9088718, shape=(), dtype=float32)
GLOBAL STEP[24]
D loss:  tf.Tensor(2.0123127, shape=(), dtype=float32)
GLOBAL STEP[25]
D loss:  tf.Tensor(1.7655286, shape=(), dtype=float32)
GLOBAL STEP[26]
D loss:  tf.Tensor(1.5841687, shape=(), dtype=float32)
GLOBAL STEP[27]
D loss:  tf.Tensor(1.5094138, shape=(), dtype=float32)
GLOBAL STEP[28]
D loss:  tf.Tensor(1.4937377, shape=(), dtype=float32)
GLOBAL STEP[29]
D loss:  tf.Tensor(1.4909501, shape=(), dtype=float32)
GLOBAL STEP[30]
D loss:  tf.Tensor(1.5101626, shape=(), dtype=float32)

While removing the feature scaling the value becomes:

D loss:  tf.Tensor(2.5409622, shape=(), dtype=float32)
GLOBAL STEP[1]
D loss:  tf.Tensor(1.0496982, shape=(), dtype=float32)
GLOBAL STEP[2]
D loss:  tf.Tensor(0.05322905, shape=(), dtype=float32)
GLOBAL STEP[3]
D loss:  tf.Tensor(0.026622836, shape=(), dtype=float32)
GLOBAL STEP[4]
D loss:  tf.Tensor(0.00439855, shape=(), dtype=float32)
GLOBAL STEP[5]
D loss:  tf.Tensor(0.011926692, shape=(), dtype=float32)
GLOBAL STEP[6]
D loss:  tf.Tensor(0.0012371272, shape=(), dtype=float32)
GLOBAL STEP[7]
D loss:  tf.Tensor(0.0125416275, shape=(), dtype=float32)
GLOBAL STEP[8]
D loss:  tf.Tensor(0.0067688944, shape=(), dtype=float32)
GLOBAL STEP[9]
D loss:  tf.Tensor(0.0018801485, shape=(), dtype=float32)
GLOBAL STEP[10]
D loss:  tf.Tensor(0.012362722, shape=(), dtype=float32)
GLOBAL STEP[11]
D loss:  tf.Tensor(0.004985124, shape=(), dtype=float32)
GLOBAL STEP[12]
D loss:  tf.Tensor(0.001480296, shape=(), dtype=float32)
GLOBAL STEP[13]
D loss:  tf.Tensor(0.0015711666, shape=(), dtype=float32)
GLOBAL STEP[14]
D loss:  tf.Tensor(0.0035329773, shape=(), dtype=float32)
GLOBAL STEP[15]
D loss:  tf.Tensor(0.0061537297, shape=(), dtype=float32)
GLOBAL STEP[16]
D loss:  tf.Tensor(0.0015294439, shape=(), dtype=float32)
GLOBAL STEP[17]
D loss:  tf.Tensor(0.0018317825, shape=(), dtype=float32)
GLOBAL STEP[18]
D loss:  tf.Tensor(0.0015391075, shape=(), dtype=float32)
GLOBAL STEP[19]
D loss:  tf.Tensor(0.0017354895, shape=(), dtype=float32)
GLOBAL STEP[20]
D loss:  tf.Tensor(0.0014736701, shape=(), dtype=float32)
GLOBAL STEP[21]
D loss:  tf.Tensor(0.0015135503, shape=(), dtype=float32)
GLOBAL STEP[22]
D loss:  tf.Tensor(0.0013820046, shape=(), dtype=float32)
GLOBAL STEP[23]
D loss:  tf.Tensor(0.0013579719, shape=(), dtype=float32)
GLOBAL STEP[24]
D loss:  tf.Tensor(0.0020369224, shape=(), dtype=float32)
GLOBAL STEP[25]
D loss:  tf.Tensor(0.0014253389, shape=(), dtype=float32)
GLOBAL STEP[26]
D loss:  tf.Tensor(0.0012767054, shape=(), dtype=float32)
GLOBAL STEP[27]
D loss:  tf.Tensor(0.001459129, shape=(), dtype=float32)
GLOBAL STEP[28]
D loss:  tf.Tensor(0.0017556315, shape=(), dtype=float32)
GLOBAL STEP[29]
D loss:  tf.Tensor(0.007157793, shape=(), dtype=float32)
GLOBAL STEP[30]
D loss:  tf.Tensor(0.011590726, shape=(), dtype=float32)
@akshaym

This comment has been minimized.

Member

akshaym commented Nov 15, 2018

Ah, so the loss is similar for me with/without scaling (which at the very least means we are on the same page with regards to reproducibility :). The loss is small, but I'd expect 0.0s or NaNs if the model had collapsed, so I was saying that this hadn't collapsed.

Is it possible that the datasets and preprocessing steps are different between graph and eager? If so, can you try the graph version with an identical dataset?

@galeone

This comment has been minimized.

galeone commented Nov 15, 2018

But the model collapses since if I let the training run for more steps, it collapses to 0.0.

Yes, it is possible the input datasets are different. But, when I tried to put inside a session the ops generated by the eager mode, the dataset was the same and the behavior was different (graph generated by the eager mode executed inside a session, and eager training - both with the same input).

This is the strangest part.

@galeone

This comment has been minimized.

galeone commented Nov 21, 2018

Update: maybe #23882 is related

@alextp

This comment has been minimized.

Member

alextp commented Nov 21, 2018

@tensorflowbutler

This comment has been minimized.

Member

tensorflowbutler commented Dec 6, 2018

Nagging Assignee @alextp: It has been 14 days with no activity and this issue has an assignee. Please update the label and/or status accordingly.

@alextp

This comment has been minimized.

Member

alextp commented Dec 6, 2018

I think the RNG is probably the source, so I'll close this. Please reopen if there's evidence otherwise.

@alextp alextp closed this Dec 6, 2018

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment