Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch DataLoading issue (with Equinox ?) #248

Closed
pablo2909 opened this issue Dec 19, 2022 · 3 comments
Closed

PyTorch DataLoading issue (with Equinox ?) #248

pablo2909 opened this issue Dec 19, 2022 · 3 comments
Labels
question User queries

Comments

@pablo2909
Copy link

Hello,

I encountered an issue (and a fix) about loading data with a PyTorch DataLoader, when used with JAX (and I think Equinox). I am not sure this belongs exactly here so please feel free to tell me and I will move it somewhere else. I also mention #137, since I feel this is related.

The setup:

I train a small MLP on a classification task on MNIST. I use the data loader given in the JAX documentation(https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html) notably the NumpyLoader and the associated collate function. When I run the training script I get incoherent losses and accuracies. if I use the standard PyTorch DataLoader, I do not face the issue.

I link two files (one failing and one passing). If anyone has an idea on why it does not work, I would love to know. I hope this can also help others since I have been looking into it for 2 days now.

Thank you for any input

link to file : https://gist.github.com/pablo2909/3a2cec869a43421859520750990f263e

@patrick-kidger patrick-kidger added the question User queries label Dec 20, 2022
@pablo2909
Copy link
Author

Update:

I think I can pinpoint the location of the problem a bit more accurately. I replaced the PyTorch DataLoader and Dataset with custom classes. This is to make sure that there was no PyTorch mechanism messing up with JAX/Equinox during training. It's a very simple code, provided in the link below.
Additionally, I provide a second file that trains an MLP on MNIST. The training will:

  • fail if we extract the data from the PyTorch MNIST dataset (uncomment line 57-59)
  • pass if we extract the data from the PyTorch DataLoader (uncomment line 55-56, comment out 57-59)

Note that even though I extract the data from PyTorch Dataset/DataLoader I still train using my custom Dataset/DataLoader.

https://gist.github.com/pablo2909/91127b9c7cb441b3b897bbebd9c0eff1

Thank you for any input

PS: Apologies for the length of messages and code.

@jatentaki
Copy link
Contributor

I suggest against numpy collate in general, instead 'tree_map(lambda tensor: tensor.numpy(), batch)'. The reason is torch tensors have special treatment when passed around multiple processes whereas numpy arrays get the standard serialize/deserialize treatment, resulting in a big performance hit last I checked. Maybe this also makes your bug go away?

@pablo2909
Copy link
Author

Sorry I failed to reply to that and close the issue. I can't recall exactly what was the issue but I ended up doing that and it fixed it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

3 participants