Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor] Small fixes for huggingface models #728

Merged
merged 3 commits into from
Aug 8, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 10 additions & 1 deletion tests/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1426,7 +1426,7 @@ def fn(mask, value):
),
)

def test_pow(self):
def test_pow1(self):
def fn(x):
return [aten.pow(x, e) for e in range(-8, 9)]

Expand All @@ -1435,6 +1435,15 @@ def fn(x):
(torch.randn([16, 16]),),
)

def test_pow2(self):
def fn(x):
return aten.pow(1000, x), aten.pow(x, 1000)

self.common(
fn,
(torch.randn([16, 16]),),
)

def test_glu(self):
def fn(x):
return aten.glu(x, -1), aten.glu(x, 1), aten.glu(x, 2)
Expand Down
4 changes: 4 additions & 0 deletions torchinductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ def exp(x):
def sqrt(x):
return f"std::sqrt({x})"

@staticmethod
def pow(a, b):
return f"std::pow({a}, {b})"

@staticmethod
def log(x):
return f"std::log({x})"
Expand Down
9 changes: 8 additions & 1 deletion torchinductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ def logical_or(a, b):
def rand(seed, offset, _): # _ here to keep the contract identical to CPU rand op
return f"tl.rand({seed}, {offset})"

@staticmethod
def pow(a, b):
return f"tl.libdevice.pow({a}, {b})"

@staticmethod
def log(x):
if has_triton_libdevice():
Expand Down Expand Up @@ -444,14 +448,17 @@ def initialize_range_tree(self, pid_cache):
def disable_reduction(self):
@contextlib.contextmanager
def ctx():
if not self.inside_reduction:
if self.numels[-1] == 1:
assert not self.inside_reduction
yield
return
# calling codegen_body() will flush all the pending buffers
# and write out a reduction loop
self.codegen_body()
self.inside_reduction = False
yield
# flush out any code before opening the next loop
self.codegen_body()
self.inside_reduction = True

return ctx()
Expand Down
4 changes: 3 additions & 1 deletion torchinductor/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,12 @@ def rsub(a, b):
return b - a


@register_decomposition([aten.masked_fill.Scalar])
@register_decomposition([aten.masked_fill])
def masked_fill(value, mask, other):
if isinstance(other, numbers.Number):
other = torch.tensor(other, dtype=value.dtype, device=value.device)
if other.device != value.device and other.numel() == 1:
other = other.to(value.device)
value, mask, other = torch.broadcast_tensors(value, mask, other)
return torch.where(mask, other, value)

Expand Down
47 changes: 35 additions & 12 deletions torchinductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,12 +271,26 @@ def _to_copy(


@register_lowering(aten.to)
def to(x, device_or_dtype, non_blocking=False, copy=False, memory_format=None):
def to(
x,
device_or_dtype=None,
non_blocking=False,
copy=False,
memory_format=None,
device=None,
dtype=None,
layout=None,
):
assert not memory_format, "TODO"
assert layout in (None, torch.strided)
if isinstance(device_or_dtype, torch.dtype):
return to_dtype(x, device_or_dtype)
if isinstance(device_or_dtype, torch.device):
return to_device(x, device_or_dtype)
if device is not None:
return to_device(x, device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will be wrong if both device and dtype is specified?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, will update.

if dtype is not None:
return to_dtype(x, dtype)
assert False, device_or_dtype


Expand Down Expand Up @@ -503,7 +517,6 @@ def roll(a, shifts, dims=tuple()):
raise RuntimeError(
f"shifts and dimensions must align. shifts: {len_shifts}, dims: {len_dims}"
)
assert len_dims > 1
tail_shifts = shifts[1:]
tail_dims = dims[1:]
first_dim_rolled = roll(a, shifts[0], dims[0])
Expand Down Expand Up @@ -531,10 +544,13 @@ def fn(index):

@register_lowering(aten.as_strided)
def as_strided(x, size, stride, storage_offset=None):
if isinstance(x, TensorBox) and isinstance(x.data, ir.BaseView):
# as_strided ignores views
x = x.data.unwrap_view()
x.realize()
if not ir.is_storage_and_layout(x):
if not ir.is_contiguous_storage_and_layout(x):
raise NotImplementedError(f"unrealized as_strided({x}, ...)")
storage, old_layout = ir.as_storage_and_layout(x)
storage, old_layout = ir.as_contiguous_storage_and_layout(x)
new_layout = ir.FixedLayout(
old_layout.device,
old_layout.dtype,
Expand Down Expand Up @@ -851,6 +867,13 @@ def arange(
end = start
start = 0

if isinstance(start, float) and int(start) == start:
start = int(start)
if isinstance(end, float) and int(end) == end:
end = int(end)
if isinstance(step, float) and int(step) == step:
step = int(step)

assert isinstance(start, int)
assert isinstance(end, int)
assert isinstance(step, int)
Expand Down Expand Up @@ -2367,16 +2390,19 @@ def pow_recursive(x, y, dtype):
return result


@make_pointwise
def pow_native(a, b):
return ops.pow(a, b)


@register_lowering(aten.pow, broadcast=True)
def pow(a, b):
# see https://github.com/openai/triton/issues/506
# triton doesn't support pow, so need to rewrite it
# this is a lowering not a decomp, due to upstream pytorch being unstable
if isinstance(b, float) and b == int(b):
return pow(a, int(b))
elif isinstance(b, int) and b == 1:
return a
elif isinstance(b, int):
elif isinstance(b, int) and -32 < b < 32:
# Optimize away small fixed powers
loader = a.make_loader()

def fn(idx):
Expand All @@ -2389,10 +2415,7 @@ def fn(idx):
ranges=a.get_size(),
)
else:
assert False, "TODO: check correctness here"
# odd integer: torch.sign(a) * torch.exp(torch.log(torch.abs(a)) * b)
# even integer: torch.exp(torch.log(torch.abs(a)) * b)
# else: torch.exp(torch.log(a) * b)
return pow_native(a, b)


def mutate_to(changed, val):
Expand Down
11 changes: 10 additions & 1 deletion torchinductor/sizevars.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import collections
import dataclasses
import functools
import logging
from typing import Dict
from typing import List

Expand All @@ -11,6 +12,8 @@

from .virtualized import V

log = logging.getLogger(__name__)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Used logging and removed for this PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think it is fine to leave this line for the future though.



@dataclasses.dataclass
class ZeroGuard:
Expand Down Expand Up @@ -278,7 +281,13 @@ def stride_hints(self, index: sympy.Expr, vars: List[sympy.Symbol]):
for v in index.free_symbols:
if str(v).startswith("indirect"):
index = index.subs({v: 0})
return [self.size_hint(s) for s in self.stride_vars(index, vars)]
result = []
for s in self.stride_vars(index, vars):
try:
result.append(self.size_hint(s))
except TypeError:
result.append(0)
return result

def stride_order(self, index: sympy.Expr, vars: List[sympy.Symbol]):
strides = tuple(map(abs, self.stride_hints(index, vars)))
Expand Down