Skip to content

Commit

Permalink
[inductor] Small fixes for huggingface models (#728)
Browse files Browse the repository at this point in the history
  • Loading branch information
jansel committed Aug 8, 2022
1 parent cbec88a commit 3b4f9bc
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 16 deletions.
11 changes: 10 additions & 1 deletion tests/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1428,7 +1428,7 @@ def fn(mask, value):
),
)

def test_pow(self):
def test_pow1(self):
def fn(x):
return [aten.pow(x, e) for e in range(-8, 9)]

Expand All @@ -1437,6 +1437,15 @@ def fn(x):
(torch.randn([16, 16]),),
)

def test_pow2(self):
def fn(x):
return aten.pow(1000, x), aten.pow(x, 1000)

self.common(
fn,
(torch.randn([16, 16]),),
)

def test_glu(self):
def fn(x):
return aten.glu(x, -1), aten.glu(x, 1), aten.glu(x, 2)
Expand Down
4 changes: 4 additions & 0 deletions torchinductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,10 @@ def exp(x):
def sqrt(x):
return f"std::sqrt({x})"

@staticmethod
def pow(a, b):
return f"std::pow({a}, {b})"

@staticmethod
def log(x):
return f"std::log({x})"
Expand Down
9 changes: 8 additions & 1 deletion torchinductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ def rand(seed, offset, _): # _ here to keep the contract identical to CPU rand
def randn(seed, offset, _): # _ here to keep the contract identical to CPU randn op
return f"tl.randn({seed}, {offset})"

@staticmethod
def pow(a, b):
return f"tl.libdevice.pow({a}, {b})"

@staticmethod
def log(x):
if has_triton_libdevice():
Expand Down Expand Up @@ -448,14 +452,17 @@ def initialize_range_tree(self, pid_cache):
def disable_reduction(self):
@contextlib.contextmanager
def ctx():
if not self.inside_reduction:
if self.numels[-1] == 1:
assert not self.inside_reduction
yield
return
# calling codegen_body() will flush all the pending buffers
# and write out a reduction loop
self.codegen_body()
self.inside_reduction = False
yield
# flush out any code before opening the next loop
self.codegen_body()
self.inside_reduction = True

return ctx()
Expand Down
4 changes: 3 additions & 1 deletion torchinductor/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,12 @@ def rsub(a, b):
return b - a


@register_decomposition([aten.masked_fill.Scalar])
@register_decomposition([aten.masked_fill])
def masked_fill(value, mask, other):
if isinstance(other, numbers.Number):
other = torch.tensor(other, dtype=value.dtype, device=value.device)
if other.device != value.device and other.numel() == 1:
other = other.to(value.device)
value, mask, other = torch.broadcast_tensors(value, mask, other)
return torch.where(mask, other, value)

Expand Down
47 changes: 35 additions & 12 deletions torchinductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,12 +271,26 @@ def _to_copy(


@register_lowering(aten.to)
def to(x, device_or_dtype, non_blocking=False, copy=False, memory_format=None):
def to(
x,
device_or_dtype=None,
non_blocking=False,
copy=False,
memory_format=None,
device=None,
dtype=None,
layout=None,
):
assert not memory_format, "TODO"
assert layout in (None, torch.strided)
if isinstance(device_or_dtype, torch.dtype):
return to_dtype(x, device_or_dtype)
if isinstance(device_or_dtype, torch.device):
return to_device(x, device_or_dtype)
if device is not None:
return to_device(x, device)
if dtype is not None:
return to_dtype(x, dtype)
assert False, device_or_dtype


Expand Down Expand Up @@ -503,7 +517,6 @@ def roll(a, shifts, dims=tuple()):
raise RuntimeError(
f"shifts and dimensions must align. shifts: {len_shifts}, dims: {len_dims}"
)
assert len_dims > 1
tail_shifts = shifts[1:]
tail_dims = dims[1:]
first_dim_rolled = roll(a, shifts[0], dims[0])
Expand Down Expand Up @@ -531,10 +544,13 @@ def fn(index):

@register_lowering(aten.as_strided)
def as_strided(x, size, stride, storage_offset=None):
if isinstance(x, TensorBox) and isinstance(x.data, ir.BaseView):
# as_strided ignores views
x = x.data.unwrap_view()
x.realize()
if not ir.is_storage_and_layout(x):
if not ir.is_contiguous_storage_and_layout(x):
raise NotImplementedError(f"unrealized as_strided({x}, ...)")
storage, old_layout = ir.as_storage_and_layout(x)
storage, old_layout = ir.as_contiguous_storage_and_layout(x)
new_layout = ir.FixedLayout(
old_layout.device,
old_layout.dtype,
Expand Down Expand Up @@ -850,6 +866,13 @@ def arange(
end = start
start = 0

if isinstance(start, float) and int(start) == start:
start = int(start)
if isinstance(end, float) and int(end) == end:
end = int(end)
if isinstance(step, float) and int(step) == step:
step = int(step)

assert isinstance(start, int)
assert isinstance(end, int)
assert isinstance(step, int)
Expand Down Expand Up @@ -2368,16 +2391,19 @@ def pow_recursive(x, y, dtype):
return result


@make_pointwise
def pow_native(a, b):
return ops.pow(a, b)


@register_lowering(aten.pow, broadcast=True)
def pow(a, b):
# see https://github.com/openai/triton/issues/506
# triton doesn't support pow, so need to rewrite it
# this is a lowering not a decomp, due to upstream pytorch being unstable
if isinstance(b, float) and b == int(b):
return pow(a, int(b))
elif isinstance(b, int) and b == 1:
return a
elif isinstance(b, int):
elif isinstance(b, int) and -32 < b < 32:
# Optimize away small fixed powers
loader = a.make_loader()

def fn(idx):
Expand All @@ -2390,10 +2416,7 @@ def fn(idx):
ranges=a.get_size(),
)
else:
assert False, "TODO: check correctness here"
# odd integer: torch.sign(a) * torch.exp(torch.log(torch.abs(a)) * b)
# even integer: torch.exp(torch.log(torch.abs(a)) * b)
# else: torch.exp(torch.log(a) * b)
return pow_native(a, b)


def mutate_to(changed, val):
Expand Down
11 changes: 10 additions & 1 deletion torchinductor/sizevars.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import collections
import dataclasses
import functools
import logging
from typing import Dict
from typing import List

Expand All @@ -11,6 +12,8 @@

from .virtualized import V

log = logging.getLogger(__name__)


@dataclasses.dataclass
class ZeroGuard:
Expand Down Expand Up @@ -278,7 +281,13 @@ def stride_hints(self, index: sympy.Expr, vars: List[sympy.Symbol]):
for v in index.free_symbols:
if str(v).startswith("indirect"):
index = index.subs({v: 0})
return [self.size_hint(s) for s in self.stride_vars(index, vars)]
result = []
for s in self.stride_vars(index, vars):
try:
result.append(self.size_hint(s))
except TypeError:
result.append(0)
return result

def stride_order(self, index: sympy.Expr, vars: List[sympy.Symbol]):
strides = tuple(map(abs, self.stride_hints(index, vars)))
Expand Down

0 comments on commit 3b4f9bc

Please sign in to comment.