You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Further enhancing Resnet-Burn model, which was recently added.
Come up with a general requirement and solution to the models added to models repo.
Now that we are adding popular models to the burn-model repo, we should consider the end user experience and come up with some basis top level requirements of what is expected when a user adopts/uses migrated model. This can evolve into a standard across other modes.
Here is my proposal:
Each model should offer an automatic weights download from a known source. The source can be overwritten if needed. We should offer in a library form and binary executable under bin folder. The destination can be defaulted to some cache location or specified by a user.
If the source file is non-burn format, we convert the file and the subsequent loading uses burn native file.
(Optional) Converted file is uploaded to HuggingFace portal under Burn organization.
The text was updated successfully, but these errors were encountered:
Funny you mention that, I was just working on adding automatic loading of pre-trained weights to the ResNet models 😄 So great timing!
Since I haven't pushed any of my changes yet (PR should come soon), I'll summarize the way I am currently approaching this.
By default, the models support no_std and I've added a pretrained feature flag that requires std and adds optional dependencies such as burn-import crate to use the PyTorchFileRecorder and burn/network (new since this PR) to use the download_file_as_bytes function with a download progress bar.
Regarding your specific points:
For storing the downloaded weights, right now I followed the default pattern I observed in burn: put them in the ~/.cache directory under the model name (e.g., ~/.cache/resnet-burn).
For loading the weights I currently added resnet*_pretrained methods that do exactly as you described: download the .pth checkpoint and use the PyTorchFileRecorder to load them.
Haven't done anything in that regard yet, but we briefly talked about something like that with @nathanielsimard
This ticket is a two fold request:
Now that we are adding popular models to the burn-model repo, we should consider the end user experience and come up with some basis top level requirements of what is expected when a user adopts/uses migrated model. This can evolve into a standard across other modes.
Here is my proposal:
The text was updated successfully, but these errors were encountered: