New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make Float8 types serializeable #114662
Make Float8 types serializeable #114662
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/114662
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 79e4b18 with merge base b6a30bb (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torch/_tensor.py
Outdated
# need to wrap with TypedStorage | ||
args = ( | ||
torch.storage.TypedStorage( | ||
v3_dtypes = [torch.float8_e5m2, torch.float8_e4m3fn] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also uz
dtypes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably makes sense to serialize all tensors whose dtype is not in the dict like this. No bc breaking, as those tensors couldn't have been serialized at all.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, so do the inverse. Makes sense, will do in follow up PR (and probably past the branch cut, as it feels a bit risky)
By finally breaking FC promise on new dtypes by serializing untyped storage and tensor dtypes - Add `_rebuild_tensor_v3` that takes an extra dtype argument - In `Tensor.__reduce_ex__` serialize tensor using untyped storage for v3_dtypes (which are at the moment limited to float8 dtypes) Test plan: `python -c "import torch;x=torch.arange(10).to(dtype=torch.float8_e4m3fn);torch.save(x, 'pt.pt');print(torch.load('pt.pt'))"`
c48da22
to
79e4b18
Compare
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
By finally breaking FC promise on new dtypes by serializing untyped storage and tensor dtypes - Add `_rebuild_tensor_v3` that takes an extra dtype argument - In `Tensor.__reduce_ex__` serialize tensor using untyped storage for v3_dtypes (which are at the moment limited to float8 dtypes) Test plan: `python -c "import torch;x=torch.arange(10).to(dtype=torch.float8_e4m3fn);torch.save(x, 'pt.pt');print(torch.load('pt.pt'))"` Fixes pytorch#114634 Pull Request resolved: pytorch#114662 Approved by: https://github.com/ngimel
By finally breaking FC promise on new dtypes by serializing untyped storage and tensor dtypes - Add `_rebuild_tensor_v3` that takes an extra dtype argument - In `Tensor.__reduce_ex__` serialize tensor using untyped storage for v3_dtypes (which are at the moment limited to float8 dtypes) Test plan: `python -c "import torch;x=torch.arange(10).to(dtype=torch.float8_e4m3fn);torch.save(x, 'pt.pt');print(torch.load('pt.pt'))"` Fixes pytorch#114634 Pull Request resolved: pytorch#114662 Approved by: https://github.com/ngimel
By finally breaking FC promise on new dtypes by serializing untyped
storage and tensor dtypes
_rebuild_tensor_v3
that takes an extra dtype argumentTensor.__reduce_ex__
serialize tensor using untyped storage forv3_dtypes (which are at the moment limited to float8 dtypes)
Test plan:
python -c "import torch;x=torch.arange(10).to(dtype=torch.float8_e4m3fn);torch.save(x, 'pt.pt');print(torch.load('pt.pt'))"
Fixes #114634