Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[jit] Support torch.save for saving values during execution #18154

Closed
wants to merge 27 commits into from

Conversation

driazati
Copy link
Contributor

@driazati driazati commented Mar 19, 2019

This PR makes torch.save call out to the pickler which saves a tensor in the same format that torch.save() does, the file looks like | pickle archive 1 (includes sizes, strides, requires_grad, etc...) | pickle archive 2 (list of tensor keys) | tensor binary data | and can be read back in with torch.load(my_file, pickle_module=torch.jit._pickle)

Fixes #18003

Unpickling in the JIT for things such as model parallelism will be a follow up PR

Differential Revision: D15015160

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Mar 19, 2019
@driazati driazati changed the title [jit] Add torch.jit.save_ivalue [wip][jit] Add torch.jit.save_ivalue Mar 20, 2019
@driazati driazati requested review from suo and eellison March 20, 2019 01:19
@driazati driazati changed the title [wip][jit] Add torch.jit.save_ivalue [jit] Add torch.jit.save_ivalue for saving values during execution Mar 21, 2019
// TODO: making IValues does a useless copy of the storage
IValue storage_bytes =
std::string((char*)tensor.storage().data(), record_size);
IValue list = c10::ivalue::GenericList::create({storage_bytes, num_elements});
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand—why are we pickling the literal tensor as a generic list?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When un-pickling only the last object on the stack gets popped off and passed to __setstate__, and to recreate the tensor we need both values hence the list wrapper. Creating IValues from them lets us re-use the existing list serialization code instead of copying it here

// Write file
std::fstream output(filename, std::ios::out | std::ios::binary);
output.write(p.stack().data(), p.stack().size());
output.close();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need to manually close, it will be RAII'd out


// Pickle the tensor
Pickler p;
p.start();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aside: is there value in the start() and finish() calls? Can we make them part of the constructor/destructor to avoid mistakes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

start could be in the constructor but finish puts some necessary opcodes at the end of the binary blob so it needs to run before the stack is stored somewhere (so it can't be in the destructor). Because of that it's more clear I think to have both start and finish

torch/jit/__init__.py Outdated Show resolved Hide resolved
@driazati driazati changed the title [jit] Add torch.jit.save_ivalue for saving values during execution [wip][jit] Add torch.jit.save_ivalue for saving values during execution Apr 2, 2019
@driazati driazati changed the title [wip][jit] Add torch.jit.save_ivalue for saving values during execution [wip][jit] Support torch.save for saving values during execution Apr 18, 2019
@driazati driazati requested review from suo and zdevito April 19, 2019 00:05
@driazati driazati changed the title [wip][jit] Support torch.save for saving values during execution [jit] Support torch.save for saving values during execution Apr 19, 2019
torch/serialization.py Outdated Show resolved Hide resolved
Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty good, but I am concerned about corner cases for tensor serialization, and have a few api suggestions.

void Pickler::pushClass(PicklerClass cls) {
const auto& name = getClassName(cls);
// Write it to the tensor table
void Pickler::pushGlobal(const std::string& name) {
auto memo_entry = memo_map_.find(&name);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think &name is a bug here. Previously it was returning one-time-allocated strings. Now it is returning things in a hash_map, whose addresses are not guaranteed to stay the same.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the maps are const doesn't that imply that there will be no re-allocating / re-hashing and so the pointers will always be valid?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on how pushGlobal is called, no, it is not always valid. I see pushGlobal being called as:

  1. with a const char* string
  2. with a string generated from stringstream.
  3. with a string using string concat

Version 2 and 3 fail. This kind of bug happens because the API looks like it does one thing (take a string), but the real API is suppose to take only statically allocated strings. In this case, it is best not to have an API like this in the first place.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding this and #20090, an API like this is nice to have so everything doesn't have to be statically spelled out. Adding a reference to pointer IValues and keeping strings around in a table on the pickler should fix this


// All attributes get pushed into a list and their indices saved in the
// module def
push<OpCode>(OpCode::EMPTY_LIST);
push<OpCode>(OpCode::MARK);
wrap_in_list_ = wrap_in_list;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weird interface bool creep here. Why not have the thing that needs to wrap the result in a list explicit call:

pickler.start();
pickler.beginPushList();
pickler.endPushList();
pickler.end();

torch/csrc/jit/pickler.cpp Show resolved Hide resolved
torch/csrc/jit/pickler.cpp Outdated Show resolved Hide resolved
torch/csrc/jit/pickler.cpp Outdated Show resolved Hide resolved
auto numel_ptr = reinterpret_cast<const char*>(&numel);
stack_.insert(stack_.end(), numel_ptr, numel_ptr + sizeof(numel));

uint64_t record_size = tensor.element_size() * tensor.numel();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the tensor contiguous? Is the tensor on the CPU? Otherwise this code is bogus. Look at how the other tensor serializer writes out tensors.

}
AT_ERROR("Unknown class name for unpickler: ", str);
}
const static std::unordered_map<std::string, PicklerClass> name_to_class{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was wrong with how it was before? The other one is almost certainly faster.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also missing an entry for LITERAL_TENSOR, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a little more readable I think and this isn't really accessed that often to be a performance bottleneck or anything.

LITERAL_TENSOR wasn't necessary for this PR (it's only needed to Unpickle these tensors in C++, which isn't crucial since torch.load can read them in)

torch/serialization.py Outdated Show resolved Hide resolved
@driazati driazati requested a review from zdevito May 2, 2019 16:11
Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking good. Two things before it is ready:

  1. Bug in pushGlobal that can cause memoization to fail.
  2. Needs more tests around how tensors are serialized because I can't tell from the implementation if it is correct or not.

torch/csrc/jit/pickler.h Outdated Show resolved Hide resolved
torch/csrc/jit/pickler.h Outdated Show resolved Hide resolved
void Pickler::pushClass(PicklerClass cls) {
const auto& name = getClassName(cls);
// Write it to the tensor table
void Pickler::pushGlobal(const std::string& name) {
auto memo_entry = memo_map_.find(&name);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on how pushGlobal is called, no, it is not always valid. I see pushGlobal being called as:

  1. with a const char* string
  2. with a string generated from stringstream.
  3. with a string using string concat

Version 2 and 3 fail. This kind of bug happens because the API looks like it does one thing (take a string), but the real API is suppose to take only statically allocated strings. In this case, it is best not to have an API like this in the first place.

torch/csrc/jit/pickler.cpp Show resolved Hide resolved
test/test_jit.py Show resolved Hide resolved
@driazati driazati requested a review from zdevito May 6, 2019 23:30
Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty much ready to go. I think there is a bug with string memoization, if you can resolve that I will just look at that change and approve.

void Pickler::pushGlobal(const std::string& name_temp) {
memoized_strings_.push_back(name_temp);
auto name = memoized_strings_.back();
auto memo_entry = memo_map_.find(&(memoized_strings_.back()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem right. What is the intention here? &memoized_strings_.back()is a pointer to a position in memoized_strings_. Since it was just inserted, it will never be in the memo_map_, and if the memoized_strings_ gets reallocated, then it will be a pointer to bogus data. memo_map_ only really works for reference IValue types. Memoizing non-ivalue strings will probably require a hash map from string -> memo id.

@driazati driazati requested a review from zdevito May 7, 2019 21:38
Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Minor API comment.

pushString(name_temp);

// Push BINPUT without adding anything to the memo_map_
pushMemoization(nullptr);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a weird API. It decides not to update the memo_map_ in pushMemoization and then carefully uses memo_id here (something secretly updated in pushMemoization)

Suggestion:

memoized_string_map_.insert({name_temp, pushBinPutNext()});

pushBinPutNext() pushes the BINPUT opcode with an incremented memo_id and then returns it. Then pushMemoization becomes:

void pushMemoization(void* item) {
  memo_map_[item] = pushBinPutNext();
}

zdevito pushed a commit to zdevito/ATen that referenced this pull request May 9, 2019
Summary:
This PR makes `torch.save` call out to the pickler which saves a tensor in the same format that `torch.save()` does, the file looks like `| pickle archive 1 (includes sizes, strides, requires_grad, etc...) | pickle archive 2 (list of tensor keys) | tensor binary data |` and can be read back in with `torch.load(my_file, pickle_module=torch.jit._pickle)`

Fixes #18003

Unpickling in the JIT for things such as model parallelism will be a follow up PR
](https://our.intern.facebook.com/intern/diff/15015160/)

Pull Request resolved: pytorch/pytorch#18154

Pulled By: driazati

Differential Revision: D15015160

fbshipit-source-id: ef76a44b8c243f4794cd7e245ec8305e965bc59f
@facebook-github-bot
Copy link
Contributor

@driazati merged this pull request in 8ebb86d.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[JIT] Allow Serialization to be exportable for debugging
4 participants