-
Notifications
You must be signed in to change notification settings - Fork 21.4k
/
TensorImpl.h
1825 lines (1654 loc) · 66.1 KB
/
TensorImpl.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#pragma once
#include <atomic>
#include <memory>
#include <numeric>
#include <c10/core/Backend.h>
#include <c10/core/MemoryFormat.h>
#include <c10/core/Storage.h>
#include <c10/core/TensorOptions.h>
#include <c10/core/DispatchKeySet.h>
#include <c10/core/impl/LocalDispatchKeySet.h>
#include <c10/core/CopyBytes.h>
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>
#include <c10/util/Flags.h>
#include <c10/util/Logging.h>
#include <c10/util/python_stub.h>
// A global boolean variable to control whether we free memory when a Tensor
// is shrinked to a smaller size. As a result, a Tensor is always going to
// keep the memory allocated for its maximum capacity reshaped to so far.
//
// This parameter is respected "upper-case" methods which call Resize()
// (e.g., CopyFrom, ResizeLike); it is NOT respected by Tensor::resize_
// or ShrinkTo, both of which guarantee to never to free memory.
C10_DECLARE_bool(caffe2_keep_on_shrink);
// Since we can have high variance in blob memory allocated across different
// inputs in the same run, we will shrink the blob only if the memory gain
// is larger than this flag in bytes. This only applies to functions which
// respect caffe2_keep_on_shrink.
C10_DECLARE_int64(caffe2_max_keep_on_shrink_memory);
namespace at {
class Tensor;
}
namespace c10 {
class Scalar;
struct Storage;
/**
* A utility function to convert vector<int> to vector<int64_t>.
*/
inline std::vector<int64_t> ToVectorint64_t(ArrayRef<int> src) {
return std::vector<int64_t>(src.begin(), src.end());
}
/**
* Return product of all dimensions starting from k
*/
inline int64_t size_from_dim_(int k, IntArrayRef dims) {
int64_t r = 1;
for (size_t i = k; i < dims.size(); ++i) {
r *= dims[i];
}
return r;
}
// Product of all dims up to k (not including dims[k])
inline int64_t size_to_dim_(int k, IntArrayRef dims) {
TORCH_CHECK((unsigned)k <= dims.size());
int64_t r = 1;
for (int i = 0; i < k; ++i) {
r *= dims[i];
}
return r;
}
// Product of all dims between k and l (not including dims[k] and dims[l])
inline int64_t size_between_dim_(int k, int l, IntArrayRef dims) {
TORCH_CHECK((unsigned)l < dims.size());
int64_t r = 1;
if (k < l) {
for (int i = k + 1; i < l; ++i) {
r *= dims[i];
}
} else {
for (int i = l + 1; i < k; ++i) {
r *= dims[i];
}
}
return r;
}
// Wrap around axis_index if it is negative, s.t., -1 is the last dim
inline int canonical_axis_index_(int axis_index, int ndims) {
TORCH_CHECK(axis_index >= -ndims);
TORCH_CHECK(axis_index < ndims);
if (axis_index < 0) {
return axis_index + ndims;
}
return axis_index;
}
using PlacementDtor = void (*)(void*, size_t);
/*
* A Context that will call extra placement deleter during
* deconstruction.
*
* Accept a already constructed DataPtr and store it as member
* during destruction, we'll call extra deleter on the underlying
* data pointer before the DataPtr is destructed.
* `data_ptr_` owns the memory.
*/
struct C10_API PlacementDeleteContext {
DataPtr data_ptr_;
PlacementDtor placement_dtor_;
size_t size_;
PlacementDeleteContext(
DataPtr&& data_ptr,
PlacementDtor placement_dtor,
size_t size)
: data_ptr_(std::move(data_ptr)),
placement_dtor_(placement_dtor),
size_(size) {}
static DataPtr makeDataPtr(
DataPtr&& data_ptr,
PlacementDtor placement_dtor,
size_t size,
Device device);
~PlacementDeleteContext() {
placement_dtor_(data_ptr_.get(), size_);
// original memory will be freed when data_ptr_ is destructed
}
};
struct TensorImpl;
struct C10_API AutogradMetaInterface {
virtual void set_requires_grad(bool requires_grad, at::TensorImpl* self_impl) = 0;
virtual bool requires_grad() const = 0;
virtual at::Tensor& mutable_grad() = 0;
virtual const at::Tensor& grad() const = 0;
virtual ~AutogradMetaInterface();
};
namespace impl {
// Unfortunately, the definition of AutogradMeta lives in a separate
// compilation unit than TensorImpl (libtorch.so versus libc10.so)
// which means that we cannot construct an AutogradMeta from TensorImpl,
// not even from the cpp file. So we have to indirect it through a factory
// function which will be initialized when we load libtorch.so.
struct C10_API AutogradMetaFactory {
virtual ~AutogradMetaFactory() = default;
virtual std::unique_ptr<AutogradMetaInterface> make() const = 0;
// This method is the dumbest method. But I don't have access
// to Tensor (not TensorImpl) which is undefined in this header.
virtual const at::Tensor& undefined_tensor() const = 0;
};
C10_API void SetAutogradMetaFactory(AutogradMetaFactory* factory);
C10_API AutogradMetaFactory* GetAutogradMetaFactory();
struct C10_API AutogradMetaFactoryRegisterer {
explicit AutogradMetaFactoryRegisterer(AutogradMetaFactory* factory) {
SetAutogradMetaFactory(factory);
}
};
} // namespace impl
struct C10_API NamedTensorMetaInterface {
virtual ~NamedTensorMetaInterface() {};
virtual std::unique_ptr<NamedTensorMetaInterface> clone() const {
TORCH_INTERNAL_ASSERT(
false,
"Not implemented: NamedTensorMetaInterface::clone");
};
virtual int64_t slow_dim() const {
TORCH_INTERNAL_ASSERT(
false,
"Not implemented: NamedTensorMetaInterface::slow_dim");
};
};
// NOTE [ Version Counter Sharing ]
//
// Every Tensor has a version counter. Version counters are incremented whenever the
// data or size of a tensor changes through in-place Variable operations. Version
// counters are used to detect modifications to saved variables which would result in
// incorrect gradient calculations. Version counters may be shared between Variables:
//
// 1. A view shares the version counter of the base Variable,
// 2. `x.detach()` shares the version counter of `x`,
// 3. Unpacked saved variables share the version counter of the source.
//
// Version counters are not shared in these scenarios:
//
// 1. When we replace a `Variable`'s underlying `Tensor` by calling `set_data(...)`,
// 2. `x.data` does not share the version counter of `x`. (See discussion at
// https://github.com/pytorch/pytorch/issues/5396)
//
// Question: Why do we put the version counter in TensorImpl instead of AutogradMeta?
//
// Answer: After the Variable/Tensor merge, a tensor will not have AutogradMeta when
// its `requires_grad_` is false, but when we use this tensor in the forward pass of
// a function that requires saving this tensor for backward, we need to keep track of
// this tensor's version to make sure it's always valid in the autograd graph.
//
// To achieve this goal, we put the version counter in TensorImpl instead of AutogradMeta,
// and have it always be available. This allows us to have the optimization of not
// carrying AutogradMeta when a tensor doesn't require gradient.
//
// A hypothetical alternative way to achieve this goal is to initialize AutogradMeta and
// create the version counter for the non-requires-grad tensor only when it's saved for
// backward. However, since saving a tensor for backward happens in the forward pass, and
// our invariant is that forward pass needs to be thread-safe, lazy-initializing AutogradMeta
// when saving a tensor can introduce race conditions when we are running the forward
// pass in multi-thread scenarios, thus making the forward pass not thread-safe anymore,
// which breaks the invariant.
struct C10_API VariableVersion {
private:
struct VersionCounter : intrusive_ptr_target {
VersionCounter(uint32_t version) : version_(version) {}
std::atomic<uint32_t> version_;
};
c10::intrusive_ptr<VersionCounter> version_counter_;
public:
bool unique() const {
return 1 == version_counter_.use_count();
}
// NOTE: As of C++11 and 14, default-constructing a std::atomic variable
// leaves it in a persistently undefined state. See
// https://cplusplus.github.io/LWG/issue2334.
VariableVersion(uint32_t version = 0)
: version_counter_(c10::make_intrusive<VersionCounter>(version)) {}
void bump() noexcept {
++version_counter_->version_;
}
uint32_t current_version() const noexcept {
return version_counter_->version_;
}
};
/**
* The low-level representation of a tensor, which contains a pointer
* to a storage (which contains the actual data) and metadata (e.g., sizes and
* strides) describing this particular view of the data as a tensor.
*
* Some basic characteristics about our in-memory representation of
* tensors:
*
* - It contains a pointer to a storage struct (Storage/StorageImpl)
* which contains the pointer to the actual data and records the
* data type and device of the view. This allows multiple tensors
* to alias the same underlying data, which allows to efficiently
* implement differing *views* on a tensor.
*
* - The tensor struct itself records view-specific metadata about
* the tensor, e.g., sizes, strides and offset into storage.
* Each view of a storage can have a different size or offset.
*
* - This class is intrusively refcounted. It is refcounted so that
* we can support prompt deallocation of large tensors; it is
* intrusively refcounted so that we can still perform reference
* counted operations on raw pointers, which is often more convenient
* when passing tensors across language boundaries.
*
* - For backwards-compatibility reasons, a tensor may be in an
* uninitialized state. A tensor may be uninitialized in the following
* two ways:
*
* - A tensor may be DTYPE UNINITIALIZED. A tensor of this
* form has an uninitialized dtype. This situation most
* frequently arises when a user writes Tensor x(CPU). The dtype and
* is subsequently initialized when mutable_data<T>() is
* invoked for the first time.
*
* - A tensor may be STORAGE UNINITIALIZED. A tensor of this form
* has non-zero size, but has a storage with a null data pointer.
* This situation most frequently arises when a user calls
* Resize() or FreeMemory(). This is because Caffe2 historically
* does lazy allocation: allocation of data doesn't occur until
* mutable_data<T>() is invoked. A tensor with zero size is
* always storage initialized, because no allocation is necessary
* in this case.
*
* All combinations of these two uninitialized states are possible.
* Consider the following transcript in idiomatic Caffe2 API:
*
* Tensor x(CPU); // x is storage-initialized, dtype-UNINITIALIZED
* x.Resize(4); // x is storage-UNINITIALIZED, dtype-UNINITIALIZED
* x.mutable_data<float>(); // x is storage-initialized, dtype-initialized
* x.FreeMemory(); // x is storage-UNINITIALIZED, dtype-initialized.
*
* All other fields on tensor are always initialized. In particular,
* size is always valid. (Historically, a tensor declared as Tensor x(CPU)
* also had uninitialized size, encoded as numel == -1, but we have now
* decided to default to zero size, resulting in numel == 0).
*
* Uninitialized storages MUST be uniquely owned, to keep our model
* simple. Thus, we will reject operations which could cause an
* uninitialized storage to become shared (or a shared storage to
* become uninitialized, e.g., from FreeMemory).
*
* In practice, tensors which are storage-UNINITIALIZED and
* dtype-UNINITIALIZED are *extremely* ephemeral: essentially,
* after you do a Resize(), you basically always call mutable_data()
* immediately afterwards. Most functions are not designed to
* work if given a storage-UNINITIALIZED, dtype-UNINITIALIZED tensor.
*
* We intend to eliminate all uninitialized states, so that every
* tensor is fully initialized in all fields. Please do not write new code
* that depends on these uninitialized states.
*/
struct C10_API TensorImpl : public c10::intrusive_ptr_target {
TensorImpl() = delete;
/**
* Construct a 1-dim 0-size tensor backed by the given storage.
*/
TensorImpl(
Storage&& storage,
DispatchKeySet,
const caffe2::TypeMeta data_type);
/**
* Construct a 1-dim 0 size tensor that doesn't have a storage.
*/
TensorImpl(DispatchKeySet, const caffe2::TypeMeta data_type, c10::optional<c10::Device> device_opt);
// Legacy constructors so I don't have to go update call sites.
// TODO: When Variable is added, delete these constructors
TensorImpl(
Storage&& storage,
DispatchKey dispatch_key,
const caffe2::TypeMeta data_type)
: TensorImpl(
std::move(storage),
DispatchKeySet(dispatch_key),
data_type) {}
TensorImpl(DispatchKey dispatch_key, const caffe2::TypeMeta data_type, c10::optional<c10::Device> device_opt)
: TensorImpl(DispatchKeySet(dispatch_key), data_type, device_opt) {}
private:
// This constructor is private, because the data_type is redundant with
// storage. Still, we pass it in separately because it's easier to write
// the initializer list if we're not worried about storage being moved out
// from under us.
TensorImpl(Storage&& storage, DispatchKeySet, const caffe2::TypeMeta data_type, c10::optional<c10::Device>);
public:
TensorImpl(const TensorImpl&) = delete;
TensorImpl& operator=(const TensorImpl&) = delete;
TensorImpl(TensorImpl&&) = default;
TensorImpl& operator=(TensorImpl&&) = default;
/**
* Release (decref) storage, and any other external allocations. This
* override is for `intrusive_ptr_target` and is used to implement weak
* tensors.
*/
virtual void release_resources() override;
/**
* Return the DispatchKeySet corresponding to this Tensor, specifying
* all of the DispatchKeys that this Tensor identifies as. This is the
* information used to dispatch operations on this tensor.
*/
DispatchKeySet key_set() const { return key_set_; }
/**
* Return a reference to the sizes of this tensor. This reference remains
* valid as long as the tensor is live and not resized.
*/
virtual IntArrayRef sizes() const;
/**
* Return a reference to the strides of this tensor. This reference remains
* valid as long as the tensor is live and not restrided.
*/
virtual IntArrayRef strides() const;
/**
* Return the number of dimensions of this tensor. Note that 0-dimension
* represents a Tensor that is a Scalar, e.g., one that has a single element.
*/
virtual int64_t dim() const;
/**
* True if this tensor has storage. See storage() for details.
*/
virtual bool has_storage() const;
/**
* Return the underlying storage of a Tensor. Multiple tensors may share
* a single storage. A Storage is an impoverished, Tensor-like class
* which supports far less operations than Tensor.
*
* Avoid using this method if possible; try to use only Tensor APIs to perform
* operations.
*/
virtual const Storage& storage() const;
/**
* The number of elements in a tensor.
*
* WARNING: Previously, if you were using the Caffe2 API, you could
* test numel() == -1 to see if a tensor was uninitialized. This
* is no longer true; numel always accurately reports the product
* of sizes of a tensor.
*/
virtual int64_t numel() const {
#ifdef DEBUG
TORCH_INTERNAL_ASSERT(compute_numel() == numel_);
#endif
return numel_;
}
bool unique_version() const {
return version_counter_.unique();
}
/**
* Whether or not a tensor is laid out in contiguous memory.
*
* Tensors with non-trivial strides are not contiguous. See
* compute_contiguous() for the exact definition of whether or not
* a tensor is contiguous or not.
*/
virtual bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const;
bool is_sparse() const {
// NB: This method is not virtual and avoid dispatches for performance reasons.
return key_set_.has(DispatchKey::SparseCPU) ||
key_set_.has(DispatchKey::SparseCUDA) ||
key_set_.has(DispatchKey::SparseHIP);
}
bool is_quantized() const {
// NB: This method is not virtual and avoid dispatches for performance reasons.
return key_set_.has(DispatchKey::QuantizedCPU) ||
key_set_.has(DispatchKey::QuantizedCUDA);
}
bool is_meta() const {
// NB: This method is not virtual and avoid dispatches for performance reasons.
return key_set_.has(DispatchKey::Meta);
}
bool is_cuda() const {
// NB: This method is not virtual and avoid dispatches for performance reasons.
return key_set_.has(DispatchKey::CUDA) ||
key_set_.has(DispatchKey::SparseCUDA) ||
key_set_.has(DispatchKey::QuantizedCUDA);
}
bool is_hip() const {
// NB: This method is not virtual and avoid dispatches for performance reasons.
return key_set_.has(DispatchKey::HIP) ||
key_set_.has(DispatchKey::SparseHIP);
}
bool is_mkldnn() const {
return key_set_.has(DispatchKey::MkldnnCPU);
}
bool is_vulkan() const {
return key_set_.has(DispatchKey::Vulkan);
}
bool is_metal() const {
return key_set_.has(DispatchKey::Metal);
}
// TODO: remove this once we don't automatically enabled Autograd dispatch keys
// in TensorImpl constructor.
// DON'T USE THIS API!! It's only created for testing purpose in
// file aten/src/ATen/core/boxing/impl/test_helpers.h
void remove_autograd_key() {
key_set_ = key_set_ - autograd_dispatch_keyset;
}
int64_t get_device() const {
TORCH_CHECK(
device_opt_.has_value(),
"tensor does not have a device");
// See NOTE [c10::optional operator usage in CUDA]
return (*device_opt_).index();
}
Device device() const {
TORCH_CHECK(
device_opt_.has_value(),
"tensor does not have a device");
// See NOTE [c10::optional operator usage in CUDA]
return *device_opt_;
}
Layout layout() const {
// NB: This method is not virtual and avoid dispatches for perf.
if (is_sparse()) {
return kSparse;
} else if (is_mkldnn()) {
return kMkldnn;
} else {
return kStrided;
}
}
/**
* True if a tensor was auto-wrapped from a C++ or Python number.
* For example, when you write 't + 2', 2 is auto-wrapped into a Tensor
* with `is_wrapped_number_` set to true.
*
* Wrapped numbers do not participate in the result type computation for
* mixed-type operations if there are any Tensors that are not wrapped
* numbers. This is useful, because we want 't + 2' to work with
* any type of tensor, not just LongTensor (which is what integers
* in Python represent).
*
* Otherwise, they behave like their non-wrapped equivalents.
* See [Result type computation] in TensorIterator.h.
*
* Why did we opt for wrapped numbers, as opposed to just having
* an extra function add(Tensor, Scalar)? This helps greatly reduce
* the amount of code we have to write for add, when actually
* a Tensor-Scalar addition is really just a Tensor-Tensor
* addition when the RHS is 0-dim (except for promotion behavior.)
*/
bool is_wrapped_number() const {
return is_wrapped_number_;
}
/**
* Set whether or not a tensor was auto-wrapped from a C++ or Python
* number. You probably don't want to call this, unless you are
* writing binding code.
*/
void set_wrapped_number(bool value) {
TORCH_INTERNAL_ASSERT(dim() == 0);
is_wrapped_number_ = value;
}
/**
* Returns true if Tensor supports as_strided and as_strided_backward.
* This is used in autograd to perform inplace update on view Tensors.
* See Note [View + Inplace update for base tensor] and
* [View + Inplace update for view tensor] for details.
* Note this method only returns true for XLA backend, where it
* simulates strided Tensor to support most view ops, but it cannot
* fully support general `as_strided` case.
* It can be expanded as needed in the future, e.g sparse Tensor.
*/
inline bool support_as_strided() const {
return device().type() != at::kXLA;
}
// ~~~~~ Autograd API ~~~~~
// Some methods below are defined in TensorImpl.cpp because Tensor is an
// incomplete type.
/**
* Set whether or not a tensor requires gradient.
*
* It is only valid to call this method on a Variable.
* See Note [Tensor versus Variable in C++].
*/
void set_requires_grad(bool requires_grad);
/**
* True if a tensor requires gradient. Tensors which require gradient
* have history tracked for any operations performed on them, so that
* we can automatically differentiate back to them. A tensor that
* requires gradient and has no history is a "leaf" tensor, which we
* accumulate gradients into.
*
* It is only valid to call this method on a Variable.
* See Note [Tensor versus Variable in C++].
*/
bool requires_grad() const;
/**
* Return a mutable reference to the gradient. This is conventionally
* used as `t.grad() = x` to set a gradient to a completely new tensor.
*
* It is only valid to call this method on a Variable.
* See Note [Tensor versus Variable in C++].
*/
at::Tensor& mutable_grad();
/**
* Return the accumulated gradient of a tensor. This gradient is written
* into when performing backwards, when this tensor is a leaf tensor.
*
* It is only valid to call this method on a Variable.
* See Note [Tensor versus Variable in C++].
*/
const at::Tensor& grad() const;
/**
* Return a typed data pointer to the actual data which this tensor refers to.
* This checks that the requested type (from the template parameter) matches
* the internal type of the tensor.
*
* It is invalid to call data() on a dtype-uninitialized tensor, even if
* the size is 0.
*
* WARNING: If a tensor is not contiguous, you MUST use strides when
* performing index calculations to determine the location of elements in
* the tensor. We recommend using 'TensorAccessor' to handle this computation
* for you; this class is available from 'Tensor'.
*/
template <typename T>
inline T * data() const {
TORCH_CHECK(has_storage(),
"Cannot access data pointer of Tensor that doesn't have storage");
TORCH_CHECK(
storage_initialized(),
"The tensor has a non-zero number of elements, but its data is not allocated yet. "
"Caffe2 uses a lazy allocation, so you will need to call "
"mutable_data() or raw_mutable_data() to actually allocate memory.");
TORCH_CHECK(
data_type_.Match<T>(),
"Tensor type mismatch, caller expects elements to be ",
caffe2::TypeMeta::TypeName<T>(),
", while tensor contains ",
data_type_.name(),
". ");
// We managed the type check ourselves
return storage_.unsafe_data<T>() + storage_offset_;
}
/**
* Return a void* data pointer to the actual data which this tensor refers to.
*
* It is invalid to call data() on a dtype-uninitialized tensor, even if the
* size is 0.
*
* WARNING: The data pointed to by this tensor may not contiguous; do NOT
* assume that itemsize() * numel() is sufficient to compute the bytes that
* can be validly read from this tensor.
*/
inline void* data() const {
TORCH_CHECK(has_storage(),
"Cannot access data pointer of Tensor that doesn't have storage");
TORCH_CHECK(dtype_initialized(),
"Cannot access data pointer of Tensor that doesn't have initialized dtype "
"(e.g., caffe2::Tensor x(CPU), prior to calling mutable_data<T>() on x)");
return static_cast<void*>(
static_cast<char*>(storage_.data()) +
data_type_.itemsize() * storage_offset_);
}
/**
* Like data<T>(), but performs no checks. You are responsible for ensuring
* that all invariants required by data() are upheld here.
*/
template <typename T>
inline T * unsafe_data() const {
return storage_.unsafe_data<T>() + storage_offset_;
}
/**
* Returns the TypeMeta of a tensor, which describes what data type
* it is (e.g., int, float, ...)
*/
const caffe2::TypeMeta dtype() const {
return data_type_;
}
/**
* Return the size of a single element of this tensor in bytes.
*/
size_t itemsize() const {
TORCH_CHECK(dtype_initialized(),
"Cannot report itemsize of Tensor that doesn't have initialized dtype "
"(e.g., caffe2::Tensor x(CPU), prior to calling mutable_data<T>() on x)");
return data_type_.itemsize();
}
/**
* Return the offset in number of elements into the storage that this
* tensor points to. Most tensors have storage_offset() == 0, but,
* for example, an index into a tensor will have a non-zero storage_offset().
*
* WARNING: This is NOT computed in bytes.
*
* XXX: The only thing stopping this function from being virtual is Variable.
*/
virtual int64_t storage_offset() const {
return storage_offset_;
}
/**
* True if a tensor has no elements (e.g., numel() == 0).
*/
inline bool is_empty() const {
return numel() == 0;
}
/**
* Change the size at some dimension. This DOES NOT update strides;
* thus, most changes to size will not preserve contiguity. You probably
* also want to call set_stride() when you call this.
*
* TODO: This should be jettisoned in favor of `set_sizes_and_strides`,
* which is harder to misuse.
*/
virtual void set_size(int64_t dim, int64_t new_size) {
TORCH_CHECK(allow_tensor_metadata_change(), "set_size ", err_msg_tensor_metadata_change_not_allowed);
sizes_.at(dim) = new_size;
refresh_numel();
refresh_contiguous();
}
/**
* Change the stride at some dimension.
*
* TODO: This should be jettisoned in favor of `set_sizes_and_strides`,
* which is harder to misuse.
*/
virtual void set_stride(int64_t dim, int64_t new_stride) {
TORCH_CHECK(allow_tensor_metadata_change(), "set_stride ", err_msg_tensor_metadata_change_not_allowed);
strides_[dim] = new_stride;
refresh_contiguous();
}
/**
* Set the offset into the storage of this tensor.
*
* WARNING: This does NOT check if the tensor is in bounds for the new
* location at the storage; the caller is responsible for checking this
* (and resizing if necessary.)
*/
virtual void set_storage_offset(int64_t storage_offset) {
TORCH_CHECK(allow_tensor_metadata_change(), "set_storage_offset ", err_msg_tensor_metadata_change_not_allowed);
storage_offset_ = storage_offset;
}
/**
* Like set_sizes_and_strides but assumes contiguous strides.
*
* WARNING: This function does not check if the requested
* sizes/strides are in bounds for the storage that is allocated;
* this is the responsibility of the caller
*/
void set_sizes_contiguous(IntArrayRef new_size) {
TORCH_CHECK(allow_tensor_metadata_change(), "set_sizes_contiguous ", err_msg_tensor_metadata_change_not_allowed);
auto new_dim = new_size.size();
sizes_.resize(new_dim);
for (size_t dim = 0; dim < new_dim; ++dim) {
sizes_[dim] = new_size[dim];
}
refresh_numel();
empty_tensor_restride(MemoryFormat::Contiguous);
}
/**
* Set the sizes and strides of a tensor.
*
* WARNING: This function does not check if the requested
* sizes/strides are in bounds for the storage that is allocated;
* this is the responsibility of the caller
*/
void set_sizes_and_strides(IntArrayRef new_size, IntArrayRef new_stride) {
TORCH_CHECK(allow_tensor_metadata_change(), "set_sizes_and_strides ", err_msg_tensor_metadata_change_not_allowed);
TORCH_CHECK(
new_size.size() == new_stride.size(),
"dimensionality of sizes (",
new_size.size(),
") must match dimensionality of strides (",
new_stride.size(),
")");
auto new_dim = new_size.size();
sizes_.resize(new_dim);
for (size_t dim = 0; dim < new_dim; ++dim) {
sizes_[dim] = new_size[dim];
}
strides_.resize(new_dim);
if (new_dim > 0) {
for (size_t dim = new_dim - 1; ; dim--) {
if (new_stride[dim] >= 0) {
strides_[dim] = new_stride[dim];
} else {
// XXX: This behavior is surprising and may need to be removed to
// support negative strides. Some pytorch functions rely on it:
// for example, torch.cat (run TestTorch.test_cat_empty).
if (dim == new_dim - 1) {
strides_[dim] = 1;
} else {
// Keep stride monotonically increasing to match NumPy.
strides_[dim] = std::max<int64_t>(sizes_[dim + 1], 1) * strides_[dim + 1];
}
}
if (dim == 0) break;
}
}
refresh_numel();
refresh_contiguous();
}
/**
* Return the size of a tensor at some dimension.
*/
virtual int64_t size(int64_t d) const;
/**
* Return the stride of a tensor at some dimension.
*/
virtual int64_t stride(int64_t d) const;
/**
* Set whether a tensor allows changes to its metadata (e.g. sizes / strides / storage / storage_offset).
* See NOTE [ Metadata Change for a Detached Tensor ] for details.
*/
void set_allow_tensor_metadata_change(bool value) {
allow_tensor_metadata_change_ = value;
}
/**
* True if a tensor allows changes to its metadata (e.g. sizes / strides / storage / storage_offset).
* See NOTE [ Metadata Change for a Detached Tensor ] for details.
*/
bool allow_tensor_metadata_change() const {
return allow_tensor_metadata_change_;
}
/**
* Set the pointer to autograd metadata.
*/
void set_autograd_meta(std::unique_ptr<c10::AutogradMetaInterface> autograd_meta);
/**
* Return the pointer to autograd metadata. May return nullptr if the
* tensor does not track gradients.
*/
c10::AutogradMetaInterface* autograd_meta() const;
/**
* Set the pointer to named tensor metadata.
*/
void set_named_tensor_meta(std::unique_ptr<c10::NamedTensorMetaInterface> named_tensor_meta) {
TORCH_WARN_ONCE(
"Named tensors and all their associated APIs are an experimental feature ",
"and subject to change. Please do not use them for anything important ",
"until they are released as stable.");
#ifdef DEBUG
if (named_tensor_meta) {
TORCH_INTERNAL_ASSERT(named_tensor_meta->slow_dim() == dim());
}
#endif
named_tensor_meta_ = std::move(named_tensor_meta);
if (named_tensor_meta_ == nullptr) {
key_set_ = key_set_.remove(DispatchKey::Named);
} else {
key_set_ = key_set_.add(DispatchKey::Named);
}
}
/**
* Return the pointer to named tensor metadata.
*/
const c10::NamedTensorMetaInterface* named_tensor_meta() const {
return named_tensor_meta_.get();
}
c10::NamedTensorMetaInterface* named_tensor_meta() {
return named_tensor_meta_.get();
}
bool has_named_tensor_meta() {
return named_tensor_meta_ != nullptr;
}
// NOTE [ TensorImpl Shallow-Copying ]
//
// TensorImpl shallow-copying is used when we want to have two Variables share the same tensor metadata
// (e.g. sizes / strides / storage pointer / storage_offset), but each with a different autograd history.
// Example call sites:
//
// 1. `var_detached = var.detach()` uses `shallow_copy_and_detach()` to create `var_detached` that shares
// the same tensor metadata with `var`, but with a completely new autograd history.
// 2. `var.set_data(tensor)` uses `shallow_copy_from()` to copy tensor metadata from
// `tensor` into `var`, while keeping `var`'s original AutogradMeta.
//
// Functions that shallow-copy a TensorImpl (such as `shallow_copy_and_detach()` / `shallow_copy_from()` /
// `copy_tensor_metadata()`) copy the tensor metadata fields (e.g. sizes / strides / storage pointer /
// storage_offset) by value. However, the following fields are not copied:
//
// 1. the AutogradMeta pointer, because it is unique for each Variable.
// 2. the version counter, because the destination TensorImpl's version counter is either set to the
// passed-in `version_counter` (in `shallow_copy_and_detach()` and `copy_tensor_metadata()`), or it is kept
// intact (in `shallow_copy_from()`). See NOTE [ Version Counter Sharing ] for details.
//
// In `shallow_copy_and_detach()` and `copy_tensor_metadata()`, the passed-in `allow_tensor_metadata_change`
// determines whether the TensorImpl shallow-copy allows changes to its metadata (e.g. sizes / strides /
// storage / storage_offset). See NOTE [ Metadata Change for a Detached Tensor ] for details.
//
// In `shallow_copy_from()`, we don't check the destination TensorImpl's `allow_tensor_metadata_change_`,
// because `shallow_copy_from()` is used for implementing functions such as `var.set_data(tensor)`, which
// changes `var`'s tensor metadata and expects its `allow_tensor_metadata_change_` to be ignored.
/**
* One TensorImpl can be copied to another TensorImpl if they have the same
* DispatchKeySet. The only two special cases (for legacy reason) are:
* CPU is compatible with CUDA and SparseCPU is
* compatible with SparseCUDA.
*/
inline bool has_compatible_shallow_copy_type(DispatchKeySet from) {
auto is_dense = [](DispatchKeySet ts) {
return ts.has(DispatchKey::CPU) ||
ts.has(DispatchKey::CUDA) ||
ts.has(DispatchKey::HIP);
};
auto is_sparse = [](DispatchKeySet ts) {
return ts.has(DispatchKey::SparseCPU) ||
ts.has(DispatchKey::SparseCUDA) ||
ts.has(DispatchKey::SparseHIP);
};
return (key_set_ == from) || (is_dense(key_set_) && is_dense(from)) || (is_sparse(key_set_) && is_sparse(from));
}
/**
* Return a TensorImpl that is a shallow-copy of this TensorImpl.
*
* For usage of `version_counter` and `allow_tensor_metadata_change`,
* see NOTE [ TensorImpl Shallow-Copying ].
*/
virtual c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change) const;
/**
* Return a TensorImpl that is a shallow-copy of this TensorImpl.
*
* For usage of `version_counter` and `allow_tensor_metadata_change`,
* see NOTE [ TensorImpl Shallow-Copying ].
*/
virtual c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach(
c10::VariableVersion&& version_counter,
bool allow_tensor_metadata_change) const;
/**
* Shallow-copies data from another TensorImpl into this TensorImpl.
*
* For why this function doesn't check this TensorImpl's `allow_tensor_metadata_change_`,
* see NOTE [ TensorImpl Shallow-Copying ].
*/
virtual void shallow_copy_from(const c10::intrusive_ptr<TensorImpl>& impl) {
copy_tensor_metadata(
/*src_impl=*/impl.get(),
/*dest_impl=*/this,
/*version_counter=*/version_counter(),
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change());
refresh_numel();
refresh_contiguous();
}
void set_version_counter(
const c10::VariableVersion& version_counter) noexcept {
version_counter_ = version_counter;
}
void set_version_counter(
c10::VariableVersion&& version_counter) noexcept {
version_counter_ = std::move(version_counter);
}
const c10::VariableVersion& version_counter() const noexcept {
return version_counter_;
}
void bump_version() noexcept {
version_counter_.bump();
}
inline void set_pyobj(PyObject* pyobj) noexcept {
pyobj_ = pyobj;
}
inline PyObject* pyobj() const noexcept {
return pyobj_;
}
private:
// See NOTE [c10::optional operator usage in CUDA]
// We probably don't want to expose this publicly until
// the note is addressed.
c10::optional<c10::Device> device_opt() const {
return device_opt_;
}
public: