Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Float8 types serializeable #114662

Closed
wants to merge 3 commits into from
Closed

Conversation

malfet
Copy link
Contributor

@malfet malfet commented Nov 28, 2023

By finally breaking FC promise on new dtypes by serializing untyped
storage and tensor dtypes

  • Add _rebuild_tensor_v3 that takes an extra dtype argument
  • In Tensor.__reduce_ex__ serialize tensor using untyped storage for
    v3_dtypes (which are at the moment limited to float8 dtypes)

Test plan: python -c "import torch;x=torch.arange(10).to(dtype=torch.float8_e4m3fn);torch.save(x, 'pt.pt');print(torch.load('pt.pt'))"

Fixes #114634

Copy link

pytorch-bot bot commented Nov 28, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/114662

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 79e4b18 with merge base b6a30bb (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@malfet malfet requested a review from ngimel November 28, 2023 02:10
torch/_tensor.py Outdated
# need to wrap with TypedStorage
args = (
torch.storage.TypedStorage(
v3_dtypes = [torch.float8_e5m2, torch.float8_e4m3fn]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also uz dtypes?

Copy link
Collaborator

@ngimel ngimel Nov 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably makes sense to serialize all tensors whose dtype is not in the dict like this. No bc breaking, as those tensors couldn't have been serialized at all.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, so do the inverse. Makes sense, will do in follow up PR (and probably past the branch cut, as it feels a bit risky)

@malfet malfet added enhancement Not as big of a feature, but technically not a bug. Should be easy to fix release notes: python_frontend release notes category labels Nov 28, 2023
By finally breaking FC promise on new dtypes by serializing untyped
storage and tensor dtypes

- Add `_rebuild_tensor_v3` that takes an extra dtype argument
- In `Tensor.__reduce_ex__` serialize tensor using untyped storage for
  v3_dtypes (which are at the moment limited to float8 dtypes)

Test plan: `python -c "import torch;x=torch.arange(10).to(dtype=torch.float8_e4m3fn);torch.save(x, 'pt.pt');print(torch.load('pt.pt'))"`
@malfet
Copy link
Contributor Author

malfet commented Nov 29, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 29, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@malfet malfet deleted the malfet/make-f8-serializable branch November 30, 2023 22:17
hyperfraise pushed a commit to hyperfraise/pytorch that referenced this pull request Dec 21, 2023
By finally breaking FC promise on new dtypes by serializing untyped
storage and tensor dtypes

- Add `_rebuild_tensor_v3` that takes an extra dtype argument
- In `Tensor.__reduce_ex__` serialize tensor using untyped storage for
  v3_dtypes (which are at the moment limited to float8 dtypes)

Test plan: `python -c "import torch;x=torch.arange(10).to(dtype=torch.float8_e4m3fn);torch.save(x, 'pt.pt');print(torch.load('pt.pt'))"`

Fixes pytorch#114634

Pull Request resolved: pytorch#114662
Approved by: https://github.com/ngimel
hyperfraise pushed a commit to hyperfraise/pytorch that referenced this pull request Dec 21, 2023
By finally breaking FC promise on new dtypes by serializing untyped
storage and tensor dtypes

- Add `_rebuild_tensor_v3` that takes an extra dtype argument
- In `Tensor.__reduce_ex__` serialize tensor using untyped storage for
  v3_dtypes (which are at the moment limited to float8 dtypes)

Test plan: `python -c "import torch;x=torch.arange(10).to(dtype=torch.float8_e4m3fn);torch.save(x, 'pt.pt');print(torch.load('pt.pt'))"`

Fixes pytorch#114634

Pull Request resolved: pytorch#114662
Approved by: https://github.com/ngimel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request enhancement Not as big of a feature, but technically not a bug. Should be easy to fix Merged release notes: python_frontend release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

float8 and bits tensors cannot be serialized
3 participants