Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix double broadcast with tch #1026

Merged
merged 3 commits into from
Dec 1, 2023
Merged

Fix double broadcast with tch #1026

merged 3 commits into from
Dec 1, 2023

Conversation

nathanielsimard
Copy link
Member

@dcvz Can you validate it fixes the problem you have with diffuser-rs?

@dcvz
Copy link
Contributor

dcvz commented Nov 30, 2023

@nathanielsimard I can confirm that the tests that were failing in my project, pass now 🚀

@@ -190,17 +190,30 @@ impl<E: TchElement> IntTensorOps<Self> for LibTorch<E> {
lhs: TchTensor<i64, D>,
rhs: TchTensor<i64, D>,
) -> TchTensor<i64, D> {
TchOps::div(lhs, rhs)
let copy = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this actually do? seems like you're just forcing the float div to be used?
I took out all the changes in this file, and the failing tests on my repo and here still pass

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah somehow I had to update this for the CI, I cast the i64 tensor to a f64 tensor without copying the data since it seems there are problems with integer division with tch.

@@ -30,6 +30,17 @@ mod tests {
assert_eq!(data_expected, data_actual);
}

#[test]
fn test_mul_broadcast_2_dims() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we care about adding a test for Int?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's necessary, even if tch supports it, in Burn we do not expose int matmul in our API

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a normal mul that is available to the int tensor, but theorically this test validates the broadcasting logic that is present for all binary operations.

Copy link
Member

@louisfd louisfd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@louisfd louisfd merged commit b0de56d into main Dec 1, 2023
10 checks passed
@nathanielsimard nathanielsimard deleted the fix/double-broadcast-tch branch December 12, 2023 14:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants