Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[INDUCTOR] [CPU] [GPT-FAST-MOE] large perf regression with coordinate_descent_tuning disabled #124697

Open
Valentine233 opened this issue Apr 23, 2024 · 3 comments
Assignees
Labels
oncall: cpu inductor CPU Inductor issues for Intel team to triage oncall: pt2

Comments

@Valentine233
Copy link
Collaborator

Valentine233 commented Apr 23, 2024

🐛 Describe the bug

When the flag coordinate_descent_tuning disabled, GPT-FAST-MOE encounters a large perf regression: 52s -> 1049s. The impact of disabling it on CPU is to fallback bmm and mm in decomposition.

Code snippet

import torch
from torch import nn, Tensor
from torch.nn import functional as F
import torch._inductor.config

torch._inductor.config.cpp.enable_kernel_profile = True
torch._inductor.config.coordinate_descent_tuning = False # True

dim = 4096
num_experts = 8
num_activated_experts = 2
intermediate_size = 14336

class ConditionalFeedForwardBit8(nn.Module):
    def __init__(self, target_dtype):
        super().__init__()

        self.target_dtype = target_dtype

        self.register_buffer("w1", torch.empty(num_experts, intermediate_size, dim, dtype=target_dtype))
        self.register_buffer("w2", torch.empty(num_experts, dim, intermediate_size, dtype=target_dtype))
        self.register_buffer("w3", torch.empty(num_experts, intermediate_size, dim, dtype=target_dtype))

        self.register_buffer("scales1", torch.empty(num_experts, intermediate_size, dtype=torch.bfloat16))
        self.register_buffer("scales2", torch.empty(num_experts, dim, dtype=torch.bfloat16))
        self.register_buffer("scales3", torch.empty(num_experts, intermediate_size, dtype=torch.bfloat16))

    def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor:
        w1_weights = self.w1.to(x.dtype)[expert_indices] # [T, A, D, D]
        w3_weights = self.w3.to(x.dtype)[expert_indices] # [T, A, D, D]
        w2_weights = self.w2.to(x.dtype)[expert_indices]
        x1 = F.silu(torch.einsum('ti,taoi -> tao', x, w1_weights) * self.scales1[expert_indices].to(x.dtype))
        x3 = torch.einsum('ti, taoi -> tao', x, w3_weights) * self.scales3[expert_indices].to(x.dtype)
        expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights) * self.scales2[expert_indices].to(x.dtype)  # [T, A, D, D]
        return expert_outs

def linear_forward_int8(x, weight_int8pack, scales, out_features):
    origin_x_size = x.size()
    x = x.reshape(-1, origin_x_size[-1])
    c = torch.ops.aten._weight_int8pack_mm(x, weight_int8pack, scales)
    new_shape = origin_x_size[:-1] + (out_features,)
    c = c.reshape(new_shape)
    return c

class WeightOnlyBit8Linear(torch.nn.Module):
    __constants__ = ['in_features', 'out_features']
    in_features: int
    out_features: int
    weight: torch.Tensor

    def __init__(self, in_features: int, out_features: int, bias: bool = True,
                 device=None, dtype=None, target_dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.register_buffer("weight", torch.empty((out_features, in_features), dtype=target_dtype))
        self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return linear_forward_int8(
           input,
           self.weight, self.scales, self.out_features)

class MOEFeedForward(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.gate = WeightOnlyBit8Linear(dim, num_experts, bias=False, target_dtype=torch.int8)
        self.cond_ffn = ConditionalFeedForwardBit8(torch.int8)
        self.dim = dim
        self.num_activated_experts = num_activated_experts

    def forward(self, x: Tensor) -> Tensor:
        x = x.view(-1, self.dim)
        # T = num_tokens, E = num_experts, D = hidden dim, A = activated experts
        # x: [T, D]
        scores = self.gate(x) # [T, E]
        expert_weights = F.softmax(scores, dim=-1)
        expert_weights, expert_indices = torch.topk(expert_weights, self.num_activated_experts, dim=-1) # [T, A], [T, A]
        expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A]
        expert_outs = self.cond_ffn(x, expert_indices)
        return torch.einsum('tai,ta -> ti', expert_outs, expert_weights)


mod = MOEFeedForward()
input_shape = (1, 1, dim,)
x = torch.randn(input_shape, dtype=torch.bfloat16)
compiled_mod = torch.compile(mod)

for i in range(5):
  y = compiled_mod(x)

prof = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU])
with prof:
  y = compiled_mod(x) 
prof.export_chrome_trace("test_moe.json")
print(prof.key_averages().table(sort_by="self_cpu_time_total"))

Profiling

coordinate_descent_tuning=True


---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                             Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
      graph_0_cpp_fused_bmm_sum_1        82.07%     884.479us        82.07%     884.479us     884.479us             1
            Torch-Compiled Region        11.72%     126.292us        98.82%       1.065ms       1.065ms             1
        aten::_weight_int8pack_mm         2.65%      28.601us         3.03%      32.667us      32.667us             1
                       aten::topk         1.33%      14.381us         1.33%      14.381us      14.381us             1
         TorchDynamo Cache Lookup         1.18%      12.767us         1.18%      12.767us      12.767us             1
    inductor::_reinterpret_tensor         0.60%       6.493us         0.60%       6.493us       2.164us             3
                      aten::empty         0.38%       4.066us         0.38%       4.066us       4.066us             1
     graph_0_cpp_fused__softmax_0         0.06%       0.698us         0.06%       0.698us       0.698us             1
---------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 1.078ms

coordinate_descent_tuning=False

-----------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
-----------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
       graph_0_cpp_fused__to_copy_index_1        61.29%      96.359ms        61.29%      96.359ms      96.359ms             1
                              aten::copy_        33.98%      53.422ms        33.98%      53.422ms      17.807ms             3
                                aten::bmm         4.16%       6.542ms        38.20%      60.056ms      15.014ms             4
                    Torch-Compiled Region         0.35%     550.429us        99.98%     157.187ms     157.187ms             1
       graph_0_cpp_fused_index_mul_silu_2         0.07%     106.154us         0.07%     106.154us     106.154us             1
                aten::_weight_int8pack_mm         0.02%      37.146us         0.03%      42.080us      42.080us             1
            inductor::_reinterpret_tensor         0.02%      37.033us         0.02%      37.033us       3.086us            12
                              aten::clone         0.02%      33.289us        34.02%      53.486ms      17.829ms             3
                 TorchDynamo Cache Lookup         0.02%      31.353us         0.02%      31.353us      31.353us             1
                               aten::topk         0.02%      24.819us         0.02%      24.819us      24.819us             1
                              aten::empty         0.01%      22.710us         0.01%      22.710us       5.678us             4
                         aten::contiguous         0.01%      16.971us        34.03%      53.503ms      17.834ms             3
                         aten::empty_like         0.01%      12.382us         0.02%      30.158us      10.053us             3
    graph_0_cpp_fused_div_index_mul_sum_3         0.01%       9.925us         0.01%       9.925us       9.925us             1
                         aten::as_strided         0.00%       7.801us         0.00%       7.801us       3.900us             2
                       aten::resolve_conj         0.00%       3.070us         0.00%       3.070us       0.384us             8
             graph_0_cpp_fused__softmax_0         0.00%       1.890us         0.00%       1.890us       1.890us             1
-----------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 157.218ms

Analysis

According to the current analysis, there are two main reasons:

  • With flag disabled, bmm breaks the origin c++ kernel and some redundant calculations are generated to store uint8 to bf16 and then convert to fp32. The conversion from uint8 to bf16 takes a long time.
  • Many of the bmm kernels do additional contiguous in Mkldnn Matmul for non-contiguous and non-transposed format, which could not be solved by [Inductor] add contiguous layout optm for bmm input #122599).

Versions

PyTorch: 34bce27

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @jgong5 @leslie-fang-intel @yanbing-j @mingfeima

@Valentine233 Valentine233 self-assigned this Apr 23, 2024
@Valentine233
Copy link
Collaborator Author

With the following optimizations:

  1. Add VecConvert for uint8 to bf16.
    Improvement of graph_0_cpp_fused__to_copy_index_1: 96.359ms -> 35.693ms.
cpp_fused__to_copy_index_1 = async_compile.cpp_pybinding(['const int64_t*', 'const signed char*', 'const signed char*', 'const signed char*', 'bfloat16*', 'bfloat16*', 'bfloat16*'], '''
#include <ATen/record_function.h>
#include "/tmp/torchinductor_liaoxuan/nd/cndd7co72iqjtof53ikp4l7yibmqrbjkni3cu6xj5p7hywloe5yg.h"
extern "C" void kernel(const int64_t* in_ptr0,
                       const signed char* in_ptr1,
                       const signed char* in_ptr2,
                       const signed char* in_ptr3,
                       bfloat16* out_ptr0,
                       bfloat16* out_ptr1,
                       bfloat16* out_ptr2)
{
    RECORD_FUNCTION("graph_0_cpp_fused__to_copy_index_1", c10::ArrayRef<c10::IValue>({}));
    #pragma omp parallel num_threads(112)
    {
        int tid = omp_get_thread_num();
        {
            #pragma omp for
            for(long x0=static_cast<long>(0L); x0<static_cast<long>(2L); x0+=static_cast<long>(1L))
            {
                for(long x1=static_cast<long>(0L); x1<static_cast<long>(58720256L); x1+=static_cast<long>(16L))
                {
                    auto tmp0 = in_ptr0[static_cast<long>(x0)];
                    auto tmp1 = decltype(tmp0)(tmp0 + 8);
                    auto tmp2 = tmp0 < 0;
                    auto tmp3 = tmp2 ? tmp1 : tmp0;
                    TORCH_CHECK((0 <= tmp3) & (tmp3 < 8L), "index out of bounds: 0 <= tmp3 < 8L")
                    auto tmp4 = at::vec::Vectorized<signed char>::loadu(in_ptr1 + static_cast<long>(x1 + (58720256L*tmp3)), 16);
                    auto tmp5 = at::vec::convert<bfloat16>(tmp4);
                    auto tmp6 = at::vec::Vectorized<signed char>::loadu(in_ptr2 + static_cast<long>(x1 + (58720256L*tmp3)), 16);
                    auto tmp7 = at::vec::convert<bfloat16>(tmp6);
                    auto tmp8 = at::vec::Vectorized<signed char>::loadu(in_ptr3 + static_cast<long>(x1 + (58720256L*tmp3)), 16);
                    auto tmp9 = at::vec::convert<bfloat16>(tmp8);
                    tmp5.store(out_ptr0 + static_cast<long>(x1 + (58720256L*x0)), 16);
                    tmp7.store(out_ptr1 + static_cast<long>(x1 + (58720256L*x0)), 16);
                    tmp9.store(out_ptr2 + static_cast<long>(x1 + (58720256L*x0)), 16);
                }
            }
        }
    }
}
''')
  1. Extend the condition of is_mkldnn_optimized_format in Mkldnn Matmul: accept stride[0]=0.
    Input example: size=[1, 4096, 2], stride=[0, 1, 4096].
    Improvement of aten::bmm: 60.056ms -> 3.738ms.

Overall profiling

-----------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
-----------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
       graph_0_cpp_fused__to_copy_index_1        89.42%      35.693ms        89.42%      35.693ms      35.693ms             1
                                aten::bmm         9.35%       3.731ms         9.36%       3.738ms     934.462us             4
                    Torch-Compiled Region         0.70%     277.698us        99.94%      39.890ms      39.890ms             1
       graph_0_cpp_fused_index_mul_silu_2         0.24%      96.978us         0.24%      96.978us      96.978us             1
                aten::_weight_int8pack_mm         0.09%      35.892us         0.10%      41.160us      41.160us             1
                 TorchDynamo Cache Lookup         0.06%      24.675us         0.06%      24.675us      24.675us             1
                               aten::topk         0.05%      20.133us         0.05%      20.133us      20.133us             1
            inductor::_reinterpret_tensor         0.03%      12.107us         0.03%      12.107us       1.009us            12
    graph_0_cpp_fused_div_index_mul_sum_3         0.02%       9.435us         0.02%       9.435us       9.435us             1
                         aten::as_strided         0.01%       5.363us         0.01%       5.363us       2.682us             2
                              aten::empty         0.01%       5.268us         0.01%       5.268us       5.268us             1
                       aten::resolve_conj         0.00%       1.367us         0.00%       1.367us       0.171us             8
             graph_0_cpp_fused__softmax_0         0.00%       1.224us         0.00%       1.224us       1.224us             1
-----------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 39.915ms

@malfet malfet added oncall: pt2 oncall: cpu inductor CPU Inductor issues for Intel team to triage labels Apr 23, 2024
@Valentine233
Copy link
Collaborator Author

With bmm fallback, weight is converted from int8 to bf16 and Onednn uses bf16 weight. With bmm decomposition, type conversion and bmm are fused in one cpp kernel. Bmm fallback leads to the regression because the case is memory bound with batch size 1.
Synced with Jiong, it is better to decompose bmm for memory bound case in lowering.

pytorchmergebot pushed a commit that referenced this issue Apr 27, 2024
The perf benefit was found in #124697 (comment).

The PR adds intrinsic specializations between int8/uint8 and bf16/fp16.

Pull Request resolved: #124828
Approved by: https://github.com/jgong5, https://github.com/jansel
carmocca pushed a commit to carmocca/pytorch that referenced this issue Apr 29, 2024
The perf benefit was found in pytorch#124697 (comment).

The PR adds intrinsic specializations between int8/uint8 and bf16/fp16.

Pull Request resolved: pytorch#124828
Approved by: https://github.com/jgong5, https://github.com/jansel
carmocca pushed a commit to carmocca/pytorch that referenced this issue Apr 29, 2024
…124826)

Fixes pytorch#124697. Resolve the issue of large regression of GPT-FAST MOE with `coordinate_descent_tuning` disabled.

To get better perf for memory bound case, we decompose bmm in lowering.

Pull Request resolved: pytorch#124826
Approved by: https://github.com/jgong5, https://github.com/jansel
andoorve pushed a commit to andoorve/pytorch that referenced this issue May 1, 2024
The perf benefit was found in pytorch#124697 (comment).

The PR adds intrinsic specializations between int8/uint8 and bf16/fp16.

Pull Request resolved: pytorch#124828
Approved by: https://github.com/jgong5, https://github.com/jansel
andoorve pushed a commit to andoorve/pytorch that referenced this issue May 1, 2024
…124826)

Fixes pytorch#124697. Resolve the issue of large regression of GPT-FAST MOE with `coordinate_descent_tuning` disabled.

To get better perf for memory bound case, we decompose bmm in lowering.

Pull Request resolved: pytorch#124826
Approved by: https://github.com/jgong5, https://github.com/jansel
petrex pushed a commit to petrex/pytorch that referenced this issue May 3, 2024
The perf benefit was found in pytorch#124697 (comment).

The PR adds intrinsic specializations between int8/uint8 and bf16/fp16.

Pull Request resolved: pytorch#124828
Approved by: https://github.com/jgong5, https://github.com/jansel
pytorch-bot bot pushed a commit that referenced this issue May 3, 2024
Fixes #124697. Resolve the issue of large regression of GPT-FAST MOE with `coordinate_descent_tuning` disabled.

To get better perf for memory bound case, we decompose bmm in lowering.

Pull Request resolved: #124826
Approved by: https://github.com/jgong5, https://github.com/jansel
@Valentine233
Copy link
Collaborator Author

The fixed PR #124826 could harm the perf of LLAMA2. Hence, we need to further investigate other optimization methods for the issue.

@Valentine233 Valentine233 reopened this Jun 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: cpu inductor CPU Inductor issues for Intel team to triage oncall: pt2
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants