New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[jit] Support torch.save
for saving values during execution
#18154
Conversation
torch.jit.save_ivalue
torch.jit.save_ivalue
torch.jit.save_ivalue
torch.jit.save_ivalue
for saving values during execution
torch/csrc/jit/pickler.cpp
Outdated
// TODO: making IValues does a useless copy of the storage | ||
IValue storage_bytes = | ||
std::string((char*)tensor.storage().data(), record_size); | ||
IValue list = c10::ivalue::GenericList::create({storage_bytes, num_elements}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand—why are we pickling the literal tensor as a generic list?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When un-pickling only the last object on the stack gets popped off and passed to __setstate__
, and to recreate the tensor we need both values hence the list wrapper. Creating IValue
s from them lets us re-use the existing list serialization code instead of copying it here
torch/csrc/jit/register_prim_ops.cpp
Outdated
// Write file | ||
std::fstream output(filename, std::ios::out | std::ios::binary); | ||
output.write(p.stack().data(), p.stack().size()); | ||
output.close(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you don't need to manually close, it will be RAII'd out
|
||
// Pickle the tensor | ||
Pickler p; | ||
p.start(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aside: is there value in the start() and finish() calls? Can we make them part of the constructor/destructor to avoid mistakes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
start
could be in the constructor but finish
puts some necessary opcodes at the end of the binary blob so it needs to run before the stack is stored somewhere (so it can't be in the destructor). Because of that it's more clear I think to have both start
and finish
torch.jit.save_ivalue
for saving values during executiontorch.jit.save_ivalue
for saving values during execution
torch.jit.save_ivalue
for saving values during executiontorch.save
for saving values during execution
torch.save
for saving values during executiontorch.save
for saving values during execution
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks pretty good, but I am concerned about corner cases for tensor serialization, and have a few api suggestions.
torch/csrc/jit/pickler.cpp
Outdated
void Pickler::pushClass(PicklerClass cls) { | ||
const auto& name = getClassName(cls); | ||
// Write it to the tensor table | ||
void Pickler::pushGlobal(const std::string& name) { | ||
auto memo_entry = memo_map_.find(&name); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think &name
is a bug here. Previously it was returning one-time-allocated strings. Now it is returning things in a hash_map, whose addresses are not guaranteed to stay the same.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the maps are const
doesn't that imply that there will be no re-allocating / re-hashing and so the pointers will always be valid?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on how pushGlobal is called, no, it is not always valid. I see pushGlobal being called as:
- with a const char* string
- with a string generated from stringstream.
- with a string using string concat
Version 2 and 3 fail. This kind of bug happens because the API looks like it does one thing (take a string), but the real API is suppose to take only statically allocated strings. In this case, it is best not to have an API like this in the first place.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding this and #20090, an API like this is nice to have so everything doesn't have to be statically spelled out. Adding a reference to pointer IValues and keeping strings around in a table on the pickler should fix this
torch/csrc/jit/pickler.cpp
Outdated
|
||
// All attributes get pushed into a list and their indices saved in the | ||
// module def | ||
push<OpCode>(OpCode::EMPTY_LIST); | ||
push<OpCode>(OpCode::MARK); | ||
wrap_in_list_ = wrap_in_list; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Weird interface bool
creep here. Why not have the thing that needs to wrap the result in a list explicit call:
pickler.start();
pickler.beginPushList();
pickler.endPushList();
pickler.end();
torch/csrc/jit/pickler.cpp
Outdated
auto numel_ptr = reinterpret_cast<const char*>(&numel); | ||
stack_.insert(stack_.end(), numel_ptr, numel_ptr + sizeof(numel)); | ||
|
||
uint64_t record_size = tensor.element_size() * tensor.numel(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the tensor contiguous? Is the tensor on the CPU? Otherwise this code is bogus. Look at how the other tensor serializer writes out tensors.
torch/csrc/jit/pickler.cpp
Outdated
} | ||
AT_ERROR("Unknown class name for unpickler: ", str); | ||
} | ||
const static std::unordered_map<std::string, PicklerClass> name_to_class{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What was wrong with how it was before? The other one is almost certainly faster.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is also missing an entry for LITERAL_TENSOR, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a little more readable I think and this isn't really accessed that often to be a performance bottleneck or anything.
LITERAL_TENSOR
wasn't necessary for this PR (it's only needed to Unpickle these tensors in C++, which isn't crucial since torch.load
can read them in)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking good. Two things before it is ready:
- Bug in pushGlobal that can cause memoization to fail.
- Needs more tests around how tensors are serialized because I can't tell from the implementation if it is correct or not.
torch/csrc/jit/pickler.cpp
Outdated
void Pickler::pushClass(PicklerClass cls) { | ||
const auto& name = getClassName(cls); | ||
// Write it to the tensor table | ||
void Pickler::pushGlobal(const std::string& name) { | ||
auto memo_entry = memo_map_.find(&name); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on how pushGlobal is called, no, it is not always valid. I see pushGlobal being called as:
- with a const char* string
- with a string generated from stringstream.
- with a string using string concat
Version 2 and 3 fail. This kind of bug happens because the API looks like it does one thing (take a string), but the real API is suppose to take only statically allocated strings. In this case, it is best not to have an API like this in the first place.
…e for the the duration of the pickler
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pretty much ready to go. I think there is a bug with string memoization, if you can resolve that I will just look at that change and approve.
torch/csrc/jit/pickler.cpp
Outdated
void Pickler::pushGlobal(const std::string& name_temp) { | ||
memoized_strings_.push_back(name_temp); | ||
auto name = memoized_strings_.back(); | ||
auto memo_entry = memo_map_.find(&(memoized_strings_.back())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't seem right. What is the intention here? &memoized_strings_.back()
is a pointer to a position in memoized_strings_. Since it was just inserted, it will never be in the memo_map_, and if the memoized_strings_ gets reallocated, then it will be a pointer to bogus data. memo_map_ only really works for reference IValue types. Memoizing non-ivalue strings will probably require a hash map from string -> memo id.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. Minor API comment.
torch/csrc/jit/pickler.cpp
Outdated
pushString(name_temp); | ||
|
||
// Push BINPUT without adding anything to the memo_map_ | ||
pushMemoization(nullptr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a weird API. It decides not to update the memo_map_ in pushMemoization and then carefully uses memo_id here (something secretly updated in pushMemoization)
Suggestion:
memoized_string_map_.insert({name_temp, pushBinPutNext()});
pushBinPutNext()
pushes the BINPUT opcode with an incremented memo_id and then returns it. Then pushMemoization becomes:
void pushMemoization(void* item) {
memo_map_[item] = pushBinPutNext();
}
Summary: This PR makes `torch.save` call out to the pickler which saves a tensor in the same format that `torch.save()` does, the file looks like `| pickle archive 1 (includes sizes, strides, requires_grad, etc...) | pickle archive 2 (list of tensor keys) | tensor binary data |` and can be read back in with `torch.load(my_file, pickle_module=torch.jit._pickle)` Fixes #18003 Unpickling in the JIT for things such as model parallelism will be a follow up PR ](https://our.intern.facebook.com/intern/diff/15015160/) Pull Request resolved: pytorch/pytorch#18154 Pulled By: driazati Differential Revision: D15015160 fbshipit-source-id: ef76a44b8c243f4794cd7e245ec8305e965bc59f
This PR makes
torch.save
call out to the pickler which saves a tensor in the same format thattorch.save()
does, the file looks like| pickle archive 1 (includes sizes, strides, requires_grad, etc...) | pickle archive 2 (list of tensor keys) | tensor binary data |
and can be read back in withtorch.load(my_file, pickle_module=torch.jit._pickle)
Fixes #18003
Unpickling in the JIT for things such as model parallelism will be a follow up PR
Differential Revision: D15015160