Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 55 additions & 19 deletions aten/src/ATen/native/cuda/RowwiseScaledMM.cu
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,6 @@ void f8f8bf16_rowwise_impl(
int N = WQ.size(1);
int K = XQ.size(1);

TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous());
TORCH_CHECK(
WQ.is_cuda() && WQ.ndimension() == 2 && WQ.stride(1) == WQ.size(0) &&
WQ.stride(0) == 1);

using LayoutInputA = cutlass::layout::RowMajor;
constexpr int AlignmentInputA = 16 / sizeof(DtypeA);

Expand Down Expand Up @@ -235,11 +230,11 @@ void f8f8bf16_rowwise_impl(
using StrideOutput = typename Gemm::GemmKernel::StrideC;

StrideInputA stride_a = cutlass::make_cute_packed_stride(
StrideInputA{}, cute::make_shape(M, K, 1));
StrideInputA{}, cute::make_shape(M, static_cast<int>(XQ.stride(0)), 1));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uhm, no, it's just saying that the innermost dimension (i.e., K) is contiguous. It would be equivalent to put XQ.stride(1) here, but since we have previously checked that this equals 1, I just hardcoded 1 right away.

Copy link
Contributor

@drisspg drisspg Aug 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is cutlass noob Q I guess I am suprrised that the strides appear to be of len 3 instead of 2, e.g. M, 1, versus (M, 1, 1)

It seems the trailing dim is refers to a L (local tile) ?

StrideInputB stride_b = cutlass::make_cute_packed_stride(
StrideInputB{}, cute::make_shape(N, K, 1));
StrideInputB{}, cute::make_shape(N, static_cast<int>(WQ.stride(1)), 1));
StrideOutput stride_output = cutlass::make_cute_packed_stride(
StrideOutput{}, cute::make_shape(M, N, 1));
StrideOutput{}, cute::make_shape(M, static_cast<int>(out.stride(0)), 1));

typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
Expand Down Expand Up @@ -360,6 +355,57 @@ void dispatch_fp8_rowwise_kernel(
}
}

void check_inputs(
const at::Tensor& a,
const at::Tensor& b,
const at::Tensor& scale_a,
const at::Tensor& scale_b,
const std::optional<at::Tensor>& bias,
const at::Tensor& out) {
TORCH_CHECK(a.is_cuda());
TORCH_CHECK(a.device() == b.device());
TORCH_CHECK(scale_a.device() == a.device());
TORCH_CHECK(scale_b.device() == b.device());

TORCH_CHECK(a.dtype() == at::kFloat8_e4m3fn || a.dtype() == at::kFloat8_e5m2);
TORCH_CHECK(b.dtype() == at::kFloat8_e4m3fn);
TORCH_CHECK(scale_a.dtype() == at::kFloat);
TORCH_CHECK(scale_b.dtype() == at::kFloat);

TORCH_CHECK(a.dim() == 2);
TORCH_CHECK(b.dim() == 2);
TORCH_CHECK(a.size(1) == b.size(0));
TORCH_CHECK(scale_a.dim() == 2);
TORCH_CHECK(scale_b.dim() == 2);
TORCH_CHECK(scale_a.size(0) == a.size(0));
TORCH_CHECK(scale_a.size(1) == 1);
TORCH_CHECK(scale_b.size(0) == 1);
TORCH_CHECK(scale_b.size(1) == b.size(1));

TORCH_CHECK(a.stride(1) == 1);
TORCH_CHECK(a.stride(0) >= a.size(1));
TORCH_CHECK(b.stride(0) == 1);
TORCH_CHECK(b.stride(1) >= b.size(0));
TORCH_CHECK(scale_a.stride(0) == 1);
TORCH_CHECK(scale_b.stride(1) == 1);

if (bias.has_value()) {
TORCH_CHECK(bias->device() == b.device());
TORCH_CHECK(bias->dtype() == at::kFloat || bias->dtype() == at::kBFloat16);
TORCH_CHECK(bias->dim() == 1);
TORCH_CHECK(bias->size(0) == b.size(1));
TORCH_CHECK(bias->stride(0) == 1);
}

TORCH_CHECK(out.device() == a.device());
TORCH_CHECK(out.dtype() == at::kBFloat16);
TORCH_CHECK(out.dim() == 2);
TORCH_CHECK(out.size(0) == a.size(0));
TORCH_CHECK(out.size(1) == b.size(1));
TORCH_CHECK(out.stride(1) == 1);
TORCH_CHECK(out.stride(0) >= out.size(1));
}

} // namespace

#endif // !defined(USE_ROCM)
Expand All @@ -374,22 +420,12 @@ void f8f8bf16_rowwise(
bool use_fast_accum,
at::Tensor& out) {
#if defined(BUILD_ROWWISE_FP8_KERNEL)
// Check datatypes.
TORCH_CHECK(
x_scale.dtype() == at::kFloat && w_scale.dtype() == at::kFloat,
"Scale tensors must be float32.");
if (bias.has_value()) {
TORCH_CHECK(
bias.value().dtype() == at::kFloat ||
bias.value().dtype() == at::kBFloat16,
"Bias type must be bfloat16 or float32 if provided.");
}
check_inputs(XQ, WQ, x_scale, w_scale, bias, out);

bool bf16_bias = bias.has_value() && bias->dtype() == at::kBFloat16;

// Templatize based on input dtype.
bool use_e5m2 = XQ.dtype() == at::kFloat8_e5m2;
TORCH_CHECK(WQ.dtype() == at::kFloat8_e4m3fn, "For RowWise scaling the second input is required to be a float8_e4m3fn dtype.");

if (bf16_bias) {
if (use_fast_accum) {
Expand Down