Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Script to convert weight from Jax to PyTorch #1601

Closed
yazdanbakhsh opened this issue Dec 23, 2022 · 6 comments
Closed

[FEATURE] Script to convert weight from Jax to PyTorch #1601

yazdanbakhsh opened this issue Dec 23, 2022 · 6 comments
Labels
enhancement New feature or request

Comments

@yazdanbakhsh
Copy link

Is your feature request related to a problem? Please describe.
I am trying to create multiple checkpoints of ViT at different iterations. Are there any systematic way to perform such conversion?

Describe the solution you'd like
I would like to be able to convert JAX ViT model to a PyTorch model, similar to this model (https://huggingface.co/google/vit-base-patch16-224)

Describe alternatives you've considered
I have tried to start pre-training HF models on A100 but so far was not successful to reach to same accuracy.

@yazdanbakhsh yazdanbakhsh added the enhancement New feature or request label Dec 23, 2022
@rwightman
Copy link
Collaborator

rwightman commented Dec 23, 2022

@yazdanbakhsh loading jax .npz checkpoints is integrated into the model for original Google jax implementations (big vision support being merged today)

However, I don't have support for the Hugging Face ViT models. Usually people go from timm -> Transformers, not the other way around. I honestly wouldn't recommend pretraining from scratch on the Transformers model, I don't think it's been well tested outside of pretrained. Other people have reported this. timm is better tested for training from random weights.

@yazdanbakhsh
Copy link
Author

@rwightman Thanks for the explanation. Which script do you suggest to start with for distributed training of ViT using timm?

@rwightman
Copy link
Collaborator

@yazdanbakhsh just dumped some hparams, but as usual, they need adapting to your scenario / specific network https://gist.github.com/rwightman/943c0fe59293b44024bbd2d5d23e6303

@yazdanbakhsh
Copy link
Author

@rwightman Thanks again for providing the details. It is much more robust for training vision models. I could finally reach an accuracy of 76% for ViT-B/16. With this training script. I also made a change to include cutout as part of the randaug to be more consistent with the BigVision repo.

Wondering if you have any other suggestions for the configs to change to reach to a reasonable accuracy as the HF model?

@yazdanbakhsh
Copy link
Author

The obvious changes that I want to make are as follows (based on bigvision)

  1. learning rate: 0.001
  2. weight decay: 0.0001
  3. warup: 15 epochs (which with my step 632 becomes around 10K steps)

@yazdanbakhsh
Copy link
Author

yazdanbakhsh commented Dec 30, 2022

@rwightman do you have the yaml scripts for ResNet, MobileNet, and other models? Would it be possible to share them? I have tried to use the new JSD loss but it gives me error.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants