Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

how to train using multi-gpu? #15

Closed
GuangmingZhu opened this issue Nov 1, 2016 · 13 comments
Closed

how to train using multi-gpu? #15

GuangmingZhu opened this issue Nov 1, 2016 · 13 comments
Assignees

Comments

@GuangmingZhu
Copy link

GuangmingZhu commented Nov 1, 2016

Tensorflow supports multi-gpu training, has tensorlayer supported multi-gpu training already?

@zsdonghao
Copy link
Member

zsdonghao commented Nov 1, 2016

Hi, TensorLayer naturally supported all features of TensorFlow.
To use multi-gpu, just simply follow this tutorial https://www.tensorflow.org/versions/r0.11/how_tos/using_gpu/index.html
i.e. create graph with TL under "multi-tower" with tf.device(xx):

@zsdonghao
Copy link
Member

zsdonghao commented Nov 2, 2016

@GuangmingZhu this is a script to run multi-gpu hope it can help you.

    num_gpus = 2
    opt = tf.train.AdamOptimizer(5e-5, beta1=0.9, beta2=0.999, epsilon=1e-08, use_locking=False)
    tower_grads = []
    for i in range(num_gpus):
        with tf.device('/gpu:%d' % i):
            with tf.name_scope('%s_%d' % ("gpu", i)) as scope:
                tl.layers.set_name_reuse(True)
                ## inference
                network, outputs = u_net_2d(x, batch_size)
                ## cost
                dice_loss = 1 - tl.cost.dice_coe(outputs[:,:,:,0], y_[:,:,:,0], epsilon=1e-10)
                ## Reuse variables for the next tower.
                tf.get_variable_scope().reuse_variables()
                ## compute grads for every tower
                grads = opt.compute_gradients(dice_loss, var_list=network.all_params)
                tower_grads.append(grads)
    grads = average_gradients(tower_grads)
    train_op = opt.apply_gradients(grads)

the average_gradients is borrowed from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py

@GuangmingZhu
Copy link
Author

@zsdonghao Thank you! How can I put variables on cpu when I define layers using tensorlayer APIs, since tensorlayer APIs define the variables within the APIs?

@zsdonghao
Copy link
Member

zsdonghao commented Nov 3, 2016

@GuangmingZhu As TensorLayer is fully transparent to TensorFlow, you can just simply define your layers under with tf.device('/cpu:0'):.

This tutorial can help you understand how to put variables on cpu. tutorial_cifar10_tfrecord.py
https://github.com/zsdonghao/tensorlayer/blob/master/tutorial_cifar10_tfrecord.py

@cobnut
Copy link

cobnut commented Nov 8, 2016

@zsdonghao Great implementation!but, "tensorlayer(1.2.5)/layers.py",
class DynamicRNNLayer(Layer),
line 2236: "self.outputs = tf.reshape(tf.concat(1, outputs), [-1, n_steps, n_hidden])"
n_steps is not defined, actually, in DynamicRNNLayer, just have n_steps(max), no n_steps,
i think maybe something wrong???

@zsdonghao
Copy link
Member

@narrator-wong Thank you for reporting bug. I just make a new commit.
The outputs should be reshaped by the max length as follow.

max_length = tf.shape(self.outputs)[1]
self.outputs = tf.reshape(tf.concat(1, outputs), [-1, max_length, n_hidden])

The idea is the same with advanced_indexing_op().

@zsdonghao
Copy link
Member

i think the problem have been solved?

@auroua
Copy link

auroua commented Mar 9, 2018

I think the method in tutorial_cifar10_tfrecord.py is not correct. The multiple gpu tutorial put the inference, loss and the gradient all on gpu, and the variable is on cpu. But the tutorial_cifar10_tfrecord.py put the inference procedure on cpu.

@luomai luomai reopened this Mar 10, 2018
@zsdonghao
Copy link
Member

@auroua please read the code carefully, tutorial_cifar10_tfrecord.py put the inference under with tf.device('/gpu:0'):

@auroua
Copy link

auroua commented Mar 19, 2018

I write a simple code follow the tutorial_cifar10_tfrecord.py.

import tensorflow as tf
import tensorlayer as tl

with tf.device('/cpu:0'):
    def inference():
        x = tf.placeholder(tf.float32, shape=[None, 784], name='x')
        network = tl.layers.InputLayer(x, name='input')
        network = tl.layers.DropoutLayer(network, keep=0.8, name='drop1')
        network = tl.layers.DenseLayer(network, n_units=800, act=tf.nn.relu, name='relu1')
        network = tl.layers.DropoutLayer(network, keep=0.5, name='drop2')
        network = tl.layers.DenseLayer(network, n_units=800, act=tf.nn.relu, name='relu2')
        network = tl.layers.DropoutLayer(network, keep=0.5, name='drop3')
        network = tl.layers.DenseLayer(network, n_units=10, act=tf.identity, name='output')
        return network


if __name__ == '__main__':
    with tf.device('/gpu:0'):
        network = inference()
        network.print_layers()
        sess = tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True,
                log_device_placement=True))
        tl.layers.initialize_global_variables(sess)

The following is the code output:

Device mapping:
/job:localhost/replica:0/task:0/device:GPU:0 -> device: 0, name: TITAN Xp, pci bus id: 0000:02:00.0, compute capability: 6.1
/job:localhost/replica:0/task:0/device:GPU:1 -> device: 1, name: GeForce GTX 1080 Ti, pci bus id: 0000:01:00.0, compute capability: 6.1
output/b: (VariableV2): /job:localhost/replica:0/task:0/device:GPU:0
output/b/read: (Identity): /job:localhost/replica:0/task:0/device:GPU:0
output/b/Assign: (Assign): /job:localhost/replica:0/task:0/device:GPU:0
output/W: (VariableV2): /job:localhost/replica:0/task:0/device:GPU:0
output/W/read: (Identity): /job:localhost/replica:0/task:0/device:GPU:0
output/W/Initializer/truncated_normal/TruncatedNormal: (TruncatedNormal): /job:localhost/replica:0/task:0/device:GPU:0
output/W/Initializer/truncated_normal/mul: (Mul): /job:localhost/replica:0/task:0/device:GPU:0
output/W/Initializer/truncated_normal: (Add): /job:localhost/replica:0/task:0/device:GPU:0
output/W/Assign: (Assign): /job:localhost/replica:0/task:0/device:GPU:0
drop3/random_uniform/sub: (Sub): /job:localhost/replica:0/task:0/device:GPU:0
relu2/b: (VariableV2): /job:localhost/replica:0/task:0/device:GPU:0
relu2/b/read: (Identity): /job:localhost/replica:0/task:0/device:GPU:0
relu2/b/Assign: (Assign): /job:localhost/replica:0/task:0/device:GPU:0
relu2/W: (VariableV2): /job:localhost/replica:0/task:0/device:GPU:0
relu2/W/read: (Identity): /job:localhost/replica:0/task:0/device:GPU:0
relu2/W/Initializer/truncated_normal/TruncatedNormal: (TruncatedNormal): /job:localhost/replica:0/task:0/device:GPU:0
relu2/W/Initializer/truncated_normal/mul: (Mul): /job:localhost/replica:0/task:0/device:GPU:0
relu2/W/Initializer/truncated_normal: (Add): /job:localhost/replica:0/task:0/device:GPU:0
relu2/W/Assign: (Assign): /job:localhost/replica:0/task:0/device:GPU:0
drop2/random_uniform/sub: (Sub): /job:localhost/replica:0/task:0/device:GPU:0
relu1/b: (VariableV2): /job:localhost/replica:0/task:0/device:GPU:0
relu1/b/read: (Identity): /job:localhost/replica:0/task:0/device:GPU:0
relu1/b/Assign: (Assign): /job:localhost/replica:0/task:0/device:GPU:0
relu1/W: (VariableV2): /job:localhost/replica:0/task:0/device:GPU:0
relu1/W/read: (Identity): /job:localhost/replica:0/task:0/device:GPU:0
relu1/W/Initializer/truncated_normal/TruncatedNormal: (TruncatedNormal): /job:localhost/replica:0/task:0/device:GPU:0
relu1/W/Initializer/truncated_normal/mul: (Mul): /job:localhost/replica:0/task:0/device:GPU:0
relu1/W/Initializer/truncated_normal: (Add): /job:localhost/replica:0/task:0/device:GPU:0
relu1/W/Assign: (Assign): /job:localhost/replica:0/task:0/device:GPU:0
init: (NoOp): /job:localhost/replica:0/task:0/device:GPU:0
drop1/random_uniform/sub: (Sub): /job:localhost/replica:0/task:0/device:GPU:0
drop1/div: (RealDiv): /job:localhost/replica:0/task:0/device:GPU:0
drop1/Shape: (Shape): /job:localhost/replica:0/task:0/device:GPU:0
drop1/random_uniform/RandomUniform: (RandomUniform): /job:localhost/replica:0/task:0/device:GPU:0
drop1/random_uniform/mul: (Mul): /job:localhost/replica:0/task:0/device:GPU:0
drop1/random_uniform: (Add): /job:localhost/replica:0/task:0/device:GPU:0
drop1/add: (Add): /job:localhost/replica:0/task:0/device:GPU:0
drop1/Floor: (Floor): /job:localhost/replica:0/task:0/device:GPU:0
drop1/mul: (Mul): /job:localhost/replica:0/task:0/device:GPU:0
relu1/MatMul: (MatMul): /job:localhost/replica:0/task:0/device:GPU:0
relu1/add: (Add): /job:localhost/replica:0/task:0/device:GPU:0
relu1/Relu: (Relu): /job:localhost/replica:0/task:0/device:GPU:0
drop2/div: (RealDiv): /job:localhost/replica:0/task:0/device:GPU:0
drop2/Shape: (Shape): /job:localhost/replica:0/task:0/device:GPU:0
drop2/random_uniform/RandomUniform: (RandomUniform): /job:localhost/replica:0/task:0/device:GPU:0
drop2/random_uniform/mul: (Mul): /job:localhost/replica:0/task:0/device:GPU:0
drop2/random_uniform: (Add): /job:localhost/replica:0/task:0/device:GPU:0
drop2/add: (Add): /job:localhost/replica:0/task:0/device:GPU:0
drop2/Floor: (Floor): /job:localhost/replica:0/task:0/device:GPU:0
drop2/mul: (Mul): /job:localhost/replica:0/task:0/device:GPU:0
relu2/MatMul: (MatMul): /job:localhost/replica:0/task:0/device:GPU:0
relu2/add: (Add): /job:localhost/replica:0/task:0/device:GPU:0
relu2/Relu: (Relu): /job:localhost/replica:0/task:0/device:GPU:0
drop3/div: (RealDiv): /job:localhost/replica:0/task:0/device:GPU:0
drop3/Shape: (Shape): /job:localhost/replica:0/task:0/device:GPU:0
drop3/random_uniform/RandomUniform: (RandomUniform): /job:localhost/replica:0/task:0/device:GPU:0
drop3/random_uniform/mul: (Mul): /job:localhost/replica:0/task:0/device:GPU:0
drop3/random_uniform: (Add): /job:localhost/replica:0/task:0/device:GPU:0
drop3/add: (Add): /job:localhost/replica:0/task:0/device:GPU:0
drop3/Floor: (Floor): /job:localhost/replica:0/task:0/device:GPU:0
drop3/mul: (Mul): /job:localhost/replica:0/task:0/device:GPU:0
output/MatMul: (MatMul): /job:localhost/replica:0/task:0/device:GPU:0
output/add: (Add): /job:localhost/replica:0/task:0/device:GPU:0
output/Identity: (Identity): /job:localhost/replica:0/task:0/device:GPU:0
output/b/Initializer/Const: (Const): /job:localhost/replica:0/task:0/device:GPU:0
output/W/Initializer/truncated_normal/stddev: (Const): /job:localhost/replica:0/task:0/device:GPU:0
output/W/Initializer/truncated_normal/mean: (Const): /job:localhost/replica:0/task:0/device:GPU:0
output/W/Initializer/truncated_normal/shape: (Const): /job:localhost/replica:0/task:0/device:GPU:0
drop3/random_uniform/max: (Const): /job:localhost/replica:0/task:0/device:GPU:0
drop3/random_uniform/min: (Const): /job:localhost/replica:0/task:0/device:GPU:0
Placeholder_2: (Placeholder): /job:localhost/replica:0/task:0/device:GPU:0
relu2/b/Initializer/Const: (Const): /job:localhost/replica:0/task:0/device:GPU:0
relu2/W/Initializer/truncated_normal/stddev: (Const): /job:localhost/replica:0/task:0/device:GPU:0
relu2/W/Initializer/truncated_normal/mean: (Const): /job:localhost/replica:0/task:0/device:GPU:0
relu2/W/Initializer/truncated_normal/shape: (Const): /job:localhost/replica:0/task:0/device:GPU:0
drop2/random_uniform/max: (Const): /job:localhost/replica:0/task:0/device:GPU:0
drop2/random_uniform/min: (Const): /job:localhost/replica:0/task:0/device:GPU:0
Placeholder_1: (Placeholder): /job:localhost/replica:0/task:0/device:GPU:0
relu1/b/Initializer/Const: (Const): /job:localhost/replica:0/task:0/device:GPU:0
relu1/W/Initializer/truncated_normal/stddev: (Const): /job:localhost/replica:0/task:0/device:GPU:0
relu1/W/Initializer/truncated_normal/mean: (Const): /job:localhost/replica:0/task:0/device:GPU:0
relu1/W/Initializer/truncated_normal/shape: (Const): /job:localhost/replica:0/task:0/device:GPU:0
drop1/random_uniform/max: (Const): /job:localhost/replica:0/task:0/device:GPU:0
drop1/random_uniform/min: (Const): /job:localhost/replica:0/task:0/device:GPU:0
Placeholder: (Placeholder): /job:localhost/replica:0/task:0/device:GPU:0
x: (Placeholder): /job:localhost/replica:0/task:0/device:GPU:0

The parameters and inference are all placed on gpu.
The code with tf.device('/cpu:0'): seems have no effect.

@auroua
Copy link

auroua commented Mar 20, 2018

I modified the tl.layers.DenseLayer, and add tf.device(/cpu:0) directly on the tf.get_variable. The results seems correct.
This is the modified tl.layers.DenseLayer:

class DenseLayer(Layer):
    """
    The :class:`DenseLayer` class is a fully connected layer.

    Parameters
    ----------
    layer : a :class:`Layer` instance
        The `Layer` class feeding into this layer.
    n_units : int
        The number of units of the layer.
    act : activation function
        The function that is applied to the layer activations.
    W_init : weights initializer
        The initializer for initializing the weight matrix.
    b_init : biases initializer or None
        The initializer for initializing the bias vector. If None, skip biases.
    W_init_args : dictionary
        The arguments for the weights tf.get_variable.
    b_init_args : dictionary
        The arguments for the biases tf.get_variable.
    name : a string or None
        An optional name to attach to this layer.

    Examples
    --------
    >>> network = tl.layers.InputLayer(x, name='input_layer')
    >>> network = tl.layers.DenseLayer(
    ...                 network,
    ...                 n_units=800,
    ...                 act = tf.nn.relu,
    ...                 W_init=tf.truncated_normal_initializer(stddev=0.1),
    ...                 name ='relu_layer'
    ...                 )

    >>> Without TensorLayer, you can do as follow.
    >>> W = tf.Variable(
    ...     tf.random_uniform([n_in, n_units], -1.0, 1.0), name='W')
    >>> b = tf.Variable(tf.zeros(shape=[n_units]), name='b')
    >>> y = tf.nn.relu(tf.matmul(inputs, W) + b)

    Notes
    -----
    If the input to this layer has more than two axes, it need to flatten the
    input by using :class:`FlattenLayer` in this case.
    """

    def __init__(
            self,
            layer=None,
            n_units=100,
            act=tf.identity,
            W_init=tf.truncated_normal_initializer(stddev=0.1),
            b_init=tf.constant_initializer(value=0.0),
            W_init_args={},
            b_init_args={},
            name='dense_layer',
    ):
        Layer.__init__(self, name=name)
        self.inputs = layer.outputs
        if self.inputs.get_shape().ndims != 2:
            raise Exception("The input dimension must be rank 2, please reshape or flatten it")

        n_in = int(self.inputs.get_shape()[-1])
        self.n_units = n_units
        print("  [TL] DenseLayer  %s: %d %s" % (self.name, self.n_units, act.__name__))
        with tf.variable_scope(name) as vs:
            with tf.device('/cpu:0'):
                W = tf.get_variable(name='W', shape=(n_in, n_units), initializer=W_init, dtype=D_TYPE, **W_init_args)
            if b_init is not None:
                try:
                    with tf.device('/cpu:0'):
                        b = tf.get_variable(name='b', shape=(n_units), initializer=b_init, dtype=D_TYPE, **b_init_args)
                except:  # If initializer is a constant, do not specify shape.
                    with tf.device('/cpu:0'):
                        b = tf.get_variable(name='b', initializer=b_init, dtype=D_TYPE, **b_init_args)
                self.outputs = act(tf.matmul(self.inputs, W) + b)
            else:
                self.outputs = act(tf.matmul(self.inputs, W))

        # Hint : list(), dict() is pass by value (shallow), without them, it is
        # pass by reference.
        self.all_layers = list(layer.all_layers)
        self.all_params = list(layer.all_params)
        self.all_drop = dict(layer.all_drop)
        self.all_layers.extend([self.outputs])
        if b_init is not None:
            self.all_params.extend([W, b])
        else:
            self.all_params.extend([W])

This is the modified code:

def inference():
        x = tf.placeholder(tf.float32, shape=[None, 784], name='x')
        network = tl.layers.InputLayer(x, name='input')
        network = tl.layers.DropoutLayer(network, keep=0.8, name='drop1')
        network = DenseLayer(network, n_units=800, act=tf.nn.relu, name='relu1')
        network = tl.layers.DropoutLayer(network, keep=0.5, name='drop2')
        network = DenseLayer(network, n_units=800, act=tf.nn.relu, name='relu2')
        network = tl.layers.DropoutLayer(network, keep=0.5, name='drop3')
        network = DenseLayer(network, n_units=10, act=tf.identity, name='output')
        return network


if __name__ == '__main__':
    with tf.device('/gpu:0'):
        network = inference()
        network.print_layers()
        sess = tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True,
                log_device_placement=True))
        tl.layers.initialize_global_variables(sess)

This is the placement result:

Device mapping:
/job:localhost/replica:0/task:0/device:GPU:0 -> device: 0, name: TITAN Xp, pci bus id: 0000:02:00.0, compute capability: 6.1
/job:localhost/replica:0/task:0/device:GPU:1 -> device: 1, name: GeForce GTX 1080 Ti, pci bus id: 0000:01:00.0, compute capability: 6.1
output/b: (VariableV2): /job:localhost/replica:0/task:0/device:CPU:0
output/b/read: (Identity): /job:localhost/replica:0/task:0/device:CPU:0
output/b/Assign: (Assign): /job:localhost/replica:0/task:0/device:CPU:0
output/W: (VariableV2): /job:localhost/replica:0/task:0/device:CPU:0
output/W/read: (Identity): /job:localhost/replica:0/task:0/device:CPU:0
output/W/Initializer/truncated_normal/TruncatedNormal: (TruncatedNormal): /job:localhost/replica:0/task:0/device:CPU:0
output/W/Initializer/truncated_normal/mul: (Mul): /job:localhost/replica:0/task:0/device:CPU:0
output/W/Initializer/truncated_normal: (Add): /job:localhost/replica:0/task:0/device:CPU:0
output/W/Assign: (Assign): /job:localhost/replica:0/task:0/device:CPU:0
drop3/random_uniform/sub: (Sub): /job:localhost/replica:0/task:0/device:GPU:0
relu2/b: (VariableV2): /job:localhost/replica:0/task:0/device:CPU:0
relu2/b/read: (Identity): /job:localhost/replica:0/task:0/device:CPU:0
relu2/b/Assign: (Assign): /job:localhost/replica:0/task:0/device:CPU:0
relu2/W: (VariableV2): /job:localhost/replica:0/task:0/device:CPU:0
relu2/W/read: (Identity): /job:localhost/replica:0/task:0/device:CPU:0
relu2/W/Initializer/truncated_normal/TruncatedNormal: (TruncatedNormal): /job:localhost/replica:0/task:0/device:CPU:0
relu2/W/Initializer/truncated_normal/mul: (Mul): /job:localhost/replica:0/task:0/device:CPU:0
relu2/W/Initializer/truncated_normal: (Add): /job:localhost/replica:0/task:0/device:CPU:0
relu2/W/Assign: (Assign): /job:localhost/replica:0/task:0/device:CPU:0
drop2/random_uniform/sub: (Sub): /job:localhost/replica:0/task:0/device:GPU:0
relu1/b: (VariableV2): /job:localhost/replica:0/task:0/device:CPU:0
relu1/b/read: (Identity): /job:localhost/replica:0/task:0/device:CPU:0
relu1/b/Assign: (Assign): /job:localhost/replica:0/task:0/device:CPU:0
relu1/W: (VariableV2): /job:localhost/replica:0/task:0/device:CPU:0
relu1/W/read: (Identity): /job:localhost/replica:0/task:0/device:CPU:0
relu1/W/Initializer/truncated_normal/TruncatedNormal: (TruncatedNormal): /job:localhost/replica:0/task:0/device:CPU:0
relu1/W/Initializer/truncated_normal/mul: (Mul): /job:localhost/replica:0/task:0/device:CPU:0
relu1/W/Initializer/truncated_normal: (Add): /job:localhost/replica:0/task:0/device:CPU:0
relu1/W/Assign: (Assign): /job:localhost/replica:0/task:0/device:CPU:0
init: (NoOp): /job:localhost/replica:0/task:0/device:CPU:0
drop1/random_uniform/sub: (Sub): /job:localhost/replica:0/task:0/device:GPU:0
drop1/div: (RealDiv): /job:localhost/replica:0/task:0/device:GPU:0
drop1/Shape: (Shape): /job:localhost/replica:0/task:0/device:GPU:0
drop1/random_uniform/RandomUniform: (RandomUniform): /job:localhost/replica:0/task:0/device:GPU:0
drop1/random_uniform/mul: (Mul): /job:localhost/replica:0/task:0/device:GPU:0
drop1/random_uniform: (Add): /job:localhost/replica:0/task:0/device:GPU:0
drop1/add: (Add): /job:localhost/replica:0/task:0/device:GPU:0
drop1/Floor: (Floor): /job:localhost/replica:0/task:0/device:GPU:0
drop1/mul: (Mul): /job:localhost/replica:0/task:0/device:GPU:0
relu1/MatMul: (MatMul): /job:localhost/replica:0/task:0/device:GPU:0
relu1/add: (Add): /job:localhost/replica:0/task:0/device:GPU:0
relu1/Relu: (Relu): /job:localhost/replica:0/task:0/device:GPU:0
drop2/div: (RealDiv): /job:localhost/replica:0/task:0/device:GPU:0
drop2/Shape: (Shape): /job:localhost/replica:0/task:0/device:GPU:0
drop2/random_uniform/RandomUniform: (RandomUniform): /job:localhost/replica:0/task:0/device:GPU:0
drop2/random_uniform/mul: (Mul): /job:localhost/replica:0/task:0/device:GPU:0
drop2/random_uniform: (Add): /job:localhost/replica:0/task:0/device:GPU:0
drop2/add: (Add): /job:localhost/replica:0/task:0/device:GPU:0
drop2/Floor: (Floor): /job:localhost/replica:0/task:0/device:GPU:0
drop2/mul: (Mul): /job:localhost/replica:0/task:0/device:GPU:0
relu2/MatMul: (MatMul): /job:localhost/replica:0/task:0/device:GPU:0
relu2/add: (Add): /job:localhost/replica:0/task:0/device:GPU:0
relu2/Relu: (Relu): /job:localhost/replica:0/task:0/device:GPU:0
drop3/div: (RealDiv): /job:localhost/replica:0/task:0/device:GPU:0
drop3/Shape: (Shape): /job:localhost/replica:0/task:0/device:GPU:0
drop3/random_uniform/RandomUniform: (RandomUniform): /job:localhost/replica:0/task:0/device:GPU:0
drop3/random_uniform/mul: (Mul): /job:localhost/replica:0/task:0/device:GPU:0
drop3/random_uniform: (Add): /job:localhost/replica:0/task:0/device:GPU:0
drop3/add: (Add): /job:localhost/replica:0/task:0/device:GPU:0
drop3/Floor: (Floor): /job:localhost/replica:0/task:0/device:GPU:0
drop3/mul: (Mul): /job:localhost/replica:0/task:0/device:GPU:0
output/MatMul: (MatMul): /job:localhost/replica:0/task:0/device:GPU:0
output/add: (Add): /job:localhost/replica:0/task:0/device:GPU:0
output/Identity: (Identity): /job:localhost/replica:0/task:0/device:GPU:0
output/b/Initializer/Const: (Const): /job:localhost/replica:0/task:0/device:CPU:0
output/W/Initializer/truncated_normal/stddev: (Const): /job:localhost/replica:0/task:0/device:CPU:0
output/W/Initializer/truncated_normal/mean: (Const): /job:localhost/replica:0/task:0/device:CPU:0
output/W/Initializer/truncated_normal/shape: (Const): /job:localhost/replica:0/task:0/device:CPU:0
drop3/random_uniform/max: (Const): /job:localhost/replica:0/task:0/device:GPU:0
drop3/random_uniform/min: (Const): /job:localhost/replica:0/task:0/device:GPU:0
Placeholder_2: (Placeholder): /job:localhost/replica:0/task:0/device:GPU:0
relu2/b/Initializer/Const: (Const): /job:localhost/replica:0/task:0/device:CPU:0
relu2/W/Initializer/truncated_normal/stddev: (Const): /job:localhost/replica:0/task:0/device:CPU:0
relu2/W/Initializer/truncated_normal/mean: (Const): /job:localhost/replica:0/task:0/device:CPU:0
relu2/W/Initializer/truncated_normal/shape: (Const): /job:localhost/replica:0/task:0/device:CPU:0
drop2/random_uniform/max: (Const): /job:localhost/replica:0/task:0/device:GPU:0
drop2/random_uniform/min: (Const): /job:localhost/replica:0/task:0/device:GPU:0
Placeholder_1: (Placeholder): /job:localhost/replica:0/task:0/device:GPU:0
relu1/b/Initializer/Const: (Const): /job:localhost/replica:0/task:0/device:CPU:0
relu1/W/Initializer/truncated_normal/stddev: (Const): /job:localhost/replica:0/task:0/device:CPU:0
relu1/W/Initializer/truncated_normal/mean: (Const): /job:localhost/replica:0/task:0/device:CPU:0
relu1/W/Initializer/truncated_normal/shape: (Const): /job:localhost/replica:0/task:0/device:CPU:0
drop1/random_uniform/max: (Const): /job:localhost/replica:0/task:0/device:GPU:0
drop1/random_uniform/min: (Const): /job:localhost/replica:0/task:0/device:GPU:0
Placeholder: (Placeholder): /job:localhost/replica:0/task:0/device:GPU:0
x: (Placeholder): /job:localhost/replica:0/task:0/device:GPU:0

@DEKHTIARJonathan
Copy link
Member

Shall it be merged in the main repository @zsdonghao or please close the issue.

@zsdonghao
Copy link
Member

zsdonghao commented May 17, 2018

@auroua @DEKHTIARJonathan I think with tf.device('/cpu:0'): can be added outside TensorLayer.

zsdonghao pushed a commit that referenced this issue May 4, 2019
fix a minor bug in test_auto_naming.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

6 participants