Skip to content

Commit

Permalink
[PyTorch] faster empty_tensor_restride via smarter refresh_contiguous
Browse files Browse the repository at this point in the history
`empty_tensor_restride` knows what memory format it just
restrided (restrode?) to. Don't throw that information away.

Differential Revision: [D25539840](https://our.internmc.facebook.com/intern/diff/D25539840/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D25539840/)!

ghstack-source-id: 118540305
Pull Request resolved: #49349
  • Loading branch information
swolchok committed Dec 14, 2020
1 parent ad631b5 commit 93302e6
Showing 1 changed file with 38 additions and 1 deletion.
39 changes: 38 additions & 1 deletion c10/core/TensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1426,7 +1426,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
}
// recompute contiguous flag, as currently NHWC/NCHW flags are not mutually
// exclusive see #24090
refresh_contiguous();
refresh_contiguous(memory_format);
}

bool is_strides_like_channels_last() const {
Expand Down Expand Up @@ -1540,6 +1540,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
* or strides.
*/
void refresh_contiguous() {
// NOTE: Make sure to keep the other overload in sync with this implementation!
is_contiguous_ = compute_contiguous();
// Note:
// Dim 0, 1, 2 will never be a channels last 2d/3d format
Expand Down Expand Up @@ -1573,6 +1574,42 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
}
}

/**
* Faster implementation of refresh_contiguous() that can be used if
* we know the current MemoryFormat.
*/
void refresh_contiguous(MemoryFormat memory_format) {
// NOTE: Make sure to keep the other overload in sync with this implementation!
is_contiguous_ = memory_format == MemoryFormat::Contiguous || compute_contiguous();
switch (memory_format) {
case MemoryFormat::Contiguous:
is_channels_last_contiguous_ = false;
is_channels_last_contiguous_ = false;
is_channels_last_3d_contiguous_ = false;
is_channels_last_ = false;
is_channels_last_3d_ = false;
is_non_overlapping_and_dense_ = true;
break;
case MemoryFormat::ChannelsLast:
is_channels_last_contiguous_ = compute_channels_last_contiguous_2d();
is_channels_last_ = true;
is_channels_last_3d_ = false;
is_channels_last_3d_contiguous_ = false;
is_non_overlapping_and_dense_ = is_contiguous_ || is_channels_last_contiguous_ || compute_non_overlapping_and_dense();
break;
case MemoryFormat::ChannelsLast3d:
is_channels_last_contiguous_ = false;
is_channels_last_ = false;
is_channels_last_3d_ = true;
is_channels_last_3d_contiguous_ = compute_channels_last_contiguous_3d();
is_non_overlapping_and_dense_ = is_contiguous_ || compute_non_overlapping_and_dense();
break;
case MemoryFormat::Preserve:
// Is this case even possible?
refresh_contiguous();
}
}

/**
* Copy the tensor metadata fields (e.g. sizes / strides / storage pointer / storage_offset)
* from one TensorImpl to another TensorImpl.
Expand Down

0 comments on commit 93302e6

Please sign in to comment.