Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tf.keras model.fit calls slow with TPU distribute strategy #30162

Closed
capilano opened this issue Jun 26, 2019 · 15 comments
Closed

tf.keras model.fit calls slow with TPU distribute strategy #30162

capilano opened this issue Jun 26, 2019 · 15 comments
Assignees
Labels
comp:dist-strat Distribution Strategy related issues comp:tpus tpu, tpuestimator stat:awaiting response Status - Awaiting response from author TF 1.15 for issues seen on TF 1.15 type:performance Performance Issue

Comments

@capilano
Copy link

capilano commented Jun 26, 2019

Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Win 10
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
  • TensorFlow installed from (source or binary):
  • TensorFlow version (use command below):1.14
  • Python version:3,6
  • Bazel version (if compiling from source):
  • GCC/Compiler version (if compiling from source):
  • CUDA/cuDNN version:
  • GPU model and memory: N/A

You can collect some of this information using our environment capture
script
You can also obtain the TensorFlow version with: 1. TF 1.0: python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)" 2. TF 2.0: python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"

Describe the current behavior
TPU distribution strategy does not support model.fit_generator, and repeated model.fit calls result in a 50x slowdown presumably because it adds operations to graph.

Describe the expected behavior

Code to reproduce the issue

resolver = tf.contrib.cluster_resolver.TPUClusterResolver(TPU_WORKER)
tf.contrib.distribute.initialize_tpu_system(resolver)
strategy = tf.distribute.experimental.TPUStrategy(resolver)
with strategy.scope():
model = ..... ## Your tf.keras model
model.compile(loss = custom_loss,optimizer ='custom_optimizer)

for i in range(num_its):
data,labels = = next(generator_fn())
model.fit(data,labels)

Other info / logs
Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.

@dynamicwebpaige dynamicwebpaige added the comp:keras Keras related issues label Jun 26, 2019
@achandraa achandraa self-assigned this Jun 27, 2019
@achandraa
Copy link

In order to expedite the trouble-shooting process, please provide a code snippet to reproduce the issue reported here. Thanks!

@achandraa achandraa added the stat:awaiting response Status - Awaiting response from author label Jun 27, 2019
@capilano
Copy link
Author

I have just added a few lines to illustrate the issue. I have a very big workflow, so I cannot add the whole code here. The issue is that when model.fit is called in a loop(as shown above), the training slows down considerably.

@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Jun 28, 2019
@gadagashwini-zz
Copy link
Contributor

@capilano It is difficult to reproduce the issue. Will it be possible to provide us the minimal code snippet to reproduce the issue. So that we can reproduce the issue on our environment for faster resolution. Thanks!

@gadagashwini-zz gadagashwini-zz added the stat:awaiting response Status - Awaiting response from author label Jul 1, 2019
@capilano
Copy link
Author

capilano commented Jul 1, 2019

I have not actually tested this code. I just made it here on the fly,so there may be minor errors.
The TPU distribute strategy does not support calls to the model.fit_generator method because it throws an exception that explicitly states that the model.fit_generator method is not supported with the current TPU distribute strategy. And so for a normal use case, in such a scenario an alternative would be to make a data generator function and make repeated calls to model.fit and this results in a 100x slowdown when compared to say, using tpu estimators instead.

#Importing Libraries
import numpy as np
import time
import os
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.models import Model
from tensorflow.keras.layers import GlobalAveragePooling2D,Dense,Input,Conv2D

#Define network

def net():
  inp = Input(shape=(224,224,3))
  x = Conv2D(64,kernel_size =(3,3),padding='same',activation='relu',strides=2)(inp)
  x = Conv2D(128,kernel_size =(3,3),padding='same',activation='relu',strides=2)(x)
  x = Conv2D(256,kernel_size =(3,3),padding='same',activation='relu',strides=1)(x)
  x = Conv2D(256,kernel_size =(3,3),padding='same',activation='relu',strides=2)(x)
  x = Conv2D(256,kernel_size =(3,3),padding='same',activation='relu',strides=2)(x)
  x = Conv2D(512,kernel_size =(3,3),padding='same',activation='relu',strides=2)(x)
  x = Conv2D(512,kernel_size =(3,3),padding='same',activation='relu',strides=2)(x)
  x = GlobalAveragePooling2D()(x)
  out = Dense(10)(x)
  model = Model(inputs = inp,outputs =out)
  return model

#TPU_init
resolver = tf.contrib.cluster_resolver.TPUClusterResolver()
tf.contrib.distribute.initialize_tpu_system(resolver)
strategy = tf.distribute.experimental.TPUStrategy(resolver)
with strategy.scope():
  model = net()
  model.compile(loss = 'categorical_crossentropy',optimizer ='adam')

#Data load
x = np.ones((224,224,3),dtype=np.float32)
n= np.random.randint(0,9)
y= np.zeros((10,),dtype=np.float32)
data = []
labels = []
for i in range(1024):
  data.append(x)
  labels.append(y)

data = np.array(data)
labels = np.array(labels)

#Fit function (Really slow. Should do this 100x faster)
for i in range(1000):
  time1 = time.time()
  model.fit(data,labels,batch_size = 1024,epochs=1)
  time2 = time.time()
  print(time2-time1)

comments

Ideally, the distribute strategy should support the fit_generator method because that makes it possible to use tf.records Dataset and load data directly from GCS buckets because it is almost never going to be possible to preload data into memory esp when there is a data augmentation step in the input pipeline.

@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Jul 2, 2019
@gadagashwini-zz
Copy link
Contributor

@capilano I tried reproducing the issue with provided code but i received
name 'TPU_WORKER' is not defined
Can you help us to reproduce the issue. Thanks!

@gadagashwini-zz gadagashwini-zz added the stat:awaiting response Status - Awaiting response from author label Jul 2, 2019
@capilano
Copy link
Author

capilano commented Jul 2, 2019

TPU_WORKER is the TPU address. If you are using google colab (with TPU accelerator), I think you can leave it blank Just call the function without passing any argument. If that does not work,
TPU_WORKER= 'grpc://{}'.format(os.environ['COLAB_TPU_ADDR'])
I have also slightly updated the code, it works when eager execution is enabled. Otherwise, please replace the last two lines which make a one hot vector with np.zeros((10,))

@gadagashwini-zz
Copy link
Contributor

@capilano I have tried reproducing the issue by adding piece of code tf.keras.layers.GlobalAveragePooling2D()(x) Since GlobalAveragePolling2D(x) was not defined. But still i am unable to replicate your issue. Provide us the full minimal code snippet. It will indeed help us to move faster.

@capilano
Copy link
Author

capilano commented Jul 2, 2019

Ok, I have changed the code. If you just copy/paste the code it should work. I have also changed the network, I am just using a few Conv layers to check, maybe the Renet50 Model has some unsupported layers. In the data loading path the last two lines are outside the for loop. I am not able to indent the code here for some reason

@gadagashwini-zz gadagashwini-zz added comp:tpus tpu, tpuestimator type:bug Bug and removed stat:awaiting response Status - Awaiting response from author labels Jul 2, 2019
@gadagashwini-zz
Copy link
Contributor

@capilano Thanks for the complete code. I am able to reproduce the issue now with Tf 1.14.0. Thanks!

@ymodak ymodak added comp:dist-strat Distribution Strategy related issues and removed comp:keras Keras related issues comp:tpus tpu, tpuestimator labels Jul 3, 2019
@ymodak ymodak assigned sb2nov and unassigned ymodak Jul 3, 2019
@sb2nov sb2nov removed their assignment Jul 25, 2019
@oanush oanush assigned oanush and unassigned oanush Aug 21, 2019
@anki-xyz
Copy link

anki-xyz commented Oct 8, 2019

I'm facing the same issue. As fit_generator does not work, I am using the same strategy because of the large dataset and involved data augmentation. But calling fit is extremely slow.
Things that come to my mind:

  • re-loading the model to the TPU?
  • data transfer?

Even when I load more data to decrease the fit calls (batch size > 2048), the sessions just fail. This makes me suspicious that it is maybe the data transfer...

@capilano
Copy link
Author

capilano commented Oct 8, 2019

You can use a tf records dataset and do this with fit and use your data augmentation pipeline using a map function as long as you can do your augmentations with tensorflow functions.

@jvishnuvardhan
Copy link
Contributor

I could reproduce the issue. But I am not sure whether it is 100X slower or not. Here is the gist. Thanks!

@jvishnuvardhan jvishnuvardhan added stat:awaiting tensorflower Status - Awaiting response from tensorflower type:performance Performance Issue TF 1.15 for issues seen on TF 1.15 and removed type:bug Bug labels Oct 16, 2019
@capilano
Copy link
Author

@jvishnuvardhan Just to give you guys a heads up, one can directly pass a dataset to model.fit and so multiple calls to fit are not really necessary if you are using a pipeline with only tensorflow functions for data augmentation.

@frankchn
Copy link
Contributor

I am not surprised that the notebook is slow as the data processing is all happening on the Colab rather than the TPU system (which has much more processing power than the Colab VMs).

With the notebook, the data variable contains 224 x 224 x 3 x 4 (bytes/float) x 1024 = 588 MB of data, which has to be transferred per step. Transferring this amount of data over the network to the TPU + encoding and decoding overhead would be non-trivial.

For performance reasons especially on non-trivial image models, you need to use tf.data Datasets with TF supported ops, and load the raw data from GCS.

@frankchn frankchn added stat:awaiting response Status - Awaiting response from author and removed stat:awaiting tensorflower Status - Awaiting response from tensorflower labels Oct 17, 2019
@frankchn frankchn assigned capilano and jvishnuvardhan and unassigned frankchn Oct 17, 2019
@Santosh-Gupta
Copy link

Is fit_generator still not supported?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:dist-strat Distribution Strategy related issues comp:tpus tpu, tpuestimator stat:awaiting response Status - Awaiting response from author TF 1.15 for issues seen on TF 1.15 type:performance Performance Issue
Projects
None yet
Development

No branches or pull requests