Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training with GPU on TF 2.0 is much slower than on TF 1.14 if set a large number to input_dim of tf.keras.layers.Embedding #32104

Closed
DeviLeo opened this issue Aug 30, 2019 · 27 comments
Assignees
Labels
comp:keras Keras related issues TF 2.0 Issues relating to TensorFlow 2.0 type:performance Performance Issue

Comments

@DeviLeo
Copy link

DeviLeo commented Aug 30, 2019

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux-3.10.0-957.21.3.el7.x86_64 CentOS-7.3.1611-Core
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device: None
  • TensorFlow installed from (source or binary): binary, pip install tensorflow-gpu
  • TensorFlow version (use command below): 2.0.0-rc0(v2.0.0-beta1-5101-gc75bb66), 1.14.0(v1.14.0-rc1-22-gaf24dc91b5)
  • Python version: 3.6.8
  • Bazel version (if compiling from source): None
  • GCC/Compiler version (if compiling from source): None
  • CUDA/cuDNN version: CUDA 10.0.130, cuDNN 7.6.3.30
  • GPU model and memory: RTX 2070 Super, 8GB

Describe the current behavior
I converted the Keras implementation of Neural Matrix Factorization (NeuMF) to tf.keras and it works well on TF 1.14.
But when I run it on TF 2.0.0-rc0, the training is much slower than on TF 1.14.
I use the profiling tools to check the time, and I found ReadVariableOp takes too much time if I set a large number to the input_dim of tf.keras.layers.Embedding.

Tensorflow version:  2.0.0-rc0
Epoch 1/3
10000/10000 [==============================] - 5s 532us/sample - loss: 0.6935
Epoch 2/3
10000/10000 [==============================] - 4s 436us/sample - loss: 0.6903
Epoch 3/3
10000/10000 [==============================] - 4s 431us/sample - loss: 0.6851
Tensorflow version:  1.14.0
Epoch 1/3
10000/10000 [==============================] - 2s 212us/sample - loss: 0.7035
Epoch 2/3
10000/10000 [==============================] - 0s 28us/sample - loss: 0.6981
Epoch 3/3
10000/10000 [==============================] - 0s 29us/sample - loss: 0.6909

Describe the expected behavior
The speed of training on TF 2.0 with large input_dim of Embedding should be the same as TF 1.14 or faster.

Code to reproduce the issue
I have shared the codes on Colab
or check the codes below.

# -*- coding:utf-8 -*-

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.regularizers import l1, l2
from tensorflow.keras.layers import Embedding, Input, Dense, Lambda, Flatten

def get_model(num_users, num_items, mf_dim=10, layers=[10], reg_layers=[0], reg_mf=0, alpha=0.5):
  assert len(layers) == len(reg_layers)
  num_layer = len(layers) #Number of layers in the MLP
  
  # Input variables
  user_input = Input(shape=(1,), dtype='int32', name = 'user_input')
  item_input = Input(shape=(1,), dtype='int32', name = 'item_input')
  
  # Embedding layer
  MF_Embedding_User = Embedding(input_dim = num_users, output_dim = mf_dim, name = 'mf_embedding_user', 
                                embeddings_initializer = keras.initializers.RandomNormal(mean=0.0, stddev=0.01, seed=None), embeddings_regularizer = l2(reg_mf), 
                                input_length=1)
  MF_Embedding_Item = Embedding(input_dim = num_items, output_dim = mf_dim, name = 'mf_embedding_item', 
                                embeddings_initializer = keras.initializers.RandomNormal(mean=0.0, stddev=0.01, seed=None), embeddings_regularizer = l2(reg_mf), 
                                input_length=1)

  MLP_Embedding_User = Embedding(input_dim = num_users, output_dim = int(layers[0]/2), name = "mlp_embedding_user", 
                                  embeddings_initializer = keras.initializers.RandomNormal(mean=0.0, stddev=0.01, seed=None), embeddings_regularizer = l2(reg_layers[0]), 
                                  input_length=1)
  MLP_Embedding_Item = Embedding(input_dim = num_items, output_dim = int(layers[0]/2), name = 'mlp_embedding_item', 
                                  embeddings_initializer = keras.initializers.RandomNormal(mean=0.0, stddev=0.01, seed=None), embeddings_regularizer = l2(reg_layers[0]), 
                                  input_length=1)

  # MF part
  mf_user_latent = Flatten()(MF_Embedding_User(user_input))
  mf_item_latent = Flatten()(MF_Embedding_Item(item_input))
  mf_vector = keras.layers.Multiply()([mf_user_latent, mf_item_latent])

  # MLP part
  mlp_user_latent = Flatten()(MLP_Embedding_User(user_input))
  mlp_item_latent = Flatten()(MLP_Embedding_Item(item_input))
  mlp_vector = keras.layers.Concatenate(axis=-1)([mlp_user_latent, mlp_item_latent])

  for idx in range(1, num_layer):
    mlp_vector = Dense(layers[idx], 
                      activation='relu', 
                      kernel_regularizer = l2(reg_layers[idx]), 
                      bias_regularizer = l2(reg_layers[idx]), 
                      name="layer%d" %idx)(mlp_vector)

  # Concatenate MF and MLP parts
  mf_vector = Lambda(lambda x: x * alpha)(mf_vector)
  mlp_vector = Lambda(lambda x : x * (1 - alpha))(mlp_vector)
  predict_vector = keras.layers.Concatenate(axis=-1)([mf_vector, mlp_vector])

  # Final prediction layer
  prediction = Dense(1, 
                    activation='sigmoid', 
                    kernel_initializer='lecun_uniform', 
                    bias_initializer ='lecun_uniform', 
                    name = "prediction")(predict_vector)

  model = keras.Model(inputs=[user_input, item_input], outputs=[prediction])
  return model

def generate_data(num_user, num_item, count=100):
    user_input = []
    item_input = []
    labels = []
    for _ in range(count):
        user = np.random.randint(0,num_user)
        item = np.random.randint(0,num_item)
        label = np.random.randint(0,2)
        user_input.append(user)
        item_input.append(item)
        labels.append(label)
    return np.asarray(user_input), np.asarray(item_input), np.asarray(labels)

def test_model():
    num_user = 1000000
    num_item = 100000
    count = 10000
    user_input, item_input, labels = generate_data(num_user, num_item, count)

    model = get_model(num_user, num_item)
    model.compile(
        optimizer=tf.keras.optimizers.Adam(),
        loss=tf.keras.losses.BinaryCrossentropy()
    )

    # Callbacks
    callbacks = [ tf.keras.callbacks.TensorBoard(log_dir='tb-logs') ]
    model.fit([user_input, item_input], labels, batch_size=256, epochs=3, callbacks=callbacks)

if __name__ == "__main__":
    print("Tensorflow version: ", tf.__version__)
    test_model()

Other info / logs
The attachment 'tb-logs.zip' is the tensorboard logs.

The profiling screenshot of the training on TF 2.0.0-rc0.
tf2-profile

The profiling screenshot of the training on TF 1.14.
tf114-profile

@gadagashwini-zz gadagashwini-zz self-assigned this Sep 3, 2019
@gadagashwini-zz gadagashwini-zz added TF 2.0-rc0 comp:keras Keras related issues labels Sep 3, 2019
@gadagashwini-zz
Copy link
Contributor

@DeviLeo, I tried executing the code on both the versions, looks like both versions performed same. Please take a look at gist of Tf-2.0.0-rc0 and Tf-1.14.0. Please let me if my understanding is bad. Thanks!

@gadagashwini-zz gadagashwini-zz added the stat:awaiting response Status - Awaiting response from author label Sep 3, 2019
@DeviLeo
Copy link
Author

DeviLeo commented Sep 3, 2019

@gadagashwini Could you try again with GPU? If training with CPU and TPU, both versions are almost the same.

@DeviLeo DeviLeo changed the title Training on TF 2.0 is much slower than on TF 1.14 if set a large number to input_dim of tf.keras.layers.Embedding Training with GPU on TF 2.0 is much slower than on TF 1.14 if set a large number to input_dim of tf.keras.layers.Embedding Sep 3, 2019
@gadagashwini-zz gadagashwini-zz added type:performance Performance Issue and removed stat:awaiting response Status - Awaiting response from author labels Sep 3, 2019
@gadagashwini-zz
Copy link
Contributor

@DeviLeo, I could reproduce the issue with GPU,please take a look at gist here. Thanks!

@gadagashwini-zz
Copy link
Contributor

Please take a look at colab gist with Tensorflow 1.14.0.

@karmel
Copy link

karmel commented Sep 3, 2019

@DeviLeo -- are you sure you're running on GPU? It's not clear how you are intending to place ops on GPU here. Have you tried with tf.distribute.MirroredStrategy?

@DeviLeo
Copy link
Author

DeviLeo commented Sep 4, 2019

@karmel Yes, I'm sure it is running on GPU. At least nvidia-smi shows the python process.

I have just tried with tf.distribute.MirroredStrategy, and it worked with model.fit.
But in practical applications, I use model.fit_generator and NotImplementedError: `fit_generator` is not supported for models compiled with tf.distribute.Strategy. is raised.

@karmel
Copy link

karmel commented Sep 12, 2019

Can you use .fit instead? In the code snippet above, you appear to be using .fit; what requires the generator?

@DeviLeo
Copy link
Author

DeviLeo commented Sep 13, 2019

Sorry, I'm afraid I cannot use `fit` instead.
The dataset for training and validation is about 50 million pairs of user-item and I have to use `fit_generator` to avoid OOM.

@cupdike
Copy link

cupdike commented Sep 27, 2019

There seems to be a significant slowdown generally when using TF2 fit_generator. It seems to be around a 3x slowdown in my own code between TF1 and TF2. It is easy to reproduce using the "Transfer Learning with TFHub" example from the TF2 official tutorials on Collab: https://www.tensorflow.org/beta/tutorials/images/hub_with_keras

To reproduce it, all I did was change model.fit to model.fit_generator. I ran these cases for both TF2 and TF1. TF1 is via this change to the first code cell:
%tensorflow_version 1.x

Here's the training runs for the four cases:

TF2 fit
Epoch 1/2
115/115 [======] - 24s 212ms/step - loss: 0.6619 - acc: 0.9062
Epoch 2/2
115/115 [======] - 20s 178ms/step - loss: 0.3309 - acc: 0.8125

TF2 fit_generator
115/115 [======] - 56s 485ms/step - loss: 0.6666 - acc: 0.9375
Epoch 2/2
115/115 [======] - 49s 424ms/step - loss: 0.3345 - acc: 0.9688

TF1 fit
Epoch 1/2
115/115 [======] - 16s 136ms/step - loss: 0.6406 - acc: 0.8750
Epoch 2/2
115/115 [======] - 15s 129ms/step - loss: 0.3279 - acc: 0.8750

TF1 fit_generator
Epoch 1/2
115/115 [======] - 16s 139ms/step - loss: 0.7300 - acc: 0.8125
Epoch 2/2
115/115 [======] - 15s 132ms/step - loss: 0.3492 - acc: 0.9062

With TF1, there is no difference between fit and fit_geneator as you might hope. TF2 seems slower in general and fit_generator in particular is 3x slower than TF1--at least for this tutorial and my own code.

BTW, Collab is using TF2 RC1 at the moment:

tf.__version__
'2.0.0-rc1'

@DanMinhNguyen
Copy link

DanMinhNguyen commented Oct 2, 2019

Can confirm with @cupdike . I had a similar issue with my own project when switching to TF2 (stable. I waited for the official release a couple days ago), with a 2x to 3x decrease in training time for the same data and code, as compared to TF1. After some Google searching and reading, I then proceeded to implement the code using tf.data.Dataset.from_generator(), instead, which allows me to use model.fit().

Unfortunately there was 0 performance benefit either way.

As for some pseudocode (posting here just in case someone can point out something fundamentally wrong with my setup), my fit_generator version of my code went something like this below. All my code uses the internal tf.keras instead of the external one:

def datagen(args):
    while True:
        #some code here to load and manipulate data into x and y. Mostly numpy functions
        yield x,y

#some here code to create and compile model 

model.fit_generator(datagen(args), . . . )

For the pseudocode using tf.data.Dataset.from_generator():

from tensorflow.compat.v2.data import Dataset

def datagen(args):
    while True:
        #some code here to load and manipulate data into x and y. Mostly numpy functions
        yield x,y

#some code here to create and compile model 

train_data = Dataset.from_generator(generator=lambda: datagen(args), . . . )
model.fit(train_data , . . . )

@robieta
Copy link

robieta commented Oct 8, 2019

FYI this has been diagnosed in #33024. The issue is that Model.fit_generator is incorrectly running eagerly. (Consistent with @cupdike's observations)

@tabacof
Copy link

tabacof commented Oct 8, 2019

I am not sure fit_generator explains all of it: I ported code that uses embedding models from TF1 to TF2 and there was a significant decrease in performance, it was 2-3x slower.

@Raukk
Copy link

Raukk commented Oct 8, 2019

@tabacof Hi, I created the other issue linked.

There are two current resolutions to the issues people were having there;
Adding the line tf.compat.v1.disable_eager_execution() right after import tf or switching to model.fit( when using TF 2.0

You say that you don't believe the issue is because of eager execution, the easiest way to prove that is to add the first fix tf.compat.v1.disable_eager_execution() right after importing TF. If this does not improved performance, then that should put to rest the eager execution argument.

The issue I encountered is that fit_generator is kicking it into eager execution no matter what as robeita says. And in my experience, eager execution is much slower in every case.

I hope that this can help you isolate the issue so that the cause can be identified.

@DeviLeo
Copy link
Author

DeviLeo commented Oct 10, 2019

Thanks to all.

I've tried tf.compat.v1.disable_eager_execution() and model.fit(x=generator, ...) with and without tf.distribute.MirroredStrategy but no help.
I think the key problem is the large input_dim of tf.keras.layers.Embedding and training with the generator.

The following cases are all tested with GPU on TF 2.0.0rc2 compared with TF 1.14.

  1. Small input_dim, model.fit without generator, without tf.distribute.MirroredStrategy. [Fast]
  2. Large input_dim, model.fit without generator, without tf.distribute.MirroredStrategy. [Slow]
  3. Large input_dim, model.fit without generator, with tf.distribute.MirroredStrategy. [Fast]
  4. Large input_dim, model.fit with generator, without tf.distribute.MirroredStrategy. [Slow]
  5. Large input_dim, model.fit with generator, with tf.distribute.MirroredStrategy. [Slow]

Here the pseudo code I have tried.

import tensorflow as tf

# `disable_eager_execution` conflicts with `tf.distribute.MirroredStrategy`. 
# "AssertError: assert isinstance(x, dataset_ops.DatasetV2)" will be raised.
# tf.compat.v1.disable_eager_execution() 

from tensorflow.keras.utils import Sequence

class MyGenerator(Sequence):
    def __init__(self, ...):
        # do something

    def __iter__(self):
        return self

    def __len__(self):
        return batches

    def __getitem__(self, index):
        # do something
        return tuple([x1, x2, ...]), y

def train():
    strategy = tf.distribute.MirroredStrategy()
    print ('Number of devices: {}'.format(strategy.num_replicas_in_sync))
    with strategy.scope():
        model = ... # Create model
        model.compile(...)
        
        my_gen = MyGenerator(...) # Create generator

        model.fit(
            x=my_gen,
            # Do not specify `y` and `batch_size` if x is a dataset, generator, or keras.utils.Sequence instance
            # Other arguments are the same as `fit_generator`'s
            ...
        )

if __name__ == '__main__':
    train()

@Shiro-LK
Copy link

Regarding your gist on colab, I suspect it is normal to see slower training with 2.0. When you do print(tf.test.is_gpu_available()) it returns False currently, so that means the gpu is not used right ?
Is it the case in your computer too?

@DeviLeo
Copy link
Author

DeviLeo commented Nov 21, 2019

@Shiro-LK No, it returns True both on the colab and my computer.

# python3
Python 3.6.8 |Anaconda, Inc.| (default, Dec 30 2018, 01:22:34)
[GCC 7.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import tensorflow as tf
>>> print("Tensorflow version: ", tf.__version__)
Tensorflow version:  2.0.0
>>> print("is_gpu_available: ", tf.test.is_gpu_available())
2019-11-21 17:15:31.360917: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2019-11-21 17:15:31.366375: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 3593275000 Hz
2019-11-21 17:15:31.367054: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x56555a7aa340 executing computations on platform Host. Devices:
2019-11-21 17:15:31.367080: I tensorflow/compiler/xla/service/service.cc:175]   StreamExecutor device (0): Host, Default Version
2019-11-21 17:15:31.368710: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcuda.so.1
2019-11-21 17:15:31.977171: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1006] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2019-11-21 17:15:31.977604: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x56555a8622c0 executing computations on platform CUDA. Devices:
2019-11-21 17:15:31.977636: I tensorflow/compiler/xla/service/service.cc:175]   StreamExecutor device (0): GeForce RTX 2070 SUPER, Compute Capability 7.5
2019-11-21 17:15:31.977803: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1006] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2019-11-21 17:15:31.978367: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1618] Found device 0 with properties:
name: GeForce RTX 2070 SUPER major: 7 minor: 5 memoryClockRate(GHz): 1.905
pciBusID: 0000:08:00.0
2019-11-21 17:15:31.978611: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudart.so.10.0
2019-11-21 17:15:31.979953: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcublas.so.10.0
2019-11-21 17:15:31.981247: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcufft.so.10.0
2019-11-21 17:15:31.981508: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcurand.so.10.0
2019-11-21 17:15:31.983163: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcusolver.so.10.0
2019-11-21 17:15:31.984398: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcusparse.so.10.0
2019-11-21 17:15:31.988258: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudnn.so.7
2019-11-21 17:15:31.988370: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1006] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2019-11-21 17:15:31.988985: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1006] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2019-11-21 17:15:31.989527: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1746] Adding visible gpu devices: 0
2019-11-21 17:15:31.989563: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudart.so.10.0
2019-11-21 17:15:31.990447: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1159] Device interconnect StreamExecutor with strength 1 edge matrix:
2019-11-21 17:15:31.990469: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1165]      0
2019-11-21 17:15:31.990482: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1178] 0:   N
2019-11-21 17:15:31.990590: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1006] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2019-11-21 17:15:31.991174: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:1006] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2019-11-21 17:15:31.991771: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1304] Created TensorFlow device (/device:GPU:0 with 7478 MB memory) -> physical GPU (device: 0, name: GeForce RTX 2070 SUPER, pci bus id: 0000:08:00.0, compute capability: 7.5)
is_gpu_available:  True

@maximveksler
Copy link

Hi guys, any updates?

@Saduf2019 Saduf2019 removed the comp:keras Keras related issues label Mar 23, 2020
@Saduf2019 Saduf2019 removed the type:performance Performance Issue label Mar 23, 2020
@goldiegadde
Copy link
Contributor

@DeviLeo and @maximveksler can you please try with 2.2.0-rc2, there have been multiple fixes in improving 2.x performance.

here is a colab gist of the original issue with 2.2.0-rc2.

@tanzhenyu
Copy link
Contributor

@DeviLeo Thanks for the report! Would you mind trying this on tf==2.2, since it contains many performance improvements on eager

@DeviLeo
Copy link
Author

DeviLeo commented Apr 1, 2020

@goldiegadde and @tanzhenyu I have tried on tf 2.2.0-rc2 with eager mode disabled and the issue has gone. But with eager mode enabled, both on tf 2.2.0-rc2 and tf 1.15.2 are slower.

Eager mode: Disabled

Tensorflow version:  2.2.0-rc2
Tensorflow eager mode:  False
is_gpu_available:  True
Train on 10000 samples
Epoch 1/3
10000/10000 [==============================] - 1s 87us/sample - loss: 0.8003
Epoch 2/3
10000/10000 [==============================] - 1s 90us/sample - loss: 0.7889
Epoch 3/3
10000/10000 [==============================] - 1s 90us/sample - loss: 0.7758
Tensorflow version:  1.15.2
Tensorflow eager mode:  False
is_gpu_available:  True
Train on 10000 samples
Epoch 1/3
10000/10000 [==============================] - 1s 95us/sample - loss: 0.7089
Epoch 2/3
10000/10000 [==============================] - 0s 31us/sample - loss: 0.7039
Epoch 3/3
10000/10000 [==============================] - 0s 32us/sample - loss: 0.6972

Eager mode: Enabled

Tensorflow version:  2.2.0-rc2
Tensorflow eager mode:  True
is_gpu_available:  True
Epoch 1/3
40/40 [==============================] - 7s 181ms/step - loss: 0.7136
Epoch 2/3
40/40 [==============================] - 7s 178ms/step - loss: 0.7085
Epoch 3/3
40/40 [==============================] - 7s 177ms/step - loss: 0.7034
Tensorflow version:  1.15.2
Tensorflow eager mode:  True
is_gpu_available:  True
Train on 10000 samples
Epoch 1/3
10000/10000 [==============================] - 9s 947us/sample - loss: 0.9751
Epoch 2/3
10000/10000 [==============================] - 8s 845us/sample - loss: 0.9588
Epoch 3/3
10000/10000 [==============================] - 8s 848us/sample - loss: 0.9409

So, I think the issue is solved.

Thank you all.

@DeviLeo DeviLeo closed this as completed Apr 1, 2020
@Saduf2019 Saduf2019 added comp:keras Keras related issues type:performance Performance Issue labels Apr 1, 2020
@tanzhenyu
Copy link
Contributor

@goldiegadde and @tanzhenyu I have tried on tf 2.2.0-rc2 with eager mode disabled and the issue has gone. But with eager mode enabled, both on tf 2.2.0-rc2 and tf 1.15.2 are slower.

Eager mode: Disabled

Tensorflow version:  2.2.0-rc2
Tensorflow eager mode:  False
is_gpu_available:  True
Train on 10000 samples
Epoch 1/3
10000/10000 [==============================] - 1s 87us/sample - loss: 0.8003
Epoch 2/3
10000/10000 [==============================] - 1s 90us/sample - loss: 0.7889
Epoch 3/3
10000/10000 [==============================] - 1s 90us/sample - loss: 0.7758
Tensorflow version:  1.15.2
Tensorflow eager mode:  False
is_gpu_available:  True
Train on 10000 samples
Epoch 1/3
10000/10000 [==============================] - 1s 95us/sample - loss: 0.7089
Epoch 2/3
10000/10000 [==============================] - 0s 31us/sample - loss: 0.7039
Epoch 3/3
10000/10000 [==============================] - 0s 32us/sample - loss: 0.6972

Eager mode: Enabled

Tensorflow version:  2.2.0-rc2
Tensorflow eager mode:  True
is_gpu_available:  True
Epoch 1/3
40/40 [==============================] - 7s 181ms/step - loss: 0.7136
Epoch 2/3
40/40 [==============================] - 7s 178ms/step - loss: 0.7085
Epoch 3/3
40/40 [==============================] - 7s 177ms/step - loss: 0.7034
Tensorflow version:  1.15.2
Tensorflow eager mode:  True
is_gpu_available:  True
Train on 10000 samples
Epoch 1/3
10000/10000 [==============================] - 9s 947us/sample - loss: 0.9751
Epoch 2/3
10000/10000 [==============================] - 8s 845us/sample - loss: 0.9588
Epoch 3/3
10000/10000 [==============================] - 8s 848us/sample - loss: 0.9409

So, I think the issue is solved.

Thank you all.

Awesome!

@lvenugopalan lvenugopalan added the TF 2.0 Issues relating to TensorFlow 2.0 label Apr 29, 2020
@DollarAkshay
Copy link

DollarAkshay commented May 14, 2020

Just upgraded to TF 2.2 . Having huge performance issues with keras models.

Update :
Disabled eager execution and changed all my inputs to the model.predict() function from tensors to numpy arrays. Looks like its much faster now.

Update 2 :
Nevermind. Im switching back to 1.14. My summary writer dosent work if i use tf.compat.v1.disable_eager_execution(). Have no time to deal with all these bugs.

@nicholas-leonard
Copy link

Having similar problems to DollarAkshay. For me, I can't get lookup table initializers to work with eager execution disabled. Strongly considering PyTorch at this point.

@tanzhenyu
Copy link
Contributor

It seems contradictory to what @DeviLeo verified -- do you have a concreate code snippet to reproduce?

@zpqiu
Copy link

zpqiu commented Sep 20, 2020

Having the same problem. I have upgraded to TF2.2. Moreover, I use the CPU to run the code. Through analyzing the timeprofile, I also find the embedding_lookup ReadVariableOp takes the too much time.

@PWZER
Copy link

PWZER commented Aug 9, 2021

Having the same problem. I have upgraded to TF2.2. Moreover, I use the CPU to run the code. Through analyzing the timeprofile, I also find the embedding_lookup ReadVariableOp takes the too much time.

I have the same problem,Tensorflow version is 2.5.0. ReadVariableOp took a long time,and unstable memory usage.

@ShrimpLau
Copy link

I have the same problem,Tensorflow version is 2.5.0. ReadVariableOp took a long time,and unstable memory usage.

@PWZER I met the same issue and I fixed it by upgrading to TF2.6

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:keras Keras related issues TF 2.0 Issues relating to TensorFlow 2.0 type:performance Performance Issue
Projects
None yet
Development

No branches or pull requests