Skip to content

Commit

Permalink
xpu: implement xpu serialization (#125530)
Browse files Browse the repository at this point in the history
Fixes: #125529

BC-breaking note:
The deprecated "async" argument to the Storage.cuda and Storage.hpu has been removed. Use non_blocking instead.

CC: @jbschlosser, @frank-wei @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @albanD

Pull Request resolved: #125530
Approved by: https://github.com/guangyey, https://github.com/albanD
  • Loading branch information
dvrogozh authored and ZelboK committed May 19, 2024
1 parent b24a9e3 commit a2e563d
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 170 deletions.
40 changes: 27 additions & 13 deletions test/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,15 @@ def test_serialization_map_location(self):
def map_location(storage, loc):
return storage

def generate_map_locations(device_type):
return [
{'cuda:0': device_type + ':0'},
device_type,
device_type + ':0',
torch.device(device_type),
torch.device(device_type, 0)
]

def load_bytes():
with open(test_file_path, 'rb') as f:
return io.BytesIO(f.read())
Expand All @@ -504,34 +513,39 @@ def load_bytes():
'cpu',
torch.device('cpu'),
]
gpu_0_map_locations = [
{'cuda:0': 'cuda:0'},
'cuda',
'cuda:0',
torch.device('cuda'),
torch.device('cuda', 0)
]
gpu_0_map_locations = generate_map_locations('cuda')
gpu_last_map_locations = [
f'cuda:{torch.cuda.device_count() - 1}',
]
xpu_0_map_locations = generate_map_locations('xpu')
xpu_last_map_locations = [
f'xpu:{torch.xpu.device_count() - 1}',
]

def check_map_locations(map_locations, tensor_class, intended_device):
def check_map_locations(map_locations, dtype, intended_device):
for fileobject_lambda in fileobject_lambdas:
for map_location in map_locations:
tensor = torch.load(fileobject_lambda(), map_location=map_location)

self.assertEqual(tensor.device, intended_device)
self.assertIsInstance(tensor, tensor_class)
self.assertEqual(tensor, tensor_class([[1.0, 2.0], [3.0, 4.0]]))
self.assertEqual(tensor.dtype, dtype)
self.assertEqual(tensor, torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=dtype, device=intended_device))

check_map_locations(cpu_map_locations, torch.FloatTensor, torch.device('cpu'))
check_map_locations(cpu_map_locations, torch.float, torch.device('cpu'))
if torch.cuda.is_available():
check_map_locations(gpu_0_map_locations, torch.cuda.FloatTensor, torch.device('cuda', 0))
check_map_locations(gpu_0_map_locations, torch.float, torch.device('cuda', 0))
check_map_locations(
gpu_last_map_locations,
torch.cuda.FloatTensor,
torch.float,
torch.device('cuda', torch.cuda.device_count() - 1)
)
if torch.xpu.is_available():
check_map_locations(xpu_0_map_locations, torch.float, torch.device('xpu', 0))
check_map_locations(
xpu_last_map_locations,
torch.float,
torch.device('xpu', torch.xpu.device_count() - 1)
)

@unittest.skipIf(torch.cuda.is_available(), "Testing torch.load on CPU-only machine")
def test_load_nonexistent_device(self):
Expand Down
35 changes: 35 additions & 0 deletions test/test_xpu.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Owner(s): ["module: intel"]

import sys
import tempfile
import unittest

import torch
Expand Down Expand Up @@ -270,6 +271,40 @@ def convert_boolean_tensors(x):

self.assertEqual(expect, actual)

def test_serialization_array_with_storage(self):
x = torch.randn(5, 5).xpu()
y = torch.zeros(2, 5, dtype=torch.int, device="xpu")
q = [x, y, x, y.storage()]
with tempfile.NamedTemporaryFile() as f:
torch.save(q, f)
f.seek(0)
q_copy = torch.load(f)
self.assertEqual(q_copy, q, atol=0, rtol=0)
q_copy[0].fill_(5)
self.assertEqual(q_copy[0], q_copy[2], atol=0, rtol=0)
self.assertEqual(q_copy[0].dtype, torch.float)
self.assertEqual(q_copy[1].dtype, torch.int)
self.assertEqual(q_copy[2].dtype, torch.float)
self.assertTrue(isinstance(q_copy[3], torch.storage.TypedStorage))
self.assertTrue(isinstance(q_copy[3]._untyped_storage, torch.UntypedStorage))
q_copy[1].fill_(10)
y.fill_(10)
self.assertEqual(q_copy[3], y.storage())

def test_serialization_array_with_empty(self):
x = [
torch.randn(4, 4).xpu(),
torch.tensor([], dtype=torch.float, device=torch.device("xpu")),
]
with tempfile.NamedTemporaryFile() as f:
torch.save(x, f)
f.seek(0)
x_copy = torch.load(f)
for original, copy in zip(x, x_copy):
self.assertEqual(copy, original)
self.assertIs(type(copy), type(original))
self.assertEqual(copy.get_device(), original.get_device())


instantiate_device_type_tests(TestXpu, globals(), only_for="xpu")

Expand Down
79 changes: 24 additions & 55 deletions torch/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,71 +52,40 @@ def _type(self, dtype=None, non_blocking=False, **kwargs):
return dtype(self.size()).copy_(self, non_blocking)


def _hpu(self, device=None, non_blocking=False, **kwargs):
"""Returns a copy of this object in HPU memory.
def _to(self, device, non_blocking=False):
"""Returns a copy of this object in device memory.
If this object is already in HPU memory and on the correct device, then
no copy is performed and the original object is returned.
If this object is already on the correct device, then no copy is performed
and the original object is returned.
Args:
device (int): The destination HPU id. Defaults to the current device.
device (int): The destination device.
non_blocking (bool): If ``True`` and the source is in pinned memory,
the copy will be asynchronous with respect to the host. Otherwise,
the argument has no effect.
**kwargs: For compatibility, may contain the key ``async`` in place of
the ``non_blocking`` argument.
"""
non_blocking = _get_async_or_non_blocking("hpu", non_blocking, kwargs)
hpu = getattr(torch, "hpu", None)
assert hpu is not None, "HPU device module is not loaded"
if self.is_hpu:
if device is None:
device = hpu.current_device()
if self.get_device() == device:
return self
else:
if device is None:
device = -1
with hpu.device(device):
assert not self.is_sparse, "sparse storage is not supported for HPU tensors"
untyped_storage = torch.UntypedStorage(self.size(), device=torch.device("hpu"))
untyped_storage.copy_(self, non_blocking)
return untyped_storage


def _cuda(self, device=None, non_blocking=False, **kwargs):
"""Returns a copy of this object in CUDA memory.
If this object is already in CUDA memory and on the correct device, then
no copy is performed and the original object is returned.
if self.device == device:
return self

Args:
device (int): The destination GPU id. Defaults to the current device.
non_blocking (bool): If ``True`` and the source is in pinned memory,
the copy will be asynchronous with respect to the host. Otherwise,
the argument has no effect.
**kwargs: For compatibility, may contain the key ``async`` in place of
the ``non_blocking`` argument.
"""
non_blocking = _get_async_or_non_blocking("cuda", non_blocking, kwargs)
if self.is_cuda:
if device is None:
device = torch.cuda.current_device()
if self.get_device() == device:
return self
else:
if device is None:
device = -1
with torch.cuda.device(device):
if self.is_sparse:
new_type = getattr(torch.cuda.sparse, self.__class__.__name__)
indices = torch.Tensor._indices(self).cuda(device, non_blocking)
values = torch.Tensor._values(self).cuda(device, non_blocking)
device_module = getattr(torch, device.type, None)
assert (
device_module is not None
), f"{device.type.upper()} device module is not loaded"
with device_module.device(device):
if self.is_sparse and hasattr(device_module, "sparse"):
new_type = getattr(device_module.sparse, self.__class__.__name__)
indices = getattr(torch.Tensor._indices(self), device.type)(
device, non_blocking
)
values = getattr(torch.Tensor._values(self), device.type)(
device, non_blocking
)
return new_type(indices, values, self.size())
else:
untyped_storage = torch.UntypedStorage(
self.size(), device=torch.device("cuda")
)
assert (
not self.is_sparse
), f"sparse storage is not supported for {device.type.upper()} tensors"
untyped_storage = torch.UntypedStorage(self.size(), device=device)
untyped_storage.copy_(self, non_blocking)
return untyped_storage

Expand Down
120 changes: 29 additions & 91 deletions torch/serialization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import difflib
import functools
import os
import io
import shutil
Expand Down Expand Up @@ -252,14 +253,6 @@ def _cpu_tag(obj):
return 'cpu'


def _cuda_tag(obj):
if obj.device.type == 'cuda':
return 'cuda:' + str(obj.device.index)

def _hpu_tag(obj):
if obj.device.type == 'hpu':
return 'hpu:' + str(obj.device.index)

def _mps_tag(obj):
if obj.device.type == 'mps':
return 'mps'
Expand All @@ -270,8 +263,9 @@ def _meta_tag(obj):
return 'meta'


def _privateuse1_tag(obj):
backend_name = torch._C._get_privateuse1_backend_name()
def _backend_tag(backend_name, obj):
if backend_name == 'privateuse1':
backend_name = torch._C._get_privateuse1_backend_name()
if obj.device.type == backend_name:
if obj.device.index is None:
return backend_name
Expand All @@ -284,66 +278,6 @@ def _cpu_deserialize(obj, location):
return obj


def validate_cuda_device(location):
device = torch.cuda._utils._get_device_index(location, True)

if not torch.cuda.is_available():
raise RuntimeError('Attempting to deserialize object on a CUDA '
'device but torch.cuda.is_available() is False. '
'If you are running on a CPU-only machine, '
'please use torch.load with map_location=torch.device(\'cpu\') '
'to map your storages to the CPU.')
device_count = torch.cuda.device_count()
if device >= device_count:
raise RuntimeError('Attempting to deserialize object on CUDA device '
f'{device} but torch.cuda.device_count() is {device_count}. Please use '
'torch.load with map_location to map your storages '
'to an existing device.')
return device


def _cuda_deserialize(obj, location):
if location.startswith('cuda'):
device = validate_cuda_device(location)
if getattr(obj, "_torch_load_uninitialized", False):
with torch.cuda.device(device):
return torch.UntypedStorage(obj.nbytes(), device=torch.device(location))
else:
return obj.cuda(device)


def validate_hpu_device(location):
hpu = getattr(torch, "hpu", None)
assert hpu is not None, "HPU device module is not loaded"
device = hpu._utils._get_device_index(location, optional=True)

if not hpu.is_available():
raise RuntimeError('Attempting to deserialize object on a HPU '
'device but torch.hpu.is_available() is False. '
'If you are running on a CPU-only machine, '
'please use torch.load with map_location=torch.device(\'cpu\') '
'to map your storages to the CPU.')
device_count = hpu.device_count()
if device >= device_count:
raise RuntimeError('Attempting to deserialize object on HPU device '
f'{device} but torch.hpu.device_count() is {device_count}. Please use '
'torch.load with map_location to map your storages '
'to an existing device.')
return device


def _hpu_deserialize(obj, location):
if location.startswith('hpu'):
hpu = getattr(torch, "hpu", None)
assert hpu is not None, "HPU device module is not loaded"
device = validate_hpu_device(location)
if getattr(obj, "_torch_load_uninitialized", False):
with hpu.device(device):
return torch.UntypedStorage(obj.nbytes(), device=torch.device(location))
else:
return obj.hpu(device)


def _mps_deserialize(obj, location):
if location.startswith('mps'):
return obj.mps()
Expand All @@ -354,18 +288,18 @@ def _meta_deserialize(obj, location):
return torch.UntypedStorage(obj.nbytes(), device='meta')


def _validate_privateuse1_device(location, backend_name):
def _validate_device(location, backend_name):
'''
Check whether the device index of privateuse1 is valid
Check whether the device index of specified backend is valid
Register a device_module of privateuse1 by torch._register_device_module.
Implement the following methods in device_module like cuda:
device_module._utils._get_device_index(location, True),
In case of privateuse1 backend, your must first register a device_module for
privateuse1 using torch._register_device_module. Implement the following
methods in device_module like cuda: device_module._utils._get_device_index(location, True),
device_module.device_count().
Args:
location: string of device
backend_name: the name of privateuse1, which can be renamed
backend_name: the backend name or the name of privateuse1, which can be renamed
Returns:
device_index: int
Expand All @@ -378,6 +312,7 @@ def _validate_privateuse1_device(location, backend_name):
device_module = getattr(torch, backend_name)
if hasattr(device_module, '_utils') and hasattr(device_module._utils, '_get_device_index'):
device_index = device_module._utils._get_device_index(location, True)
device = torch.device(backend_name, device_index)
else:
device = torch.device(location)
device_index = device.index if device.index else 0
Expand All @@ -394,29 +329,32 @@ def _validate_privateuse1_device(location, backend_name):
f'{device_index} but torch.{backend_name}.device_count() is {device_count}. '
'Please use torch.load with map_location to map your storages '
'to an existing device.')
return device_index
return device


def validate_cuda_device(location):
return _validate_device(location, 'cuda').index


def _privateuse1_deserialize(obj, location):
backend_name = torch._C._get_privateuse1_backend_name()
def validate_hpu_device(location):
return _validate_device(location, 'hpu').index


def _deserialize(backend_name, obj, location):
if backend_name == 'privateuse1':
backend_name = torch._C._get_privateuse1_backend_name()
if location.startswith(backend_name):
if not hasattr(obj, backend_name):
raise RuntimeError(f'Attempting to load the storages to the {backend_name.upper()} device '
f'but torch.storage._StorageBase.{backend_name}() or '
f'torch.storage.TypedStorage.{backend_name}() is not generated. '
'Please use torch.utils.generate_methods_for_privateuse1_backend '
f'to generate storage.{backend_name}() method first.')
device_index = _validate_privateuse1_device(location, backend_name)
return getattr(obj, backend_name)(device_index)
device = _validate_device(location, backend_name)
return obj.to(device=device)


register_package(10, _cpu_tag, _cpu_deserialize)
register_package(20, _cuda_tag, _cuda_deserialize)
register_package(20, functools.partial(_backend_tag, 'cuda'), functools.partial(_deserialize, 'cuda'))
register_package(21, _mps_tag, _mps_deserialize)
register_package(22, _meta_tag, _meta_deserialize)
register_package(23, _privateuse1_tag, _privateuse1_deserialize)
register_package(24, _hpu_tag, _hpu_deserialize)

register_package(23, functools.partial(_backend_tag, 'privateuse1'), functools.partial(_deserialize, 'privateuse1'))
register_package(24, functools.partial(_backend_tag, 'hpu'), functools.partial(_deserialize, 'hpu'))
register_package(25, functools.partial(_backend_tag, 'xpu'), functools.partial(_deserialize, 'xpu'))

def location_tag(storage: Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage]):
for _, tagger, _ in _package_registry:
Expand Down
Loading

0 comments on commit a2e563d

Please sign in to comment.