-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Move version_counter_ to TensorImpl #18223
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4d65bce
cd378ec
bb95573
07c897a
6a4c907
c6d5609
7212564
31581a6
919d522
8021735
6222ad0
03910ac
a351fba
f6e416e
9ce6b5a
bec5c06
b020120
2f466db
1c9151a
27917a4
03e6486
9c54ad3
fdb5680
8750f5e
fb8f996
35c0806
0128a17
c93251a
92603b4
9a568b7
379da62
c65a081
2e94856
1204b19
940ab21
bbd8104
21cb700
0e1f9f9
eb8f1d5
939349a
60f520e
9cadb1e
9778abd
e0f8e9d
7e19358
5ee8185
1983ca5
5504057
75ee500
d08f490
34c6e1e
107b58d
f089c6e
54f79c4
68814a3
1ba87d5
e573d35
6cdf7f4
f424cf1
46e22e4
49b5104
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -138,6 +138,61 @@ struct C10_API AutogradMetaInterface { | |
virtual ~AutogradMetaInterface(); | ||
}; | ||
|
||
// NOTE [ Version Counter Sharing ] | ||
// | ||
// Every Tensor has a version counter. Version counters are incremented whenever the | ||
// data or size of a tensor changes through in-place Variable operations. Version | ||
// counters are used to detect modifications to saved variables which would result in | ||
// incorrect gradient calculations. Version counters may be shared between Variables: | ||
// | ||
// 1. A view shares the version counter of the base Variable, | ||
// 2. `x.detach()` shares the version counter of `x`, | ||
// 3. Unpacked saved variables share the version counter of the source. | ||
// | ||
// Version counters are not shared in these scenarios: | ||
// | ||
// 1. When we replace a `Variable`'s underlying `Tensor` by calling `set_data(...)`, | ||
// 2. `x.data` does not share the version counter of `x`. (See discussion at | ||
// https://github.com/pytorch/pytorch/issues/5396) | ||
// | ||
// Question: Why do we put the version counter in TensorImpl instead of AutogradMeta? | ||
// | ||
// Answer: After the Variable/Tensor merge, a tensor will not have AutogradMeta when | ||
// its `requires_grad_` is false, but when we use this tensor in the forward pass of | ||
// a function that requires saving this tensor for backward, we need to keep track of | ||
// this tensor's version to make sure it's always valid in the autograd graph. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This note will be difficult for future readers to understand. The reason it is difficult to understand is it is making a comment relative to a change, but in the future, no one will see the change, they will see the code as is! Sometimes knowing the historical context is helpful to know why code is setup this way, but in this case, the issue should be explained from first principles. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will remove the "After the Variable/Tensor merge" phrase after the merge is completed, change the future tense to present tense and rework the comment to make it easier to understand. |
||
// | ||
// To achieve this goal, we put the version counter in TensorImpl instead of AutogradMeta, | ||
// and have it always be available. This allows us to have the optimization of not | ||
// carrying AutogradMeta when a tensor doesn't require gradient. | ||
// | ||
// A hypothetical alternative way to achieve this goal is to initialize AutogradMeta and | ||
// create the version counter for the non-requires-grad tensor only when it's saved for | ||
// backward. However, since saving a tensor for backward happens in the forward pass, and | ||
// our invariant is that forward pass needs to be thread-safe, lazy-initializing AutogradMeta | ||
// when saving a tensor can introduce race conditions when we are running the forward | ||
// pass in multi-thread scenarios, thus making the forward pass not thread-safe anymore, | ||
// which breaks the invariant. | ||
struct C10_API VariableVersion { | ||
public: | ||
// NOTE: As of C++11 and 14, default-constructing a std::atomic variable | ||
// leaves it in a persistently undefined state. See | ||
// https://cplusplus.github.io/LWG/issue2334. | ||
VariableVersion(uint32_t version = 0) | ||
: version_block_(std::make_shared<std::atomic<uint32_t>>(version)) {} | ||
|
||
void bump() noexcept { | ||
version_block_->fetch_add(1); | ||
} | ||
|
||
uint32_t current_version() const noexcept { | ||
return version_block_->load(); | ||
} | ||
|
||
private: | ||
std::shared_ptr<std::atomic<uint32_t>> version_block_; | ||
}; | ||
|
||
/** | ||
* The low-level representation of a tensor, which contains a pointer | ||
* to a storage (which contains the actual data) and metadata (e.g., sizes and | ||
|
@@ -845,13 +900,18 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { | |
return std::move(autograd_meta_); | ||
} | ||
|
||
// NOTE: `shallow_copy_and_detach()` does not copy the AutogradMeta pointer | ||
// because it is unique for each Variable. | ||
// NOTE: `shallow_copy_and_detach()` does not copy the following TensorImpl fields: | ||
// 1. the AutogradMeta pointer, because it is unique for each Variable. | ||
// 2. the version counter, because although it lives in TensorImpl, the version counter is managed | ||
// by autograd, and the call sites of `shallow_copy_and_detach()` (from autograd) should decide what | ||
// the version counter should be for each new TensorImpl. See NOTE [ Version Counter Sharing ] for details. | ||
// | ||
// NOTE: We don't set `allow_tensor_metadata_change_` to false here, because there are call sites | ||
// to this function that need to change the shallow copy's size or storage afterwards, and setting | ||
// `allow_tensor_metadata_change_` to false would prevent those changes from happening and is | ||
// undesirable. | ||
virtual c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach() const { | ||
AT_ASSERT(!is_variable()); // TODO: remove this when Variable and Tensor are merged | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above. |
||
auto impl = c10::make_intrusive<TensorImpl>(Storage(storage()), type_id(), is_variable()); | ||
impl->set_sizes_and_strides(sizes(), strides()); | ||
impl->storage_offset_ = storage_offset_; | ||
|
@@ -862,6 +922,19 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { | |
return impl; | ||
} | ||
|
||
void set_version_counter( | ||
const c10::VariableVersion& version_counter) noexcept { | ||
version_counter_ = version_counter; | ||
} | ||
|
||
const c10::VariableVersion& version_counter() const noexcept { | ||
return version_counter_; | ||
} | ||
|
||
void bump_version() noexcept { | ||
version_counter_.bump(); | ||
} | ||
|
||
inline void set_pyobj(PyObject* pyobj) noexcept { | ||
pyobj_ = pyobj; | ||
} | ||
|
@@ -1384,6 +1457,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { | |
// at a time). | ||
std::unique_ptr<c10::AutogradMetaInterface> autograd_meta_ = nullptr; | ||
|
||
c10::VariableVersion version_counter_; | ||
|
||
PyObject* pyobj_ = nullptr; // weak reference | ||
|
||
// We could save a word or two by combining the SmallVector structs, | ||
|
@@ -1471,6 +1546,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { | |
// weak refcount | ||
// storage pointer | ||
// autograd metadata pointer | ||
// version counter (word 0) | ||
// version counter (word 1) | ||
// PyObject pointer | ||
// sizes SmallVector (begin) | ||
// sizes SmallVector (end) | ||
|
@@ -1495,7 +1572,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { | |
// miscellaneous bitfield | ||
// | ||
static_assert(sizeof(void*) != sizeof(int64_t) || // if 64-bit... | ||
sizeof(TensorImpl) == sizeof(int64_t) * 27, | ||
sizeof(TensorImpl) == sizeof(int64_t) * 29, | ||
"You changed the size of TensorImpl on 64-bit arch." | ||
"See Note [TensorImpl size constraints] on how to proceed."); | ||
|
||
|
This file was deleted.
Uh oh!
There was an error while loading. Please reload this page.