Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi GPU support #20

Closed
ma7555 opened this issue Oct 19, 2020 · 3 comments
Closed

Multi GPU support #20

ma7555 opened this issue Oct 19, 2020 · 3 comments

Comments

@ma7555
Copy link

ma7555 commented Oct 19, 2020

Using MirroredStrategy for distributed training results in an error

File "C:\Users\***\anaconda3\envs\tf\lib\site-packages\tensorflow_core\python\framework\ops.py", line 1619, in _create_c_op
    c_op = c_api.TF_FinishOperation(op_desc)
tensorflow.python.framework.errors_impl.InvalidArgumentError: The outer 2 dimensions of indices.shape=[2,12000,3] must match the outer 2 dimensions of updates.shape=[1,12000,64]: Dimension 0 in both shapes must be equal, but are 2 and 1. Shapes are [2,12000] and [1,12000]. for 'pillars/scatter_nd/ScatterNd' (op: 'ScatterNd') with input shapes: [2,12000,3], [1,12000,64], [4].
@ma7555 ma7555 mentioned this issue Nov 7, 2020
@ma7555
Copy link
Author

ma7555 commented Nov 8, 2020

Issue explained:

file networks.py hardcodes batch_size into the correct_batch_indices function

def correct_batch_indices(tensor, batch_size):

This results into wrong dimensinality during ditributed training as batch_size is actually divided by number of GPUs or replicas during .fit()

I have been thinking for a while about changes in this function but nothing worked. This is what I tried

    def correct_batch_indices(tensor):
        seq = tf.range(tf.shape(tensor)[0])
        array = tf.Variable(lambda: tf.zeros_like(tensor))
        array = array[seq, :, 0].assign(seq)
        return tf.math.add(tensor, array)

Using a tf.Variable inside a lambda is a bad idea, if you can suggest something better let me know

@ma7555
Copy link
Author

ma7555 commented Nov 9, 2020

fixed for network.py, will need to look at the generator tomorrow too

@ma7555
Copy link
Author

ma7555 commented Nov 11, 2020

PR for fix #25

@ma7555 ma7555 closed this as completed Nov 11, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant