Skip to content

Commit 34042a9

Browse files
ksivamanpytorchmergebot
authored andcommitted
Change intra-graph offset dtype to uint64_t (#164515)
Even though `offset_intragraph_` only tracks RNG consumption within a single graph replay, we have observed that the 32bit storage for these offsets is easy to overshoot, especially for cases with big CUDA graph captures including kernels that are generating a large amount of random numbers. Pull Request resolved: #164515 Approved by: https://github.com/eee4017, https://github.com/eqy
1 parent 9d1ab4f commit 34042a9

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

aten/src/ATen/cuda/CUDAGeneratorImpl.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ void CUDAGeneratorState::increase(uint64_t increment) {
109109
offset_intragraph_ % 4 == 0, "RNG offset must be a multiple of 4.");
110110
// Ensures the increment does not cause overflow.
111111
TORCH_INTERNAL_ASSERT(
112-
offset_intragraph_ <= std::numeric_limits<uint32_t>::max() - increment,
112+
offset_intragraph_ <= std::numeric_limits<uint64_t>::max() - increment,
113113
"Increment causes overflow in the offset value.");
114114
offset_intragraph_ += increment;
115115
} else {
@@ -461,7 +461,7 @@ void CUDAGeneratorImpl::unregister_graph(cuda::CUDAGraph* graph) {
461461
*/
462462
PhiloxCudaState CUDAGeneratorImpl::philox_cuda_state(uint64_t increment) {
463463
if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) {
464-
uint32_t offset = state_->offset_intragraph_;
464+
uint64_t offset = state_->offset_intragraph_;
465465
state_->increase(increment);
466466
return PhiloxCudaState(
467467
state_->seed_extragraph_.data_ptr<int64_t>(),

aten/src/ATen/cuda/CUDAGeneratorImpl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ struct CUDAGraph;
9696
struct CUDAGeneratorState : public c10::intrusive_ptr_target {
9797
uint64_t seed_;
9898
uint64_t philox_offset_per_thread_;
99-
uint32_t offset_intragraph_;
99+
uint64_t offset_intragraph_;
100100
bool capturing_{};
101101
std::unordered_set<cuda::CUDAGraph*> registered_graphs_;
102102
at::TensorBase seed_extragraph_{};
@@ -105,7 +105,7 @@ struct CUDAGeneratorState : public c10::intrusive_ptr_target {
105105
CUDAGeneratorState(
106106
uint64_t seed = default_rng_seed_val,
107107
uint64_t philox_offset_per_thread = 0,
108-
uint32_t offset_intragraph = 0)
108+
uint64_t offset_intragraph = 0)
109109
: seed_(seed),
110110
philox_offset_per_thread_(philox_offset_per_thread),
111111
offset_intragraph_(offset_intragraph) {}

aten/src/ATen/cuda/detail/PhiloxCudaStateRaw.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ struct PhiloxCudaState {
1919
// Called if graph capture is underway
2020
PhiloxCudaState(int64_t* seed,
2121
int64_t* offset_extragraph,
22-
uint32_t offset_intragraph) {
22+
uint64_t offset_intragraph) {
2323
seed_.ptr = seed;
2424
offset_.ptr = offset_extragraph;
2525
offset_intragraph_ = offset_intragraph;
@@ -36,7 +36,7 @@ struct PhiloxCudaState {
3636

3737
Payload seed_{};
3838
Payload offset_{};
39-
uint32_t offset_intragraph_ = 0;
39+
uint64_t offset_intragraph_ = 0;
4040
bool captured_ = false;
4141
};
4242

0 commit comments

Comments
 (0)