Skip to content

Commit

Permalink
Revert "Add full support for serialization of MPS Tensors (#79465)"
Browse files Browse the repository at this point in the history
This reverts commit 64c2a27.

Reverted #79465 on behalf of https://github.com/zengk95 due to this broke X linux-xenial-py3.7-clang7-onnx / test (default, 1, 2, linux.2xlarge). Not sure why since it passed on pull.
  • Loading branch information
pytorchmergebot committed Jun 14, 2022
1 parent 31ada13 commit ce6ce74
Show file tree
Hide file tree
Showing 7 changed files with 2 additions and 63 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/mps/MPSDevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class TORCH_API MPSDevice {

TORCH_API bool is_available();

TORCH_API at::Allocator* GetMPSAllocator(bool useSharedAllocator = false);
at::Allocator* GetMPSAllocator(bool useSharedAllocator = false);

} // namespace mps
} // namespace at
37 changes: 0 additions & 37 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import unittest
import warnings
import subprocess
import tempfile
import os
import torch
import torch.nn as nn
Expand Down Expand Up @@ -4537,42 +4536,6 @@ def test_legacy_constructor(self):

b = a.new(1)

def test_serialization_map_location(self):

# Ensures that cpu Tensor can be loaded on mps
with tempfile.NamedTemporaryFile() as f:
x = torch.rand(2)
torch.save(x, f)

f.seek(0)
x2 = torch.load(f, map_location="mps")

self.assertEqual(x, x2)
self.assertEqual(x2.device.type, "mps")

# Ensures that mps Tensors can be loaded on mps
with tempfile.NamedTemporaryFile() as f:
x = torch.rand(2, device="mps")
torch.save(x, f)

f.seek(0)
x2 = torch.load(f)

self.assertEqual(x, x2)
self.assertEqual(x2.device.type, "mps")

# Ensures that mps Tensors can be loaded on cpu
with tempfile.NamedTemporaryFile() as f:
x = torch.rand(2, device="mps")
torch.save(x, f)

f.seek(0)
x2 = torch.load(f, map_location="cpu")

self.assertEqual(x, x2)
self.assertEqual(x2.device.type, "cpu")




if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion torch/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def _reduce_ex_internal(self, proto):
# 2. Python list is not a good fit due to performance reason.
# `tolist()` converts every single element in the tensor into python objects
# and serialize them one by one.
if self.device.type in ['xla', 'ort', 'hpu']:
if self.device.type in ['xla', 'ort', 'mps', 'hpu']:
# Convert BFloat16 tesors to Float32 before conversion to numpy, as numpy doesn't
# support BFloat16. The rebuild tensor from numpy takes in the original self.dtype,
# this would reconstruct the BFloat16 tensor from numpy.
Expand Down
2 changes: 0 additions & 2 deletions torch/csrc/DynamicTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ at::DeprecatedTypeProperties* get_type_properties(
backend = at::Backend::CPU;
} else if (device_type == at::kCUDA) {
backend = at::Backend::CUDA;
} else if (device_type == at::kMPS) {
backend = at::Backend::MPS;
} else if (device_type == at::DeviceType::Meta) {
backend = at::Backend::Undefined;
} else {
Expand Down
5 changes: 0 additions & 5 deletions torch/csrc/Storage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#endif
#include <structmember.h>

#include <ATen/mps/MPSDevice.h>
#include <c10/core/CPUAllocator.h>
#include <libshm.h>
#include <torch/csrc/CudaIPCTypes.h>
Expand Down Expand Up @@ -94,10 +93,6 @@ static PyObject* THPStorage_pynew(
} else if (device.type() == at::kCUDA) {
at::globalContext().lazyInitCUDA();
allocator = c10::cuda::CUDACachingAllocator::get();
#endif
#ifdef USE_MPS
} else if (device.type() == at::kMPS) {
allocator = at::mps::GetMPSAllocator();
#endif
} else if (device.type() == at::DeviceType::Meta) {
allocator = c10::GetAllocator(device.type());
Expand Down
10 changes: 0 additions & 10 deletions torch/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,6 @@ def _cuda_tag(obj):
return 'cuda:' + str(obj.device.index)


def _mps_tag(obj):
if obj.device.type == 'mps':
return 'mps'


def _cpu_deserialize(obj, location):
if location == 'cpu':
return obj
Expand Down Expand Up @@ -161,14 +156,9 @@ def _cuda_deserialize(obj, location):
else:
return obj.cuda(device)

def _mps_deserialize(obj, location):
if location == 'mps':
return obj.mps()


register_package(10, _cpu_tag, _cpu_deserialize)
register_package(20, _cuda_tag, _cuda_deserialize)
register_package(21, _mps_tag, _mps_deserialize)


def location_tag(storage: Union[Storage, torch.storage._TypedStorage, torch._UntypedStorage]):
Expand Down
7 changes: 0 additions & 7 deletions torch/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,6 @@ def cpu(self):
else:
return self

def mps(self):
"""Returns a CPU copy of this storage if it's not already on the CPU"""
if self.device.type != 'mps':
return torch._UntypedStorage(self.size(), device="mps").copy_(self, False)
else:
return self

def _to(self, dtype):
if not isinstance(dtype, torch.dtype):
raise TypeError(f"Argument 'dtype' must be torch.dtype, not {type(dtype)}")
Expand Down

0 comments on commit ce6ce74

Please sign in to comment.