-
Notifications
You must be signed in to change notification settings - Fork 21.4k
/
TensorGeometry.h
144 lines (130 loc) · 4.17 KB
/
TensorGeometry.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#pragma once
#include <ATen/core/TensorBase.h>
#include <c10/core/WrapDimMinimal.h>
namespace at {
// Return if the tensor geometry represented by `sizes` and `strides` is
// contiguous Although we cache is_contiguous in tensor now, this is till useful
// because it allows checking if a particular geometry is contiguous without
// explicitly constructing a tensor, e.g., when you want to choose a kernel
// strategy based on whether a subgeometry is contiguous.
TORCH_API bool geometry_is_contiguous(IntArrayRef sizes, IntArrayRef strides);
struct TORCH_API TensorGeometry {
TensorGeometry() = default;
explicit TensorGeometry(c10::SymIntArrayRef sizes)
: sizes_(sizes.vec()),
strides_(sizes.size()),
has_symbolic_sizes_strides_(
!c10::asIntArrayRefSlowOpt(sizes).has_value()) {
int64_t dim = static_cast<int64_t>(sizes.size());
c10::SymInt expected_stride = 1;
for (int64_t i = dim - 1; i >= 0; i--) {
strides_[i] = expected_stride;
expected_stride *= sizes_[i];
}
numel_ = expected_stride;
}
explicit TensorGeometry(const TensorBase& t)
: sizes_(t.sym_sizes().vec()),
strides_(t.sym_strides().vec()),
storage_offset_(t.sym_storage_offset()),
numel_(t.sym_numel()),
has_symbolic_sizes_strides_(
t.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {}
// true if the tensor is contiguous
bool is_contiguous() const;
int64_t dim() const {
return static_cast<int64_t>(sizes_.size());
}
int64_t size(int64_t dim) const {
TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
dim = c10::maybe_wrap_dim(dim, this->dim());
return sizes_.at(static_cast<size_t>(dim)).as_int_unchecked();
}
c10::IntArrayRef sizes() const {
TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
return c10::asIntArrayRefUnchecked(sizes_);
}
int64_t stride(int64_t dim) const {
TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
dim = c10::maybe_wrap_dim(dim, this->dim());
return strides_.at(static_cast<size_t>(dim)).as_int_unchecked();
}
c10::IntArrayRef strides() const {
TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
return c10::asIntArrayRefUnchecked(strides_);
}
int64_t storage_offset() const {
TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
return storage_offset_.as_int_unchecked();
}
int64_t numel() const {
TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_);
return numel_.as_int_unchecked();
}
c10::SymInt sym_size(int64_t dim) const {
dim = c10::maybe_wrap_dim(dim, this->dim());
return sizes_.at(static_cast<size_t>(dim));
}
c10::SymIntArrayRef sym_sizes() const {
return sizes_;
}
c10::SymInt sym_stride(int64_t dim) const {
dim = c10::maybe_wrap_dim(dim, this->dim());
return strides_.at(static_cast<size_t>(dim));
}
c10::SymIntArrayRef sym_strides() const {
return strides_;
}
c10::SymInt sym_storage_offset() const {
return storage_offset_;
}
c10::SymInt sym_numel() const {
return numel_;
}
TensorGeometry transpose(int64_t dim0, int64_t dim1) {
TensorGeometry r = *this; // copy
TORCH_CHECK(
dim0 < dim(),
"transpose: dim0=",
dim0,
" out of range (dim=",
dim(),
")")
TORCH_CHECK(
dim1 < dim(),
"transpose: dim1=",
dim1,
" out of range (dim=",
dim(),
")")
std::swap(r.sizes_[dim0], r.sizes_[dim1]);
std::swap(r.strides_[dim0], r.strides_[dim1]);
return r;
}
std::vector<c10::SymInt>& mutable_sizes() {
return sizes_;
}
std::vector<c10::SymInt>& mutable_strides() {
return strides_;
}
c10::SymInt& mutable_storage_offset() {
return storage_offset_;
}
void recompute() {
// recalculate numel after a change
c10::SymInt numel = 1;
for (const auto& i : sizes_) {
numel = numel * i;
}
numel_ = std::move(numel);
has_symbolic_sizes_strides_ =
!c10::asIntArrayRefSlowOpt(sizes_).has_value();
}
private:
std::vector<c10::SymInt> sizes_;
std::vector<c10::SymInt> strides_;
c10::SymInt storage_offset_;
c10::SymInt numel_;
bool has_symbolic_sizes_strides_{false};
};
} // namespace at