-
Notifications
You must be signed in to change notification settings - Fork 21.4k
/
shape_analysis.cpp
2200 lines (2077 loc) · 89.1 KB
/
shape_analysis.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#include <torch/csrc/jit/passes/shape_analysis.h>
#include <c10/util/Exception.h>
#include <torch/csrc/jit/frontend/error_report.h>
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/ir/constants.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/runtime/exception_message.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/autograd/variable.h>
#include <ATen/DeviceGuard.h>
#include <ATen/ExpandUtils.h>
#include <exception>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
namespace torch {
namespace jit {
namespace prim {
using namespace ::c10::prim;
}
struct propagation_error : std::exception {};
#define SHAPE_ASSERT(cond) \
if (!(cond)) \
throw propagation_error()
namespace {
bool isValidArgumentForRunning(Value* v) {
// allow constants
if (toIValue(v))
return true;
if (TensorTypePtr tt = v->type()->cast<TensorType>()) {
if (!tt->scalarType()) {
return false;
}
return !at::isIntegralType(*tt->scalarType(), /*includeBool=*/false);
}
return v->type()->isSubtypeOf(FloatType::get());
}
bool isValidReturnForRunning(Value* v) {
return v->type()->isSubtypeOf(TensorType::get()) ||
v->type()->isSubtypeOf(NumberType::get());
}
bool containsTensorType(const TypePtr& t) {
auto n_contained = t->containedTypes().size();
if (n_contained == 1) {
return t->containedTypes().at(0)->isSubtypeOf(TensorType::get());
} else if (n_contained > 1) {
return std::any_of(
t->containedTypes().begin(),
t->containedTypes().end(),
containsTensorType);
}
return false;
}
class ShapePropagator {
public:
explicit ShapePropagator(const std::shared_ptr<Graph>& graph)
: aliasDb_(graph) {
collectResizeSet(graph->block());
}
void PropagateShapeOnBlock(Block* block, bool insert_expands = true) {
for (Node* node : block->nodes()) {
try {
PropagateShapeOnNode(node, insert_expands);
} catch (propagation_error& e) {
setUnshapedType(node);
} catch (std::exception& e) {
throw ErrorReport(node->sourceRange())
<< ExceptionMessage(e)
<< "\nThe above operation failed shape propagation in this context";
}
}
}
private:
ValueSet resized_alias_set;
const AliasDb aliasDb_;
bool resizesInput(Node* n) {
static std::unordered_set<Symbol> resize_ops{
aten::resize_,
aten::resize_as_,
aten::copy_,
aten::set_,
aten::unsqueeze_,
aten::t_,
aten::transpose_,
};
if (resize_ops.count(n->kind()))
return true;
if (!n->maybeSchema())
return false;
// ops which take the result and write to input "out"
if (auto out_arg_index = n->schema().argumentIndexWithName("out")) {
auto arg = n->schema().arguments().at(*out_arg_index);
return arg.kwarg_only() && arg.type()->isSubtypeOf(TensorType::get());
}
return false;
}
void collectResizeSet(Block* block) {
for (Node* n : block->nodes()) {
for (Block* b : n->blocks()) {
collectResizeSet(b);
}
if (resizesInput(n)) {
for (const auto input : n->inputs()) {
if (aliasDb_.writesToAlias(n, {input})) {
resized_alias_set.insert(input);
}
}
}
}
}
void setUnshapedType(Value* o) {
o->setType(unshapedType(o->type()));
}
void setUnshapedType(Node* node) {
for (auto o : node->outputs()) {
setUnshapedType(o);
}
}
int64_t wrapDim(int64_t dim, at::IntArrayRef sizes) {
if (dim < 0) {
dim += sizes.size();
}
return dim;
}
// TODO: Would be better to make JIT not assume that CUDA devices
// are the only thing that exist.
static at::Device jitDeviceIndexToDevice(int device) {
return device == -1 ? at::kCPU : at::Device(at::kCUDA, device);
}
IValue representativeValue(Value* v) {
TypePtr type_ = v->type();
// if the value is actually constant, just use it!
if (auto iv = toIValue(v)) {
return *iv;
}
if (TensorTypePtr type = type_->cast<TensorType>()) {
if (type->isComplete()) {
auto attype = type->device()->is_cpu() ? at::CPU(*type->scalarType())
: at::CUDA(*type->scalarType());
at::DeviceGuard device_guard(*type->device());
return at::empty_strided(
*type->sizes().concrete_sizes(),
*type->strides().concrete_sizes(),
attype.options())
.zero_();
}
// fallthrough
} else if (type_->isSubtypeOf(FloatType::get())) {
return 0.f;
}
// we should not get here because isValidArgumentForRunning should have
// prevented it
std::stringstream ss;
ss << "unable to create representative value for: " << type_->str()
<< ". File a bug report";
throw std::runtime_error(ss.str());
}
// for each node in the schema with type Tensor, extract the T type
// returns c10::nullopt if any Tensor in the schema does not have a known
// shape ignores non-tensor in the list of inputs
c10::optional<std::vector<TensorTypePtr>> gatherTensorTypes(
Node* node,
bool complete = false) {
std::vector<TensorTypePtr> tensor_types;
auto schema_opt = node->maybeSchema();
if (!schema_opt) {
return c10::nullopt;
}
auto& schema = *schema_opt;
auto& args = schema.arguments();
// can't handle varargs primitives because we don't know what should be a
// Tensor
if (schema.is_vararg()) {
return c10::nullopt;
}
for (size_t i = 0; i < args.size(); ++i) {
if (args[i].type()->isSubtypeOf(ListType::ofTensors())) {
return c10::nullopt;
} else if (args[i].type()->isSubtypeOf(TensorType::get())) {
if (auto type = node->input(i)->type()->cast<TensorType>()) {
if (complete && !type->isComplete()) {
return c10::nullopt;
}
tensor_types.push_back(type);
} else {
return c10::nullopt;
}
} else /* non-tensor type */ {
continue;
}
}
return tensor_types;
}
c10::ScalarType unionScalarTypes(
c10::ScalarType original,
c10::ScalarType next) {
if (original == c10::ScalarType::Undefined) {
return next;
} else {
return c10::promoteTypes(original, next);
}
}
// Promotes result types for arithmetic operations on Tensor operands using
// new type promotion logic. See tensor_attributes.rst for details.
// This doesn't handle the case of arithmetic ops with Scalar arguments (when
// `Tensor.getUnsafeTensorImpl()->is_wrapped_nubmer()` would return true)
c10::optional<c10::ScalarType> getPromotedTypeForArithmeticOp(Node* node) {
c10::ScalarType dimmed = c10::ScalarType::Undefined;
c10::ScalarType zerodim = c10::ScalarType::Undefined;
// binary arithmetic ops, more than 2 args is alpha.
for (size_t i = 0; i < 2; i++) {
auto dtt = node->inputs()[i]->type()->expect<TensorType>();
auto inputDtype = dtt->scalarType();
if (!dtt || !inputDtype) {
return c10::nullopt;
}
if (dtt->dim() && *dtt->dim() > 0) {
dimmed = unionScalarTypes(dimmed, *inputDtype);
} else if (!isFloatingType(dimmed)) {
// if no dimensions
zerodim = unionScalarTypes(zerodim, *inputDtype);
}
}
// if a tensor with dimensions is already of the highest category, don't
// need to check zero-dim tensors.
if (isFloatingType(dimmed)) {
return dimmed;
}
// int_tensor * zero_dim_floating -> floating_tensor
if (isIntegralType(dimmed, false) && isFloatingType(zerodim)) {
return zerodim;
}
// bool_tensor * non_bool_scalar -> non_bool_tensor
if (c10::ScalarType::Bool == dimmed &&
c10::ScalarType::Undefined != zerodim) {
return zerodim;
}
// types of dimensioned tensors generally take precedence over zero-dim
// tensors if not promoting due to category. e.g.:
// int_tensor * long -> int_tensor
if (c10::ScalarType::Undefined != dimmed) {
return dimmed;
}
// no dimmed tensors. e.g. zero_dim_tensor + zero_dim_tensor.
return zerodim;
}
bool mergeTypes(
ArrayRef<Value*> lhs,
ArrayRef<Value*> rhs,
ArrayRef<Value*> outputs) {
AT_ASSERT(lhs.size() == rhs.size() && rhs.size() == outputs.size());
bool changed = false;
for (size_t i = 0; i < lhs.size(); ++i) {
auto old_output_type = outputs[i]->type();
auto new_type =
unifyTypes(lhs[i]->type(), rhs[i]->type(), /*default_to_any=*/true);
AT_ASSERT(new_type);
outputs[i]->setType(*new_type);
if (*old_output_type != *outputs[i]->type())
changed = true;
}
return changed;
}
void broadcastBinary(
Node* node,
std::vector<TensorTypePtr>& types,
size_t idx1,
size_t idx2) {
auto expected_size = at::infer_size(
*types[idx1]->sizes().concrete_sizes(),
*types[idx2]->sizes().concrete_sizes());
auto broadcast = [&](size_t input_idx) {
TensorTypePtr input_type = types.at(input_idx);
if (input_type->sizes() == expected_size)
return;
auto graph = node->owningGraph();
WithInsertPoint point_guard{node};
Node* expand = graph
->create(
aten::expand,
{node->inputs().at(input_idx),
graph->insertConstant(expected_size),
graph->insertConstant(false)})
->insertBefore(node);
PropagateShapeOnNode(expand);
node->replaceInput(input_idx, expand->output());
};
broadcast(idx1);
broadcast(idx2);
types[0] = node->inputs().at(idx1)->type()->expect<TensorType>();
types[1] = node->inputs().at(idx2)->type()->expect<TensorType>();
}
OperatorSet cannot_propagate_shape_by_running_it = {
"aten::solve(Tensor self, Tensor A) -> (Tensor, Tensor)",
"aten::inverse(Tensor self) -> Tensor",
};
// Check if this node depends on a value that has been mutated previously. If
// it has, then it's not safe to run this node in isolation, since we don't
// know whether the dependency has been executed.
std::unordered_map<Node*, bool> dependsOnMutationMemo_;
bool dependsOnMutation(Node* node) {
if (dependsOnMutationMemo_.count(node) != 0) {
return dependsOnMutationMemo_[node];
}
if (aliasDb_.hasWriters(node)) {
// If something could have written to a value used by this node, we can't
// guarantee the result is the same when running it in isolation.
dependsOnMutationMemo_[node] = true;
return true;
}
// recursively check the producers of its inputs. We need to do this if the
// mutable value has been laundered through a pure function:
// a += 1
// c = a + b
// d = c + 1
// In this case, `d` cares whether `a` has been mutated even though it's not
// a direct input.
auto depends = false;
for (auto input : node->inputs()) {
depends |= dependsOnMutation(input->node());
}
dependsOnMutationMemo_[node] = depends;
return depends;
}
bool canPropagateShapeByRunningIt(Node* node) {
if (node->isMemberOf(cannot_propagate_shape_by_running_it)) {
return false;
}
if (dependsOnMutation(node)) {
return false;
}
bool valid_args = std::all_of(
node->inputs().begin(),
node->inputs().end(),
isValidArgumentForRunning);
if (!valid_args)
return false;
bool valid_returns = std::all_of(
node->outputs().begin(),
node->outputs().end(),
isValidReturnForRunning);
if (!valid_returns)
return false;
return true;
}
// If there's no Tensor in outputs, e.g float / float,
// we don't need to propagate shape.
bool DoesntRefineOutputs(Node* node) {
auto outputs = node->outputs();
for (auto& out : outputs) {
if (containsTensorType(out->type())) {
return false;
}
}
return true;
}
bool PropagateShapeOnNodeByRunningIt(Node* node, Operation op = nullptr) {
if (!canPropagateShapeByRunningIt(node))
return false;
if (!op)
op = node->getOperation();
Stack stack;
for (auto input : node->inputs()) {
stack.push_back(representativeValue(input));
}
// XXX: we're not catching any exceptions from the op for now. This
// is to uncover any mistakes we could make when editing this code,
// and eventually it shouldn't matter, because this phase should be
// preceded by schema checking.
op(&stack);
AT_ASSERT(stack.size() == node->outputs().size());
for (size_t i = 0; i < stack.size(); ++i) {
// some ops may have mixed tensor/primitive outputs
// for primitives, we don't need to change the type because it is already
// its most constrained form.
auto tensor_type = node->outputs()[i]->type()->cast<TensorType>();
if (stack[i].isTensor() && tensor_type) {
// gradient information isn't always available or part of represenative
// inputs, maintain original grad property
auto tensor_grad = tensor_type->requiresGrad();
node->outputs()[i]->setType(TensorType::create(stack[i].toTensor())
->withRequiresGrad(tensor_grad));
}
}
return true;
}
void PropagateCatShape(Node* cat_node) {
static const auto propagate_complete =
[this](Node* node, at::ArrayRef<Value*> tensors) -> bool {
auto input_types =
fmap(tensors, [](Value* v) { return v->type()->cast<TensorType>(); });
if (!std::all_of(
input_types.begin(),
input_types.end(),
[](const TensorTypePtr& tp) {
return tp != nullptr && tp->isComplete();
})) {
return false;
}
if (!node->is_constant(attr::dim))
return false;
std::vector<int64_t> sizes = *input_types[0]->sizes().concrete_sizes();
const int64_t dim = wrapDim(node->get<int64_t>(attr::dim).value(), sizes);
const int64_t ndim = sizes.size();
if (dim < 0 || dim >= ndim)
return false;
sizes[dim] = 0;
for (auto& tp : input_types) {
auto tp_sizes = tp->sizes().concrete_sizes().value();
if (sizes.size() != tp_sizes.size())
return false;
for (int64_t i = 0; i < ndim; ++i) {
if (sizes[i] != tp_sizes[i] && i != dim) {
return false;
}
}
sizes[dim] += tp_sizes[dim];
}
node->output()->setType(input_types[0]->withSizes(sizes));
return true;
};
static const auto propagate = [](Node* node,
at::ArrayRef<Value*> tensors) -> bool {
for (Value* v : tensors) {
if (auto type = v->type()->cast<TensorType>()) {
node->output()->setType(type->dimensionedOnly());
return true;
}
}
return false;
};
auto list_node =
((cat_node->kind() == prim::FusedConcat)
? cat_node
: cat_node->namedInput(attr::tensors)->node());
if (list_node->kind() == prim::ListConstruct ||
cat_node->kind() == prim::FusedConcat) {
auto tensors = list_node->inputs();
if (!tensors.empty()) {
if (propagate_complete(cat_node, tensors)) {
return;
} else if (propagate(cat_node, tensors)) {
return;
}
}
}
setUnshapedType(cat_node);
}
void propagateTorchTensorShape(Node* node) {
auto input_type = node->inputs().at(0)->type();
size_t dims = 0;
auto input_base_type = input_type;
auto list_type = input_type->cast<ListType>();
while (list_type) {
dims++;
input_base_type = list_type->getElementType();
list_type = input_base_type->cast<ListType>();
}
at::optional<at::ScalarType> default_type =
tryScalarTypeFromJitType(input_base_type);
if (auto grad_index = node->schema().argumentIndexWithName("dtype")) {
auto inp = toIValue(node->inputs().at(*grad_index));
if (inp == c10::nullopt) {
return;
} else if (!inp->isNone()) {
default_type = inp->toScalarType();
}
}
at::Device default_device = at::kCPU;
if (auto device_index = node->schema().argumentIndexWithName("device")) {
auto inp = toIValue(node->inputs().at(*device_index));
if (inp == c10::nullopt) {
return;
} else if (!inp->isNone()) {
default_device = inp->toDevice();
}
}
node->output()->setType(TensorType::create(
default_type, default_device, dims, /*requires_grad=*/c10::nullopt));
}
// returns whether any such values were found
bool setUnshapedTypeIfAliasResizedSet(at::ArrayRef<Value*> vs) {
bool in_resize = false;
for (auto v : vs) {
if (aliasDb_.mayAlias(ValueSet{v}, resized_alias_set)) {
setUnshapedType(v);
in_resize = true;
}
}
return in_resize;
}
void PropagateShapeOnNode(Node* node, bool insert_expands = true) {
// Certain ops like resize_ change the input tensors size. Because our
// analysis is flow invariant, we set any Tensor that can alias a resized
// Tensor to the base Tensor Type without size information.
if (setUnshapedTypeIfAliasResizedSet(node->inputs())) {
return setUnshapedType(node);
}
// These don't require the types, and have complicated schema. Return early
// after we process them.
switch (node->kind()) {
case prim::If: {
auto then_block = node->blocks().at(0);
auto else_block = node->blocks().at(1);
PropagateShapeOnBlock(then_block);
PropagateShapeOnBlock(else_block);
mergeTypes(
then_block->outputs(), else_block->outputs(), node->outputs());
return;
}
case prim::Loop: {
auto body_block = node->blocks().at(0);
// propagate counter type
body_block->inputs().at(0)->setType(node->inputs().at(0)->type());
// propagate loop-carried input types to block inputs
auto loop_carried_inputs = node->inputs().slice(2); // skip max, cond
auto loop_carried_block = body_block->inputs().slice(1); // skip trip
for (size_t i = 0; i < loop_carried_inputs.size(); ++i) {
loop_carried_block[i]->setType(loop_carried_inputs[i]->type());
}
auto loop_carried_outputs = body_block->outputs().slice(1); // skip cond
do {
PropagateShapeOnBlock(body_block, /*insert_expands=*/false);
// note: inserting expands is unsafe at this point, we don't know
// if the types are stable yet, so the arguments to expand may change
} while (mergeTypes(
loop_carried_block, loop_carried_outputs, loop_carried_block));
// now that the types are stable, we can insert the expands
PropagateShapeOnBlock(body_block, /*insert_expands=*/true);
for (size_t i = 0; i < loop_carried_inputs.size(); ++i) {
node->outputs()[i]->setType(loop_carried_block[i]->type());
}
return;
}
case aten::Bool:
case aten::Int:
case aten::Float:
case aten::ScalarImplicit:
case aten::FloatImplicit:
case aten::IntImplicit:
return; // correct num type is already set
case prim::NumToTensor: {
TypePtr typ = node->input()->type();
if (typ->isSubtypeOf(IntType::get()) ||
typ->isSubtypeOf(BoolType::get())) {
node->output()->setType(TensorType::create(
at::kLong, at::kCPU, 0, /*requires_grad=*/c10::nullopt));
} else if (node->input()->type()->isSubtypeOf(FloatType::get())) {
node->output()->setType(TensorType::create(
at::kDouble, at::kCPU, 0, /*requires_grad=*/c10::nullopt));
}
return;
}
case aten::tensor:
case aten::as_tensor: {
// as_tensor has an overloaded schema and can either have a tensor or
// a list as the first input, if the input is a tensor, we delegate
// the shape propagation in PropagateTensorShapeOnNode
if (node->inputs().at(0)->type()->isSubtypeOf(TensorType::get())) {
break;
}
return propagateTorchTensorShape(node);
}
case prim::TupleConstruct: {
// We refresh the tuple type, because the input types could have been
// refined.
auto orig_type = node->output()->type()->expect<TupleType>();
auto new_types =
fmap(node->inputs(), [](Value* v) { return v->type(); });
node->output()->setType(
orig_type->createWithContained(std::move(new_types)));
return;
}
case prim::TupleUnpack: {
auto tuple_type = node->input()->type()->cast<TupleType>();
AT_ASSERT(
tuple_type &&
tuple_type->elements().size() == node->outputs().size());
auto elems = tuple_type->elements();
for (size_t i = 0; i < node->outputs().size(); ++i) {
node->output(i)->setType(elems[i]);
}
return;
}
case prim::Constant: {
if (node->output()->type()->isSubtypeOf(TensorType::get())) {
node->output()->inferTypeFrom(node->t(attr::value));
}
return;
}
case prim::unchecked_unwrap_optional: {
// If we have specialized the optional type to the element type,
// we want to pass it down. We write this as input.isSubtypeOf(output)
// to be sure that we don't screw up nested optionals.
if (node->input()->type()->isSubtypeOf(node->output()->type())) {
node->output()->setType(node->input()->type());
}
return;
}
case prim::ConstantChunk: {
Value* tensor = node->input();
if (auto type = tensor->type()->cast<TensorType>()) {
type = type->dimensionedOnly();
for (Value* output : node->outputs()) {
output->setType(type);
}
} else {
setUnshapedType(node);
}
return;
}
case prim::grad: {
auto tt = node->input()->type()->expect<TensorType>();
// grad may be undefined
// requires_grad may be required
auto grad_type = TensorType::get()->withPossiblyUndefined();
node->output()->setType(grad_type);
return;
}
case prim::CallFunction:
case prim::CallMethod:
case prim::AutogradZero: {
setUnshapedType(node);
return;
}
case prim::GetAttr: {
auto cls = node->input()->type()->expect<ClassType>();
// propagate any type specializations encoded in the type of the class
node->output()->setType(cls->getAttribute(node->s(attr::name)));
return;
}
case aten::_unwrap_optional: {
// If we have specialized the optional type to the element type,
// we want to pass it down. We write this as input.isSubtypeOf(output)
// to be sure that we don't screw up nested optionals.
if (node->input()->type()->isSubtypeOf(node->output()->type())) {
node->output()->setType(node->input()->type());
}
return;
}
default:
break; // fall-through
}
if (node->hasSideEffects()) {
return;
}
if (node->matches("aten::cat(Tensor[] tensors, int dim) -> Tensor") ||
node->kind() == prim::FusedConcat) {
return PropagateCatShape(node);
}
if (auto maybe_complete_types =
gatherTensorTypes(node, /*complete=*/true)) {
if (PropagateCompleteShapeOnNode(
node, insert_expands, std::move(*maybe_complete_types))) {
return;
}
}
if (PropagateTensorShapeOnNode(node, insert_expands)) {
return;
}
if (DoesntRefineOutputs(node)) {
return;
}
if (PropagateShapeOnNodeByRunningIt(node)) {
return;
}
return setUnshapedType(node);
}
static c10::optional<size_t> determineListSize(Value* list) {
AT_ASSERT(list->type()->cast<ListType>());
if (auto shape = constant_as<c10::List<int64_t>>(list)) {
return shape->size();
}
auto input_node = list->node();
if (input_node->kind() == prim::ListConstruct) {
return input_node->inputs().size();
}
return c10::nullopt;
}
// is it ok to try to run the op
// If an input is a constant, then we assume that the input is valid
// and we can try to run it.
// Otherwise:
// Integral typed _inputs_ are often an indicator that we're indexing into
// a tensor, so we should special-case these ops in the shape propagation.
// Additionally, passing in a zero representative tensor into an integer
// division op causes divide-by-zero errors
// _Outputs_ must be tensors or primitives
// We will call inferTypeFrom on the tensors, and ignore the primitives.
// However, we allow primitive returns because we want to support mixed
// primitive/tensor outputs.
bool PropagateTensorShapeOnNode(Node* node, bool insert_expands) {
static const auto broadcast =
[](std::vector<TensorTypePtr>& tensor_types,
c10::optional<at::ScalarType> t) -> TensorTypePtr {
if (tensor_types.size() == 1) {
return tensor_types[0]->dimensionedOnly()->withScalarType(t);
}
AT_ASSERT(!tensor_types.empty());
auto any_type = tensor_types[0];
auto max_dims = any_type->dim();
for (auto& type : tensor_types) {
if (!max_dims || !type->dim()) {
max_dims = c10::nullopt;
} else {
max_dims = std::max(*max_dims, *type->dim());
}
}
return TensorType::create(
t,
any_type->device(),
max_dims,
/*requires_grad=*/c10::nullopt);
};
using type_vec_t = std::vector<TensorTypePtr>;
// Formula is expected to return a vector of length equal to the number of
// tensor outputs of the node, or an empty vector which implies that it
// failed to propagate.
using formula_t = std::function<type_vec_t(Node*)>;
static std::mutex shape_formulas_mutex;
static std::vector<std::pair<OperatorSet, formula_t>> shape_formulas;
struct register_formula_for {
register_formula_for(OperatorSet operators, formula_t formula) {
std::unique_lock<std::mutex> lock{shape_formulas_mutex};
shape_formulas.emplace_back(std::move(operators), std::move(formula));
}
};
// Requirements:
// dims : preserved
// scalar type : preserved
// device : preserved
// tensor inputs : 1
// tensor outputs : 1
// Additionally:
// - First input should be the only tensor input
static const register_formula_for simple_unary_ops{
{
"aten::acos(Tensor self) -> Tensor",
"aten::neg(Tensor self) -> Tensor",
"aten::t(Tensor self) -> Tensor",
"aten::sigmoid(Tensor self) -> Tensor",
"aten::logit(Tensor self, float? eps=None) -> Tensor",
"aten::tanh(Tensor self) -> Tensor",
"aten::relu(Tensor self) -> Tensor",
"aten::asin(Tensor self) -> Tensor",
"aten::atan(Tensor self) -> Tensor",
"aten::ceil(Tensor self) -> Tensor",
"aten::clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor",
"aten::contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)",
"aten::bernoulli(Tensor self, *, Generator? generator) -> Tensor",
"aten::celu(Tensor self, Scalar alpha) -> Tensor",
"aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor",
"aten::clamp_max(Tensor self, Scalar max) -> Tensor",
"aten::clamp_min(Tensor self, Scalar min) -> Tensor",
"aten::alpha_dropout(Tensor input, float p, bool train) -> Tensor",
"aten::bernoulli(Tensor self, float p, *, Generator? generator) -> Tensor",
"aten::cos(Tensor self) -> Tensor",
"aten::cosh(Tensor self) -> Tensor",
"aten::digamma(Tensor self) -> Tensor",
"aten::dropout(Tensor input, float p, bool train) -> Tensor",
"aten::elu(Tensor self, Scalar alpha, Scalar scale, Scalar input_scale) -> Tensor",
"aten::erf(Tensor self) -> Tensor",
"aten::erfc(Tensor self) -> Tensor",
"aten::erfinv(Tensor self) -> Tensor",
"aten::exp(Tensor self) -> Tensor",
"aten::expm1(Tensor self) -> Tensor",
"aten::log(Tensor self) -> Tensor",
"aten::log10(Tensor self) -> Tensor",
"aten::log1p(Tensor self) -> Tensor",
"aten::log2(Tensor self) -> Tensor",
"aten::log_sigmoid(Tensor self) -> Tensor",
"aten::floor(Tensor self) -> Tensor",
"aten::frac(Tensor self) -> Tensor",
"aten::flip(Tensor self, int[] dims) -> Tensor",
"aten::feature_alpha_dropout(Tensor input, float p, bool train) -> Tensor",
"aten::feature_dropout(Tensor input, float p, bool train) -> Tensor",
"aten::hardshrink(Tensor self, Scalar lambd) -> Tensor",
"aten::hardtanh(Tensor self, Scalar min_val, Scalar max_val) -> Tensor",
"aten::glu(Tensor self, int dim) -> Tensor",
"aten::inverse(Tensor self) -> Tensor",
"aten::leaky_relu(Tensor self, Scalar negative_slope) -> Tensor",
"aten::lgamma(Tensor self) -> Tensor",
"aten::mvlgamma(Tensor self, int p) -> Tensor",
"aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor",
"aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor",
"aten::permute(Tensor self, int[] dims) -> Tensor",
"aten::pin_memory(Tensor(a) self) -> Tensor(a)",
"aten::pinverse(Tensor self, float rcond) -> Tensor",
"aten::reciprocal(Tensor self) -> Tensor",
"aten::relu(Tensor self) -> Tensor",
"aten::round(Tensor self) -> Tensor",
"aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator? generator) -> Tensor",
"aten::rsqrt(Tensor self) -> Tensor",
"aten::selu(Tensor self) -> Tensor",
"aten::gelu(Tensor self) -> Tensor",
"aten::sigmoid(Tensor self) -> Tensor",
"aten::sign(Tensor self) -> Tensor",
"aten::sin(Tensor self) -> Tensor",
"aten::sinh(Tensor self) -> Tensor",
"aten::softplus(Tensor self, Scalar beta, Scalar threshold) -> Tensor",
"aten::softshrink(Tensor self, Scalar lambd) -> Tensor",
"aten::sqrt(Tensor self) -> Tensor",
"aten::tan(Tensor self) -> Tensor",
"aten::tanh(Tensor self) -> Tensor",
"aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor",
"aten::transpose(Tensor self, int dim0, int dim1) -> Tensor",
"aten::tril(Tensor self, int diagonal) -> Tensor",
"aten::triu(Tensor self, int diagonal) -> Tensor",
"aten::trunc(Tensor self) -> Tensor",
"aten::rot90(Tensor self, int k, int[] dims) -> Tensor",
"aten::narrow(Tensor self, int dim, int start, int length) -> Tensor",
"aten::slice(Tensor self, int dim, int? start=0, int? end=9223372036854775807, int step=1) -> Tensor",
"aten::alias(Tensor self) -> Tensor",
},
[](Node* node) -> type_vec_t {
auto input_type = node->input(0)->type()->cast<TensorType>();
return input_type ? type_vec_t{input_type->dimensionedOnly()}
: type_vec_t{};
}};
// Requirements:
// dims : preserved
// scalar type : preserved, except complex maps to float
// device : preserved
// tensor inputs : 1
// tensor outputs : 1
// Additionally:
// - First input should be the only tensor input
static const register_formula_for simple_unary_ops_complex_to_float{
{
"aten::abs(Tensor self) -> Tensor",
},
[](Node* node) -> type_vec_t {
auto input_type = node->input(0)->type()->cast<TensorType>();
// Maps complex -> float
if (input_type->scalarType()) {
const auto scalar_type = *(input_type->scalarType());
if (isComplexType(scalar_type)) {
const auto out_type = c10::toValueType(scalar_type);
return type_vec_t{
input_type->dimensionedOnly()->withScalarType(out_type)};
}
}
return input_type ? type_vec_t{input_type->dimensionedOnly()}
: type_vec_t{};
}};
// Requirements:
// dims : broadcast all tensor args
// scalar type : promoted from input dtypes
// device : always matching and preserved
// tensor inputs : *
// tensor outputs : 1
static const register_formula_for broadcasting_ops_arithmetic{
{
// Tensor-Tensor operators
"aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
"aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
"aten::mul(Tensor self, Tensor other) -> Tensor",
"aten::div(Tensor self, Tensor other) -> Tensor",
},
[this](Node* node) -> type_vec_t {
if (auto maybe_tensor_types = gatherTensorTypes(node)) {
AT_ASSERT(maybe_tensor_types->size() >= 2);
auto dtype = getPromotedTypeForArithmeticOp(node);
return {broadcast(*maybe_tensor_types, dtype)};
}
return {};
}};
// Requirements:
// dims : broadcast all tensor args
// scalar type : always matching and preserved
// device : always matching and preserved
// tensor inputs : *
// tensor outputs : 1
static const register_formula_for broadcasting_ops{
{
"aten::pow(Tensor self, Tensor exponent) -> Tensor",
"aten::fmod(Tensor self, Tensor other) -> Tensor",
"aten::remainder(Tensor self, Tensor other) -> Tensor",
"aten::lerp(Tensor self, Tensor end, Scalar weight) -> Tensor",
"aten::lerp(Tensor self, Tensor end, Tensor weight) -> Tensor",
"aten::max(Tensor self, Tensor other) -> Tensor",
"aten::min(Tensor self, Tensor other) -> Tensor",
"aten::__and__(Tensor self, Tensor other) -> Tensor",
"aten::__or__(Tensor self, Tensor other) -> Tensor",
"aten::__xor__(Tensor self, Tensor other) -> Tensor",
"aten::__lshift__(Tensor self, Tensor other) -> Tensor",
"aten::__rshift__(Tensor self, Tensor other) -> Tensor",
"aten::__iand__(Tensor self, Tensor other) -> Tensor",
"aten::__ior__(Tensor self, Tensor other) -> Tensor",
"aten::__ixor__(Tensor self, Tensor other) -> Tensor",
"aten::__ilshift__(Tensor self, Tensor other) -> Tensor",
"aten::__irshift__(Tensor self, Tensor other) -> Tensor",
// Ops with Tensor-Tensor overloads only
"aten::atan2(Tensor self, Tensor other) -> Tensor",
},
[this](Node* node) -> type_vec_t {
if (auto maybe_tensor_types = gatherTensorTypes(node)) {
AT_ASSERT(maybe_tensor_types->size() >= 2);
auto first_scalar_type = (*maybe_tensor_types)[0]->scalarType();
auto second_scalar_type = (*maybe_tensor_types)[1]->scalarType();
if (!first_scalar_type || !second_scalar_type) {
return {};
}
size_t arg_for_type = 0;
if (c10::promoteTypes(*first_scalar_type, *second_scalar_type) !=
first_scalar_type) {
arg_for_type = 1;
}
auto t = (*maybe_tensor_types)[arg_for_type]->scalarType();
return {broadcast(*maybe_tensor_types, *t)};
}
return {};
}};
static const register_formula_for fused_accum_binary_ops{
{
// Non-binary ops
"aten::addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value) -> Tensor",
"aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value) -> Tensor",
},