Skip to content

Commit

Permalink
Remove useless copy on zip file load (#36362)
Browse files Browse the repository at this point in the history
Summary:
Instead of copying to a buffer, then setting a tensor's storage with that buffer, create a storage directly from the file
](https://our.intern.facebook.com/intern/diff/21057090/)
Pull Request resolved: #36362

Pulled By: driazati

Differential Revision: D21057090

fbshipit-source-id: e3d30a3b09f4d67bf4bb7a0dd7f4f60c3dd1a47e
  • Loading branch information
davidriazati authored and facebook-github-bot committed May 22, 2020
1 parent 8e69c3b commit 455bf77
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 9 deletions.
22 changes: 22 additions & 0 deletions torch/csrc/jit/python/init.cpp
Expand Up @@ -729,6 +729,28 @@ void initJITBindings(PyObject* module) {
std::tie(data, size) = self.getRecord(key);
return py::bytes(reinterpret_cast<const char*>(data.get()), size);
})
.def(
"get_storage_from_record",
[](PyTorchStreamReader& self,
const std::string& key,
size_t numel,
py::object data_type_obj) {
at::DataPtr data(std::get<0>(self.getRecord(key)));
auto scalar_type =
reinterpret_cast<THPDtype*>(data_type_obj.ptr())->scalar_type;

auto storage = c10::Storage(
c10::Storage::use_byte_size_t(),
at::CPU(scalar_type).typeMeta(),
numel * elementSize(scalar_type),
std::move(data),
/*allocator=*/nullptr,
/*resizable=*/false);
auto ptr =
c10::make_intrusive<at::TensorImpl, at::UndefinedTensorImpl>(
std::move(storage), at::DispatchKeySet());
return at::Tensor(std::move(ptr));
})
.def("get_all_records", [](PyTorchStreamReader& self) {
return self.getAllRecords();
});
Expand Down
15 changes: 6 additions & 9 deletions torch/serialization.py
Expand Up @@ -473,7 +473,6 @@ def persistent_id(obj):
if storage.device.type == 'cpu':
# If it's on the CPU we can directly copy it into the zip file
num_bytes = storage.size() * storage.element_size()
buf = io.BytesIO()
zip_file.write_record(name, storage.data_ptr(), num_bytes)
else:
# Copy to a buffer, then serialize that
Expand Down Expand Up @@ -810,14 +809,12 @@ def _load(zip_file, map_location, pickle_module, **pickle_load_args):

loaded_storages = {}

def load_tensor(obj, size, key, location):
loaded_storages[key] = restore_location(obj, location)
def load_tensor(data_type, size, key, location):
name = 'data/{}'.format(key)
size_long = struct.pack("<Q", size)
tensor_file = io.BytesIO(size_long + zip_file.get_record(name))
offset = None
is_real_file = False
loaded_storages[key]._set_from_file(tensor_file, offset, is_real_file)
dtype = data_type(0).dtype

storage = zip_file.get_storage_from_record(name, size, dtype).storage()
loaded_storages[key] = restore_location(storage, location)

def persistent_load(saved_id):
assert isinstance(saved_id, tuple)
Expand All @@ -828,7 +825,7 @@ def persistent_load(saved_id):
"Unknown typename for persistent_load, expected 'storage' but got '{}'".format(typename)
data_type, key, location, size = data
if key not in loaded_storages:
load_tensor(data_type(size), size, key, _maybe_decode_ascii(location))
load_tensor(data_type, size, key, _maybe_decode_ascii(location))
storage = loaded_storages[key]
return storage

Expand Down

0 comments on commit 455bf77

Please sign in to comment.