Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4400,6 +4400,10 @@ def map_location(storage, loc):
self.assertEqual(type(tensor), torch.FloatTensor)
self.assertEqual(tensor, torch.FloatTensor([[1.0, 2.0], [3.0, 4.0]]))

tensor = torch.load(test_file_path, map_location='cpu')
self.assertEqual(type(tensor), torch.FloatTensor)
self.assertEqual(tensor, torch.FloatTensor([[1.0, 2.0], [3.0, 4.0]]))

def test_from_buffer(self):
a = bytearray([1, 2, 3, 4])
self.assertEqual(torch.ByteStorage.from_buffer(a).tolist(), [1, 2, 3, 4])
Expand Down
13 changes: 11 additions & 2 deletions torch/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import warnings
from contextlib import closing, contextmanager
from ._utils import _import_dotted_name
from ._six import string_classes as _string_classes
if sys.version_info[0] == 2:
import cPickle as pickle
else:
Expand Down Expand Up @@ -225,7 +226,10 @@ def load(f, map_location=None, pickle_module=pickle):
the right device. Otherwise, torch.load will fall back to the default behavior,
as if map_location wasn't specified.

If map_location is a dict, it will be used to remap location tags
If map_location is a string, it should be a device tag, where all tensors
should be loaded.

Otherwise, if map_location is a dict, it will be used to remap location tags
appearing in the file (keys), to ones that specify where to put the
storages (values).

Expand All @@ -236,14 +240,16 @@ def load(f, map_location=None, pickle_module=pickle):
f: a file-like object (has to implement fileno that returns a file
descriptor, and must implement seek), or a string containing a file
name
map_location: a function or a dict specifying how to remap storage
map_location: a function, string or a dict specifying how to remap storage
locations
pickle_module: module used for unpickling metadata and objects (has to
match the pickle_module used to serialize file)

Example:
>>> torch.load('tensors.pt')
# Load all tensors onto the CPU
>>> torch.load('tensors.pt', map_location='cpu')
# Load all tensors onto the CPU, using a function
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage)
# Load all tensors onto GPU 1
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
Expand Down Expand Up @@ -273,6 +279,9 @@ def _load(f, map_location, pickle_module):
def restore_location(storage, location):
location = map_location.get(location, location)
return default_restore_location(storage, location)
elif isinstance(map_location, _string_classes):
def restore_location(storage, location):
return default_restore_location(storage, map_location)
else:
def restore_location(storage, location):
result = map_location(storage, location)
Expand Down