Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tf.device scope not working correctly #44510

Closed
JohnTaylor2000 opened this issue Nov 2, 2020 · 19 comments
Closed

tf.device scope not working correctly #44510

JohnTaylor2000 opened this issue Nov 2, 2020 · 19 comments
Assignees
Labels
comp:gpu GPU related issues comp:keras Keras related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author TF 2.2 Issues related to TF 2.2 type:support Support issues

Comments

@JohnTaylor2000
Copy link

JohnTaylor2000 commented Nov 2, 2020

Please make sure that this is a bug. As per our
GitHub Policy,
we only address code/doc bugs, performance issues, feature requests and
build/installation issues on GitHub. tag:bug_template

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Y es
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): ppc64le-linux
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
  • TensorFlow installed from (source or binary)
  • TensorFlow version (use command below):
  • Python version: python3.7
  • Bazel version (if compiling from source):
  • GCC/Compiler version (if compiling from source):
  • CUDA/cuDNN version:10.1.243
  • GPU model and memory: V100 16GB

You can collect some of this information using our environment capture
script
You can also obtain the TensorFlow version with:

  1. TF 1.0: python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"
  2. TF 2.0: python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"

Describe the current behavior

tf.device command does not correctly assign a GPU device to tf.keras layers on node with 4 GPUs so cannot implement model parallelism. All layers appear on device GPU:0 with the exception of some IO based on output of tf.debugging.set_log_device_placement(True)

Describe the expected behavior
tf.keras layers are correctly assigned to a device.

Standalone code to reproduce the issue

import tensorflow as tf
from tensorflow import keras

tf.debugging.set_log_device_placement(True)

print("On GPU:1")
inputs = keras.Input(shape=(784,))
with tf.device("/device:GPU:1"): # Or GPU:1 for the 2nd GPU, GPU:2 for the 3rd etc.
x = keras.layers.Dense(256, activation="relu")(inputs)
print(x)
assert x.device.endswith("/GPU:1")

Provide a reproducible test case that is the bare minimum necessary to generate
the problem. If possible, please share a link to Colab/Jupyter/any notebook.

I have a larger test problem that will run on a 4 GPU node. If you turn off the assert statement, then using nvidia-smi you can see that all memory and computational work is happening on GPU:0 and almost none is assigned to other GPUs. Happy to supply this code if needed.

Other info / logs Include any logs or source code that would be helpful to
diagnose the problem. If including tracebacks, please include the full
traceback. Large logs and files should be attached.

Tensor("dense/Identity:0", shape=(None, 256), dtype=float32)
Traceback (most recent call last):
File "py_test.py", line 11, in
assert x.device.endswith("/GPU:1")
AssertionError

@ravikyram ravikyram added comp:gpu GPU related issues type:support Support issues and removed type:bug Bug labels Nov 2, 2020
@ravikyram
Copy link
Contributor

@JohnTaylor2000

Will it be possible to share complete code snippet .Please, share the output of tf.config.list_physical_devices().Thanks!

@ravikyram ravikyram added the stat:awaiting response Status - Awaiting response from author label Nov 2, 2020
@JohnTaylor2000
Copy link
Author

JohnTaylor2000 commented Nov 3, 2020

NOTE : uses input data file mnist.npz

import tensorflow as tf
from tensorflow import keras

import time
import socket

import os

tf.debugging.set_log_device_placement(True)

print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
gpus = tf.config.list_physical_devices('GPU')
print(' gpus = ', gpus)

def get_compiled_model():
    # Make a simple 2-layer densely-connected neural network.
    inputs = keras.Input(shape=(784,))
    print("On GPU:0")
    with tf.device("/device:GPU:0"):
       x = keras.layers.Dense(256, activation="relu")(inputs)
       #assert x.device.endswith("/GPU:0")
       x = keras.layers.Dense(256, activation="relu")(inputs)
       x = keras.layers.Dense(256, activation="relu")(inputs)
       x = keras.layers.Dense(256, activation="relu")(inputs)
       x = keras.layers.Dense(256, activation="relu")(inputs)
       x = keras.layers.Dense(256, activation="relu")(inputs)
    print("On GPU:1")
    with tf.device("/device:GPU:1"): # Or GPU:1 for the 2nd GPU, GPU:2 for the 3rd etc.
       x = keras.layers.Dense(256, activation="relu")(inputs)
       #assert x.device.endswith("/GPU:1")
       x = keras.layers.Dense(256, activation="relu")(inputs)
       x = keras.layers.Dense(256, activation="relu")(inputs)
       x = keras.layers.Dense(256, activation="relu")(inputs)
       x = keras.layers.Dense(256, activation="relu")(inputs)
       x = keras.layers.Dense(256, activation="relu")(inputs)
    print("On GPU:2")
    with tf.device("/device:GPU:2"): # Or GPU:1 for the 2nd GPU, GPU:2 for the 3rd etc.
       x = keras.layers.Dense(256, activation="relu")(inputs)
       #assert x.device.endswith("/GPU:2")
       x = keras.layers.Dense(256, activation="relu")(inputs)
       x = keras.layers.Dense(256, activation="relu")(inputs)
       x = keras.layers.Dense(256, activation="relu")(inputs)
       x = keras.layers.Dense(256, activation="relu")(inputs)
       x = keras.layers.Dense(256, activation="relu")(inputs)
    print("On GPU:3")
    with tf.device("/device:GPU:3"): # Or GPU:1 for the 2nd GPU, GPU:2 for the 3rd etc.
      x = keras.layers.Dense(256, activation="relu")(x)
      #assert x.device.endswith("GPU:3")
      x = keras.layers.Dense(256, activation="relu")(x)
      x = keras.layers.Dense(256, activation="relu")(x)
      x = keras.layers.Dense(256, activation="relu")(x)
      x = keras.layers.Dense(256, activation="relu")(x)
      x = keras.layers.Dense(256, activation="relu")(x)
      outputs = keras.layers.Dense(10)(x)
    model = keras.Model(inputs, outputs)
 
    opt = keras.optimizers.Adam()
    model.compile(
        optimizer=opt,
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[keras.metrics.SparseCategoricalAccuracy()],
        experimental_run_tf_function=False,
    )
    return model


def get_dataset():
    batch_size = 32
    num_val_samples = 10000

    # Return the MNIST dataset in the form of a `tf.data.Dataset`.
    path = '/g/g92/jtaylor/workspace/TFnew/mnist.npz'
    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data(path)

    # Preprocess the data (these are Numpy arrays)
    x_train = x_train.reshape(-1, 784).astype("float32") / 255
    x_test = x_test.reshape(-1, 784).astype("float32") / 255
    y_train = y_train.astype("float32")
    y_test = y_test.astype("float32")

    # Reserve num_val_samples samples for validation
    x_val = x_train[-num_val_samples:]
    y_val = y_train[-num_val_samples:]
    x_train = x_train[:-num_val_samples]
    y_train = y_train[:-num_val_samples]
    return (
        tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size),
        tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(batch_size),
        tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size),
    )

model = get_compiled_model()

# Train the model on all available devices.
train_dataset, val_dataset, test_dataset = get_dataset()
model.fit(train_dataset, epochs=2, validation_data=val_dataset)

# Test the model on all available devices.
model.evaluate(test_dataset)

@JohnTaylor2000
Copy link
Author

JohnTaylor2000 commented Nov 3, 2020

Output of tf.config.list_physical_devices() as requested:-

Num GPUs Available:  4
 gpus =  [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU'), 
PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU'), 
PhysicalDevice(name='/physical_device:GPU:2', device_type='GPU'), 
PhysicalDevice(name='/physical_device:GPU:3', device_type='GPU')]

@ravikyram
Copy link
Contributor

@JohnTaylor2000

Please, let us know which TF version you are using?
Thanks!

@JohnTaylor2000
Copy link
Author

JohnTaylor2000 commented Nov 3, 2020

PowerAi Tensorflow version 2.1.0
Also tensorflow 2.2.0 on x86_64

@ravikyram ravikyram added the TF 2.2 Issues related to TF 2.2 label Nov 3, 2020
@ravikyram ravikyram assigned rmothukuru and unassigned ravikyram Nov 3, 2020
@ravikyram ravikyram removed the stat:awaiting response Status - Awaiting response from author label Nov 3, 2020
@rmothukuru
Copy link
Contributor

@JohnTaylor2000,
If I understand your requirement correctly, you want to utilize all the GPUs while Training your Model. If that is the case, the best API to use is Distributed Training. Even the Dataset can be Distributed using Distributed Strategy.

Thanks!

@JohnTaylor2000
Copy link
Author

The actual model that I am working on is too large to fit into GPU memory. I have a data parallel code using horovod that runs on hundreds of GPUs but now need to use a larger model. To do this I need to spread the layers across multiple GPUs.

@rmothukuru rmothukuru assigned ymodak and jvishnuvardhan and unassigned rmothukuru and ymodak Nov 6, 2020
@jvishnuvardhan jvishnuvardhan added comp:keras Keras related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower labels Nov 7, 2020
@JohnTaylor2000
Copy link
Author

Hi all, just wondering if you have been able to run the test code that I have provided and/or need any further help?

@JohnTaylor2000
Copy link
Author

Wondering if you can confirm that this is an issue based on my supplied bug report and if you have a solution? This is a roadblock to my research so I am keen to have this resolved - any help greatly appreciated!

@JohnTaylor2000
Copy link
Author

Any developments on this issue please?

@JohnTaylor2000
Copy link
Author

Any help available? I have had no response since 5 November?

@laplacericky
Copy link

I think TF currently does not support model parallelism with keras model like what you have written. However, I believe they do support model parallelism with primitive operations. Also, it is almost like they do not have any docs about model parallelism which is an issue.
Maybe you can try something like wrapping all init, build and call methods inside with tf.device(the device you want).
Example code:

import tensorflow as tf
import numpy as np


tf.debugging.set_log_device_placement(True)

class Dense_in_gpu(tf.keras.layers.Layer):
    def __init__(self, gpu_to_use, units,activation=None,**kwargs):
        self.gpu_to_use=gpu_to_use
        with tf.device(f"/GPU:{self.gpu_to_use}"):
            self.dense=tf.keras.layers.Dense(units,activation=activation)
        super().__init__(**kwargs)

    def call(self, inputs):
        with tf.device(f"/GPU:{self.gpu_to_use}"):
            return self.dense(inputs)


inputs = tf.keras.Input(shape=(784,))
#run in the first gpu
x = Dense_in_gpu(0,64, activation="relu")(inputs)
#run in the second gpu
x = Dense_in_gpu(1,64, activation="relu")(x)
#run in the third gpu
outputs = Dense_in_gpu(2,10)(x)

model = tf.keras.Model(inputs=inputs, outputs=outputs)

model.predict_on_batch(np.random.rand(64,784))

tf_dataset=tf.data.Dataset.from_tensor_slices((np.random.rand(64,784),np.random.rand(64,10)))
tf_dataset=tf_dataset.batch(64)
model.compile(optimizer=tf.keras.optimizers.RMSprop(), loss=tf.keras.losses.MeanSquaredError())
model.fit(tf_dataset)

@JohnTaylor2000
Copy link
Author

Thank you for the interesting solution to this problem. I am investigating whether this will work on the model that I have developed.

I noticed that the latest versions of tf.keras.layers no longer has the attribute 'device' which means, as you say, that Keras no longer supports model parallelism where we assign layers to a device.

Interestingly, the development of GPUs with much larger memory eg 80GB on the A100 and the proposed new Grace architecture which allow high bandwidth access to CPU memory, will reduce the need for model parallelism. However, this will likely be offset by the desire to build bigger more complex model systems using the Keras functional API.

@Saduf2019
Copy link
Contributor

@JohnTaylor2000
Is this still an issue, could you please try on the latest tf version and let us know.

@Saduf2019 Saduf2019 added the stat:awaiting response Status - Awaiting response from author label Nov 2, 2021
@JohnTaylor2000
Copy link
Author

I have been able to test this using TensorFlow 2.6.0 using a test code. I do not have access to 2.7.0 at the moment.

Still seeing the same problem when using a dense layer tf.keras.layers.Dense the response was
AttributeError: 'KerasTensor' object has no attribute 'device'. I also observed that the code intended to run on 4 GPUs runs primarily on a single GPU ie the 'with tf.device("GPU:3"): ' is ignored.

here is a simple test where the print command will generate the error above:-

with tf.device("GPU:3"):
    x = tf.keras.layers.Dense(num_hidden_units, activation="relu")(x)
    x = tf.keras.layers.Dense(num_hidden_units, activation="relu")(x)
    x = tf.keras.layers.Dense(num_hidden_units, activation="relu")(x)
    x = tf.keras.layers.Dense(num_hidden_units, activation="relu")(x)
    x = tf.keras.layers.Dense(num_hidden_units, activation="relu")(x)
    print(x.device)

@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Nov 11, 2021
@Saduf2019 Saduf2019 removed the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Nov 11, 2021
@mohantym mohantym self-assigned this Mar 29, 2022
@mohantym
Copy link
Contributor

Hi @JohnTaylor2000 ! You can use set_visible_devices to disable and enable specific GPU's during operation. Have you tried the same in 2.8 version yet?

@mohantym mohantym added the stat:awaiting response Status - Awaiting response from author label Apr 12, 2022
@google-ml-butler
Copy link

This issue has been automatically marked as stale because it has no recent activity. It will be closed if no further activity occurs. Thank you.

@google-ml-butler google-ml-butler bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label Apr 19, 2022
@google-ml-butler
Copy link

Closing as stale. Please reopen if you'd like to work on this further.

@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:gpu GPU related issues comp:keras Keras related issues stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author TF 2.2 Issues related to TF 2.2 type:support Support issues
Projects
None yet
Development

No branches or pull requests

9 participants