Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove *_devauto functions #1110

Merged
merged 3 commits into from Jan 6, 2024

Conversation

kpot
Copy link
Contributor

@kpot kpot commented Jan 4, 2024

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

This work is done as a part of #518, and implements this decision to remove all *_devauto functions that construct tensors without the need to explicitly specify a device.

Changes

All the tests, documentation examples etc. were updated.
I've also spotted and fixed a few remaining places that would still use a default device implicitly: one_hot, parameters in BatchNorm etc.

Sadly, one aspect I left untouched: de-serialization. It relies on this trait

/// Trait to define a family of types which can be recorded using any [settings](PrecisionSettings).
pub trait Record: Send + Sync {
    /// Type of the item that can be serialized and deserialized.
    type Item<S: PrecisionSettings>: Serialize + DeserializeOwned;

    /// Convert the current record into the corresponding item that follows the given [settings](PrecisionSettings).
    fn into_item<S: PrecisionSettings>(self) -> Self::Item<S>;

    /// Convert the given item into a record.
    fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self;
}

And currently I don't have a clear idea how to change it in a way that would make a de-serialization to an explicit device possible, while still remaining convenient for the #[derive(Record)] macro and non-tensor types like Vec.

So for now de-serialization is still going to be done on a default device first.

Testing

On my machine it passed all the tests I know of.

@kpot kpot mentioned this pull request Jan 4, 2024
@nathanielsimard
Copy link
Member

@kpot Thanks a lot! I'll do a proper review in the coming days. As for the serialization system, I think the best way will be to introduce a device somewhere in the trait, but this can be done in a following PR as well.

Copy link
Member

@nathanielsimard nathanielsimard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!

Just doing this refactor fixes some problems that we have, in mask generation, for instance, where the default device needed to be the same as the one passed as an argument. I'm pretty happy about the change, way harder to make a mistake and pretty clear where you need to pay attention to the device. Thanks a lot for this work.

We will wait for the CI to be fixed before merging, @syl20bnr.

@syl20bnr
Copy link
Member

syl20bnr commented Jan 5, 2024

@kpot CI is fixed in main, you can rebase this PR.

Copy link

codecov bot commented Jan 6, 2024

Codecov Report

Attention: 13 lines in your changes are missing coverage. Please review.

Comparison is base (fab344c) 85.52% compared to head (b743424) 85.70%.

Files Patch % Lines
burn-core/src/optim/rmsprop.rs 86.84% 5 Missing ⚠️
burn-core/src/nn/loss/cross_entropy.rs 94.73% 2 Missing ⚠️
burn-core/src/record/tensor.rs 33.33% 2 Missing ⚠️
burn-tensor/src/tests/ops/chunk.rs 81.81% 2 Missing ⚠️
burn-import/src/burn/node/unary.rs 0.00% 1 Missing ⚠️
burn-tensor/src/tests/ops/iter_dim.rs 90.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1110      +/-   ##
==========================================
+ Coverage   85.52%   85.70%   +0.18%     
==========================================
  Files         511      511              
  Lines       55079    55825     +746     
==========================================
+ Hits        47106    47845     +739     
- Misses       7973     7980       +7     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@kpot
Copy link
Contributor Author

kpot commented Jan 6, 2024

@kpot CI is fixed in main, you can rebase this PR.

Done

@nathanielsimard nathanielsimard merged commit 9729753 into tracel-ai:main Jan 6, 2024
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants