Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem with where function #28

Open
zarif98sjs opened this issue Mar 16, 2024 · 4 comments
Open

Problem with where function #28

zarif98sjs opened this issue Mar 16, 2024 · 4 comments

Comments

@zarif98sjs
Copy link

zarif98sjs commented Mar 16, 2024

Shouldn't the where function be this?

def where(q, a, b):
    "Use this function to replace an if-statement."
    return (q * a) + (torch.logical_not(q)) * b

Otherwise if we use ~q, technically isn't that incorrect according to the desired function outcome?

If we used ~q,
where(arange(4) * 0, 0, 1) returns tensor([-1, -1, -1, -1]).
But the desired output should be tensor([1, 1, 1, 1])

@shunzh
Copy link

shunzh commented Apr 22, 2024

I agree. ~ is bitwise NOT. So the behavior is unexpected if q is a list of integers.

@srush
Copy link
Owner

srush commented Apr 23, 2024

Oops, will fix if I do a new version.

@zarif98sjs
Copy link
Author

Ah, nice! When creating the issue I was wondering why nobody noticed all these years 😅 Can send a PR if you want

@srush
Copy link
Owner

srush commented Apr 23, 2024

no I should just do a v2 this summer. lots of small fixes abound

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants