Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable bucketized all-reduce for gradients #7216

Merged
merged 11 commits into from
Jun 14, 2024
Merged

Conversation

amithrm
Copy link
Collaborator

@amithrm amithrm commented Jun 7, 2024

This PR adds bucketing (aka coalescing) to all-reduce, to increase DMA utilization and reduce DMA overhead associated with small data transfers.

Replaces #6417 .

@jeffhataws
Copy link
Collaborator

Something wrong with the build @JackCaoG . Maybe it is one-off? I don't know how to restart the run though.

Copy link
Collaborator

@jeffhataws jeffhataws left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(removed)

import torch_xla.distributed.xla_backend
import torch.distributed as dist


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to set the ALLREDUCE_GRADIENTS_BUCKET_SIZE_MB envvar for this test?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, do you need to add this test to run_tests.sh?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea this test is not being run in CI

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need the env flag, the test runs the bucketized version directly

@JackCaoG
Copy link
Collaborator

JackCaoG commented Jun 7, 2024

Hmm this is weird

Attempt 1 of 5 failed with error: Unexpected token '<', "<!DOCTYPE "... is not valid JSON. Retrying request in 3000 ms...
Attempt 2 of 5 failed with error: Unexpected token '<', "<!DOCTYPE "... is not valid JSON. Retrying request in 4669 ms...
Attempt 3 of 5 failed with error: Unexpected token '<', "<!DOCTYPE "... is not valid JSON. Retrying request in 7065 ms...
Attempt 4 of 5 failed with error: Unexpected token '<', "<!DOCTYPE "... is not valid JSON. Retrying request in 13950 ms...
Error: Failed to FinalizeArtifact: Failed to make request after 5 attempts: Unexpected token '<', "<!DOCTYPE "... is not valid JSON

Let me rerun, if it still fails I will ask someone on our end to take a look. Sorry that CI gave you guys so much trouble,,

@JackCaoG
Copy link
Collaborator

JackCaoG commented Jun 7, 2024

cpu failures looks real

2024-06-07T21:55:07.5781389Z ======================================================================
2024-06-07T21:55:07.5782463Z ERROR: test_all_reduce_no_op_with_one_replica (__main__.TestExperimentalPjrtMultiCpu)
2024-06-07T21:55:07.5783741Z TestExperimentalPjrtMultiCpu.test_all_reduce_no_op_with_one_replica
2024-06-07T21:55:07.5784883Z ----------------------------------------------------------------------
2024-06-07T21:55:07.5787364Z concurrent.futures.process._RemoteTraceback: 
2024-06-07T21:55:07.5788147Z """
2024-06-07T21:55:07.5788624Z Traceback (most recent call last):
2024-06-07T21:55:07.5789762Z   File "/usr/local/lib/python3.10/concurrent/futures/process.py", line 246, in _process_worker
2024-06-07T21:55:07.5790972Z     r = call_item.fn(*call_item.args, **call_item.kwargs)
2024-06-07T21:55:07.5792188Z   File "/usr/local/lib/python3.10/concurrent/futures/process.py", line 205, in _process_chunk
2024-06-07T21:55:07.5793238Z     return [fn(*args) for args in chunk]
2024-06-07T21:55:07.5794366Z   File "/usr/local/lib/python3.10/concurrent/futures/process.py", line 205, in <listcomp>
2024-06-07T21:55:07.5795412Z     return [fn(*args) for args in chunk]
2024-06-07T21:55:07.5796704Z   File "/usr/local/lib/python3.10/site-packages/torch_xla/runtime.py", line 95, in wrapper
2024-06-07T21:55:07.5797661Z     return fn(*args, **kwargs)
2024-06-07T21:55:07.5798915Z   File "/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 78, in _run_thread_per_device
2024-06-07T21:55:07.5800088Z     replica_results = list(
2024-06-07T21:55:07.5801068Z   File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 621, in result_iterator
2024-06-07T21:55:07.5802133Z     yield _result_or_cancel(fs.pop())
2024-06-07T21:55:07.5803199Z   File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 319, in _result_or_cancel
2024-06-07T21:55:07.5804274Z     return fut.result(timeout)
2024-06-07T21:55:07.5805210Z   File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 458, in result
2024-06-07T21:55:07.5806190Z     return self.__get_result()
2024-06-07T21:55:07.5807167Z   File "/usr/local/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
2024-06-07T21:55:07.5808175Z     raise self._exception
2024-06-07T21:55:07.5809053Z   File "/usr/local/lib/python3.10/concurrent/futures/thread.py", line 58, in run
2024-06-07T21:55:07.5810078Z     result = self.fn(*self.args, **self.kwargs)
2024-06-07T21:55:07.5811369Z   File "/usr/local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 71, in _thread_fn
2024-06-07T21:55:07.5812441Z     return fn()
2024-06-07T21:55:07.5813878Z   File "/__w/xla/xla/pytorch/xla/test/pjrt/test_runtime_multi_cpu.py", line 136, in _all_reduce_hlo
2024-06-07T21:55:07.5815107Z     return torch_xla._XLAC._get_xla_tensors_hlo([reduced])
2024-06-07T21:55:07.5816583Z RuntimeError: Error while lowering: [UNKNOWN_SCALAR[]] xla::device_data, xla_shape=f32[3,3]***1,0***, dynamic_dims: (), device=CPU:0
2024-06-07T21:55:07.5818214Z Error: ./torch_xla/csrc/runtime/pjrt_computation_client.h:185 : Check failed: HasValue() 
2024-06-07T21:55:07.5819213Z *** Begin stack trace ***
2024-06-07T21:55:07.5819799Z 	tsl::CurrentStackTrace()
2024-06-07T21:55:07.5820587Z 	torch_xla::runtime::PjRtComputationClient::PjRtData::GetHandle()
2024-06-07T21:55:07.5822634Z 	torch_xla::LoweringContext::GetParameter(std::shared_ptr<torch::lazy::BackendData> const&, std::unordered_set<unsigned int, std::hash<unsigned int>, std::equal_to<unsigned int>, std::allocator<unsigned int> > const&)
2024-06-07T21:55:07.5824574Z 	torch_xla::DeviceData::Lower(torch_xla::LoweringContext*) const
2024-06-07T21:55:07.5825586Z 	torch_xla::LoweringContext::LowerNode(torch::lazy::Node const*)
2024-06-07T21:55:07.5826615Z 	torch_xla::LoweringContext::GetOutputOp(torch::lazy::Output const&)
2024-06-07T21:55:07.5827919Z 	torch_xla::LoweringContext::AddResult(torch::lazy::Output const&)
2024-06-07T21:55:07.5829336Z 	torch_xla::DumpUtil::ToHlo(c10::ArrayRef<torch::lazy::Value>, torch::lazy::BackendDevice const&, torch_xla::EmitMode)
2024-06-07T21:55:07.5832459Z 	torch_xla::XLAGraphExecutor::DumpHloComputation(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > > const&, torch_xla::EmitMode)
2024-06-07T21:55:07.5834849Z 	
2024-06-07T21:55:07.5835264Z 	
2024-06-07T21:55:07.5835675Z 	
2024-06-07T21:55:07.5836135Z 	_PyObject_MakeTpCall
2024-06-07T21:55:07.5836690Z 	_PyEval_EvalFrameDefault
2024-06-07T21:55:07.5837242Z 	
2024-06-07T21:55:07.5837653Z 	
2024-06-07T21:55:07.5838182Z 	_PyEval_EvalFrameDefault
2024-06-07T21:55:07.5838725Z 	
2024-06-07T21:55:07.5839171Z 	_PyEval_EvalFrameDefault
2024-06-07T21:55:07.5839704Z 	
2024-06-07T21:55:07.5840157Z 	_PyEval_EvalFrameDefault
2024-06-07T21:55:07.5840717Z 	
2024-06-07T21:55:07.5841162Z 	_PyEval_EvalFrameDefault
2024-06-07T21:55:07.5841715Z 	
2024-06-07T21:55:07.5842160Z 	_PyEval_EvalFrameDefault
2024-06-07T21:55:07.5842693Z 	
2024-06-07T21:55:07.5843135Z 	_PyEval_EvalFrameDefault
2024-06-07T21:55:07.5843670Z 	
2024-06-07T21:55:07.5844077Z 	
2024-06-07T21:55:07.5844488Z 	
2024-06-07T21:55:07.5844901Z 	
2024-06-07T21:55:07.5845335Z 	
2024-06-07T21:55:07.5845742Z 	
2024-06-07T21:55:07.5846154Z 	clone
2024-06-07T21:55:07.5846616Z *** End stack trace ***
2024-06-07T21:55:07.5847298Z buffer with shape f32[3,3] on device CPU:0 is deleted

seems like one of the buffer has been aliased(hence the buffer is deleted) but is being referenced again.

@JackCaoG
Copy link
Collaborator

JackCaoG commented Jun 7, 2024

Ci might be glictching.. You should add your test to test.sh with correct env var and rerun

@jeffhataws
Copy link
Collaborator

@JackCaoG the error in the GPU run for torch_mp_op is not very clear. Do you know why it is failing?

@JackCaoG
Copy link
Collaborator

it is relevant, we can ignore. However the test in this pr is not being run.

@jeffhataws
Copy link
Collaborator

Hi @JackCaoG seems like there's some CI infra issue?

@JackCaoG
Copy link
Collaborator

yea github action is glitching, this affects all github projects.

@jeffhataws jeffhataws changed the title Bucketing gradients Enable bucketized all-reduce for gradients Jun 14, 2024
@JackCaoG JackCaoG merged commit 28f9887 into master Jun 14, 2024
21 of 22 checks passed
@ManfeiBai
Copy link
Collaborator

ManfeiBai commented Oct 2, 2024

Hi, I saw we want to backport this PR into release 2.4, and mentioned in #7242, but looks like not backported in release 2.4, please correct me if I'm wrong

now we are preparing 2.5 release, and would you mind help to describe more context about this PR's feature/use-case for or inspired from? @amithrm

@jeffhataws
Copy link
Collaborator

Hi, I saw we want to backport this PR into release 2.4, and mentioned in #7242, but looks like not backported in release 2.4, please correct me if I'm wrong

now we are preparing 2.5 release, and would you mind help to describe more context about this PR's feature/use-case for or inspired from? @amithrm

@ManfeiBai thanks for checking. Yeah we can drop the 2.4 backport request.

This change adds bucketing of all-reduce so that it prevents small tensor all-reduces which are inefficient for DMAs. The bucketing aggregates/coaelesce small tensors until a specified size, and to one all-reduce on the aggregate. This feature is already part of all-gather/reduce-scatter.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants