diff --git a/tests/test_modules.py b/tests/test_modules.py index f08e0e470e..b382106ef0 100755 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -196,6 +196,7 @@ class NNModuleTests(torchdynamo.testing.TestCase): test_iseval2 = make_test(IsEvalCheck()) test_viamodulecall = make_test(ViaModuleCall()) test_isnonelayer = make_test(IsNoneLayer()) + test_intarg = make_test(IntArg()) # not yet implemented # test_layerlist = make_test(LayerList()) diff --git a/torchdynamo/symbolic_convert.py b/torchdynamo/symbolic_convert.py index c375f46e71..4ce860cac4 100644 --- a/torchdynamo/symbolic_convert.py +++ b/torchdynamo/symbolic_convert.py @@ -4,6 +4,7 @@ import functools import inspect import itertools +from numbers import Real import operator import types import typing @@ -23,6 +24,7 @@ from .guards import GuardSource from .variable_tracker import AllowedFunctionOrModuleVariable, PythonModuleVariable from .variable_tracker import BaseListVariable +from .variable_tracker import BasicTypeVariable from .variable_tracker import BuiltinVariable from .variable_tracker import ConstDictVariable from .variable_tracker import ConstantVariable @@ -184,6 +186,13 @@ def wrap_local(self, name, value): value=value, guards={Guard(name, GuardSource.LOCAL, GuardBuilder.VALUE_MATCH)}, ) + elif isinstance(value, Real): + self.graphargs.append(LocalArg(name)) + return BasicTypeVariable( + proxy=self.create_graph_input(name), + state=TracingSupported.UNKNOWN, + guards={Guard(name, GuardSource.LOCAL, GuardBuilder.TYPE_MATCH)}, + ) elif type(value) in (tuple, list) and all( isinstance(x, torch.Tensor) for x in value ): diff --git a/torchdynamo/variable_tracker.py b/torchdynamo/variable_tracker.py index 2914886ad2..3f95d4ecc4 100644 --- a/torchdynamo/variable_tracker.py +++ b/torchdynamo/variable_tracker.py @@ -97,6 +97,15 @@ def as_proxy(self): return self.proxy +class BasicTypeVariable(TensorVariable): + """ + Points to a simple type, e.g. int, float, str. So far, we treat this + the same as TensorVariable + """ + + pass + + class NNModuleVariable(VariableTracker): def __init__(self, module_key: str, **kwargs): super(NNModuleVariable, self).__init__(**kwargs)