diff --git a/aten/src/ATen/native/Linear.cpp b/aten/src/ATen/native/Linear.cpp index 3a4a8e1fd7f2d..847a2dab5e838 100644 --- a/aten/src/ATen/native/Linear.cpp +++ b/aten/src/ATen/native/Linear.cpp @@ -34,6 +34,12 @@ Tensor linear(const Tensor& input, const Tensor& weight, const c10::optionaldefined() && input.is_contiguous()) { + // Also hit the fused path for contiguous 3D input. + const auto input_sizes = input.sizes(); + const auto result = at::addmm(*bias, input.view({input_sizes[0] * input_sizes[1], input_sizes[2]}), weight.t()); + return result.view({input_sizes[0], input_sizes[1], result.size(1)}); + } auto output = at::matmul(input, weight.t()); if (bias->defined()) { output.add_(*bias);