Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

State serialization/deserialization overhaul #247

Merged
merged 11 commits into from
Mar 23, 2023
Merged

Conversation

nathanielsimard
Copy link
Member

@nathanielsimard nathanielsimard commented Mar 22, 2023

fix #202 I saved models with f16, but It's also possible to save weights with bf16. I don't expect much difference in performance for both.

fix #201 by default we save states in compressed bincode, which is extremelly small :)

@nathanielsimard nathanielsimard changed the title Feat/serde formats State serialization/deserialization overhaul Mar 22, 2023
Copy link
Collaborator

@antimora antimora left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow. That was a quick with lots of awesome features.

I found a minor fix in the readme doc. Basically additional updates.

[`bincode`](https://github.com/bincode-org/bincode) (for compactness) and included as part of the
final wasm output. The MNIST model is initialized with trained weights from memory during the
runtime.

The inference API for JavaScript is exposed with the help of
[`wasm-bindgen`](https://github.com/rustwasm/wasm-bindgen)'s library and tools.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Under the future improvements you can now remove #201 and #202

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably you can also remove references to the wasm file since it is smaller, so that we don't have to keep updating it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The state is not saved with f16 anymore, since it doesn't compile correctly without std, so I'll keep the link to issue 202 as potential improvement.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably you can also remove references to the wasm file since it is smaller, so that we don't have to keep updating it.

I'm not sure what you are refering to.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I meant in the readme file, the document talks about the file size of the wasm output under comparison section (e.g. 1,509,747 bytes).

It's okay. I'll update the readme file not to talk about the file sizes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the same size because it's still in f32... for now :)

Copy link
Collaborator

@antimora antimora left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. I had a few questions regarding f16 changes. Please see if we need to modify the code more.

[`bincode`](https://github.com/bincode-org/bincode) (for compactness) and included as part of the
final wasm output. The MNIST model is initialized with trained weights from memory during the
runtime.

The inference API for JavaScript is exposed with the help of
[`wasm-bindgen`](https://github.com/rustwasm/wasm-bindgen)'s library and tools.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I meant in the readme file, the document talks about the file size of the wasm output under comparison section (e.g. 1,509,747 bytes).

It's okay. I'll update the readme file not to talk about the file sizes.

@@ -64,14 +65,21 @@ pub fn run<B: ADBackend>(device: B::Device) {
.metric_valid_plot(AccuracyMetric::new())
.metric_train_plot(LossMetric::new())
.metric_valid_plot(LossMetric::new())
.with_file_checkpointer::<f32>(2)
.with_file_checkpointer::<burn::tensor::f16>(2, StateFormat::default())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it still f16, since you reverted back your changes about f16.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The checkpoint during training are f16 with compressed bin format (bin.gz) but we now save the final model somewhere else, we don't use the checkpoints.

@@ -78,17 +78,18 @@ pub fn train<B: ADBackend, D: TextClassificationDataset + 'static>(
.metric_valid(AccuracyMetric::new())
.metric_train_plot(LossMetric::new())
.metric_valid_plot(LossMetric::new())
.with_file_checkpointer::<f32>(2)
.with_file_checkpointer::<burn::tensor::f16>(2, StateFormat::default())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to revert back this change or does it work because serialization works for std?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

serialization works with std

@nathanielsimard nathanielsimard merged commit 6f43d98 into main Mar 23, 2023
@nathanielsimard nathanielsimard deleted the feat/serde-formats branch March 23, 2023 15:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants