Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor] Fix shape mismatch in sdpa pattern matcher #115038

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
49 changes: 49 additions & 0 deletions test/inductor/test_cpu_repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -2463,6 +2463,55 @@ def fn(x1, x2):
)
self.assertEqual(metrics.generated_kernel_count, 1)

def test_attention_size_mismatch(self):
class Attention(torch.nn.Module):
def __init__(self, hidden_size, num_heads):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_size = hidden_size // num_heads
self.query = torch.nn.Linear(hidden_size, hidden_size)
self.key = torch.nn.Linear(hidden_size, hidden_size)
self.value = torch.nn.Linear(hidden_size, hidden_size)
self.inv_scale = torch.nn.Parameter(
torch.Tensor([1 / self.head_size**0.5]), requires_grad=False
)

def forward(self, x):
query = self.query(x)
key = self.key(x)
value = self.value(x)
(batch_size, seq_len, hidden_size) = query.size()
query = query.view(
batch_size, seq_len, self.num_heads, self.head_size
).permute(0, 2, 1, 3)
key = key.view(
batch_size, seq_len, self.num_heads, self.head_size
).permute(0, 2, 3, 1)
value = value.view(
batch_size, seq_len, self.num_heads, self.head_size
).permute(0, 2, 1, 3)
attention_weights = (
torch.matmul(query, key).div(self.inv_scale).softmax(dim=-1)
)
output = torch.matmul(attention_weights, value)
return output

torch.manual_seed(123)
hidden_size = 16
num_heads = 1
seq_len = 4
batch_size = 1
x = torch.randn(batch_size, seq_len, hidden_size)

func = Attention(hidden_size, num_heads).to("cpu")

with torch.no_grad():
res1 = func(x)
jit_func = torch.compile(func)
res2 = jit_func(x)
self.assertEqual(res1, res2)

def test_scalar_mul_bfloat16(self):
def f(x):
return torch.ops.aten.mul.Tensor(x, 1.7015043497085571)
Expand Down
10 changes: 9 additions & 1 deletion torch/_inductor/pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,15 @@ def check_fn(match: Match):
device=args[i].device,
requires_grad=grad,
)
specific_graph = trace_fn(search_fn, args)
try:
specific_graph = trace_fn(search_fn, args)
except RuntimeError as e:
log.info(
"Replacement pattern %s failed to apply due to shape mismatch: %s",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is shape mismatch the only exception that could be thrown here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be, since we already passed the shape-agnostic match.

search_fn.__name__,
e,
)
return False
specific_pattern = fx_to_pattern(
specific_graph,
argnames=argnames,
Expand Down