Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
8e331db
[Inductor Cutlass backend] 2 of 2 - Enabling flexible EVT-based poin…
kadeng Feb 9, 2024
c02dc57
Update on " [Inductor Cutlass backend] 2 of 2 - Enabling flexible EVT…
kadeng Feb 9, 2024
4d0b89f
Update on " [Inductor Cutlass backend] 2 of 2 - Enabling flexible EVT…
kadeng Feb 12, 2024
cb91324
Update on " [Inductor Cutlass backend] 2 of 2 - Enabling flexible EVT…
kadeng Feb 12, 2024
c77535e
Update on " [Inductor Cutlass backend] 2 of 2 - Enabling flexible EVT…
kadeng Feb 13, 2024
49bbc6b
Update on " [Inductor Cutlass backend] 2 of 2 - Enabling flexible EVT…
kadeng Feb 16, 2024
9fa6409
Update on " [Inductor Cutlass backend] 2 of 2 - Enabling flexible EVT…
kadeng Feb 20, 2024
d5aadd6
Update on " [Inductor Cutlass backend] 2 of 2 - Enabling flexible EVT…
kadeng Feb 20, 2024
eedad10
Update on " [Inductor Cutlass backend] 2 of 2 - Enabling flexible EVT…
kadeng Feb 21, 2024
f6d3be5
Update on " [Inductor Cutlass backend] 2 of 2 - Enabling flexible EVT…
kadeng Feb 21, 2024
9d773d4
Update on " [Inductor Cutlass backend] 2 of 2 - Enabling flexible EVT…
kadeng Feb 22, 2024
492f161
Update on " [Inductor Cutlass backend] 2 of 2 - Enabling flexible EVT…
kadeng Feb 22, 2024
b884189
Update on " [Inductor Cutlass backend] 2 of 2 - Enabling flexible EVT…
kadeng Feb 22, 2024
c391b8d
Update on " [Inductor Cutlass backend] 2 of 2 - Enabling flexible EVT…
kadeng Feb 22, 2024
82e8a5b
Update on " [Inductor Cutlass backend] 2 of 2 - Enabling flexible EVT…
kadeng Feb 26, 2024
d2d5baa
Update on " [Inductor Cutlass backend] 2 of 2 - Enabling flexible EVT…
kadeng Feb 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions torch/_inductor/codegen/cuda/cuda_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ def generate_kernel_source_for_benchmark(self, input_tensor_meta, kernel, kwargs
getattr(input_node, "layout", None) for input_node in self.input_nodes
]
try:
# temporarily set the strides of input nodes with FlexibleLayouts
# to the strides of the input_tensor_meta
for input_node, input_tensor_meta_variant in zip(
self.input_nodes, input_tensor_meta
):
Expand Down Expand Up @@ -264,6 +266,18 @@ def generate_variants_after_fusion(
epilogue_nodes: List[IRNode],
kwarg_override_variants: Sequence[Dict], # type: ignore[type-arg]
) -> Generator[CUDATemplateCaller, None, None]:
"""
Generates variants of the given CUDATemplateBuffer after fusion with the given epilogue nodes.
May be used to determine the best configuration for a fused kernel, which may differ from
the best configuration for the unfused kernel.

Args:
template_buffer (CUDATemplateBuffer): The CUDATemplateBuffer to generate variants of.
epilogue_nodes (List[IRNode]): The epilogue nodes to fuse with the given CUDATemplateBuffer.
kwarg_override_variants (Sequence[Dict]): A sequence of keyword argument overrides to use for
generating the variants. Will typically be used to override the "op" argument for
CUTLASSGemmTemplates and provide a sequence of ops to try.
"""
original_kwargs = template_buffer.make_kernel_render.render_kwargs
original_template = template_buffer.make_kernel_render.template
assert original_template is self
Expand Down
179 changes: 176 additions & 3 deletions torch/_inductor/codegen/cuda/cutlass_epilogue_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,147 @@
_MAGIC_SYMPY_ERROR_STRING = "[!sympy: unsupported expr!]"


EVT_EXTRA_HEADER = """
namespace cutlass::epilogue::fusion {

using namespace cute;
using namespace detail;

// Sm90AuxLoadDirect implementation ( based on Sm90ColBroadcast )
// which loads directly from global memory, without using shared memory
template<
int _IgnoredStages,
class CtaTileShapeMNK,
class Element,
class StrideMNL,
class _IgnoredSmemLayoutAtom,
class _IgnoredCopyOpS2R,
int Alignment = 128 / sizeof_bits_v<Element>,
bool EnableNullptr = true // Fallback scalar broadcast for nullptr params
>
struct Sm90AuxLoadDirect {
constexpr static int Stages = 0;
static_assert(Stages == 0, "Direct load only supports 0 stages");
static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");

// Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem
struct SharedStorage { };

struct Arguments {
Element const* ptr_col = nullptr;
Element null_default = Element(0);
StrideMNL dCol = {};
};

using Params = Arguments;

template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
return args;
}

template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
return 0;
}

template <class ProblemShape>
static cutlass::Status
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream) {
return cutlass::Status::kSuccess;
}

CUTLASS_DEVICE bool
is_producer_load_needed() const {
return false;
}

CUTLASS_DEVICE bool
is_C_load_needed() const {
return false;
}

CUTLASS_HOST_DEVICE
Sm90AuxLoadDirect() { }

CUTLASS_HOST_DEVICE
Sm90AuxLoadDirect(Params const& params, SharedStorage const& shared_storage)
: params(params) { }

Params params;

template <class... Args>
CUTLASS_DEVICE auto
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
return EmptyProducerLoadCallbacks{};
}

template<class GTensor, class RTensor>
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
CUTLASS_DEVICE
ConsumerStoreCallbacks(GTensor&& tCgCol, RTensor&& tCrCol, Params const& params)
: tCgCol(cute::forward<GTensor>(tCgCol)),
tCrCol(cute::forward<RTensor>(tCrCol)),
params(params) {}

GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
RTensor tCrCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
Params const& params;

CUTLASS_DEVICE void
begin() {
if constexpr (EnableNullptr) {
if (params.ptr_col == nullptr) {
fill(tCrCol, params.null_default);
return;
}
}

// Filter so we don't issue redundant copies over stride-0 modes
// (only works if 0-strides are in same location, which is by construction)
copy_aligned(filter(tCgCol), filter(tCrCol));
}

template <typename ElementAccumulator, int FragmentSize>
CUTLASS_DEVICE Array<Element, FragmentSize>
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
Array<Element, FragmentSize> frg_col;
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);

CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < FragmentSize; ++i) {
frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i);
}

return frg_col;
}

};

template <
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
class... Args
>
CUTLASS_DEVICE auto
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {

auto [M, N, K, L] = args.problem_shape_mnkl;
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol);
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)

return ConsumerStoreCallbacks<decltype(tCgCol), decltype(tCrCol)>(
cute::move(tCgCol), cute::move(tCrCol), params);
}
};

} // namespace cutlass::epilogue::fusion
"""


def _arg_parse(a):
if isinstance(a, sympy.Expr):
return a
Expand Down Expand Up @@ -63,6 +204,12 @@ def __init__(
- accumulator_node_name (str): The name of the output Buffer for the GEMM operation in the original (unfused)
IR graph.
- evt_type_name (str): The output name of the EVT type we are generating.
- pre_fused_evt (Optional[str]): Optional EVT expression declaration that is pre-fused into the template
(typically addmm style bias addition etc.)
- c_operand_alias (Optional[str]): Optional name of the C operand
- gemm_output_layout: Output layout of the GEMM operation.
- flip_mn: Whether to flip the M and N dimensions of the GEMM output layout.
- aux_load_direct: Whether to use direct loads for auxiliary inputs (i.e. load directly from global memory)

"""
self.accumulator_node_name = accumulator_node_name
Expand Down Expand Up @@ -96,6 +243,12 @@ def ir_to_evt_string(
evt_type_name (str): The name of the EVT type.
epilogue_nodes (List[IRNode]): A list of IR nodes representing the epilogue nodes. As of now, these must be
ComputedBuffer nodes wrapping Pointwise nodes.
pre_fused_evt: Optional EVT expression declaration that is pre-fused into the template
(typically addmm style bias addition etc.)
c_operand_alias: Optional name of the C operand
gemm_output_layout: Output layout of the GEMM operation.
flip_mn: Whether to flip the M and N dimensions of the GEMM output layout.
aux_load_direct: Whether to use direct loads for auxiliary inputs (i.e. load directly from global memory)

Returns:
A string representation of the IR nodes formatted according to the Cutlass EVT format.
Expand Down Expand Up @@ -164,6 +317,19 @@ def check_range_and_stride_compatibility(gemm_output_layout, pnode):

@staticmethod
def create_pre_fused_addmm_evt_type() -> str:
"""returns the name of the ADDMM EVT type which has been declared like this:

using ADDMM_EVT = // alpha * acc + beta * C
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::homogeneous_multiply_add,
ElementD, ElementCompute, RoundStyle>, // beta * C + (alpha * acc)
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementScalar>, // beta
cutlass::epilogue::fusion::Sm90SrcFetch, // C
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies,
ElementCompute, ElementCompute, RoundStyle>, // alpha * acc
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementScalar>, // alpha
cutlass::epilogue::fusion::Sm90AccFetch // acc
>>
"""
return "ADDMM_EVT"

def __getattr__(self, name):
Expand Down Expand Up @@ -469,9 +635,6 @@ def _op_load(self, name, index_expr):
elif name == self.c_operand_alias:
return "{}"
else:
raise CUTLASSEVTOpNotImplementedError(
f"Operand {name} not found. Auxiliary inputs not supported yet."
)
if self.dry_run:
return f"{{ /* dry run placeholder for aux input {name} */ }}"
kernel = virtualized.V.kernel
Expand All @@ -492,6 +655,10 @@ def _op_load(self, name, index_expr):
strides: List[int] = map_pointwise_index_to_read_strides(
index_expr, self.gemm_output_layout, self.flip_mn
)
# For the sanity check,
# output size dimensions need to be flipped if flip_mn=True for the GEMM
# since the output layout (incl. sizes) might have been flipped
# from what is actually written
gemm_write_sizes = list(self.gemm_output_layout.size)
if self.flip_mn:
gemm_write_sizes[-1], gemm_write_sizes[-2] = (
Expand All @@ -502,6 +669,7 @@ def _op_load(self, name, index_expr):
load_stride_max_idx = sum(
[stride * (dim - 1) for stride, dim in zip(strides, gemm_write_sizes)]
)
# The strides might have reinterpreted the buffer, but they may not read beyond it's bounds, let's check that...
assert (
load_stride_max_idx < aux_input_node.get_numel()
), f"Aux input would read beyond bounds (A): Load stride {strides} for node {aux_input_node.get_name()} with layout {aux_input_node.get_layout()} - accessed using index expr {index_expr} is too large for the node when mapped onto GEMM with output layout {self.gemm_output_layout}." # noqa: B950
Expand Down Expand Up @@ -612,6 +780,10 @@ def create_cutlass_aux_load_descriptor(
strides: List[int] = map_pointwise_index_to_read_strides(
index_expr, gemm_output_layout, flip_mn
)
# For the sanity check,
# output size dimensions need to be flipped if flip_mn=True for the GEMM
# since the output layout (incl. sizes) might have been flipped
# from what is actually written
gemm_write_sizes = list(gemm_output_layout.size)
if flip_mn:
gemm_write_sizes[-1], gemm_write_sizes[-2] = (
Expand All @@ -621,6 +793,7 @@ def create_cutlass_aux_load_descriptor(
load_stride_max_idx = sum(
[stride * (dim - 1) for stride, dim in zip(strides, gemm_write_sizes)]
)
# The strides might have reinterpreted the buffer, but they may not read beyond it's bounds, let's check that...
assert (
load_stride_max_idx < node.get_numel()
), f"Aux input would read beyond bounds (B): Load stride {strides} for node {node.get_name()} with layout {node.get_layout()} - accessed using index expr {index_expr} is too large for the node when mapped onto GEMM with output layout {gemm_output_layout}." # noqa: B950
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,35 @@ def __init__(self, operation_suffix=""):
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
using ElementAcc = ${element_accumulator};
using ElementD = ${element_d};
using ElementC = ${element_c};
using TileShapeMNK = cute::Shape<cute::_${tile_shape_m}, cute::_${tile_shape_n}, cute::_${tile_shape_k}>;
using ClusterShapeMNK = cute::Shape<cute::_${cluster_m},cute::_${cluster_n},cute::_${cluster_k}>;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor<
TileShapeMNK,
EpilogueTileType,
ElementC,
ElementD,
EpilogueScheduleType
>;

using ADDMM_EVT = // alpha * acc + beta * C
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::homogeneous_multiply_add,
ElementD, ElementAcc, RoundStyle>, // beta * C + (alpha * acc)
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementAcc>, // beta
cutlass::epilogue::fusion::Sm90SrcFetch, // C
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies, ElementAcc,
ElementAcc, RoundStyle>, // alpha * acc
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementAcc>, // alpha
cutlass::epilogue::fusion::Sm90AccFetch // acc
>>;
${epilogue_functor};
using ${operation_name}_epilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
${arch}, ${opcode_class},
cute::Shape<cute::_${tile_shape_m}, cute::_${tile_shape_n}, cute::_${tile_shape_k}>,
cute::Shape<cute::_${cluster_m},cute::_${cluster_n},cute::_${cluster_k}>,
cutlass::epilogue::collective::EpilogueTileAuto,
TileShapeMNK,
ClusterShapeMNK,
EpilogueTileType,
${element_accumulator}, ${element_epilogue},
${element_c}, ${layout_c}, ${align_c},
${element_d}, ${layout_d}, ${align_d},
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/codegen/cuda/cutlass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

# This is a feature flag. If _DISABLE_CUTLASS_BACKEND is set to true,
# then the CUDA / CUTLASS backend is entirely disabled, including it's unit tests
_DISABLE_CUTLASS_BACKEND = True
_DISABLE_CUTLASS_BACKEND = False


def _rename_cutlass_import(content: str, cutlass_modules: List[str]) -> str:
Expand Down
Loading