Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using tf.device when creating Keras layers does not move layer computation to that device #53671

Open
psobot opened this issue Jan 6, 2022 · 10 comments
Assignees
Labels
comp:core issues related to core part of tensorflow comp:gpu GPU related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.9 Issues found in the TF 2.9 release (or RCs) type:bug Bug

Comments

@psobot
Copy link
Contributor

psobot commented Jan 6, 2022

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Google Cloud AI Platform (TensorFlow Enterprise 2.7)
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): v2.7.0-rc1-69-gc256c071bb2 2.7.0
  • Python version: Python 3.7.12
  • CUDA/cuDNN version: 11.3, V11.3.109
  • GPU model and memory: NVIDIA Tesla K80 (x4)

Describe the current behavior

When creating a Keras model using the tf.device context manager, the resources used by that layer are placed on the requested device, but the TensorFlow ops involved in that layer do not execute on that device.

Describe the expected behavior

Using the tf.device context manager when constructing a Keras layer should perform that layer's computation on the specified device.

Standalone code to reproduce the issue

import tensorflow as tf
tf.debugging.set_log_device_placement(True)
        
_input = tf.keras.layers.Input(shape=(1,), dtype=tf.float32)
x = _input
with tf.device("/GPU:1"):
    x = tf.keras.layers.Dense(10, name="should_be_on_gpu")(x)
    x = tf.keras.layers.Dense(10, name="should_be_on_gpu_2")(x)
model = tf.keras.models.Model(inputs=[_input], outputs=[x])
model.compile('adam', 'mse')
model.summary()
model.fit([2], [4])

Hundreds of log lines are printed showing the placement of each op on device, but crucially, the /GPU:1 device stores the Dense layers (i.e.: ReadVariableOp) but is not where the computation (MatMul, in this case) happens:

model_3/ExpandDims: (ExpandDims): /job:localhost/replica:0/task:0/device:GPU:0
model_3/Cast: (Cast): /job:localhost/replica:0/task:0/device:GPU:0
model_3/should_be_on_gpu/MatMul/ReadVariableOp: (ReadVariableOp): /job:localhost/replica:0/task:0/device:GPU:1
model_3/should_be_on_gpu/MatMul: (MatMul): /job:localhost/replica:0/task:0/device:GPU:0
model_3/should_be_on_gpu/BiasAdd/ReadVariableOp: (ReadVariableOp): /job:localhost/replica:0/task:0/device:GPU:1
model_3/should_be_on_gpu/BiasAdd: (BiasAdd): /job:localhost/replica:0/task:0/device:GPU:0
model_3/should_be_on_gpu_2/MatMul/ReadVariableOp: (ReadVariableOp): /job:localhost/replica:0/task:0/device:GPU:1
model_3/should_be_on_gpu_2/MatMul: (MatMul): /job:localhost/replica:0/task:0/device:GPU:0
model_3/should_be_on_gpu_2/BiasAdd/ReadVariableOp: (ReadVariableOp): /job:localhost/replica:0/task:0/device:GPU:1
model_3/should_be_on_gpu_2/BiasAdd: (BiasAdd): /job:localhost/replica:0/task:0/device:GPU:0
@psobot psobot added the type:bug Bug label Jan 6, 2022
@tilakrayal tilakrayal added TF 2.7 Issues related to TF 2.7.0 comp:gpu GPU related issues labels Jan 7, 2022
@tilakrayal
Copy link
Contributor

tilakrayal commented Jan 7, 2022

@Saduf2019 ,
I was able to reproduce the issue in tf v2.5, v2.7 and nightly.Please find the gist of it here.

Updated
Was able to reproduce the issue with tf-nightly-2.11.0-dev20220829. Please find the gist here. Thank you!

@tilakrayal tilakrayal assigned Saduf2019 and unassigned tilakrayal Jan 7, 2022
@Saduf2019
Copy link
Contributor

@psobot
Could you please refer to the night execution of the code shared, i do not see the error reported, the output is same as is in the cpu without error.

@Saduf2019 Saduf2019 added the stat:awaiting response Status - Awaiting response from author label Jan 10, 2022
@psobot
Copy link
Contributor Author

psobot commented Jan 10, 2022

Hi @Saduf2019! The Colab notebook shared above was run with a single GPU, but the code provided moves layers onto /GPU:1, which would require two different GPUs to verify.

I've just modified the Colab notebook to add the following code to create two virtual GPUs, allowing you to see the issue in that environment:

tf.debugging.set_log_device_placement(True)
gpus = tf.config.list_physical_devices('GPU')
if not gpus:
    raise ValueError("At least one GPU required for this test!")
if len(gpus) == 1:
    # Create two virtual GPUs for this test:
    tf.config.set_logical_device_configuration(
        gpus[0],
        [tf.config.LogicalDeviceConfiguration(memory_limit=1024),
          tf.config.LogicalDeviceConfiguration(memory_limit=1024)])
    logical_gpus = tf.config.list_logical_devices('GPU')
    print(f"{len(gpus)} physical GPUs, split into {len(logical_gpus)} logical GPUs")
    print(logical_gpus)

@Saduf2019 Saduf2019 added the comp:dist-strat Distribution Strategy related issues label Jan 10, 2022
@jvishnuvardhan jvishnuvardhan removed the comp:dist-strat Distribution Strategy related issues label Jan 10, 2022
@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Jan 12, 2022
@jvishnuvardhan jvishnuvardhan added comp:core issues related to core part of tensorflow stat:awaiting tensorflower Status - Awaiting response from tensorflower labels Jan 13, 2022
@wangpengmit
Copy link
Member

wangpengmit commented Jan 20, 2022

I tried the code snippet above but couldn't reproduce your log. Could you provide a Colab with the log as a reproduce?

My trial is similar to @tilakrayal's gist above. There are no "MatMul" anywhere in the prints.

@psobot
Copy link
Contributor Author

psobot commented Jan 20, 2022

Hi @wangpengmit! Please see this Colab which combines the two code snippets above. The logs I pasted in above seem to be written by TensorFlow directly to the C++ stderr stream, which doesn't show up inline when executing the code in Colab - I've included instructions on how to find those logs in the Colab.

@wangpengmit
Copy link
Member

Thanks for the reproduce! The two lines

x = tf.keras.layers.Dense(10, name="should_be_on_gpu")(x)
x = tf.keras.layers.Dense(10, name="should_be_on_gpu_2")(x)

only determine where the variables are placed. The four lines (or some subset of them) afterwards

model = tf.keras.models.Model(inputs=[_input], outputs=[x])
model.compile('adam', 'mse')
model.summary()
model.fit([2], [4])

determine where real computations like MatMul and BiasAdd are placed. So the current behavior is working-as-intended. If you want to place MatMul and BiasAdd on GPU:1, you should indent the last four lines:

with tf.device("/GPU:1"):
    x = tf.keras.layers.Dense(10, name="should_be_on_second_gpu")(x)
    x = tf.keras.layers.Dense(10, name="should_also_be_on_second_gpu")(x)
    model = tf.keras.models.Model(inputs=[_input], outputs=[x])
    model.compile('adam', 'mse')
    model.summary()
    model.fit([2], [4])

@psobot
Copy link
Contributor Author

psobot commented Jan 27, 2022

Thanks @wangpengmit - unfortunately, I don't think that's a useful solution here, especially in more complex models that spread their variables (and/or layers) across multiple GPUs. Consider an example like this:

_input = tf.keras.layers.Input(shape=(1,), dtype=tf.float32)

with tf.device("/GPU:0"):
    x = _input
    x = tf.keras.layers.Dense(10, name="should_be_on_first_gpu")(x)
    x = tf.keras.layers.Dense(10, name="should_also_be_on_first_gpu")(x)
    gpu0 = x
with tf.device("/GPU:1"):
    x = _input
    x = tf.keras.layers.Dense(10, name="should_be_on_second_gpu")(x)
    x = tf.keras.layers.Dense(10, name="should_also_be_on_second_gpu")(x)
    gpu1 = x

model = tf.keras.models.Model(inputs=[_input], outputs=[gpu0, gpu1])
model.compile('adam', 'mse')
model.summary()
model.fit([2], [4])

In this case, all of the computation in this model will occur on GPU0, and the current API has no ability to move that computation onto each GPU independently. (This could be possible by manually writing a training loop, manually calling the required layers in a @tf.function, and manually managing the gradients to ensure that both the backward and forward pass are computed on the appropriate GPUs - but that's a huge change and results in much less flexible code.) This is obviously not a huge problem for these toy examples, but as models grow beyond the memory bounds of a single GPU, this becomes a blocker.

Is there any way with the current TensorFlow and Keras APIs to force computations to be colocated on the GPU where their inputs are?

@wangpengmit
Copy link
Member

This touches on an unfortunate TF design flaw: TF op's placement is statically determined by tf.device scopes at graph-construction time, instead of "following the inputs", thus we can't build either a "placement-polymorphic" graph or a graph whose subparts go to different devices depending on the inputs.

In your case, you can use custom layers, something like

class GPU0Layer(tf.keras.layers.Layer):
  def call(self, x):
     with tf.device("GPU:0"):
      ...

class GPU1Layer(tf.keras.layers.Layer):
  def call(self, x):
     with tf.device("GPU:1"):
      ...

model = tf.keras.models.Model(inputs=[_input], outputs=[GPU0Layer(_input), GPU1Layer(_input)])

But I understand this is less ideal than a "follow-inputs" placement system.

We are investigating supporting "follow-inputs" placement semantics, but it'll take a while.

@psobot
Copy link
Contributor Author

psobot commented Jan 29, 2022

That's great detail, thanks!

Using custom layers like that will indeed work just fine for computing the forward pass, but unfortunately not the backward pass. I can't find anything in the documentation that allows layers to customize how/where to run their backward passes in TF2, and the best solution I've been able to come up with involves creating a @tf.function that uses a different GradientTape instance for each GPU.

@mohantym mohantym self-assigned this Apr 19, 2022
@wangpengmit
Copy link
Member

Sorry for the delayed reply! Yes, this is another unfortunate problem with TF's device placement. The backward pass basically ignores all tf.device annotations in the forward pass and does its own analysis to determine where to put its ops. We have an internal bug tracking this issue.

@mohantym mohantym removed their assignment May 4, 2022
@gadagashwini gadagashwini added TF 2.9 Issues found in the TF 2.9 release (or RCs) and removed TF 2.7 Issues related to TF 2.7.0 labels Aug 30, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:core issues related to core part of tensorflow comp:gpu GPU related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.9 Issues found in the TF 2.9 release (or RCs) type:bug Bug
Projects
None yet
Development

No branches or pull requests

8 participants