-
Notifications
You must be signed in to change notification settings - Fork 444
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
State serialization/deserialization overhaul #247
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wow. That was a quick with lots of awesome features.
I found a minor fix in the readme doc. Basically additional updates.
[`bincode`](https://github.com/bincode-org/bincode) (for compactness) and included as part of the | ||
final wasm output. The MNIST model is initialized with trained weights from memory during the | ||
runtime. | ||
|
||
The inference API for JavaScript is exposed with the help of | ||
[`wasm-bindgen`](https://github.com/rustwasm/wasm-bindgen)'s library and tools. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Under the future improvements you can now remove #201 and #202
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably you can also remove references to the wasm file since it is smaller, so that we don't have to keep updating it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The state is not saved with f16 anymore, since it doesn't compile correctly without std, so I'll keep the link to issue 202 as potential improvement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably you can also remove references to the wasm file since it is smaller, so that we don't have to keep updating it.
I'm not sure what you are refering to.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I meant in the readme file, the document talks about the file size of the wasm output under comparison section (e.g. 1,509,747 bytes).
It's okay. I'll update the readme file not to talk about the file sizes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's the same size because it's still in f32... for now :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. I had a few questions regarding f16 changes. Please see if we need to modify the code more.
[`bincode`](https://github.com/bincode-org/bincode) (for compactness) and included as part of the | ||
final wasm output. The MNIST model is initialized with trained weights from memory during the | ||
runtime. | ||
|
||
The inference API for JavaScript is exposed with the help of | ||
[`wasm-bindgen`](https://github.com/rustwasm/wasm-bindgen)'s library and tools. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I meant in the readme file, the document talks about the file size of the wasm output under comparison section (e.g. 1,509,747 bytes).
It's okay. I'll update the readme file not to talk about the file sizes.
@@ -64,14 +65,21 @@ pub fn run<B: ADBackend>(device: B::Device) { | |||
.metric_valid_plot(AccuracyMetric::new()) | |||
.metric_train_plot(LossMetric::new()) | |||
.metric_valid_plot(LossMetric::new()) | |||
.with_file_checkpointer::<f32>(2) | |||
.with_file_checkpointer::<burn::tensor::f16>(2, StateFormat::default()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it still f16, since you reverted back your changes about f16.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The checkpoint during training are f16 with compressed bin format (bin.gz) but we now save the final model somewhere else, we don't use the checkpoints.
@@ -78,17 +78,18 @@ pub fn train<B: ADBackend, D: TextClassificationDataset + 'static>( | |||
.metric_valid(AccuracyMetric::new()) | |||
.metric_train_plot(LossMetric::new()) | |||
.metric_valid_plot(LossMetric::new()) | |||
.with_file_checkpointer::<f32>(2) | |||
.with_file_checkpointer::<burn::tensor::f16>(2, StateFormat::default()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need to revert back this change or does it work because serialization works for std?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
serialization works with std
fix #202 I saved models with f16, but It's also possible to save weights with bf16. I don't expect much difference in performance for both.
fix #201 by default we save states in compressed bincode, which is extremelly small :)