diff --git a/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredLinear.cu b/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredLinear.cu index 287938962a34..e74d71fe1aff 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredLinear.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredLinear.cu @@ -75,12 +75,6 @@ Tensor two_four_sgemm( using LayoutC = LayoutOutput; constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; - using BiasTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout< - ThreadblockShape, - WarpShape, - ElementC, - AlignmentC, - NumEVTEpilogueStages>; using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout< ThreadblockShape, WarpShape, @@ -94,7 +88,7 @@ Tensor two_four_sgemm( cutlass::epilogue::threadblock::VisitorScalarBroadcast; using BiasTensor = cutlass::epilogue::threadblock::VisitorColBroadcast< - BiasTileThreadMap, + OutputTileThreadMap, ElementC, cute::Stride>; using Bias = std::conditional_t; diff --git a/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu b/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu index 3b0f3a3170ca..35d6559b62ce 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu @@ -70,12 +70,6 @@ void spgemm_cutlass( using LayoutC = LayoutOutput; constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; - using TensorCTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout< - ThreadblockShape, - WarpShape, - ElementC, - AlignmentC, - NumEVTEpilogueStages>; using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout< ThreadblockShape, WarpShape, @@ -105,7 +99,7 @@ void spgemm_cutlass( cutlass::epilogue::threadblock::VisitorScalarBroadcast; using TensorCTensor = cutlass::epilogue::threadblock::VisitorColBroadcast< - TensorCTileThreadMap, + OutputTileThreadMap, ElementC, cute::Stride>; using TensorC = std::conditional_t;