Skip to content

Commit

Permalink
[AOTI] Handle empty input args
Browse files Browse the repository at this point in the history
Summary: When the model takes no inputs, AOTInductor relies on checking weights to figure out which device to compile the model into. Currently recording buffer device type happens too late, and this PR fixes that.

ghstack-source-id: c625ba743a81e4021587c636b7da498b6dff3b63
Pull Request resolved: #114682
  • Loading branch information
desertfire committed Nov 28, 2023
1 parent cc7a969 commit cf2c1da
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 17 deletions.
15 changes: 14 additions & 1 deletion test/inductor/test_aot_inductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ def check_model(
constraints,
disable_constraint_solver,
)

self.assertTrue(same(actual, expected))


Expand Down Expand Up @@ -1386,6 +1385,20 @@ def forward(self, inp):
inputs = (torch.rand(4, 4, 4, 4, device=self.device),)
self.check_model(Model(4), inputs)

def test_no_args(self):
class Model(torch.nn.Module):
def __init__(self, m, n):
super().__init__()
self.weight = torch.nn.Parameter(
torch.randn(m, n),
)
self.alpha = torch.nn.Parameter(torch.randn(m, n))

def forward(self):
return self.weight * self.alpha

self.check_model(Model(6, 4), ())


common_utils.instantiate_parametrized_tests(AOTInductorTestsTemplate)

Expand Down
17 changes: 7 additions & 10 deletions torch/_inductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,9 +403,10 @@ def warn_fallback(self, name):
self._warned_fallback.add(name)
perf_hint_log.info("Using FallbackKernel: %s", name)

def add_device_idx(self, idx: Optional[int]):
if idx is not None:
self.device_idxs.add(idx)
def add_device_info(self, device: torch.device):
self.device_types.add(device.type)
if device.index is not None:
self.device_idxs.add(device.index)

@property
def fake_mode(self):
Expand Down Expand Up @@ -452,6 +453,7 @@ def register_buffer(self, buffer: ir.Buffer):
name = f"buf{len(self.buffers)}"
self.buffers.append(buffer)
self.name_to_buffer[name] = buffer
self.add_device_info(buffer.get_device())
return name

def register_list(self, buffer_names: List[str]):
Expand Down Expand Up @@ -576,8 +578,7 @@ def placeholder(self, target: str, args, kwargs):
)
self.graph_inputs[target] = tensor
self.graph_inputs_original[target] = tensor.data.data
self.device_types.add(example.device.type)
self.add_device_idx(example.device.index)
self.add_device_info(example.device)
return tensor

def call_function(self, target, args, kwargs):
Expand Down Expand Up @@ -910,10 +911,6 @@ def init_wrapper_code(self):
return

device_types = self.device_types.copy()
# In terms of some operations that don't have input tensors, we need to
# check the device of the buffers.
for buffer in self.buffers:
device_types.add(buffer.get_device().type)
device_types.discard("cpu")
# TODO(Eikan): Only support mixing cpu and other device now.
assert len(device_types) <= 1, "Does not support mixing {}".format(
Expand Down Expand Up @@ -946,7 +943,7 @@ def materialize(x):
else:
assert isinstance(
x, torch.Tensor
), "Unknown type when creating real inputs"
), "Unknown type when creating real inputs" + str(type(x))
return x

with torch.utils._python_dispatch._disable_current_modes():
Expand Down
6 changes: 2 additions & 4 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -4149,10 +4149,8 @@ def create(cls, x, device):
):
return x.constant_to_device(device)

V.graph.device_types.add(device.type)
V.graph.add_device_idx(device.index)
V.graph.device_types.add(x.get_device().type)
V.graph.add_device_idx(x.get_device().index)
V.graph.add_device_info(device)
V.graph.add_device_info(x.get_device())

developer_warning("DeviceCopy in input program")
return DeviceCopy(
Expand Down
3 changes: 1 addition & 2 deletions torch/_inductor/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2106,8 +2106,7 @@ def create_backend(self, device: torch.device):
assert (
device.type != "cuda" or device.index is not None
), f"{device} should have been normalized in lowering"
V.graph.device_types.add(device.type)
V.graph.add_device_idx(device.index)
V.graph.add_device_info(device)

device_scheduling = get_scheduling_for_device(device.type)
if device_scheduling is None:
Expand Down

0 comments on commit cf2c1da

Please sign in to comment.