Skip to content

Commit

Permalink
[Inductor] support masked vectorization for the tail_loop
Browse files Browse the repository at this point in the history
ghstack-source-id: 57ca43d0ef1f782aa091cbd29c50f549a880c5aa
Pull Request resolved: #126526
  • Loading branch information
jiayisunx authored and CaoE committed May 25, 2024
1 parent e2f0818 commit 1721a99
Show file tree
Hide file tree
Showing 2 changed files with 242 additions and 70 deletions.
167 changes: 125 additions & 42 deletions torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
cache_on_self,
get_bounds_index_expr,
get_fused_kernel_name,
has_free_symbols,
is_welford_reduction,
parallel_num_threads,
Placeholder,
Expand Down Expand Up @@ -1574,7 +1575,7 @@ def _gen_parallel_reduction_buffers(
if (
reduction_type == "welford_reduce"
and welford_weight_reciprocal_vec_fn
and hasattr(self, "weight_recp_vec_range")
and hasattr(self, "reduction_main_size")
and "vec" in f"{acc_type}"
):
self.local_reduction_init.writeline(
Expand Down Expand Up @@ -2024,6 +2025,7 @@ def __init__(
tiling_factor=0,
tiling_idx=-1,
tiling_dtype=torch.float,
tail_size=None,
):
super().__init__(args, num_threads)
self.vec_isa = codecache.pick_vec_isa()
Expand All @@ -2032,6 +2034,7 @@ def __init__(
tiling_factor = self.vec_isa.nelements(dtype=tiling_dtype)
self.tiling_factor = tiling_factor
self.tiling_idx = tiling_idx
self.tail_size = tail_size

def _try_get_const_stride(self, index: sympy.Expr, itervar: sympy.Symbol):
if self.index_indirect_depends_on(index, itervar):
Expand Down Expand Up @@ -2110,7 +2113,7 @@ def _get_vec_load_line(
line = (
f"{load_mask_str}.template loadu<{cpp_type},{num_vectors}>({loadbuf})"
if load_mask_str
else f"{self._get_vec_type(dtype)}::loadu({loadbuf}, {self.tiling_factor})"
else f"{self._get_vec_type(dtype)}::loadu({loadbuf}, {self.tail_size if self.tail_size else self.tiling_factor})"
)
return line

Expand Down Expand Up @@ -2143,7 +2146,9 @@ def _load_or_store_non_contiguous(
buffer = self.loads

def get_result_size(dtype: torch.dtype) -> int:
if dtype.itemsize < 4:
if self.tail_size:
return self.tail_size
elif dtype.itemsize < 4:
return self.tiling_factor * (4 // dtype.itemsize)
else:
return self.tiling_factor
Expand Down Expand Up @@ -2206,11 +2211,17 @@ def vec_to_array(vec_var: CppCSEVariable) -> CppCSEVariable:
else:
load_mask = f"{self._load_mask} != 0"
if codecache.is_gcc():
code.writeline(f"#pragma GCC unroll {self.tiling_factor}")
code.writeline(
f"#pragma GCC unroll {self.tail_size if self.tail_size else self.tiling_factor}"
)
else:
code.writeline(f"#pragma unroll {self.tiling_factor}")
code.writeline(
f"#pragma unroll {self.tail_size if self.tail_size else self.tiling_factor}"
)
code.writeline(
f"for (long {itervar_inner} = 0; {itervar_inner} < {self.tiling_factor}; {itervar_inner}++)"
f"for (long {itervar_inner} = 0; "
+ f"{itervar_inner} < {self.tail_size if self.tail_size else self.tiling_factor}; "
+ f"{itervar_inner}++)"
)
with code.indent(), contextlib.ExitStack() as stack:
index_c = cexpr_index(index)
Expand Down Expand Up @@ -2287,10 +2298,9 @@ def _get_store_line(
stride = self._try_get_const_stride(index, tiling_var)
code = IndentedBuffer()
if stride == 1:
if dtype == torch.float:
code.writeline(f"{value}.store({var_expr});")
else:
code.writeline(f"{value}.store({var_expr}, {self.tiling_factor});")
code.writeline(
f"{value}.store({var_expr}, {self.tail_size if self.tail_size else self.tiling_factor});"
)
else:
self._load_or_store_non_contiguous(
var, index, dtype, buffer=code, store_value=value
Expand Down Expand Up @@ -2348,15 +2358,19 @@ def reduction(self, dtype, src_dtype, reduction_type, value):
self.reduction_prefix.writeline(
f"{acc_type_vec} {acc_vec} = {self.reduction_init_vec(reduction_type, dtype)};"
)
# save the reciprocal of weights for welford reduce if using static shape
reduction_size = functools.reduce(
lambda x, y: x * y, self.ranges[self.reduction_depth :]
)
if reduction_type == "welford_reduce":
reduction_factor = (
self.tiling_factor if self.tiling_idx >= self.reduction_depth else 1
# save the reciprocal of weights for welford reduce if using static shape
reduction_size = functools.reduce(
lambda x, y: x * y, self.ranges[self.reduction_depth :]
)
self.weight_recp_vec_range = FloorDiv(reduction_size, reduction_factor)
# calculate the reduction size that will be vectorized
reduction_inner_size = self.ranges[-1] if self.reduction_depth < len(self.ranges) - 1 else self.ranges[self.reduction_depth]

Check failure on line 2367 in torch/_inductor/codegen/cpp.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [operator]

Unsupported operand types for > ("int" and "None")

Check failure on line 2367 in torch/_inductor/codegen/cpp.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [index]

Invalid index type "Any | None" for "list[Expr]"; expected type "SupportsIndex"
# calculate loops size outside the vectorized loop
self.reduction_outer_size = reduction_size / reduction_inner_size
# calculate the main loop size
self.reduction_main_size = FloorDiv(reduction_inner_size, self.tiling_factor) * self.tiling_factor
# calculate the tail loop size
self.reduction_tail_size = reduction_inner_size - self.reduction_main_size
self.non_parallel_reduction_prefix.writeline(
self.welford_weight_reciprocal_vec(dtype, None)
)
Expand Down Expand Up @@ -2483,29 +2497,33 @@ def reduction_acc_type_vec(self, reduction_type, dtype):
return vec_type

def welford_weight_reciprocal_vec(self, dtype, num_threads=None):
vec_num_range_thread = (
CeilDiv(self.weight_recp_vec_range, num_threads)
reduction_main_size_thread = (
CeilDiv(self.reduction_main_size / self.tiling_factor, num_threads) * self.tiling_factor
if num_threads
else self.weight_recp_vec_range
else self.reduction_main_size
)
vec_num_range_thread_expr = cexpr_index(vec_num_range_thread)
return f"static WeightRecp<{self._get_vec_type(dtype)}> weight_recps({vec_num_range_thread_expr});"
reduction_main_size_thread_expr = cexpr_index(reduction_main_size_thread)
reduction_outer_size_expr = cexpr_index(self.reduction_outer_size)
reduction_tail_size_expr = cexpr_index(self.reduction_tail_size)
return (f"static WeightRecp<{self._get_vec_type(dtype)}> weight_recps"
f"({reduction_outer_size_expr}, "
f"{reduction_main_size_thread_expr}, "
f"{reduction_tail_size_expr});")

def reduction_combine_vec(
self, reduction_type, var, next_value, use_weight_recps=False
):
if reduction_type == "max":
return f"at::vec::maximum({var}, {next_value})"
elif reduction_type == "min":
return f"at::vec::minimum({var}, {next_value})"
elif reduction_type == "sum":
return f"{var} + {next_value}"
elif reduction_type == "prod":
return f"{var} * {next_value}"
elif reduction_type == "xor_sum":
return f"{var} ^ {next_value}"
if reduction_type in ["max", "min", "sum", "prod", "xor_sum"]:
if self.tail_size:
return (
f'reduce({var}, {next_value}, "{reduction_type}", {self.tail_size})'
)
else:
return f'reduce({var}, {next_value}, "{reduction_type}")'
elif reduction_type == "welford_reduce":
if use_weight_recps:
if self.tail_size:
return f"welford_combine({var}, {next_value}, {self.tail_size}, &weight_recps)"
elif use_weight_recps:
return f"welford_combine({var}, {next_value}, &weight_recps)"
else:
return f"welford_combine({var}, {next_value})"
Expand All @@ -2516,7 +2534,10 @@ def reduction_combine_vec(
else:
# When combining intermediate accumulators we have a Welford<T> struct
mean, m2, weight = reduction_project(reduction_type, next_value)
return f"welford_combine({var}, {{{mean}, {m2}, {weight}}})"
if self.tail_size:
return f"welford_combine({var}, {{{mean}, {m2}, {weight}}}, {self.tail_size})"
else:
return f"welford_combine({var}, {{{mean}, {m2}, {weight}}})"
else:
raise NotImplementedError

Expand Down Expand Up @@ -2721,6 +2742,8 @@ def __init__(self, args, num_threads, tiling_factor, tiling_idx=-1):

self.simd_vec = True

self.simd_masked_vec = True

self.fast_vec_list = []
for k, v in CppVecOverrides.__dict__.items():
if isinstance(v, staticmethod):
Expand All @@ -2739,10 +2762,21 @@ def __init__(self, args, num_threads, tiling_factor, tiling_idx=-1):
torch.int64,
]

self.supported_dtypes_for_masked_vec: List[torch.dtype] = [
torch.float,
torch.bfloat16,
]

def disable_vec(self, msg=None):
if schedule_log.isEnabledFor(logging.DEBUG):
schedule_log.debug("Disabled vectorization: %s", msg)
self.simd_vec = False
self.simd_masked_vec = False

def disable_masked_vec(self, msg=None):
if schedule_log.isEnabledFor(logging.DEBUG):
schedule_log.debug("Disabled masked vectorization: %s", msg)
self.simd_masked_vec = False

def load(self, name: str, index: sympy.Expr):
with RecordOptimizationContext(__name__) as node_ctx:
Expand All @@ -2753,6 +2787,14 @@ def load(self, name: str, index: sympy.Expr):
opt_ctx.dtype = load_dtype
var = self.cse.newvar()

if load_dtype not in self.supported_dtypes_for_masked_vec:
self.disable_masked_vec(
f"{load_dtype} not supported by masked vectorization"
)

if has_free_symbols(self.ranges):
self.disable_masked_vec("Symbolic ranges not supported by masked load")

if len(self.itervars) == 0:
self.disable_vec("not a loop")
return var
Expand All @@ -2768,12 +2810,20 @@ def load(self, name: str, index: sympy.Expr):

def store(self, name, index, value, mode=None):
with RecordOptimizationContext(__name__) as node_ctx:
store_dtype = V.graph.get_dtype(name)

if store_dtype not in self.supported_dtypes_for_masked_vec:
self.disable_masked_vec(
f"{store_dtype} not supported by masked vectorization"
)

if has_free_symbols(self.ranges):
self.disable_masked_vec("Symbolic ranges not supported by masked store")

if len(self.itervars) == 0:
self.disable_vec("not a loop")
return self.simd_vec

store_dtype = V.graph.get_dtype(name)

opt_ctx: OptimizationContext = node_ctx.get_opt_ctx()
assert opt_ctx
opt_ctx.dtype = store_dtype
Expand All @@ -2792,6 +2842,9 @@ def store(self, name, index, value, mode=None):
return self.simd_vec

def reduction(self, dtype, src_dtype, reduction_type, value):
if has_free_symbols(self.ranges):
self.disable_masked_vec("Symbolic ranges not supported by masked reduction")

if not (
(dtype == torch.float and src_dtype == torch.float)
or (dtype == torch.int64 and src_dtype == torch.int64)
Expand Down Expand Up @@ -2882,6 +2935,11 @@ def constant(val, dtype):
):
opt_ctx.dtype = torch.float32

if opt_ctx.dtype not in self.supported_dtypes_for_masked_vec:
self.disable_masked_vec(
f"{opt_ctx.dtype} not supported by masked vectorization"
)

if opt_ctx.dtype not in self.supported_dtypes:
self.disable_vec(f"constant dtype: {opt_ctx.dtype}")
return val
Expand Down Expand Up @@ -2955,6 +3013,11 @@ def masked(mask, body, other):

@staticmethod
def to_dtype(x, dtype, src_dtype=None):
if dtype not in self.supported_dtypes_for_masked_vec:
self.disable_masked_vec(
f"{dtype} not supported by masked vectorization"
)

if dtype not in self.supported_dtypes:
self.disable_vec(f"to_dtype: {dtype}")
return x
Expand Down Expand Up @@ -3318,6 +3381,7 @@ def select_tiling(dtype: torch.dtype = torch.float):
tiling_indices = select_tiling_indices(tiling_factor)
if tiling_indices:
could_vec = True
could_masked_vec = True
for tiling_indice in tiling_indices:
with CppVecKernelChecker(
deepcopy(self.kernel_group.args),
Expand All @@ -3327,21 +3391,28 @@ def select_tiling(dtype: torch.dtype = torch.float):
) as vec_checker:
run(vec_checker)
could_vec = could_vec and vec_checker.simd_vec
could_masked_vec = (
could_masked_vec and vec_checker.simd_masked_vec
)
if not could_vec:
break
if could_vec:
if len(tiling_indices) == 1:
return [tiling_factor], tiling_indices
return [tiling_factor], tiling_indices, could_masked_vec
if len(tiling_indices) == 2:
return [tiling_factor, tiling_factor], tiling_indices
return [], []
return (
[tiling_factor, tiling_factor],
tiling_indices,
could_masked_vec,
)
return [], [], False

# Kernels share the same global contexts like V.graph.wrapper_code, V.kernel.args.
# But the generated scalar kernel has updated these global contexts. Hence, the other kernels
# should not do this again to avoid context conflict. By now, we only control the
# config.inplace_buffers. In the future, we could maintain more contexts.
with torch._inductor.config.patch(inplace_buffers=False):
tiling_factors, tiling_indices = select_tiling(vec_dtype)
tiling_factors, tiling_indices, could_masked_vec = select_tiling(vec_dtype)
assert len(tiling_factors) == len(tiling_indices)
if len(tiling_indices) == 1:
vec_kernel = codegen_kernel(
Expand All @@ -3352,9 +3423,21 @@ def select_tiling(dtype: torch.dtype = torch.float):
tiling_indices[0], factor=tiling_factors[0]
)
main_loop.set_kernel(vec_kernel)
tail_loop.set_kernel(scalar_kernel)
main_loop.simd_vec = True
tail_loop.simd_omp = True
if could_masked_vec:
tail_loop.steps = tail_loop.size - tail_loop.offset
masked_vec_kernel = codegen_kernel(
CppVecKernel,
tiling_factors[0],
tiling_indices[0],
vec_dtype,
tail_loop.steps,
)
tail_loop.set_kernel(masked_vec_kernel)
tail_loop.simd_vec = True
else:
tail_loop.set_kernel(scalar_kernel)
tail_loop.simd_omp = True
# We chop the loop into two cubes by the nelements - main loop and tail loop.
# Regarding the main loop, it is straightforward that it could be vectorized with
# nelements. But for the tail loop, it still could be vectorized. For example,
Expand Down
Loading

0 comments on commit 1721a99

Please sign in to comment.