Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch config extraction #1323

Merged
merged 7 commits into from Feb 20, 2024
Merged

Pytorch config extraction #1323

merged 7 commits into from Feb 20, 2024

Conversation

antimora
Copy link
Collaborator

@antimora antimora commented Feb 18, 2024

This PR enhances PyTorch module to allow to extract model configuration. For example, Whisper model contains its model configuration as part of .pt file. Instead of using python, one can use the deserializer in burn-import to populate a struct with extracted values.

BTW, this is a common pattern in other exported files. For example, Safetensors and gguf file formats contain metadata (information about the model) that is critical in reconstructing a model.

This is how it would look with Burn:

use std::collections::HashMap;

use burn::config::Config;
use burn_import::pytorch::config_from_file;

#[derive(Debug, Config)]
struct NetConfig {
    n_head: usize,
    n_layer: usize,
    d_model: usize,
    // Candle's pickle has a bug with float serialization
    // https://github.com/huggingface/candle/issues/1729
    // some_float: f64,
    some_int: i32,
    some_bool: bool,
    some_str: String,
    some_list_int: Vec<i32>,
    some_list_str: Vec<String>,
    // Candle's pickle has a bug with float serialization
    // https://github.com/huggingface/candle/issues/1729
    // some_list_float: Vec<f64>,
    some_dict: HashMap<String, String>,
}

fn main() {
    let path = "weights_with_config.pt";
    let top_level_key = Some("my_config");
    let config: NetConfig = config_from_file(path, top_level_key).unwrap();
    println!("{:#?}", config);

    // After extracting, it's recommended you save it as a json file.
    config.save("my_config.json").unwrap();
}

Pull Request Template

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Changes

Enhanced pytorch module to support for config extraction.

Testing

  1. Added new unit test
  2. Tested on Whisper model

Copy link

codecov bot commented Feb 18, 2024

Codecov Report

Attention: 29 lines in your changes are missing coverage. Please review.

Comparison is base (9df2071) 84.41% compared to head (af7b21f) 84.49%.
Report is 47 commits behind head on main.

Files Patch % Lines
burn-core/src/record/serde/data.rs 4.54% 21 Missing ⚠️
burn-import/src/pytorch/config.rs 88.57% 8 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1323      +/-   ##
==========================================
+ Coverage   84.41%   84.49%   +0.08%     
==========================================
  Files         549      571      +22     
  Lines       61952    63818    +1866     
==========================================
+ Hits        52295    53923    +1628     
- Misses       9657     9895     +238     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@antimora
Copy link
Collaborator Author

antimora commented Feb 19, 2024

Floating number bug has been fixed by Candle team: huggingface/candle#1729

We can come back and fix the test after the fix is released.

Copy link
Member

@nathanielsimard nathanielsimard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM just a minor comment.

burn-import/src/pytorch/config.rs Outdated Show resolved Hide resolved
@antimora antimora merged commit e9bb273 into tracel-ai:main Feb 20, 2024
13 of 14 checks passed
@antimora antimora deleted the pytorch-config branch February 21, 2024 21:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Enhance existing features
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants