-
Notifications
You must be signed in to change notification settings - Fork 248
Support Huggingface models from safetensors #1249
Description
🚀 The feature, motivation and pitch
There are many models on Huggingface that are published as safetensors rather than model.pth checkpoints. The request here is to support converting and loading those checkpoints into a format that is usable with torchchat.
There are several places where this limitation is currently enforced:
- _download_hf_snapshot method explicitly ignores
safetensorsfiles. - convert_hf_checkpoint explicitly looks for
pytorch_model.bin.index.jsonwhich would be named differently for models that usesafetensors(e.g.model.safetensors.index.json) - convert_hf_checkpoint only supports
torch.loadto load thestate_dictrather thansafetensors.torch.load
Alternatives
Currently, this safetensors -> model.pth can be accomplished manually after downloading a model locally, so this could be solved with documentation instead of code.
Additional context
This issue is a piece of the puzzle for adding support for Granite Code 3b/8b which use the llama architecture in transormers, but take advantage several pieces of the architecture that are not currently supported by torchchat. The work-in-progress for Granite Code can be found on my fork: https://github.com/gabe-l-hart/torchchat/tree/GraniteCodeSupport
RFC (Optional)
I have a working implementation to support safetensors during download and conversion that I plan to submit as a PR. The changes address the three points in code referenced above:
- Allow the download of
safetensorsfiles in_download_hf_snapshot- I'm not yet sure how to avoid double-downloading weights for models that have both
safetensorsandmodel.pth, so will look to solve this before concluding the work
- I'm not yet sure how to avoid double-downloading weights for models that have both
- When looking for the tensor index file, search for all files ending in
.index.json, and if a single file is found, use that one - When loading the
state_dict, use the correct method based on the type of file (torch.loadorsafetensors.torch.load)