Skip to content

Conversation

Sibylau
Copy link
Contributor

@Sibylau Sibylau commented Sep 19, 2025

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 19, 2025
Sibylau added a commit that referenced this pull request Sep 19, 2025
stack-info: PR: #635, branch: Sibylau/stack/3
@Sibylau
Copy link
Contributor Author

Sibylau commented Sep 19, 2025

  • code generation test passes with python -m unittest test_examples.TestExamples.test_gather_gemv
image image

Copy link
Contributor

@yf225 yf225 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @Sibylau ! left some nit comments, and also might need to rebase to fix conflicts with main branch


def baseline_gather_gemv(w: Tensor, idx: Tensor, x: Tensor) -> Tensor:
"""PyTorch baseline implementation."""
# A hard-wired fix for tritonbench baseline: w[idx].to(x.dtype) @ x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe can remove this comment

for idx_val in idx.tolist():
outputs.append(w[idx_val].to(x.dtype) @ x)
return torch.stack(outputs, dim=0)
# return torch.stack([w[idx[0]].to(x.dtype) @ x, w[idx[1]].to(x.dtype) @ x])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe remove?

args,
expected(*args),
fn_name="gather_gemv",
block_sizes=[64, 64],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not exactly sure why AMD job fails.. but I suspect changing the block sizes to some smaller value might help

@Sibylau
Copy link
Contributor Author

Sibylau commented Sep 19, 2025

@yf225 The CI test on rocm fails due to code mismatch:

-         dot = tl.dot(tl.cast(gathered, tl.float32), tl.cast(load_1, tl.float32), input_precision='tf32', out_dtype=tl.float32)
+         dot = tl.split(tl.permute(tl.reshape(tl.split(tl.permute(tl.reshape(tl.split(tl.permute(tl.reshape(tl.split(tl.permute(tl.reshape(tl.split(tl.permute(tl.reshape(tl.dot(tl.reshape(tl.permute(tl.join(tl.cast(gathered, tl.float32), 

Do you know why the generated code for AMD is different? can i put a @skipIfRocm for this kernel test?

)
)

@skipIfRocm("failure on rocm")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of using skipIfRocm which skips the whole test including the output equality check, maybe we can add a skip_rocm: bool arg to def assertExpectedJournal that skips the journal check if the device is rocm. And to check the device is rocm, we can add something like this to _testing.py:

def is_rocm() -> bool:
    """Return True if running on ROCm (AMD GPU)."""
    return (
        triton.runtime.driver.active.get_current_target().backend == "hip"
        and DEVICE.type == "cuda"
    )

(Please feel free to do this in a follow-up PR. Thanks!)

stack-info: PR: #635, branch: Sibylau/stack/3
@Sibylau Sibylau merged commit 7d09e0a into main Sep 23, 2025
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants