From e6a9d0e9eeeb07d59307083249acecc90b6d1da9 Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Mon, 22 Aug 2022 22:59:01 +0000 Subject: [PATCH 1/2] Add lowering for adaptive avg pool --- torchinductor/lowering.py | 110 +++++++++++++++++++++++++++++++++++--- 1 file changed, 103 insertions(+), 7 deletions(-) diff --git a/torchinductor/lowering.py b/torchinductor/lowering.py index ecc98570f0..b2929b0862 100644 --- a/torchinductor/lowering.py +++ b/torchinductor/lowering.py @@ -1,6 +1,7 @@ import functools import itertools import logging +import math from collections.abc import Iterable from typing import List @@ -1929,6 +1930,7 @@ def constant_boundary_condition_2d(x, fill_value, padding): def load(index): *prefix, ih, iw = index + mask = ops.and_( range_mask(ih, h), range_mask(iw, w), @@ -2123,6 +2125,107 @@ def fn(idx): ) +def pad_adaptive_loader(x): + *_, h, w = x.get_size() + x_loader = x.make_loader() + + def load(prefix, increments, start_indices, end_indices): + ih, iw = increments + h_start_index, w_start_index = start_indices + h_end_index, w_end_index = end_indices + + mask = ops.and_( + ops.lt( + ops.index_expr(h_start_index + ih, torch.int64), + ops.index_expr(h_end_index, torch.int64), + ), + ops.lt( + ops.index_expr(w_start_index + iw, torch.int64), + ops.index_expr(w_end_index, torch.int64), + ), + ) + + return ops.masked( + mask, + lambda: x_loader([*prefix, h_start_index + ih, w_start_index + iw]), + 0.0, + ) + + return load + + +@register_lowering(aten._adaptive_avg_pool2d) +def _adaptive_avg_pool2d(x, output_size): + assert isinstance(x, TensorBox) + assert len(output_size) == 2 + x.realize_hint() + + *batch, h_in, w_in = x.get_size() + + h_in = V.graph.sizevars.guard_static_shape(h_in) + w_in = V.graph.sizevars.guard_static_shape(w_in) + + h_out, w_out = output_size + + # no-op if the same input and output + if h_in == h_out and w_in == w_out: + # TODO: do I need to copy ? _to_copy does not + return x + + if h_in % h_out == 0 and w_in % w_out == 0: + kernel_size = [h_in // h_out, w_in // w_out] + return avg_pool2d(x, kernel_size) + + h_kernel_max = math.ceil((h_in + h_out - 1) / h_out) + w_kernel_max = math.ceil((w_in + h_out - 1) / w_out) + + new_size = list(batch) + [h_out, w_out] + dtype = x.get_dtype() + + def fn_sum(idx, loader): + *prefix, bh, bw = idx + + def start_index(index, out_dim, inp_dim): + return ir.IndexingDiv((index * inp_dim), out_dim) + + def end_index(index, out_dim, inp_dim): + return ir.IndexingDiv((index + 1) * inp_dim + out_dim - 1, out_dim) + + h_start_index = start_index(bh, h_out, h_in) + h_end_index = end_index(bh, h_out, h_in) + + w_start_index = start_index(bw, w_out, w_in) + w_end_index = end_index(bw, w_out, w_in) + + total = None + for ih, iw in itertools.product(range(h_kernel_max), range(w_kernel_max)): + val = loader( + prefix, + [ih, iw], + [h_start_index, w_start_index], + [h_end_index, w_end_index], + ) + if total is None: + total = val + else: + total = ops.add(val, total) + return total + + ones_loader = pad_adaptive_loader(ones_like(x)) + + def fn(idx): + return ops.div(fn_sum(idx, pad_adaptive_loader(x)), fn_sum(idx, ones_loader)) + + rv = Pointwise.create( + device=x.get_device(), + dtype=dtype, + inner_fn=fn, + ranges=new_size, + ) + # TODO: should we force these to be realized? + return rv + + @register_lowering(aten.avg_pool2d, type_promote=False) def avg_pool2d( x, @@ -2346,13 +2449,6 @@ def fn(idx): return rv -@register_lowering(aten._adaptive_avg_pool2d, type_promote=False) -def _adaptive_avg_pool2d(x, output_size): - assert isinstance(x, TensorBox) - assert len(output_size) == 2 - return TensorBox.create(ir.AdaptiveAvgPool2d.create(x, output_size)) - - def _validate_reduction_axis(x, axis): size = x.get_size() if isinstance(axis, int): From 7005b2319eb1e2f8226940cf8af4a904e1b718aa Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Tue, 23 Aug 2022 16:09:53 +0000 Subject: [PATCH 2/2] respond to comments --- test/test_torchinductor.py | 12 ++++++++++++ torchinductor/lowering.py | 3 +-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/test/test_torchinductor.py b/test/test_torchinductor.py index 2b05b8546b..3f6a8c5aa4 100755 --- a/test/test_torchinductor.py +++ b/test/test_torchinductor.py @@ -1196,6 +1196,18 @@ def fn(x): (torch.randn(2, 4, 16, 16),), ) + # lowering to avg_pool2d case + self.common( + fn, + (torch.randn(2, 4, 3, 3),), + ) + + # no-op case + self.common( + fn, + (torch.randn(2, 4, 6, 6),), + ) + def test_max_pool2d1(self): def fn(x): return aten.max_pool2d_with_indices(x, [3, 3], [2, 2]) diff --git a/torchinductor/lowering.py b/torchinductor/lowering.py index b2929b0862..2fb2b5931c 100644 --- a/torchinductor/lowering.py +++ b/torchinductor/lowering.py @@ -2169,8 +2169,7 @@ def _adaptive_avg_pool2d(x, output_size): # no-op if the same input and output if h_in == h_out and w_in == w_out: - # TODO: do I need to copy ? _to_copy does not - return x + return clone(x) if h_in % h_out == 0 and w_in % w_out == 0: kernel_size = [h_in // h_out, w_in // w_out]