Skip to content

Commit

Permalink
Update base for Update on "Make CI error on inductor fallback when de…
Browse files Browse the repository at this point in the history
…comp is available"


Fixes #99446 

Remove the warning, as that annoyed end-users who don't know what to do about it.

Instead, try to hold the line by preventing any decomp from being added without making
the corresponding change to inductor's fallbacks.

Note: we probably still need to better document how to update inductor's decomps,
for now it's pretty much "go ask the inductor team for advice"

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
  • Loading branch information
wconstab committed Apr 20, 2023
2 parents 6bb2149 + 7c3fa5c commit 3be4c1e
Show file tree
Hide file tree
Showing 55 changed files with 873 additions and 150 deletions.
2 changes: 1 addition & 1 deletion .circleci/scripts/binary_windows_build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ export CUDA_VERSION="${DESIRED_CUDA/cu/}"
export USE_SCCACHE=1
export SCCACHE_BUCKET=ossci-compiler-cache
export SCCACHE_IGNORE_SERVER_IO_ERROR=1
export VC_YEAR=2022
export VC_YEAR=2019

if [[ "${DESIRED_CUDA}" =~ cu1[1-2][0-9] ]]; then
export BUILD_SPLIT_CUDA=ON
Expand Down
2 changes: 1 addition & 1 deletion .circleci/scripts/binary_windows_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ set -eux -o pipefail
source "${BINARY_ENV_FILE:-/c/w/env}"

export CUDA_VERSION="${DESIRED_CUDA/cu/}"
export VC_YEAR=2022
export VC_YEAR=2019

pushd "$BUILDER_ROOT"

Expand Down
15 changes: 15 additions & 0 deletions aten/src/ATen/CPUGeneratorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,21 @@ void CPUGeneratorImpl::set_current_seed(uint64_t seed) {
engine_ = mt19937(seed);
}

/**
* Sets the offset of RNG state.
* See Note [Acquire lock when using random generators]
*/
void CPUGeneratorImpl::set_offset(uint64_t offset) {
TORCH_CHECK(false, "CPU Generator does not use offset");
}

/**
* Gets the current offset of CPUGeneratorImpl.
*/
uint64_t CPUGeneratorImpl::get_offset() const {
TORCH_CHECK(false, "CPU Generator does not use offset");
}

/**
* Gets the current seed of CPUGeneratorImpl.
*/
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/CPUGeneratorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ struct TORCH_API CPUGeneratorImpl : public c10::GeneratorImpl {
// CPUGeneratorImpl methods
std::shared_ptr<CPUGeneratorImpl> clone() const;
void set_current_seed(uint64_t seed) override;
void set_offset(uint64_t offset) override;
uint64_t get_offset() const override;
uint64_t current_seed() const override;
uint64_t seed() override;
void set_state(const c10::TensorImpl& new_state) override;
Expand Down
7 changes: 7 additions & 0 deletions aten/src/ATen/core/Generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ struct TORCH_API Generator {
}

void set_current_seed(uint64_t seed) { impl_->set_current_seed(seed); }
// Sets the offset of Generator state to the desired offset. This is currently
// supported for only Philox based Generators, i.e., CUDA and MPS.
void set_offset(uint64_t offset) { impl_->set_offset(offset); }

// Returns the offset of Generator state. This is currently supported for only
// Philox based Generators, i.e., CUDA and MPS.
uint64_t get_offset() const { return impl_->get_offset(); }

uint64_t current_seed() const { return impl_->current_seed(); }

Expand Down
17 changes: 17 additions & 0 deletions aten/src/ATen/core/PhiloxRNGEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,23 @@ class philox_engine {
STATE = 0;
}

/**
* Set the offset field of Philox Generator to the desired offset.
*/
C10_HOST_DEVICE inline void set_offset(uint64_t offset) {
counter_[0] = static_cast<uint32_t>(offset);
counter_[1] = static_cast<uint32_t>(offset >> 32);
}

/**
* Gets the current offset of the Philox Generator.
*/
C10_HOST_DEVICE uint64_t get_offset() const {
uint64_t lo = static_cast<uint64_t>(counter_[0]);
uint64_t hi = static_cast<uint64_t>(counter_[1]) << 32;
return lo | hi;
}

/**
* Produces a unique 32-bit pseudo random number on every invocation. Bookeeps state to avoid waste.
*/
Expand Down
21 changes: 21 additions & 0 deletions aten/src/ATen/cuda/CUDAGeneratorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,27 @@ void CUDAGeneratorImpl::set_current_seed(uint64_t seed) {
no_reset_rnn_state_.clear();
}

/**
* Sets the offset to be used by curandStatePhilox4_32_10
*
* See Note [Acquire lock when using random generators]
*/
void CUDAGeneratorImpl::set_offset(uint64_t offset) {
at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::set_offset");
philox_offset_per_thread_ = offset;
no_reset_rnn_state_.clear();
}

/**
* Gets the current offset of CUDAGeneratorImpl.
*/
uint64_t CUDAGeneratorImpl::get_offset() const {
// Debatable if get_offset() should be allowed in captured regions.
// Conservatively disallow it for now.
at::cuda::assertNotCapturing("Cannot call CUDAGeneratorImpl::get_offset");
return philox_offset_per_thread_;
}

#define CAPTURE_DEFAULT_GENS_MSG \
"In regions captured by CUDA graphs, you may only use the default CUDA RNG " \
"generator on the device that's current when capture begins. " \
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/cuda/CUDAGeneratorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ struct TORCH_CUDA_CPP_API CUDAGeneratorImpl : public c10::GeneratorImpl {
// CUDAGeneratorImpl methods
std::shared_ptr<CUDAGeneratorImpl> clone() const;
void set_current_seed(uint64_t seed) override;
void set_offset(uint64_t offset) override;
uint64_t get_offset() const override;
uint64_t current_seed() const override;
uint64_t seed() override;
void set_state(const c10::TensorImpl& new_state) override;
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/mps/MPSGeneratorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ struct TORCH_API MPSGeneratorImpl : public c10::GeneratorImpl {
// MPSGeneratorImpl methods
std::shared_ptr<MPSGeneratorImpl> clone() const;
void set_current_seed(uint64_t seed) override;
void set_offset(uint64_t offset) override;
uint64_t get_offset() const override;
uint64_t current_seed() const override;
uint64_t seed() override;
void set_state(const c10::TensorImpl& new_state) override;
Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/mps/MPSGeneratorImpl.mm
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ Generator createMPSGenerator(uint64_t seed_val) {
engine_.reset_state(seed);
}

void MPSGeneratorImpl::set_offset(uint64_t offset) {
engine_.set_offset(offset);
}

uint64_t MPSGeneratorImpl::get_offset() const {
return engine_.get_offset();
}

uint64_t MPSGeneratorImpl::current_seed() const {
return data_.seed;
}
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/test/cpu_rng_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ struct TestCPUGenerator : public c10::GeneratorImpl {
void set_next_float_normal_sample(c10::optional<float> randn) { next_float_normal_sample_ = randn; }
void set_next_double_normal_sample(c10::optional<double> randn) { next_double_normal_sample_ = randn; }
void set_current_seed(uint64_t seed) override { throw std::runtime_error("not implemented"); }
void set_offset(uint64_t offset) override { throw std::runtime_error("not implemented"); }
uint64_t get_offset() const override { throw std::runtime_error("not implemented"); }
uint64_t current_seed() const override { throw std::runtime_error("not implemented"); }
uint64_t seed() override { throw std::runtime_error("not implemented"); }
void set_state(const c10::TensorImpl& new_state) override { throw std::runtime_error("not implemented"); }
Expand Down
2 changes: 2 additions & 0 deletions c10/core/GeneratorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ struct C10_API GeneratorImpl : public c10::intrusive_ptr_target {

// Common methods for all generators
virtual void set_current_seed(uint64_t seed) = 0;
virtual void set_offset(uint64_t offset) = 0;
virtual uint64_t get_offset() const = 0;
virtual uint64_t current_seed() const = 0;
virtual uint64_t seed() = 0;
virtual void set_state(const c10::TensorImpl& new_state) = 0;
Expand Down
15 changes: 13 additions & 2 deletions functorch/experimental/_cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
_pop_mode_temporarily,
)
from torch.utils._pytree import tree_flatten
from torch._dynamo.exc import CondOpArgsMismatchError


@dataclass
Expand Down Expand Up @@ -56,12 +57,22 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):

flat_true_outs, _ = pytree.tree_flatten(true_outs)
flat_false_outs, _ = pytree.tree_flatten(false_outs)
assert(len(flat_true_outs) == len(flat_false_outs))
if len(flat_true_outs) != len(flat_false_outs):
raise CondOpArgsMismatchError(
f"Expected to return same number of outputs but got:"
f"\n {true_fn.__name__} returns {len(flat_true_outs)} item(s)"
f"\n {false_fn.__name__} returns {len(flat_false_outs)} item(s)"
)

for i in range(0, len(flat_true_outs)):
true_out = flat_true_outs[i]
false_out = flat_false_outs[i]
assert true_out.meta['tensor_meta'] == false_out.meta['tensor_meta']
if true_out.meta['tensor_meta'] != false_out.meta['tensor_meta']:
raise CondOpArgsMismatchError(
f"Expected each tensor to have same metadata but got:"
f"\n {true_fn.__name__} returns {true_out.meta['tensor_meta']}"
f"\n {false_fn.__name__} returns {false_out.meta['tensor_meta']}"
)

# There are probably better ways - I know that create_arg has some self incrementing name
# magic to it, but since we explicitly have to get the name for register_module,
Expand Down
2 changes: 2 additions & 0 deletions test/cpp_extensions/rng_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ struct TestCPUGenerator : public c10::GeneratorImpl {
uint32_t random() { return static_cast<uint32_t>(value_); }
uint64_t random64() { return value_; }
void set_current_seed(uint64_t seed) override { throw std::runtime_error("not implemented"); }
void set_offset(uint64_t offset) override { throw std::runtime_error("not implemented"); }
uint64_t get_offset() const override { throw std::runtime_error("not implemented"); }
uint64_t current_seed() const override { throw std::runtime_error("not implemented"); }
uint64_t seed() override { throw std::runtime_error("not implemented"); }
void set_state(const c10::TensorImpl& new_state) override { throw std::runtime_error("not implemented"); }
Expand Down
14 changes: 7 additions & 7 deletions test/distributed/_spmd/test_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,10 @@ def _ddp_train_step(model, optim, batch):
foreach=(not use_fused_optimizer),
fused=use_fused_optimizer,
)
for _ in range(num_iters):
for i in range(num_iters):
batch = torch.randn(batch_size, dim).cuda()
out = train_step(model, optim, batch)
kwargs = {} if i < num_iters - 1 else {"last_train_step": True}
out = train_step(model, optim, batch, **kwargs)
ddp_out = _ddp_train_step(ddp_model, ddp_optim, batch)
self.assertEqual(list(ddp_model.parameters()), list(model.parameters()))

Expand All @@ -120,7 +121,7 @@ def test_basic_transformation(self):
dim = 100
num_iters = 5

@compile(gm_transformation=GraphModuleTransformation(num_iters=num_iters))
@compile(gm_transformation=GraphModuleTransformation())
def train_step(model, optim, batch):
model(batch).sum().backward()
optim.step()
Expand All @@ -140,7 +141,7 @@ def test_inductor(self):

@compile(
gm_transformation=GraphModuleTransformation(
num_iters=num_iters, enable_inductor=True, dump_graphs=True
enable_inductor=True, dump_graphs=True
)
)
def train_step(model, optim, batch):
Expand All @@ -162,7 +163,6 @@ def test_graph_optimization_with_foreach(self):

@compile(
gm_transformation=GraphModuleTransformation(
num_iters=num_iters,
enable_graph_optimization=True,
dump_graphs=False,
)
Expand All @@ -184,7 +184,6 @@ def test_graph_optimization_with_fused(self):

@compile(
gm_transformation=GraphModuleTransformation(
num_iters=num_iters,
enable_graph_optimization=True,
dump_graphs=False,
)
Expand Down Expand Up @@ -218,6 +217,7 @@ def my_transformation(gm):
gm.graph.eliminate_dead_code()
gm.recompile()
self.assertEquals(len(get_all_fused_optimizer_blocks(gm, "_fused_adam")), 2)
gm.finalize_setup()
return gm

@compile(gm_transformation=my_transformation)
Expand All @@ -244,7 +244,7 @@ def my_transformation(gm):
schedule_comm_wait(gm)
remove_copy_from_optimizer(gm)
iter_move_grads_and_optimizers(gm, "all_reduce_default_1", "relu")
gm.setup(num_iters)
gm.finalize_setup()
return gm

@compile(gm_transformation=my_transformation)
Expand Down
3 changes: 0 additions & 3 deletions test/dynamo/test_after_aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ def strip_trailing_whitespace(r):


class TestAfterAot(torch._dynamo.test_case.TestCase):
def tearDown(self):
assert not torch.cuda.is_initialized()

def test_save_graph_repro(self):
return
buf = io.StringIO()
Expand Down
8 changes: 8 additions & 0 deletions test/dynamo/test_comptime.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,14 @@ def _(ctx):
}
-
global '' DETERMINISTIC_ALGORITHMS
{
'guard_types': None,
'code': None,
'obj_weakref': None
'guarded_class': None
}
-
global '' DEFAULT_DEVICE
{
'guard_types': None,
'code': None,
Expand Down

0 comments on commit 3be4c1e

Please sign in to comment.