Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JIT][Static Runtime] Memory optimization for output tensors #53867

Open
hlu1 opened this issue Mar 12, 2021 · 0 comments
Open

[JIT][Static Runtime] Memory optimization for output tensors #53867

hlu1 opened this issue Mar 12, 2021 · 0 comments
Assignees
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue
Projects

Comments

@hlu1
Copy link
Contributor

hlu1 commented Mar 12, 2021

One of the unique features of Static Runtime is the MemoryPlanner, which aggregates all the memory allocation of the intermediate tensors into a single malloc and caches all the TensorImpls into Static Runtime. It helps speed up inference by reducing the number of mallocs and the time it takes to create/destry Tensor objects and the associated refcount bumps on the fly. However, MemoryPlanner only manages the intermediate tensors, which exclude model inputs and outputs. If we can extend MemoryPlanner to include the output tensors, we can speed up models with multiple outputs dramatically.

First, we'll need some bookkeeping for the output tensors:

  std::vector<std::pair<size_t, std::vector<c10::StorageImpl*>>> managed_output_storage_;
  size_t managed_output_bytes_{0};
  at::DataPtr output_buffer_; // for outputs only

For implementation, see https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/runtime/static/impl.cpp
Similar to the intermediates, for outputs, MemoryPlanner can only manage output tensors of ops with out variants. For ops without out variants, their output tensors will be dynamically created by the op. There is nothing the MemoryPlanner can do.

Do pay attention to aliases. We'll need to exclude model input and input aliases. Aliases of intermediate tensors and output tensors need to handled carefully.

For testing, there are a lot of unit tests in https://github.com/pytorch/pytorch/blob/master/benchmarks/static_runtime/test_static_runtime.cc and https://github.com/pytorch/pytorch/blob/master/test/test_static_runtime.py.

cc @gmagogsfm

@hlu1 hlu1 added the oncall: jit Add this issue/PR to JIT oncall triage queue label Mar 12, 2021
@github-actions github-actions bot added this to Need triage in JIT Triage Mar 12, 2021
@tugsbayasgalan tugsbayasgalan moved this from Need triage to Pending in JIT Triage Mar 16, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
JIT Triage
  
Pending
Development

No branches or pull requests

2 participants