Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 1 addition & 13 deletions backends/cuda/cuda_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,18 +117,6 @@ def preprocess(
if node.op == "placeholder" and node.name in user_input_names:
user_input_placeholders.append(node.meta["val"])

# Create pseudo user inputs using torch.randn and metadata from input placeholders
faked_user_inputs = []
for placeholder in user_input_placeholders:
if isinstance(placeholder, torch.Tensor):
# Generate fake input with same shape and dtype, on CUDA
fake_input = torch.randn(
placeholder.shape, dtype=placeholder.dtype, device="cuda"
)
faked_user_inputs.append(fake_input)

faked_user_inputs = tuple(faked_user_inputs)

options: dict[str, typing.Any] = {
# Embed CUDA kernel binaries directly into the compiled shared object
"aot_inductor.embed_kernel_binary": True,
Expand All @@ -145,7 +133,7 @@ def preprocess(
}

with collect_unsupported_fallback_kernels():
so_path = torch._inductor.aot_compile(edge_program_module, faked_user_inputs, options=options) # type: ignore[arg-type]
so_path = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type]
if len(missing_fallback_kernels) > 0:
formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels))
raise RuntimeError(
Expand Down
Loading