Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2114,6 +2114,36 @@ def fn(x, y):

self.common(fn, (torch.randn(4), torch.randn(4)), check_lowp=False)

@requires_multigpu()
def test_recompile_on_index(self):
torch.set_float32_matmul_precision("high")

def gemm(x, y):
return x @ y

failed_guard = None

def fail(guard):
nonlocal failed_guard
failed_guard = guard

gemm_opt = torch._dynamo.optimize("inductor", guard_fail_fn=fail)(gemm)

x0 = torch.randn(1024, 1024, device="cpu:0")
y0 = torch.randn(1024, 1024, device="cpu:0")

gemm_opt(x0, y0)

x1 = torch.randn(1024, 1024, device="cpu:1")
y1 = torch.randn(1024, 1024, device="cpu:1")

gemm_opt(x1, y1)
self.assertTrue(failed_guard is not None)
self.assertTrue(
"tensor 'x' Tensor device index mismatch. Expected device index to be"
in failed_guard.reason
)

def test_unbind(self):
def fn(a):
return torch.unbind(a), torch.unbind(a, -1)
Expand Down
11 changes: 11 additions & 0 deletions torch/csrc/dynamo/guards.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class TensorCheck {
: pytype(pt),
dispatch_key_(state.apply(v.key_set()).raw_repr()),
dtype_(v.dtype().toScalarType()),
device_index_(v.device().index()),
requires_grad_(state.grad_mode_enabled && v.requires_grad()),
dynamic_shapes_(dynamic_shapes) {
auto ndim = v.ndimension();
Expand All @@ -46,6 +47,7 @@ class TensorCheck {
bool check(const LocalState& state, const at::Tensor& v) {
if (dispatch_key_ != state.apply(v.key_set()).raw_repr() ||
dtype_ != v.dtype().toScalarType() ||
device_index_ != v.device().index() ||
requires_grad_ != (state.grad_mode_enabled && v.requires_grad())) {
return false;
}
Expand Down Expand Up @@ -85,6 +87,11 @@ class TensorCheck {
fail_reason << "dtype mismatch. expected " << dtype_ << ", actual "
<< v.dtype().toScalarType();
return fail_reason.str();
} else if (device_index_ != v.device().index()) {
fail_reason
<< "Tensor device index mismatch. Expected device index to be "
<< device_index_ << ", actual " << v.device().index();
return fail_reason.str();
} else if (
requires_grad_ != (state.grad_mode_enabled && v.requires_grad())) {
// return fmt::format("tensor requires_grad mismatch. expected {}",
Expand Down Expand Up @@ -128,6 +135,10 @@ class TensorCheck {
private:
uint64_t dispatch_key_; // DispatchKeySet includes device/layout
at::ScalarType dtype_;
// Note(voz): While dispatch_key_ is sufficiently representative of a device
// In that keys are more granular AND device specific - they do not
// necessarily capture device indices correctly.
at::DeviceIndex device_index_;
bool requires_grad_;
bool dynamic_shapes_;
std::vector<int64_t> sizes_;
Expand Down