/
tensor.py
626 lines (554 loc) · 22.2 KB
/
tensor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
import itertools
import operator
from typing import Dict, List
import torch.fx
import torch.random
from .. import config, variables
from ..exc import unimplemented
from ..guards import GuardBuilder
from ..source import AttrSource
from ..utils import (
get_fake_value,
get_real_value,
product,
proxy_args_kwargs,
tensortype_to_dtype,
)
from .base import VariableTracker
from .constant import ConstantVariable
from .lists import ShapeVariable, SizeVariable
class TensorVariable(VariableTracker):
"""A torch.Tensor input or an intermediate value in the FX graph"""
_nonvar_fields = [
"proxy",
"dtype",
"device",
"layout",
"ndim",
"size",
"stride",
"requires_grad",
"is_quantized",
"is_contiguous",
]
def get_real_value(self):
"""
Get the actual value represented by this variable if computation is run
using the user-provided inputs.
NOTE: this runs actual tensor computation and may be
slow and memory-intensive.
"""
return get_real_value(self.proxy.node, self.proxy.tracer)
def __init__(
self,
proxy: torch.fx.Proxy,
dtype=None,
device=None,
layout=None,
ndim=None,
size=None,
stride=None,
requires_grad=None,
is_quantized=None,
is_contiguous=None,
is_sparse=None,
class_type=torch.Tensor,
specialized_value=None,
**kwargs,
):
super(TensorVariable, self).__init__(**kwargs)
self.proxy = proxy
self.dtype = dtype
self.device = device
self.layout = layout
self.ndim = ndim
self.size = size
self.stride = stride
self.requires_grad = requires_grad
self.is_quantized = is_quantized
self.is_contiguous = is_contiguous
self.is_sparse = is_sparse
self.class_type = class_type
self.specialized_value = specialized_value
def as_proxy(self):
return self.proxy
def python_type(self):
return self.class_type
def call_isinstance(self, tensor_type):
def check_type(ty):
if ty not in tensortype_to_dtype:
return issubclass(self.python_type(), ty)
dtypes = tensortype_to_dtype[ty]
return self.dtype in dtypes
if type(tensor_type) is tuple:
return any([check_type(ty) for ty in tensor_type])
else:
return check_type(tensor_type)
@staticmethod
def specialize(value: torch.Tensor):
props = {
"dtype": value.dtype,
"device": value.device,
"layout": value.layout,
"ndim": int(value.ndim),
"requires_grad": value.requires_grad,
"is_quantized": value.is_quantized,
"is_sparse": value.is_sparse,
"class_type": type(value),
}
if not config.dynamic_shapes:
props["size"] = tuple(value.size())
props["stride"] = tuple(value.stride())
props["is_contiguous"] = tuple(
[
x
for x in torch._prims_common._memory_formats
if value.is_contiguous(memory_format=x)
]
)
return props
def var_getattr(self, tx, name):
from . import ConstantVariable, TorchVariable
result = None
options = VariableTracker.propagate(self)
if name == "ndim" and self.ndim is not None:
result = ConstantVariable(self.ndim, **options)
elif name == "dtype" and self.dtype is not None:
result = TorchVariable(self.dtype, **options)
elif name == "device" and self.device is not None:
result = TorchVariable(self.device, **options)
elif name == "layout" and self.layout is not None:
result = TorchVariable(self.layout, **options)
elif name == "is_cuda" and self.device is not None:
result = ConstantVariable(self.device.type == "cuda", **options)
elif name == "shape" and self.size is not None:
sizes = [variables.ConstantVariable(x) for x in self.size]
result = ShapeVariable(sizes, **options)
elif name == "requires_grad" and self.requires_grad is not None:
result = ConstantVariable(self.requires_grad, **options)
elif name == "is_quantized" and self.is_quantized is not None:
result = ConstantVariable(self.is_quantized, **options)
elif name == "is_sparse" and self.is_sparse is not None:
result = ConstantVariable(self.is_sparse, **options)
elif name == "shape" and self.size is None:
result = self.call_method(tx, "size", [], {})
elif name == "ndim" and self.ndim is None:
result = self.call_method(tx, "dim", [], {})
elif name == "data":
result = self.call_method(tx, "detach", [], {})
elif name == "T":
args = [variables.ConstantVariable(i) for i in range(self.ndim - 1, -1, -1)]
result = self.call_method(tx, "permute", args, {})
if name == "__class__":
return TorchVariable(self.python_type(), **options)
# Add a guard for type matching, these guards are checked before tensor guards
# In some cases, a <tensor>.<attr> guard can be evaluated first, and break if
# <tensor> is later changed to another type
if result is not None and self.source is not None:
result = result.add_guard(self.make_guard(GuardBuilder.TYPE_MATCH))
if result is None:
raise NotImplementedError()
return result
def unpack_var_sequence(self, tx, idxes=None):
from .builder import wrap_fx_proxy
if idxes is None:
if self.size:
idxes = range(self.size[0])
else:
return super(TensorVariable, self).unpack_var_sequence(tx)
options = VariableTracker.propagate(self)
return [wrap_fx_proxy(tx, self.as_proxy()[i], **options) for i in idxes]
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
from . import ConstantVariable, TorchVariable, TupleVariable
from .builder import wrap_fx_proxy
kwargs = dict(kwargs)
options = VariableTracker.propagate(self, args, kwargs.values())
if name == "stride" and self.stride is not None:
constant_result = ConstantVariable(self.stride, **options)
elif name == "size" and self.size is not None:
sizes = [variables.ConstantVariable(x) for x in self.size]
constant_result = SizeVariable(sizes, **options)
elif name == "size" and self.size is None and config.dynamic_shapes:
return wrap_fx_proxy(
tx,
tx.output.create_proxy(
"call_method",
name,
*proxy_args_kwargs([self] + list(args), kwargs),
),
**options,
)
elif name in ("numel", "nelement") and self.size is not None:
constant_result = ConstantVariable(product(self.size), **options)
elif name in ("ndimension", "dim") and self.ndim is not None:
constant_result = ConstantVariable(self.ndim, **options)
elif name == "is_floating_point" and self.dtype is not None:
constant_result = ConstantVariable(self.dtype.is_floating_point, **options)
elif name == "is_contiguous" and self.is_contiguous is not None:
if "memory_format" in kwargs:
memory_format = kwargs.pop("memory_format").as_python_constant()
else:
memory_format = torch.contiguous_format
constant_result = ConstantVariable(
memory_format in self.is_contiguous, **options
)
elif (
name == "type"
and self.dtype is not None
and len(args) == 0
and isinstance(self.device, torch.device)
):
tensortype = [k for k, v in tensortype_to_dtype.items() if self.dtype in v][
0
]
if self.device.type == "cuda":
constant_result = ConstantVariable(
f"torch.cuda.{tensortype.__name__}", **options
)
else:
constant_result = ConstantVariable(
f"torch.{tensortype.__name__}", **options
)
elif name == "get_device" and isinstance(self.device, torch.device):
index = self.device.index if self.device.type != "cpu" else -1
constant_result = ConstantVariable(index, **options)
else:
constant_result = None
if constant_result:
assert not kwargs, f"Tensor.{name}() unhandled kwargs"
if len(args) == 1:
return constant_result.getitem_const(args[0])
elif args:
return TupleVariable(
[constant_result.getitem_const(a) for a in args], **options
)
return constant_result
elif (
name == "repeat"
and not all(
x.is_python_constant() for x in itertools.chain(args, kwargs.values())
)
and not config.dynamic_shapes
):
unimplemented("dynamic Tensor.repeat")
elif name in ("tolist", "numpy", "backward", "data_ptr"):
unimplemented(f"Tensor.{name}")
elif name == "nonzero" and not config.dynamic_shapes:
unimplemented(f"Tensor.{name}")
elif name == "item":
if config.capture_scalar_outputs:
example_value = get_fake_value(self.proxy.node, tx)
return wrap_fx_proxy(
tx,
tx.output.create_proxy(
"call_method",
"item",
(self.as_proxy(),),
{},
),
example_value=example_value,
**options,
)
else:
unimplemented(f"Tensor.{name}")
elif name == "__len__":
if self.size:
assert not config.dynamic_shapes
return ConstantVariable(self.size[0], **options)
else:
return wrap_fx_proxy(
tx,
tx.output.create_proxy(
"call_function",
len,
(self.as_proxy(),),
{},
),
**options,
)
elif name == "__setitem__":
tx.output.guards.update(options["guards"])
tx.output.create_proxy(
"call_function",
operator.setitem,
*proxy_args_kwargs([self] + list(args), kwargs),
)
return ConstantVariable(None, **options)
elif name in ("resize_", "resize_as_"):
if "memory_format" in kwargs:
memory_format = kwargs["memory_format"].as_python_constant()
else:
memory_format = torch.contiguous_format
if name == "resize_":
self.size = args[0].as_python_constant()
self.is_contiguous = (memory_format,)
else:
assert isinstance(args[0], TensorVariable)
if self.size and args[0].size:
if (
self.size == args[0].size
or memory_format is torch.preserve_format
):
self.is_contiguous = args[0].is_contiguous
else:
self.size = args[0].size
self.stride = args[0].stride
self.ndim = args[0].ndim
self.is_contiguous = (memory_format,)
return wrap_fx_proxy(
tx,
tx.output.create_proxy(
"call_method",
name,
*proxy_args_kwargs([self] + list(args), kwargs),
),
**options,
)
elif (
name == "add_" and len(args) == 1 and len(kwargs) == 1 and "alpha" in kwargs
):
result = TorchVariable(torch.mul, **options).call_function(
tx, args + [kwargs["alpha"]], {}
)
return self.call_method(tx, "add_", [result], {})
elif (
name == "addcdiv_"
and len(args) == 2
and len(kwargs) == 1
and "value" in kwargs
):
result = TorchVariable(torch.div, **options).call_function(tx, args, {})
result = TorchVariable(torch.mul, **options).call_function(
tx, [result, kwargs["value"]], {}
)
return self.call_method(tx, "add_", [result], {})
else:
# Convert x.new(torch.Size) into x.new_empty(torch.Size),
# as Tensor.new acts differently with a Size input versus a tuple input.
if (
name == "new"
and len(args) == 1
and isinstance(args[0], (SizeVariable, ShapeVariable))
and not config.dynamic_shapes
):
name = "new_empty"
return wrap_fx_proxy(
tx,
tx.output.create_proxy(
"call_method",
name,
*proxy_args_kwargs([self] + list(args), kwargs),
),
**options,
)
class DynamicShapeVariable(VariableTracker):
"""
Represents a symbolic size, e.g., as returned by tensor.size(0)
"""
@classmethod
def create(cls, tx, proxy, dyn_shape, **options):
if "example_value" in proxy.node.meta:
assert proxy.node.meta["example_value"] == dyn_shape
if dyn_shape is None:
dyn_shape = get_fake_value(proxy.node, tx)
proxy.node.meta["example_value"] = dyn_shape
return DynamicShapeVariable(proxy, dyn_shape, **options)
def __init__(self, proxy, dyn_shape, **kwargs):
super(DynamicShapeVariable, self).__init__(**kwargs)
self.proxy = proxy
self.dyn_shape = dyn_shape
def python_type(self):
return type(self.dyn_shape)
def unpack_var_sequence(self, tx):
super(DynamicShapeVariable, self).unpack_var_sequence(tx)
def as_proxy(self):
return self.proxy
def evaluate_expr(self, output_graph):
if not isinstance(self.dyn_shape, torch.SymInt):
return self.dyn_shape
return output_graph.shape_env.evaluate_expr(self.dyn_shape.get_pyobj().expr)
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
from .builder import wrap_fx_proxy
options = VariableTracker.propagate(self, args, kwargs.values())
return wrap_fx_proxy(
tx,
tx.output.create_proxy(
"call_method",
name,
*proxy_args_kwargs([self] + list(args), kwargs),
),
**options,
)
class TensorWithTFOverrideVariable(VariableTracker):
"""
Represents a tensor subclass instance with a __torch_function__ override.
"""
def __init__(
self,
tensor_variable,
orig_tensor_variable_source,
subclass_torch_function__func,
subclass_type,
**kwargs,
):
super(TensorWithTFOverrideVariable, self).__init__(**kwargs)
self.tensor_variable = tensor_variable
self.orig_tensor_variable_source = orig_tensor_variable_source
self.subclass_torch_function__func = subclass_torch_function__func
self.subclass_type = subclass_type
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
# This code block implements inlining the __torch_function__ override
# of `call_method`.
from . import GetAttrVariable
options = VariableTracker.propagate(self, args, kwargs.values())
# insert unwrapped version of self as the first argument
# TODO: This is wrong! When you call the internal __torch_function__,
# you still get the wrapped version of self, and if you call functions
# inside __torch_function__, they should come back here. If we unwrap
# the tensor immediately, that will not happen.
# See https://github.com/pytorch/torchdynamo/issues/1951
args = list(args)
args.insert(0, self.tensor_variable)
func_var = GetAttrVariable(self.tensor_variable, name)
unwrapped = TensorWithTFOverrideVariable.inline_torch_function_unwrapped(
tx,
func_var,
self.orig_tensor_variable_source,
self.subclass_torch_function__func,
self.subclass_type,
options,
args,
kwargs,
)
# TODO(future PR): implement rewrapping conditional on method presence
# in `torch.overrides.get_default_nowrap_function()`. It's unclear how
# to do this easily in the current codebase since the resolution of
# `GetAttrVariable` depends on the type of the underlying object.
return TensorWithTFOverrideVariable(
unwrapped,
self.orig_tensor_variable_source,
self.subclass_torch_function__func,
self.subclass_type,
)
@staticmethod
def inline_torch_function_unwrapped(
tx,
original_func_var,
tensor_with_tf_override_source,
tf_func,
subclass_type,
options,
args,
kwargs,
):
"""
This function inlines the `__torch_function__` override for `original_func_var`.
For example, if the user code is
x1 = torch.sigmoid(x0)
And `x0` has an override, then:
* `original_func_var` will be a `VariableTracker` object wrapping `torch.sigmoid`
* `tensor_with_tf_override_source` will be the `Source` object from
the original tensor override instance in the beginning of the program
* `tf_func` will be the custom `__torch_function__` function
* `subclass_type` will be `type(x0)`
The caller is expected to properly massage args and kwargs before
passing them into this function.
The caller is responsible for wrapping the return value, if needed.
"""
from . import UserDefinedClassVariable
from .builder import TupleVariable, VariableBuilder
source = AttrSource(
AttrSource(tensor_with_tf_override_source, "__torch_function__"),
"__func__",
)
tf_func_var = VariableBuilder(tx, source)(tf_func)
type_var = UserDefinedClassVariable(subclass_type, **options)
# signature:
# def __torch_function__(cls, func, types, args=(), kwargs=None):
tf_args = (
type_var, # cls
original_func_var, # func
(type_var,), # types
TupleVariable(args), # args
kwargs, # kwargs
)
# Disable __torch_function__ here to prevent the clone of the
# example tensor from going into the override.
with torch._C.DisableTorchFunction():
return tx.inline_user_function_return(tf_func_var, tf_args, {})
class UnspecializedNumpyVariable(TensorVariable):
"""
This is a 1-element tensor represents unspecialized numpy float/int.
"""
def __init__(self, proxy: torch.fx.Proxy, **kwargs):
raw_value = kwargs.pop("raw_value", None)
super(UnspecializedNumpyVariable, self).__init__(proxy, **kwargs)
self.raw_value = raw_value
@classmethod
def from_tensor_variable(cls, tensor_variable, raw_value):
# Convert a `TensorVariable` instance into an `UnspecializedNumpyVariable` instance.
return UnspecializedNumpyVariable(
**dict(tensor_variable.__dict__), raw_value=raw_value
)
def as_specialized(self, tx):
for graph_arg in tx.output.graphargs:
if graph_arg.source is self.source:
graph_arg.erase()
for g in self.guards:
if g.is_volatile:
g.create_fn = GuardBuilder.CONSTANT_MATCH
return ConstantVariable(value=self.raw_value, guards=self.guards)
class UnspecializedPythonVariable(TensorVariable):
"""
This is a 1-element tensor represents unspecialized python float/int.
"""
def __init__(self, proxy: torch.fx.Proxy, **kwargs):
raw_value = kwargs.pop("raw_value", None)
need_unwrap = kwargs.pop("need_unwrap", True)
super(UnspecializedPythonVariable, self).__init__(proxy, **kwargs)
self.raw_value = raw_value
self.need_unwrap = need_unwrap
@classmethod
def from_tensor_variable(cls, tensor_variable, raw_value, need_unwrap=True):
# Convert a `TensorVariable` instance into an `UnspecializedPythonVariable` instance.
return UnspecializedPythonVariable(
**dict(tensor_variable.__dict__),
raw_value=raw_value,
need_unwrap=need_unwrap,
)
def as_specialized(self, tx):
for graph_arg in tx.output.graphargs:
if graph_arg.source is self.source:
graph_arg.erase()
for g in self.guards:
if g.is_volatile:
g.create_fn = GuardBuilder.CONSTANT_MATCH
return ConstantVariable(value=self.raw_value, guards=self.guards)
class FakeItemVariable(TensorVariable):
"""An unspecialized python variable which prevents access to the underlying raw value.
This is needed if item is called on a FakeTensor."""
def __init__(self, proxy: torch.fx.Proxy, **kwargs):
need_unwrap = kwargs.pop("need_unwrap", False)
super(FakeItemVariable, self).__init__(proxy, **kwargs)
self.need_unwrap = need_unwrap
@classmethod
def from_tensor_variable(cls, tensor_variable):
return FakeItemVariable(**dict(tensor_variable.__dict__))