Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 27 additions & 9 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6658,6 +6658,22 @@ def fn(a):
with self.assertRaises(RuntimeError):
torch.compile(fn)(a)

def test_ir_node_str(self):
@torch.compile
def fn(x: torch.Tensor) -> torch.Tensor:
return x.sin(), torch.nn.Softmax(dim=1)(x.cos())

def run_node_alt(*args, **kwargs):
rv = run_node(*args, **kwargs)
strings.append(str(rv))
return rv

strings = []
run_node = GraphLowering.run_node
with patch.object(GraphLowering, "run_node", run_node_alt):
fn(torch.randn([8, 128]))
self.assertGreater(len(strings), 3)


if HAS_CUDA and not TEST_WITH_ASAN:
import triton
Expand Down Expand Up @@ -7394,16 +7410,18 @@ def fn():
self.assertEqual(fn_opt(), fn())

def test_split_op_with_sym(self):
for dynamic_shapes in [True, False]:
torch._dynamo.config.dynamic_shapes = dynamic_shapes

def fn(x: torch.Tensor) -> torch.Tensor:
# split(tensor, sympy.Integer), split(tensor, sympy.Expr)
return torch.split(x, x.shape[0]), torch.split(x, x.shape[0] // 2)
def fn(x: torch.Tensor) -> torch.Tensor:
# split(tensor, sympy.Integer), split(tensor, sympy.Expr)
return torch.split(x, x.shape[0]), torch.split(x, x.shape[0] // 2)

fn_opt = torch._dynamo.optimize("inductor", dynamic=dynamic_shapes)(fn)
inps = torch.randn([5, 5])
fn_opt(inps)
for dynamic_shapes in [True, False]:
with torch._dynamo.config.patch(dynamic_shapes=dynamic_shapes):
torch._dynamo.reset()
fn_opt = torch._dynamo.optimize("inductor", dynamic=dynamic_shapes)(
fn
)
inps = torch.randn([5, 5])
fn_opt(inps)


class ExprPrinterTests(TestCase):
Expand Down
16 changes: 16 additions & 0 deletions torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,14 @@ def acos(x):
def asin(x):
return f"{x}.asin()"

@staticmethod
def cosh(x):
return f"{x}.cosh()"

@staticmethod
def sinh(x):
return f"{x}.sinh()"

@staticmethod
def log10(x):
return f"{x}.log10()"
Expand Down Expand Up @@ -702,6 +710,14 @@ def acos(x):
def acosh(x):
return f"std::acosh({x})"

@staticmethod
def cosh(x):
return f"std::cosh({x})"

@staticmethod
def sinh(x):
return f"std::sinh({x})"

@staticmethod
def asin(x):
return f"std::asin({x})"
Expand Down
24 changes: 9 additions & 15 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,12 +394,8 @@ def _index(ranges, prefix="i"):

@cache_on_self
def inner_fn_str(self):
formatter = V.KernelFormatterHandler(V.MockHandler())
with V.set_ops_handler(formatter), patch.object(
FlexibleLayout, "allow_indexing", True
):
result = self.inner_fn(self._index(self.ranges))
return formatter.getvalue(result)
index = self._index(self.ranges)
return V.KernelFormatterHandler.ir_to_string(self.inner_fn, index)

def is_zero_elements(self):
return any(r == 0 for r in self.ranges)
Expand Down Expand Up @@ -515,15 +511,13 @@ def index_length(self):

@cache_on_self
def inner_fn_str(self):
formatter = V.KernelFormatterHandler(V.MockHandler())
with V.set_ops_handler(formatter), patch.object(
FlexibleLayout, "allow_indexing", True
):
result = self.inner_fn(
self._index(self.ranges),
self._index(self.reduction_ranges, "r"),
)
return formatter.getvalue(result)
index = self._index(self.ranges)
rindex = self._index(self.reduction_ranges, "r")
return V.KernelFormatterHandler.ir_to_string(
self.inner_fn,
index,
rindex,
)

def constant_to_device(self, device):
"""Move this to a given device. Requires that all reads are to constants."""
Expand Down
33 changes: 30 additions & 3 deletions torch/_inductor/virtualized.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from contextlib import contextmanager
from itertools import chain
from threading import local
from unittest.mock import patch

import sympy

Expand Down Expand Up @@ -66,13 +67,13 @@ def __getattr__(self, name):
def inner(*args, **kwargs):
fargs = [_arg_str(a) for a in args]
fargs.extend(f"{k}={v}" for k, v in kwargs.items())
return f"{name}({', '.join(fargs)})"
return f"ops.{name}({', '.join(fargs)})"

return inner

@staticmethod
def masked(mask, body, other):
return f"masked({mask}, {body()}, {other})"
return f"ops.masked({mask}, {body()}, {other})"

@staticmethod
def indirect_indexing(index_var):
Expand All @@ -96,9 +97,35 @@ def inner(*args):
class KernelFormatterHandler:
def __init__(self, parent_handler):
self.parent_handler = parent_handler
self.output = IndentedBuffer()
self.output = IndentedBuffer(1)
self.var_counter = itertools.count()

@staticmethod
def ir_to_string(ir_fn, index, rindex=None):
from .ir import FlexibleLayout

args = [index, rindex] if rindex is not None else [index]
names = ["index", "rindex"] if rindex is not None else ["index"]
formatter = KernelFormatterHandler(MockHandler())

with formatter.output.indent(-1):
formatter.output.writeline(f"def inner_fn({', '.join(names)}):")
for name, arg in zip(names, args):
if arg:
lhs = ", ".join(
[
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be a generator instead of a list comprehension

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it matters since this list will be <10 elements.

str("_" if isinstance(v, (int, sympy.Integer)) else v)
for v in arg
]
)
formatter.output.writeline(f"{lhs} = {name}")

with V.set_ops_handler(formatter), patch.object(
FlexibleLayout, "allow_indexing", True
):
result = ir_fn(*args)
return formatter.getvalue(result)

def __getattr__(self, name):
def inner(*args, **kwargs):
line = getattr(self.parent_handler, name)(*args, **kwargs)
Expand Down