Skip to content

Conversation

yf225
Copy link
Contributor

@yf225 yf225 commented Oct 3, 2025

Simplify A[tile_k.begin * 2 : tile_k.begin * 2 + tile_k.block_size * 2] to A[0:tile_k.block_size * 2], so that the tile_k.block_size symbol will be reused in the output shape to avoid downstream symbol mismatch errors.

Fixes #753.

@yf225 yf225 requested review from jansel and oulgen October 3, 2025 23:34
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 3, 2025
@yf225 yf225 force-pushed the int4_gemm_symbol_error branch from 9ec21d9 to 24a04ce Compare October 3, 2025 23:37
@yf225 yf225 marked this pull request as draft October 4, 2025 00:04
@yf225 yf225 force-pushed the int4_gemm_symbol_error branch 6 times, most recently from 017543b to a8ac5c0 Compare October 7, 2025 17:42
@yf225 yf225 force-pushed the int4_gemm_symbol_error branch 9 times, most recently from 601ebd4 to 4ac6c7d Compare October 7, 2025 19:22
self.assertExpectedJournal(code)

expected = y_true[:, : y_pred.size(0)].sum() / y_pred.size(0)
expected = y_true[:, :].sum() / y_pred.size(0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original expected value is buggy - to match the intended kernel behavior, it should have been a sum on y_true[:, :] not y_true[:, : y_pred.size(0)].

@yf225 yf225 marked this pull request as ready for review October 7, 2025 19:25
@yf225 yf225 changed the title Apply simplification to range indexing in order to reuse symbols Apply simplification to range indexing in order to reuse block size symbols Oct 7, 2025
@yf225 yf225 mentioned this pull request Oct 7, 2025
@yf225 yf225 force-pushed the int4_gemm_symbol_error branch from 4ac6c7d to 893eee9 Compare October 7, 2025 19:45
@yf225 yf225 force-pushed the int4_gemm_symbol_error branch from 893eee9 to 7ecaeb0 Compare October 7, 2025 19:46
@yf225 yf225 merged commit ebbd2c4 into main Oct 7, 2025
13 checks passed
@yf225 yf225 deleted the int4_gemm_symbol_error branch October 7, 2025 20:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Mismatch reduction dim in inner loop using torch.mm

3 participants