Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Output of batch doesn't contain the output of a subet of a batch #31919

Closed
praateekmahajan opened this issue Jan 7, 2020 · 5 comments
Closed

Comments

@praateekmahajan
Copy link

@praateekmahajan praateekmahajan commented Jan 7, 2020

Questions and Help

Might be an obvious question to some, but I am really struggling with understanding what could go wrong.

If we pass a batch input to a function, shouldn't the output of the batch also have the output of the subset of a batch?

Programmatically what I mean is

f(batch)[:k] == f(batch[:k])

Shouldn't this be always true?

I try to reproduce the bug where f is a Linear layer without bias (and also the barebone matmul operation).

I can imagine someone saying that it's not safe to compare floats.

But then how come this is always true that :

f(batch)[:batch] == f(batch[:batch])

Wrote a function here to reproduce the error

def check(batch, input_dim, output_dim, topk=1):
  """
  The output of f(batch)[:topk] is different than f(batch[:topk]).
  
  This functions tries to check that for different input_dimension and output_dimension.
  
  f here is a linear layer (or more barebone version i.e a matmul)
  """
  
  # input of batch x input_dim
  x = torch.randn(batch, input_dim)
  # weight matrix of output_dim x input_dim (used for matmul to replicate linear layer results)
  w = torch.randn(output_dim, input_dim)
  # linear layer that takes input of input_dim and spits out output_dim.. (set to bias to false for simplicity)
  f = nn.Linear(input_dim, output_dim, bias=False)
  
  # this is a boolean to compare the output of linear layer
  out1 = (f(x)[:topk] == f(x[:topk])).all()
  # this is a boolean to compare the output of matmul layer
  out2 = (x.matmul(w.t())[:topk] == x[:topk].matmul(w.t())).all()
  
  # the output should be the same for both linear and matmul
  return (out1 & out2).item()
check(batch=300, input_dim=3, output_dim=100, topk=300) # True
check(batch=300, input_dim=3, output_dim=100, topk=1)   # False
@SsnL

This comment has been minimized.

Copy link
Collaborator

@SsnL SsnL commented Jan 7, 2020

you should check that the difference is in an eps. precision is a real issue.

@praateekmahajan

This comment has been minimized.

Copy link
Author

@praateekmahajan praateekmahajan commented Jan 7, 2020

@SsnL are you saying that the issue is due to floating point precision issues? I am wondering how come this is always true then?

f(batch)[:batch] == f(batch[:batch])
@ptrblck

This comment has been minimized.

Copy link
Contributor

@ptrblck ptrblck commented Jan 7, 2020

@praateekmahajan
The first call might get bitwise accurate results, since the indexing might be a no-op (since it slices all elements), but I'm not sure about it.

Also, the general assumption that

f(batch)[:k] == f(batch[:k])

is always true, is not correct. Take batch norm layers as an example.
Since the normalization (during training or with track_running_stats=False) depends on the batch size, you won't get the same results.

@praateekmahajan

This comment has been minimized.

Copy link
Author

@praateekmahajan praateekmahajan commented Jan 7, 2020

Thanks @ptrblck, that's a good mention about batch norm layers. Didn't think about them.

Not sure what you mean by

since the indexing might be a no-op (since it slices all elements)

Thank you also for your suggestion on the torch.allclose() function.

It's still not entirely clear to me that we have some precision errors when using comparing a set with a subset of another set, and no errors when comparing two complete sets.

Looks like when we compare two complete sets, there should be a higher chance of precision error, because there are

more floats involved == i.e higher chance of encountering a floating point precision error

@zou3519

This comment has been minimized.

Copy link
Contributor

@zou3519 zou3519 commented Jan 9, 2020

In the future, please ask questions on our forums, discuss.pytorch.org. We like to keep our issue tracker filled with issues and feature requests.

@zou3519 zou3519 closed this Jan 9, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Linked pull requests

Successfully merging a pull request may close this issue.

None yet
4 participants
You can’t perform that action at this time.