You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I encountered an issue (and a fix) about loading data with a PyTorch DataLoader, when used with JAX (and I think Equinox). I am not sure this belongs exactly here so please feel free to tell me and I will move it somewhere else. I also mention #137, since I feel this is related.
The setup:
I train a small MLP on a classification task on MNIST. I use the data loader given in the JAX documentation(https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html) notably the NumpyLoader and the associated collate function. When I run the training script I get incoherent losses and accuracies. if I use the standard PyTorch DataLoader, I do not face the issue.
I link two files (one failing and one passing). If anyone has an idea on why it does not work, I would love to know. I hope this can also help others since I have been looking into it for 2 days now.
I think I can pinpoint the location of the problem a bit more accurately. I replaced the PyTorch DataLoader and Dataset with custom classes. This is to make sure that there was no PyTorch mechanism messing up with JAX/Equinox during training. It's a very simple code, provided in the link below.
Additionally, I provide a second file that trains an MLP on MNIST. The training will:
fail if we extract the data from the PyTorch MNIST dataset (uncomment line 57-59)
pass if we extract the data from the PyTorch DataLoader (uncomment line 55-56, comment out 57-59)
Note that even though I extract the data from PyTorch Dataset/DataLoader I still train using my custom Dataset/DataLoader.
I suggest against numpy collate in general, instead 'tree_map(lambda tensor: tensor.numpy(), batch)'. The reason is torch tensors have special treatment when passed around multiple processes whereas numpy arrays get the standard serialize/deserialize treatment, resulting in a big performance hit last I checked. Maybe this also makes your bug go away?
Hello,
I encountered an issue (and a fix) about loading data with a PyTorch DataLoader, when used with JAX (and I think Equinox). I am not sure this belongs exactly here so please feel free to tell me and I will move it somewhere else. I also mention #137, since I feel this is related.
The setup:
I train a small MLP on a classification task on MNIST. I use the data loader given in the JAX documentation(https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html) notably the
NumpyLoader
and the associatedcollate
function. When I run the training script I get incoherent losses and accuracies. if I use the standard PyTorch DataLoader, I do not face the issue.I link two files (one failing and one passing). If anyone has an idea on why it does not work, I would love to know. I hope this can also help others since I have been looking into it for 2 days now.
Thank you for any input
link to file : https://gist.github.com/pablo2909/3a2cec869a43421859520750990f263e
The text was updated successfully, but these errors were encountered: