Skip to content

Commit

Permalink
[fix] fix data accuracy issue for gemma (#126)
Browse files Browse the repository at this point in the history
fix precision issue mentioned in
huggingface/transformers#29402
this diff:
* fixed 1> Approx Gelu and 3> sqrt(hidden_dim) with dtype
* fixed the head_dim for gemma_7b_* models
  • Loading branch information
guocuimi committed Apr 15, 2024
1 parent 2f9e768 commit 75d58a6
Show file tree
Hide file tree
Showing 20 changed files with 124 additions and 57 deletions.
2 changes: 1 addition & 1 deletion src/engine/llm_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ bool LLMEngine::init_model(const std::string& model_weights_path) {
const int64_t n_heads = args_.n_heads();
const int64_t n_kv_heads = args_.n_kv_heads().value_or(n_heads);
n_local_kv_heads_ = n_kv_heads / world_size;
head_dim_ = args_.hidden_size() / n_heads;
head_dim_ = args_.head_dim();
dtype_ = parse_dtype(args_.dtype(), devices_[0]);

// key + value for all layers
Expand Down
2 changes: 1 addition & 1 deletion src/layers/attention/handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ std::unique_ptr<AttentionHandler> AttentionHandler::create_handler_with_rope(
const ModelArgs& args,
bool interleaved,
const torch::TensorOptions& options) {
const int64_t head_dim = args.hidden_size() / args.n_heads();
const int64_t head_dim = args.head_dim();
// default to use head_dim if rotary_dim is not specified
int64_t rotary_dim = args.rotary_dim() > 0 ? args.rotary_dim() : head_dim;
// apply rotary_dim percentage
Expand Down
4 changes: 2 additions & 2 deletions src/layers/normalization_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ TEST(NormalizationTest, RMSNormKernel) {

EXPECT_TRUE(torch::allclose(output,
output_ref,
/*rtol=*/1e-03,
/*atol=*/1e-05));
/*rtol=*/1e-02,
/*atol=*/1e-03));
}

TEST(NormalizationTest, RMSNormResidualKernel) {
Expand Down
6 changes: 5 additions & 1 deletion src/models/huggingface/aquila.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class AquilaAttentionImpl : public torch::nn::Module {
const int64_t hidden_size = args.hidden_size();
const int64_t n_heads = args.n_heads();
const int64_t n_kv_heads = args.n_kv_heads().value_or(n_heads);
const int64_t head_dim = hidden_size / n_heads;
const int64_t head_dim = args.head_dim();
const int64_t n_local_heads = n_heads / world_size;
const int64_t n_local_kv_heads = n_kv_heads / world_size;

Expand Down Expand Up @@ -414,5 +414,9 @@ REGISTER_MODEL_ARGS(aquila, [&] {
LOAD_ARG_OR(bos_token_id, "bos_token_id", 1);
LOAD_ARG_OR(eos_token_id, "eos_token_id", 2);
LOAD_ARG_OR(rope_theta, "rope_theta", 10000.0f);

LOAD_ARG_OR_FUNC(head_dim, "head_dim", [&] {
return args->hidden_size() / args->n_heads();
});
});
} // namespace llm::hf
8 changes: 5 additions & 3 deletions src/models/huggingface/baichuan.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,8 @@ class BaichuanAttentionImpl : public torch::nn::Module {
const int32_t world_size = parallel_args.world_size();
const int64_t hidden_size = args.hidden_size();
const int64_t n_heads = args.n_heads();
// const int64_t n_kv_heads = args.n_kv_heads().value_or(n_heads);
const int64_t head_dim = hidden_size / n_heads;
const int64_t head_dim = args.head_dim();
const int64_t n_local_heads = n_heads / world_size;
// const int64_t n_local_kv_heads = n_kv_heads / world_size;

// size for local q, k, v
qkv_sizes_ = {n_local_heads * head_dim,
Expand Down Expand Up @@ -517,6 +515,10 @@ REGISTER_MODEL_ARGS(baichuan, [&] {
LOAD_ARG_OR(rope_theta, "rope_theta", 10000.0f);
LOAD_ARG_OR(rope_scaling, "rope_scaling", 1.0f);
// LOAD_ARG_OR(tie_word_embeddings, "tie_word_embeddings", false);

LOAD_ARG_OR_FUNC(head_dim, "head_dim", [&] {
return args->hidden_size() / args->n_heads();
});
});

} // namespace llm::hf
6 changes: 5 additions & 1 deletion src/models/huggingface/bloom.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class BloomAttentionImpl : public torch::nn::Module {
const int64_t n_heads = args.n_heads();
const int64_t n_local_heads = n_heads / world_size;
hidden_size_ = args.hidden_size();
head_dim_ = hidden_size_ / n_heads;
head_dim_ = args.head_dim();

// register submodules
query_key_value_ =
Expand Down Expand Up @@ -451,6 +451,10 @@ REGISTER_MODEL_ARGS(bloom, [&] {
LOAD_ARG_OR_FUNC(intermediate_size, "intermediate_size", [&] {
return args->hidden_size() * 4;
});

LOAD_ARG_OR_FUNC(head_dim, "head_dim", [&] {
return args->hidden_size() / args->n_heads();
});
});

} // namespace llm::hf
7 changes: 6 additions & 1 deletion src/models/huggingface/chatglm.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <torch/torch.h>
#include <torch/types.h>

#include "chat_template/coded_chat_template.h"
#include "layers/activation.h"
#include "layers/attention/attention.h"
#include "layers/attention/handler.h"
Expand Down Expand Up @@ -88,7 +89,7 @@ class ChatGLMAttentionImpl : public torch::nn::Module {
const int64_t hidden_size = args.hidden_size();
const int64_t n_heads = args.n_heads();
const int64_t n_kv_heads = args.n_kv_heads().value_or(n_heads);
const int64_t head_dim = hidden_size / n_heads;
const int64_t head_dim = args.head_dim();
const int64_t n_local_heads = n_heads / world_size;
const int64_t n_local_kv_heads = n_kv_heads / world_size;

Expand Down Expand Up @@ -533,6 +534,10 @@ REGISTER_MODEL_ARGS(chatglm, [&] {
return hidden_size / n_heads;
});

LOAD_ARG_OR_FUNC(head_dim, "head_dim", [&] {
return args->hidden_size() / args->n_heads();
});

// stop token ids: "</s>", "<|user|>", "<|assistant|>", "<|observation|>"
SET_ARG(stop_token_ids,
std::unordered_set<int32_t>({2, 64795, 64796, 64797}));
Expand Down
73 changes: 37 additions & 36 deletions src/models/huggingface/gemma.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#pragma once

#include <glog/logging.h>
#include <torch/torch.h>

#include "chat_template/coded_chat_template.h"
Expand All @@ -17,21 +17,14 @@
// gemma model compatible with huggingface weight
namespace llm::hf {

// TODO only support the gemma-2B now
enum GemmaType {
gemma_2B,
gemma_2B_it,
gemma_7B,
gemma_7B_it,
};

class GemmaMLPImpl : public torch::nn::Module {
public:
GemmaMLPImpl(const ModelArgs& args,
const QuantArgs& quant_args,
const ParallelArgs& parallel_args,
const torch::TensorOptions& options) {
act_with_mul_ = Activation::get_act_with_mul_func("gelu", options.device());
act_with_mul_ =
Activation::get_act_with_mul_func(args.hidden_act(), options.device());
CHECK(act_with_mul_ != nullptr);

const int64_t hidden_size = args.hidden_size();
Expand Down Expand Up @@ -94,7 +87,7 @@ class GemmaAttentionImpl : public torch::nn::Module {
const int32_t world_size = parallel_args.world_size();
const int64_t hidden_size = args.hidden_size();
const int64_t n_heads = args.n_heads();
const int64_t head_dim = hidden_size / n_heads;
const int64_t head_dim = args.head_dim();
const int64_t n_kv_heads = args.n_kv_heads().value_or(n_heads);
const int64_t n_local_heads = n_heads / world_size;
const int64_t n_local_kv_heads = n_kv_heads / world_size;
Expand All @@ -116,7 +109,7 @@ class GemmaAttentionImpl : public torch::nn::Module {
options));

o_proj_ = register_module("o_proj",
RowParallelLinear(hidden_size,
RowParallelLinear(n_heads * head_dim,
hidden_size,
/*bias=*/false,
/*input_is_parallelized=*/true,
Expand All @@ -125,10 +118,10 @@ class GemmaAttentionImpl : public torch::nn::Module {
options));

// initialize attention
const float scale = 1.0f / std::sqrt(static_cast<float>(head_dim));
atten_ = register_module(
"atten", Attention(n_local_heads, n_local_kv_heads, head_dim, handler));
}

torch::Tensor forward(torch::Tensor x,
torch::Tensor positions,
KVCache& kv_cache,
Expand Down Expand Up @@ -254,6 +247,14 @@ class GemmaModelImpl : public torch::nn::Module {
ParallelEmbedding(
args.vocab_size(), args.hidden_size(), parallel_args, options));

// normalize the embedding by sqrt(hidden_size)
// N.B. the data type of the normalizer should be the same as the embedding
// ref to:
// https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/gemma_causal_lm.py#L426
const float normalizer = std::sqrt(args.hidden_size());
normalizer_ =
register_buffer("normalizer", torch::tensor({normalizer}, options));

norm_ = register_module(
"norm",
RMSNormResidual(args.hidden_size(), args.rms_norm_eps(), options));
Expand All @@ -278,10 +279,7 @@ class GemmaModelImpl : public torch::nn::Module {
std::vector<KVCache>& kv_caches,
const InputParameters& input_params) {
// embedding tokens
auto h = embed_tokens_(tokens);

// normalize the embedding by sqrt(hidden_size)
h *= sqrt(modelArgs_.hidden_size());
auto h = embed_tokens_(tokens) * normalizer_;

torch::Tensor residual;
for (int32_t i = 0; i < modelArgs_.n_layers(); i++) {
Expand Down Expand Up @@ -322,6 +320,9 @@ class GemmaModelImpl : public torch::nn::Module {
// embedding module
ParallelEmbedding embed_tokens_{nullptr};

// embedding normalizer
torch::Tensor normalizer_{nullptr};

RMSNormResidual norm_{nullptr};
// attention handler
std::unique_ptr<AttentionHandler> handler_{nullptr};
Expand Down Expand Up @@ -359,8 +360,7 @@ class GemmaForCausalLMImpl : public torch::nn::Module {
const torch::Tensor& positions,
std::vector<KVCache>& kv_caches,
const InputParameters& input_params) {
auto h = model_(tokens, positions, kv_caches, input_params);
return h;
return model_(tokens, positions, kv_caches, input_params);
}

// hidden_states: [num_tokens, hidden_size]
Expand All @@ -374,16 +374,13 @@ class GemmaForCausalLMImpl : public torch::nn::Module {
h = h.index_select(/*dim=*/0, seleted_idxes);
}

h = lm_head_(h);

return h;
return lm_head_(h);
}

// load the weight from the checkpoint
void load_state_dict(const StateDict& state_dict) {
model_->load_state_dict(state_dict.select("model."));

// lm_head is not used in vllm as it is tied with embed_token.
// Share the embedding weights with the final llm_head layer.
lm_head_->load_state_dict(state_dict.select("model.embed_tokens."));
}
Expand All @@ -398,14 +395,13 @@ class GemmaForCausalLMImpl : public torch::nn::Module {
GemmaModel model_{nullptr};

ColumnParallelLinear lm_head_{nullptr};
int index;
};
TORCH_MODULE(GemmaForCausalLM);

class GemmaChatTemplate final : public CodedChatTemplate {
public:
std::optional<std::string> get_prompt(
const std::string_view& system_message,
const std::string_view& /*system_message*/,
const std::vector<std::string_view>& messages) const override {
// at least one user message
if (messages.size() % 2 == 0) {
Expand All @@ -418,11 +414,6 @@ class GemmaChatTemplate final : public CodedChatTemplate {
* <start_of_turn>model
*/
std::stringstream ss;
// start with system message
if (!system_message.empty()) {
ss << "<bos> <start_of_turn> model\n"
<< system_message << "<end_of_turn>";
}
// then user message
for (size_t i = 0; i < messages.size(); i++) {
ss << "\n<start_of_turn> user\n" << messages[i] << "<end_of_turn>";
Expand All @@ -439,14 +430,11 @@ REGISTER_CAUSAL_MODEL(gemma, GemmaForCausalLM);
REGISTER_DEFAULT_CHAT_TEMPLATE(gemma, GemmaChatTemplate);

REGISTER_MODEL_ARGS(gemma, [&] {
// example config from vllm project:
// transformers/models/gemma/configuration_gemma.py
// example config from
// https://huggingface.co/google/gemma-2b/blob/main/config.json
LOAD_ARG_OR(model_type, "model_type", "gemma");
// LOAD_ARG_OR(attention_dropout,"attention_dropout",0f);
// LOAD_ARG_OR(attention_bias,"attention_bias",false);
LOAD_ARG_OR(bos_token_id, "bos_token_id", 2);
LOAD_ARG_OR(eos_token_id, "eos_token_id", 1);
LOAD_ARG_OR(hidden_act, "hidden_act", "gelu");
LOAD_ARG_OR(hidden_size, "hidden_size", 2048);
LOAD_ARG_OR(intermediate_size, "intermediate_size", 16384);
LOAD_ARG_OR(max_position_embeddings, "max_position_embeddings", 8192);
Expand All @@ -458,7 +446,20 @@ REGISTER_MODEL_ARGS(gemma, [&] {
LOAD_ARG_OR(dtype, "torch_dtype", "bfloat16");
LOAD_ARG_OR(vocab_size, "vocab_size", 256000);

// LOAD_ARG_OR(pad_token_id,"pad_token_id",0);
LOAD_ARG_OR_FUNC(hidden_act, "hidden_activation", [&] {
const auto hidden_act = json.value<std::string>("hidden_act");
if (hidden_act.has_value()) {
LOG(WARNING) << "Gemma's activation function was initially released with "
"an incorrect setting. Override the "
"activation function from '"
<< hidden_act.value() << "' to 'gelu_pytorch_tanh'";
}
return "gelu_pytorch_tanh";
});

LOAD_ARG_OR_FUNC(head_dim, "head_dim", [&] {
return args->hidden_size() / args->n_heads();
});
});

} // namespace llm::hf
6 changes: 5 additions & 1 deletion src/models/huggingface/gpt2.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class GPT2AttentionImpl : public torch::nn::Module {
const auto world_size = parallel_args.world_size();
const int64_t n_local_heads = args.n_heads() / world_size;
hidden_size_ = args.hidden_size();
head_dim_ = args.hidden_size() / args.n_heads();
head_dim_ = args.head_dim();

// register submodules
c_attn_ = register_module("c_attn",
Expand Down Expand Up @@ -385,6 +385,10 @@ REGISTER_MODEL_ARGS(gpt2, [&] {

LOAD_ARG_OR_FUNC(
intermediate_size, "n_inner", [&] { return args->hidden_size() * 4; });

LOAD_ARG_OR_FUNC(head_dim, "head_dim", [&] {
return args->hidden_size() / args->n_heads();
});
});

} // namespace llm::hf
6 changes: 5 additions & 1 deletion src/models/huggingface/gpt_j.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class GPTJAttentionImpl : public torch::nn::Module {
AttentionHandler* handler) {
const int64_t n_local_heads = args.n_heads() / parallel_args.world_size();
const int64_t hidden_size = args.hidden_size();
const int64_t head_dim = args.hidden_size() / args.n_heads();
const int64_t head_dim = args.head_dim();

// register submodules
qkv_proj_ = register_module("qkv_proj",
Expand Down Expand Up @@ -365,5 +365,9 @@ REGISTER_MODEL_ARGS(gptj, [&] {
// set it to 4 times n_embd
return args->hidden_size() * 4;
});

LOAD_ARG_OR_FUNC(head_dim, "head_dim", [&] {
return args->hidden_size() / args->n_heads();
});
});
} // namespace llm::hf
6 changes: 5 additions & 1 deletion src/models/huggingface/gpt_neox.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class GPTNeoXAttentionImpl : public torch::nn::Module {
const auto world_size = parallel_args.world_size();
const int64_t n_local_heads = args.n_heads() / world_size;
hidden_size_ = args.hidden_size();
head_dim_ = args.hidden_size() / args.n_heads();
head_dim_ = args.head_dim();

// register submodules
query_key_value_ =
Expand Down Expand Up @@ -411,6 +411,10 @@ REGISTER_MODEL_ARGS(gpt_neox, [&] {
LOAD_ARG_OR(bos_token_id, "bos_token_id", 0);
LOAD_ARG_OR(eos_token_id, "eos_token_id", 2);
LOAD_ARG_OR(use_parallel_residual, "use_parallel_residual", true);

LOAD_ARG_OR_FUNC(head_dim, "head_dim", [&] {
return args->hidden_size() / args->n_heads();
});
});

} // namespace llm::hf
7 changes: 6 additions & 1 deletion src/models/huggingface/internlm.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <torch/torch.h>

#include "chat_template/coded_chat_template.h"
#include "layers/activation.h"
#include "layers/attention/attention.h"
#include "layers/attention/handler.h"
Expand Down Expand Up @@ -85,7 +86,7 @@ class InternlmAttentionImpl : public torch::nn::Module {
const int32_t world_size = parallel_args.world_size();
const int64_t hidden_size = args.hidden_size();
const int64_t n_heads = args.n_heads();
const int64_t head_dim = hidden_size / n_heads;
const int64_t head_dim = args.head_dim();
const int64_t n_local_heads = n_heads / world_size;

// register submodules
Expand Down Expand Up @@ -401,6 +402,10 @@ REGISTER_MODEL_ARGS(internlm, [&] {
LOAD_ARG_OR(hidden_act, "hidden_act", "silu");
LOAD_ARG_OR(rope_theta, "rope_theta", 10000.0f);

LOAD_ARG_OR_FUNC(head_dim, "head_dim", [&] {
return args->hidden_size() / args->n_heads();
});

// stop token ids: [1, 103028]
SET_ARG(stop_token_ids, std::unordered_set<int32_t>({1, 103028}));
});
Expand Down

0 comments on commit 75d58a6

Please sign in to comment.