Skip to content

Commit

Permalink
fix hpu storage serialization (#101680)
Browse files Browse the repository at this point in the history
Change-Id: Ia534400a0e8972590374eceba5b62a2525b796e5

Fixes #ISSUE_NUMBER

Pull Request resolved: #101680
Approved by: https://github.com/mikaylagawarecki
  • Loading branch information
ppiskorski authored and pytorchmergebot committed Jun 21, 2023
1 parent 9590228 commit 7fb2a92
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 1 deletion.
32 changes: 32 additions & 0 deletions torch/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,38 @@ def _type(self, dtype=None, non_blocking=False, **kwargs):
return dtype(self.size()).copy_(self, non_blocking)


def _hpu(self, device=None, non_blocking=False, **kwargs):
"""Returns a copy of this object in HPU memory.
If this object is already in HPU memory and on the correct device, then
no copy is performed and the original object is returned.
Args:
device (int): The destination HPU id. Defaults to the current device.
non_blocking (bool): If ``True`` and the source is in pinned memory,
the copy will be asynchronous with respect to the host. Otherwise,
the argument has no effect.
**kwargs: For compatibility, may contain the key ``async`` in place of
the ``non_blocking`` argument.
"""
non_blocking = _get_async_or_non_blocking("hpu", non_blocking, kwargs)
hpu = getattr(torch, "hpu", None)
assert hpu is not None, "HPU device module is not loaded"
if self.is_hpu:
if device is None:
device = hpu.current_device()
if self.get_device() == device:
return self
else:
if device is None:
device = -1
with hpu.device(device):
assert not self.is_sparse, "sparse storage is not supported for HPU tensors"
untyped_storage = torch.UntypedStorage(self.size(), device=torch.device("hpu"))
untyped_storage.copy_(self, non_blocking)
return untyped_storage


def _cuda(self, device=None, non_blocking=False, **kwargs):
"""Returns a copy of this object in CUDA memory.
Expand Down
37 changes: 37 additions & 0 deletions torch/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
'register_package',
'check_module_version_greater_or_equal',
'validate_cuda_device',
'validate_hpu_device',
'location_tag',
'default_restore_location',
'normalize_storage_type',
Expand Down Expand Up @@ -145,6 +146,9 @@ def _cuda_tag(obj):
if obj.device.type == 'cuda':
return 'cuda:' + str(obj.device.index)

def _hpu_tag(obj):
if obj.device.type == 'hpu':
return 'hpu:' + str(obj.device.index)

def _mps_tag(obj):
if obj.device.type == 'mps':
Expand Down Expand Up @@ -198,6 +202,38 @@ def _cuda_deserialize(obj, location):
return obj.cuda(device)


def validate_hpu_device(location):
hpu = getattr(torch, "hpu", None)
assert hpu is not None, "HPU device module is not loaded"
device = hpu._utils._get_device_index(location, optional=True)

if not hpu.is_available():
raise RuntimeError('Attempting to deserialize object on a HPU '
'device but torch.hpu.is_available() is False. '
'If you are running on a CPU-only machine, '
'please use torch.load with map_location=torch.device(\'cpu\') '
'to map your storages to the CPU.')
device_count = hpu.device_count()
if device >= device_count:
raise RuntimeError('Attempting to deserialize object on HPU device '
f'{device} but torch.hpu.device_count() is {device_count}. Please use '
'torch.load with map_location to map your storages '
'to an existing device.')
return device


def _hpu_deserialize(obj, location):
hpu = getattr(torch, "hpu", None)
assert hpu is not None, "HPU device module is not loaded"
if location.startswith('hpu'):
device = validate_hpu_device(location)
if getattr(obj, "_torch_load_uninitialized", False):
with hpu.device(device):
return torch.UntypedStorage(obj.nbytes(), device=torch.device(location))
else:
return obj.hpu(device)


def _mps_deserialize(obj, location):
if location.startswith('mps'):
return obj.mps()
Expand Down Expand Up @@ -251,6 +287,7 @@ def _privateuse1_deserialize(obj, location):
register_package(21, _mps_tag, _mps_deserialize)
register_package(22, _meta_tag, _meta_deserialize)
register_package(23, _privateuse1_tag, _privateuse1_deserialize)
register_package(24, _hpu_tag, _hpu_deserialize)


def location_tag(storage: Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage]):
Expand Down
23 changes: 22 additions & 1 deletion torch/storage.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import io

import torch
from ._utils import _type, _cuda
from ._utils import _type, _cuda, _hpu
from torch.types import Storage
from typing import Any, TypeVar, Type, Union, cast, Dict as _Dict
import copy
Expand Down Expand Up @@ -39,6 +39,7 @@ def size(self) -> int:

def type(self, dtype: str = None, non_blocking: bool = False) -> T: ... # noqa: E704
def cuda(self, device=None, non_blocking=False, **kwargs) -> T: ... # noqa: E704
def hpu(self, device=None, non_blocking=False, **kwargs) -> T: ... # noqa: E704
def element_size(self) -> int: ... # noqa: E704

def get_device(self) -> int:
Expand Down Expand Up @@ -76,6 +77,8 @@ def _shared_incref(self, *args, **kwargs): ... # noqa: E704
def _free_weak_ref(cls, *args, **kwargs): ... # noqa: E704
@property
def is_cuda(self): ... # noqa: E704
@property
def is_hpu(self): ... # noqa: E704
@classmethod
def from_file(cls, filename, shared, nbytes) -> T: ... # noqa: E704
@classmethod
Expand Down Expand Up @@ -314,6 +317,10 @@ def __getitem__(self, *args, **kwargs):
def is_cuda(self):
return self.device.type == 'cuda'

@property
def is_hpu(self):
return self.device.type == 'hpu'

@_share_memory_lock_protected
def share_memory_(self, *args, **kwargs):
return super().share_memory_(*args, **kwargs)
Expand All @@ -332,6 +339,7 @@ def _load_from_bytes(b):

_StorageBase.type = _type # type: ignore[assignment]
_StorageBase.cuda = _cuda # type: ignore[assignment]
_StorageBase.hpu = _hpu # type: ignore[assignment]


@lru_cache(maxsize=None)
Expand Down Expand Up @@ -592,6 +600,11 @@ def is_cuda(self):
_warn_typed_storage_removal()
return self._untyped_storage.device.type == 'cuda'

@property
def is_hpu(self):
_warn_typed_storage_removal()
return self._untyped_storage.device.type == 'hpu'

def untyped(self):
"""Returns the internal :class:`torch.UntypedStorage`"""
_warn_typed_storage_removal()
Expand Down Expand Up @@ -735,6 +748,13 @@ def cuda(self, device=None, non_blocking=False, **kwargs) -> T:
cuda_storage: torch.UntypedStorage = self._untyped_storage.cuda(device, non_blocking, **kwargs)
return self._new_wrapped_storage(cuda_storage)

def hpu(self, device=None, non_blocking=False, **kwargs) -> T:
_warn_typed_storage_removal()
if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
raise RuntimeError("Cannot create HPU storage with quantized dtype")
hpu_storage: torch.UntypedStorage = self._untyped_storage.hpu(device, non_blocking, **kwargs)
return self._new_wrapped_storage(hpu_storage)

def element_size(self):
_warn_typed_storage_removal()
return self._element_size()
Expand Down Expand Up @@ -1101,6 +1121,7 @@ def _get_legacy_storage_class(self):

TypedStorage.type.__doc__ = _type.__doc__
TypedStorage.cuda.__doc__ = _cuda.__doc__
TypedStorage.hpu.__doc__ = _hpu.__doc__

class _LegacyStorageMeta(type):
dtype: torch.dtype
Expand Down

0 comments on commit 7fb2a92

Please sign in to comment.