Skip to content

Conversation

yaochengji
Copy link
Collaborator

Note this PR should be used together with the pytorch PR.

@yaochengji
Copy link
Collaborator Author

@davidel @JackCaoG Could you help review?

@JackCaoG
Copy link
Collaborator

JackCaoG commented Dec 1, 2020

@yaochengji Sorry for the delay, will take a look today.

@yaochengji
Copy link
Collaborator Author

Thanks, @JackCaoG. Note the torch_xla/amp folder is almost the same as the amp folder in pytorch.

@ailzhang
Copy link
Contributor

ailzhang commented Dec 1, 2020

@yaochengji Thanks for contributing to pytorch/xla! Would you mind adding a bit more context here, e.g. what's your main use case for pytorch/xla + amp, and with this PR have you see meaningful perf gains on some common models? Thanks a lot!

@yaochengji
Copy link
Collaborator Author

yaochengji commented Dec 1, 2020

Hi @ailzhang , I mainly use torch/xla to accelerate pytorch training on GPUs. And for almost all the CV models, we'd better enable amp.

The torch/xla amp-training of resnet-50 on V100 with the tensorflow fix could reach 1150 images/s, which could get 50% speedup compared to pytorch-amp, refer here. And only a little slower than tf amp, refer here.

@JackCaoG
Copy link
Collaborator

JackCaoG commented Dec 1, 2020

wow, that performance improvement with xla:gpu is amazing. Copying files from pytorch doesn't sounds like a very good idea. If the amp logic is device agnostic, is there a way to make xla share the same code instead of copying it?

@yaochengji
Copy link
Collaborator Author

Hi @JackCaoG, I did some modification in pytorch then I could simplify the code change here.

@davidel
Copy link
Collaborator

davidel commented Dec 3, 2020

Sorry for the delay, I need a bit of time to take a look at this one.

@ailzhang ailzhang self-requested a review December 4, 2020 22:52
@yaochengji yaochengji force-pushed the add-amp branch 3 times, most recently from 6305eaa to 3d3662c Compare December 11, 2020 06:46
@ailzhang ailzhang requested a review from JackCaoG December 16, 2020 22:35
@JackCaoG
Copy link
Collaborator

Sorry for the delay, I will try to take a look soon

@yaochengji
Copy link
Collaborator Author

Thanks, @JackCaoG.

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yaochengji Thanks for contributing! I have not finished reviewing the whole thing, will circle back after vacation.

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly LGTM, some nits

scatter_dnums);
}

std::vector<xla::XlaOp> BuildAmpForachNonFiniteCheckAndUnscale(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you mentioned And my XlaOp could only handle the scaler one. Could you point me to where is that limitation coming from?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry could you give a bit more details? Why does ReduceAll return type matters here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The result of xla::AllReduce is a scalar, which could not be used together with a XlaOp with shape (1, ).

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BuildAmpForachNonFiniteCheckAndUnscale is typo for for each???

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Fixed.

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making the change, I think we are pretty close!

scatter_dnums);
}

std::vector<xla::XlaOp> BuildAmpForachNonFiniteCheckAndUnscale(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry could you give a bit more details? Why does ReduceAll return type matters here?

@yaochengji yaochengji force-pushed the add-amp branch 3 times, most recently from eb9c01d to efd803d Compare January 8, 2021 06:37
@JackCaoG
Copy link
Collaborator

JackCaoG commented Jan 9, 2021

@yaochengji Can you rebase this branch? I think you need #2685

@yaochengji
Copy link
Collaborator Author

@JackCaoG rebased, and CI could not pass until corresponding pytorch PR is merged.

@JackCaoG
Copy link
Collaborator

@yaochengji You can add a torch_patches/.torch_pin which will make CI use the specified branch of pytorch. https://github.com/pytorch/xla/pull/2718/files is an example of using a pin version of pytorch.

@JackCaoG
Copy link
Collaborator

gpu test failed with message

/var/lib/jenkins/.local/lib/python3.6/site-packages/torchvision/datasets/mnist.py:480: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  ../torch/csrc/utils/tensor_numpy.cpp:143.)
  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
1654784it [00:00, 2972434.46it/s]                          
8192it [00:00, 33651.25it/s]/opt/conda/lib/python3.6/multiprocessing/semaphore_tracker.py:143: UserWarning: semaphore_tracker: There appear to be 12 leaked semaphores to clean up at shutdown
  len(cache))
/opt/conda/lib/python3.6/multiprocessing/semaphore_tracker.py:143: UserWarning: semaphore_tracker: There appear to be 1 leaked semaphores to clean up at shutdown
  len(cache))
Traceback (most recent call last):
  File "test/test_train_mp_mnist_amp.py", line 194, in <module>
    xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS.num_cores)
  File "/opt/conda/lib/python3.6/site-packages/torch_xla-1.9-py3.6-linux-x86_64.egg/torch_xla/distributed/xla_multiprocessing.py", line 394, in spawn
    start_method=start_method)
  File "/tmp/pytorch/torch/multiprocessing/spawn.py", line 188, in start_processes
    while not context.join():
  File "/tmp/pytorch/torch/multiprocessing/spawn.py", line 136, in join
    signal_name=name
torch.multiprocessing.spawn.ProcessExitedException: process 0 terminated with signal SIGBUS

Exited with code exit status 1

@yaochengji Could you take a look? Thanks!

@yaochengji
Copy link
Collaborator Author

gpu test failed with message

/var/lib/jenkins/.local/lib/python3.6/site-packages/torchvision/datasets/mnist.py:480: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  ../torch/csrc/utils/tensor_numpy.cpp:143.)
  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
1654784it [00:00, 2972434.46it/s]                          
8192it [00:00, 33651.25it/s]/opt/conda/lib/python3.6/multiprocessing/semaphore_tracker.py:143: UserWarning: semaphore_tracker: There appear to be 12 leaked semaphores to clean up at shutdown
  len(cache))
/opt/conda/lib/python3.6/multiprocessing/semaphore_tracker.py:143: UserWarning: semaphore_tracker: There appear to be 1 leaked semaphores to clean up at shutdown
  len(cache))
Traceback (most recent call last):
  File "test/test_train_mp_mnist_amp.py", line 194, in <module>
    xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS.num_cores)
  File "/opt/conda/lib/python3.6/site-packages/torch_xla-1.9-py3.6-linux-x86_64.egg/torch_xla/distributed/xla_multiprocessing.py", line 394, in spawn
    start_method=start_method)
  File "/tmp/pytorch/torch/multiprocessing/spawn.py", line 188, in start_processes
    while not context.join():
  File "/tmp/pytorch/torch/multiprocessing/spawn.py", line 136, in join
    signal_name=name
torch.multiprocessing.spawn.ProcessExitedException: process 0 terminated with signal SIGBUS

Exited with code exit status 1

@yaochengji Could you take a look? Thanks!

I'm looking into it. It is due to I changed the inv_scale back to shape (1, ) on the pytorch side, and then it will bump into a broadcast error when multiply inv_scale with a non-scalar tensor. Do you know how to correct this in XLA primitives. I tried to apply xla::Broadcast manually or using xla::ReduceAll but it seemed not work.

@yaochengji
Copy link
Collaborator Author

I could fix inv_scale shape (1, ) issue by using xla::ReduceAll to change inv_scale to scalar first. But still got another problem when testing _amp_update_scale:

Invalid argument: Run-time shape mismatch for XRTExecute argument[0] (3804975852334223). Expected element_type: F32                                                                                             ment_type: F32
layout {
  format: DENSE
}
; got element_type: F32
dimensions: 1
layout {
  minor_to_major: 0
  format: DENSE
}
is_dynamic_dimension: false

         [[{{node XRTExecute}}]]

@yaochengji
Copy link
Collaborator Author

yaochengji commented Feb 18, 2021

@JackCaoG , I already fixed the shape (1, ) issue.

And I ran GPU_NUM_DEVICES=2 python3 test/test_train_mp_mnist.py on master branch, the process 0 terminated with signal SIGBUS error still occurs. Could help check it on your side?

BTW, GPU_NUM_DEVICES=2 python3 test/test_train_mp_mnist_amp.py --fake_data on add-amp branch could run successfully.

@JackCaoG
Copy link
Collaborator

@yaochengji thanks for the update, I will try to take a look later today. So the SIGBUS error is from the master? From the circleCI run, it seems like test_train_mnist.py --tidy passed but it failed on test_train_mp_mnist_amp.

@yaochengji
Copy link
Collaborator Author

@yaochengji thanks for the update, I will try to take a look later today. So the SIGBUS error is from the master? From the circleCI run, it seems like test_train_mnist.py --tidy passed but it failed on test_train_mp_mnist_amp.

Yes, it seems from the master branch. And I'm double checking on another machine in case of environment chaos.

@JackCaoG
Copy link
Collaborator

Thanks @yaochengji ! btw, did you ever try 4 gpu setup? Did you ever encounter the error described in #2758?

@yaochengji
Copy link
Collaborator Author

yaochengji commented Feb 18, 2021

#2758

I just reproduced the error on another clean machine on master branch.

The error of 2 gpus and 4 gpus setup are the same: process 0 terminated with signal SIGBUS.

1 gpu could pass, that is why test_train_mnist.py --tidy succeed.

@yaochengji
Copy link
Collaborator Author

To summarize: on master branch

  1. GPU_NUM_DEVICES=2 python3 test/test_train_mp_mnist.py failed
  2. GPU_NUM_DEVICES=1 python3 test/test_train_mp_mnist.py passed
  3. GPU_NUM_DEVICES=2 python3 test/test_train_mp_mnist.py --fake_data passed

@JackCaoG
Copy link
Collaborator

Thanks for the confirmation. The team is pretty busy with the upcoming 1.8 release and tpuvm public preview work. I will try to take a look before tmr. Since this error is not introduced by this pr, I think we can currently workaround it using --fake_data. @ailzhang wdyt?

Copy link
Contributor

@ailzhang ailzhang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can merge this to master first, I will take a look at the multi GPU test failure tmr.
Thanks @yaochengji for contributing!

@JackCaoG
Copy link
Collaborator

Great work @yaochengji , thanks for the contribution! My hand is a bit tight with tpuvm stuff right now but will try to review and import the XRT pr when I am a bit more free.

@JackCaoG JackCaoG merged commit e2aedb7 into pytorch:master Feb 19, 2021
@yaochengji
Copy link
Collaborator Author

@JackCaoG @ailzhang Thanks for your time reviewing the pr. I'm experimentally using torch/xla in my company and willing to contribute more if I could.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

lowering ATen Operation lowering REMOVE_TORCH_PIN

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants