Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#6 from tianyan01/v2.4.2
Browse files Browse the repository at this point in the history
add fmt support int8 and fix some int8 bug
  • Loading branch information
laipaang committed Nov 28, 2023
2 parents 456ab5b + a96bfb1 commit 142bef2
Show file tree
Hide file tree
Showing 21 changed files with 2,222 additions and 452 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel<T> {
ffn1_dropout_mask_data,
ffn1_in_scale[i],
ffn1_out_scales[i]->data<float>(),
0,
ffn2_in_scale[i],
quant_round_type,
quant_max_bound,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ namespace paddle {
namespace operators {

using Tensor = phi::DenseTensor;
// #define _DEBUG_FUSED_MULTI_TRANSFORMER

template <typename T>
static void PrintMatrix(const T* mat_d, int num, std::string name) {
Expand Down Expand Up @@ -72,7 +73,7 @@ class FusedMultiTransformerMoeINT8OpKernel : public framework::OpKernel<T> {
auto quant_max_bound = ctx.Attr<float>("quant_max_bound");
auto quant_min_bound = ctx.Attr<float>("quant_min_bound");

// dequant output scales, tensor, size = [num_layers, n], n is gemm output
// dequant output scales, vertor<tensor>, size = [num_layers, n], n is gemm output
// size
auto qkv_out_scales = ctx.MultiInput<Tensor>("QKVOutScale");
auto out_linear_out_scales =
Expand Down Expand Up @@ -164,7 +165,7 @@ class FusedMultiTransformerMoeINT8OpKernel : public framework::OpKernel<T> {
qktv_out.Resize({{bsz, num_head, seq_len, dim_head}});
auto *qktv_out_data =
dev_ctx.Alloc<T>(&qktv_out, qktv_out.numel() * sizeof(T));
fmha_out.Resize({{bsz_seq, num_head, dim_head}});
fmha_out.Resize({{bsz, seq_len, num_head, dim_head}});
auto *fmha_out_data =
dev_ctx.Alloc<T>(&fmha_out, fmha_out.numel() * sizeof(T));

Expand Down Expand Up @@ -231,7 +232,7 @@ class FusedMultiTransformerMoeINT8OpKernel : public framework::OpKernel<T> {
topk_value.Resize({{sliced_bsz_seq, topk}});
dev_ctx.Alloc<T>(&topk_value, topk_value.numel() * sizeof(T));
topk_idx.Resize({{sliced_bsz_seq, topk}});
dev_ctx.Alloc<T>(&topk_idx, topk_idx.numel() * sizeof(T));
dev_ctx.Alloc<int64_t>(&topk_idx, topk_idx.numel() * sizeof(int64_t));
// local expert count, global expert count
Tensor local_expert_count, global_expert_count;
local_expert_count.Resize({{tot_expert}});
Expand Down Expand Up @@ -424,7 +425,7 @@ class FusedMultiTransformerMoeINT8OpKernel : public framework::OpKernel<T> {
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step3.2 out linear";
#endif
// T -> int8
// T -> int32
out_linear_compute.ComputeForwardTToINT8(out_linear_weights[i],
out_linear_in_scale[i],
&fmha_out,
Expand All @@ -444,7 +445,6 @@ class FusedMultiTransformerMoeINT8OpKernel : public framework::OpKernel<T> {
#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step4";
#endif

// step5. ln(residual + dropout(input + bias))
auto *ln_scale_data = ffn_ln_scales[i]->data<U>();
auto *ln_bias_data = ffn_ln_biases[i]->data<U>();
Expand All @@ -455,7 +455,6 @@ class FusedMultiTransformerMoeINT8OpKernel : public framework::OpKernel<T> {
// 改为输出先不做scale,输出是fp16,输出到buf0
AffineQuantStore<T, LayerNormComputeType, T, false, true> store(buf0.data<T>(), dim_embed, ln_scale_data, ln_bias_data);
DispatchLayerNorm<decltype(load), decltype(store), LayerNormComputeType>(dev_ctx.stream(), load, store, bsz_seq, dim_embed, epsilon, ln_mean_data, ln_var_data);

#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER
VLOG(0) << "step5";
#endif
Expand Down Expand Up @@ -564,9 +563,6 @@ class FusedMultiTransformerMoeINT8OpKernel : public framework::OpKernel<T> {
all_expert_out.Resize({{fwd_bsz, dim_embed}});
dev_ctx.Alloc<T>(&all_expert_out, all_expert_out.numel() * sizeof(T));

// global_scatter_out.Resize({{fwd_bsz, dim_embed}});
// all_expert_out.Resize({{fwd_bsz, dim_embed}});

// step 5, MOEScatter
// step 5.1, index select
// suppose tmp_pos->shape != [0]
Expand Down Expand Up @@ -614,19 +610,16 @@ class FusedMultiTransformerMoeINT8OpKernel : public framework::OpKernel<T> {
int end = cur_expert_count + last_index;

Tensor expert_in_tmp; // int8_t
expert_in_tmp.Resize({{(cur_expert_count * dim_feedforward + 31) / 32 * 32 }});
expert_in_tmp.Resize({{cur_expert_count, dim_feedforward}});
dev_ctx.Alloc<int8_t>(&expert_in_tmp, expert_in_tmp.numel() * sizeof(int8_t));

Tensor expert_out1; // int32_t
expert_out1.Resize({{(cur_expert_count * dim_feedforward + 31) / 32 * 32}});
expert_out1.Resize({{cur_expert_count, dim_feedforward}});
dev_ctx.Alloc<int32_t>(&expert_out1, expert_out1.numel() * sizeof(int32_t));

Tensor expert_out2; // T(fp16)
expert_out2.Resize({{cur_expert_count, dim_embed}});
dev_ctx.Alloc<T>(&expert_out2, expert_out2.numel() * sizeof(T));
// act_bias_out.Resize({{cur_expert_count, dim_feedforward}}); maybe int8_t?
// maybe use input_workspace and output workspace?
// dev_ctx.Alloc<T>(&act_bias_out, act_bias_out.numel() * sizeof(T));

// input is int32_t, output is int8_t
FusedDropoutHelper<T, uint8_t, int32_t, int8_t> fused_act_dropout_helper(
Expand Down Expand Up @@ -654,7 +647,7 @@ class FusedMultiTransformerMoeINT8OpKernel : public framework::OpKernel<T> {
expert_out1.data<int32_t>(),
expert_biases1[expert_idx]->data<T>(),
"gelu",
expert_in_tmp.data<int8_t>(),
expert_in_tmp.data<int8_t>(), // output
nullptr,
expert_weight1_in_scale[expert_idx],
expert_weight1_out_scales[expert_idx]->data<float>(),
Expand All @@ -668,7 +661,7 @@ class FusedMultiTransformerMoeINT8OpKernel : public framework::OpKernel<T> {
MatMulINT8ToT<T>(dev_ctx,
expert_weights2[expert_idx],
expert_weight2_in_scale[expert_idx],
&expert_in_tmp,
&expert_in_tmp, // input
expert_biases2[expert_idx],
&expert_out2,
&expert_out1, // output_tmp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ class FusedMultiTransformerMoeOpKernel : public framework::OpKernel<T> {
topk_value.Resize({{sliced_bsz_seq, topk}});
dev_ctx.Alloc<T>(&topk_value, topk_value.numel() * sizeof(T));
topk_idx.Resize({{sliced_bsz_seq, topk}});
dev_ctx.Alloc<T>(&topk_idx, topk_idx.numel() * sizeof(T));
dev_ctx.Alloc<int64_t>(&topk_idx, topk_idx.numel() * sizeof(int64_t));
// local expert count, global expert count
Tensor local_expert_count, global_expert_count;
local_expert_count.Resize({{tot_expert}});
Expand Down Expand Up @@ -642,7 +642,6 @@ class FusedMultiTransformerMoeOpKernel : public framework::OpKernel<T> {

Tensor tmp_inp = global_scatter_out.Slice(last_index, end);
int expert_idx = i * num_expert + idx;

// linear1 matmul
// VLOG(0) << "moe, Expert Computation, linear1 mul";
phi::MatMulAndAdd<T>(dev_ctx,
Expand Down
10 changes: 5 additions & 5 deletions paddle/fluid/operators/fused/fused_multi_transformer_moe_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,11 +216,11 @@ void MatMulTToINT8(const phi::GPUContext& dev_ctx,
dev_ctx.stream());

helper->GEMM(input_tmp->data<int8_t>(),
weight->data<int8_t>(),
output->data<int32_t>(),
dev_ctx.stream(),
(void*)workspace->data<int8_t>(),
workspace->numel());
weight->data<int8_t>(),
output->data<int32_t>(),
dev_ctx.stream(),
(void*)workspace->data<int8_t>(),
workspace->numel());
}

template <typename T>
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/fused/fused_residual_dropout_bias.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
for (int ii = 0; ii < VecSize; ii++) {
T tmp;
if (std::is_same<InType, int32_t>::value) {
T tmp0 = static_cast<T>(static_cast<float>(src_vec[ii]) *
quant_last_in_scale / quant_out_scale_vec[ii]);
T tmp0 = static_cast<T>(static_cast<float>(src_vec[ii]) *
quant_out_scale_vec[ii]);
tmp = tmp0 + bias_vec[ii];
} else {
tmp = static_cast<T>(src_vec[ii]) + bias_vec[ii];
Expand Down
8 changes: 5 additions & 3 deletions paddle/fluid/operators/fused/layernorm_quant_dequant.h
Original file line number Diff line number Diff line change
Expand Up @@ -1044,7 +1044,9 @@ struct DequantSkipLoadAndStoreResidual {
src_pack.storage = *(reinterpret_cast<const PackType<InputType, N>*>(src) + offset);
bias_pack.storage = *(reinterpret_cast<const PackType<SRC, N>*>(bias) + bias_offset);
skip_pack.storage = *(reinterpret_cast<const PackType<SRC, N>*>(skip) + offset);
dequant_scale_pack.storage = *(reinterpret_cast<const PackType<float, N>*>(dequant_scale) + bias_offset); // equal to col.
if (do_dequant) {
dequant_scale_pack.storage = *(reinterpret_cast<const PackType<float, N>*>(dequant_scale) + bias_offset); // equal to col.
}
#pragma unroll
for (int i = 0; i < N; ++i) {
// First we need to cast src and dequant.
Expand All @@ -1053,8 +1055,8 @@ struct DequantSkipLoadAndStoreResidual {
+ bias_pack.elem[i]
+ skip_pack.elem[i]);
} else {
residual_out_pack.elem[i] = static_cast<DST>(static_cast<DST>(src_pack.elem[i]) + bias_pack.elem[i]
+ skip_pack.elem[i]);
// trick for smoe, dont add bias.
residual_out_pack.elem[i] = static_cast<DST>(static_cast<DST>(src_pack.elem[i]) + skip_pack.elem[i]);
}
}
#pragma unroll
Expand Down
5 changes: 5 additions & 0 deletions python/paddle/fluid/dygraph/amp/auto_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ def pure_fp16_initialize(models):
for idx in range(len(models)):
for layer in models[idx].sublayers(include_self=True):
layer._casted_by_pure_fp16 = True
if isinstance(layer, paddle.incubate.nn.FusedMultiTransformerMoeINT8):
continue
if (layer._dtype == 'float16') or isinstance(
layer, (paddle.nn.BatchNorm, paddle.nn.BatchNorm1D,
paddle.nn.BatchNorm2D, paddle.nn.BatchNorm3D,
Expand All @@ -196,6 +198,9 @@ def pure_fp16_initialize(models):
paddle.incubate.nn.FusedMoELayer)):
layer._amp_decorate(dtype='float16')
continue
# if isinstance(layer, paddle.incubate.nn.FusedMultiTransformerMoeINT8):
# layer._amp_decorate(dtype='int8')
# continue
layer._to_impl(dtype='float16',
include_sublayers=False,
floating_only=True)
Expand Down
23 changes: 23 additions & 0 deletions python/paddle/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -1281,6 +1281,29 @@ def _debug_string_(proto, throw_on_error=True):
return proto.__str__()


def _create_tensor(
type=core.VarDesc.VarType.LOD_TENSOR,
name=None,
shape=None,
dtype=None,
persistable=None,
**kwargs,
):
if dtype is not None:
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)

eager_tensor = core.eager.Tensor(
dtype if dtype else core.VarDesc.VarType.FP32,
list(shape) if shape else [],
name,
type if type else core.VarDesc.VarType.LOD_TENSOR,
True if persistable else False,
)
eager_tensor.retain_grads()
return eager_tensor


def _varbase_creator(
type=core.VarDesc.VarType.LOD_TENSOR,
name=None,
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/incubate/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@
from .layer.fused_transformer import FusedBiasDropoutResidualLayerNorm # noqa: F401
from .layer.fused_transformer import FusedMoELayer # noqa: F401
from .layer.fused_transformer import FusedMultiTransformerMoe # noqa: F401
from .layer.fused_transformer import FusedMultiTransformerMoeINT8 # noqa: F401

__all__ = [ #noqa
'FusedMultiHeadAttention',
'FusedFeedForward',
'FusedTransformerEncoderLayer',
'FusedMultiTransformer',
'FusedMultiTransformerMoe',
'FusedMultiTransformerMoeINT8',
'FusedLinear',
'FusedBiasDropoutResidualLayerNorm',
'FusedMoELayer',
Expand Down
Loading

0 comments on commit 142bef2

Please sign in to comment.