Skip to content

Commit

Permalink
修复MiniCPM GPU初始化/低内存模式错误
Browse files Browse the repository at this point in the history
  • Loading branch information
cgli authored and TylunasLi committed Mar 2, 2024
1 parent 4505bc9 commit cc9552f
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 44 deletions.
9 changes: 9 additions & 0 deletions include/models/minicpm.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ namespace fastllm {
public:
MiniCpmModel(); // 构造函数

virtual void InitParams(); // 初始化参数信息

// 推理
virtual int Forward(
const Data &inputIds,
Expand Down Expand Up @@ -65,6 +67,13 @@ namespace fastllm {
virtual std::string MakeInput(const std::string &history, int round, const std::string &input); // 根据历史信息和当前输入生成prompt

virtual std::string MakeHistory(const std::string &history, int round, const std::string &input, const std::string &output); // 根据当前回复更新history

private:
float embed_scale = 1.f;

float attention_scale = 1.f / std::sqrt(block_cnt);

float rms_scale = 1.f / 4096.f;
};
}

Expand Down
2 changes: 1 addition & 1 deletion src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ namespace fastllm {
model = new LlamaModel();
model->model_type = "qwen";
} else if (modelType=="minicpm") {
model = (basellm*)(new MiniCpmModel());
model = new MiniCpmModel();
} else if (modelType == "qwen") {
model = (basellm *) (new QWenModel());
model->weight.tokenizer.type = Tokenizer::TokenizerType::QWEN;
Expand Down
71 changes: 28 additions & 43 deletions src/models/minicpm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,6 @@ namespace fastllm {
MiniCpmModel::MiniCpmModel() {
this->model_type = "minicpm";

// 默认使用alpaca的提示词和instruction
/*
this->pre_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n";
this->user_role = "### Instruction:\n";
this->bot_role = "\n\n### Response:";
*/
this->history_sep = "";
this->pre_prompt = "";
this->user_role = "";
Expand Down Expand Up @@ -87,6 +81,21 @@ namespace fastllm {
weight.embeddingNames.insert("model.embed_tokens.weight");
}

void MiniCpmModel::InitParams() {
basellm::InitParams();
if (this->weight.dicts.find("scale_emb") != this->weight.dicts.end()) {
this->embed_scale = std::stof(this->weight.dicts["scale_emb"]);
}
if (this->weight.dicts.find("scale_depth") != this->weight.dicts.end()) {
float scale_depth = std::stof(this->weight.dicts["scale_depth"]);
this->attention_scale = scale_depth / std::sqrt(block_cnt);
}
if (this->weight.dicts.find("dim_model_base") != this->weight.dicts.end()) {
int32_t dim_model_base = std::stoi(this->weight.dicts["dim_model_base"]);
this->rms_scale = 1.f / (this->embed_dim / dim_model_base);
}
}

int MiniCpmModel::Forward(const fastllm::Data &inputIds, const fastllm::Data &attentionMask,
const fastllm::Data &positionIds, std::vector<std::pair<Data, Data>> &pastKeyValues,
const GenerationConfig &generationConfig, const LastTokensManager &lastTokens,
Expand All @@ -105,15 +114,8 @@ namespace fastllm {
Data attenLastOutput;
Data w1, w2, w3;

float scale_emb = std::stof(this->weight.dicts["scale_emb"]);
float scale_depth = std::stof(this->weight.dicts["scale_depth"]);
int32_t num_hidden_layers = std::stoi(this->weight.dicts["num_hidden_layers"]);
int32_t dim_model = std::stoi(this->weight.dicts["hidden_size"]);
int32_t dim_model_base = std::stoi(this->weight.dicts["dim_model_base"]);
float rms_scale = 1.f / (dim_model / dim_model_base);

Embedding(inputIds, this->weight["model.embed_tokens.weight"], hiddenStates);
Mul(hiddenStates, scale_emb, hiddenStates);
Mul(hiddenStates, embed_scale, hiddenStates);
for (int i = 0; i < block_cnt; i++) {
ApplyDeviceMap(this->deviceMap, i + 1, block_cnt);
RMSNorm(hiddenStates, this->weight["model.layers." + std::to_string(i) + ".input_layernorm.weight"],
Expand Down Expand Up @@ -213,18 +215,16 @@ namespace fastllm {
attenOutput.Reshape({bsz, seqlen, -1});

Linear(attenOutput, weight[oWeightName], Data(), attenLastOutput);

Mul(attenLastOutput, scale_depth / std::sqrt(num_hidden_layers), attenLastOutput);
Mul(attenLastOutput, this->attention_scale, attenLastOutput);
AddTo(hiddenStates, attenLastOutput);

// 2. mlp
RMSNorm(hiddenStates, this->weight["model.layers." + std::to_string(i) + ".post_attention_layernorm.weight"], 1e-5, attenInput);
Linear(attenInput, weight["model.layers." + std::to_string(i) + ".mlp.gate_proj.weight"], Data(), w1);
Linear(attenInput, weight["model.layers." + std::to_string(i) + ".mlp.up_proj.weight"], Data(), w3);
Silu(w1, w1);
MulTo(w1, w3);
Linear(w1, weight["model.layers." + std::to_string(i) + ".mlp.down_proj.weight"], Data(), w2);
Mul(w2, scale_depth / std::sqrt(num_hidden_layers), w2);
Mul(w2, this->attention_scale, w2);
AddTo(hiddenStates, w2);
}
Data logits, topk;
Expand All @@ -241,8 +241,8 @@ namespace fastllm {
{
auto &hiddenStates = *lastHiddenStates;
RMSNorm(hiddenStates, weight["model.norm.weight"], 1e-5, hiddenStates);
Mul(hiddenStates, rms_scale, hiddenStates);
Linear(hiddenStates, weight["model.embed_tokens.weight"], Data(), logits);
Mul(hiddenStates, this->rms_scale, hiddenStates);
Linear(hiddenStates, weight["lm_head.weight"], Data(), logits);
if (generationConfig.output_logits && retLogits != nullptr) {
int size = logits.dims.back();
logits.ToDevice(DataDevice::CPU);
Expand Down Expand Up @@ -278,16 +278,9 @@ namespace fastllm {
Data attenWeights, attenOutput;
Data attenLastOutput;
Data w1, w2, w3;

float scale_emb = std::stof(this->weight.dicts["scale_emb"]);
float scale_depth = std::stof(this->weight.dicts["scale_depth"]);
int32_t num_hidden_layers = std::stoi(this->weight.dicts["num_hidden_layers"]);
int32_t dim_model = std::stoi(this->weight.dicts["hidden_size"]);
int32_t dim_model_base = std::stoi(this->weight.dicts["dim_model_base"]);
float rms_scale = 1.f / (dim_model / dim_model_base);

Embedding(inputIds, this->weight["model.embed_tokens.weight"], hiddenStates);
Mul(hiddenStates, scale_emb, hiddenStates);
Mul(hiddenStates, embed_scale, hiddenStates);
int seqlen = hiddenStates.dims[1];
for (int i = 0; i < block_cnt; i++) {
ApplyDeviceMap(this->deviceMap, i + 1, block_cnt);
Expand Down Expand Up @@ -391,7 +384,7 @@ namespace fastllm {
PermuteSelf(attenOutput, {1, 0, 2});

Linear(attenOutput, weight[oWeightName], Data(), attenLastOutput);
Mul(attenLastOutput, scale_depth / std::sqrt(num_hidden_layers), attenLastOutput);
Mul(attenLastOutput, this->attention_scale, attenLastOutput);
AddTo(hiddenStates, attenLastOutput);
// 2. mlp
RMSNorm(hiddenStates, this->weight["model.layers." + std::to_string(i) + ".post_attention_layernorm.weight"], 1e-5, attenInput);
Expand All @@ -400,7 +393,7 @@ namespace fastllm {
Silu(w1, w1);
MulTo(w1, w3);
Linear(w1, weight["model.layers." + std::to_string(i) + ".mlp.down_proj.weight"], Data(), w2);
Mul(w2, scale_depth / std::sqrt(num_hidden_layers), w2);
Mul(w2, this->attention_scale, w2);
AddTo(hiddenStates, w2);
}

Expand All @@ -418,7 +411,7 @@ namespace fastllm {
{
auto &hiddenStates = *lastHiddenStates;
RMSNorm(hiddenStates, weight["model.norm.weight"], 1e-5, hiddenStates);
Mul(hiddenStates, rms_scale, hiddenStates);
Mul(hiddenStates, this->rms_scale, hiddenStates);
Linear(hiddenStates, weight["lm_head.weight"], Data(), logits);
if (generationConfig.IsSimpleGreedy()) {
TopK(logits, topk, 1);
Expand Down Expand Up @@ -459,15 +452,8 @@ namespace fastllm {
Data attenLastOutput;
Data w1, w2, w3;

float scale_emb = std::stof(this->weight.dicts["scale_emb"]);
float scale_depth = std::stof(this->weight.dicts["scale_depth"]);
int32_t num_hidden_layers = std::stoi(this->weight.dicts["num_hidden_layers"]);
int32_t dim_model = std::stoi(this->weight.dicts["hidden_size"]);
int32_t dim_model_base = std::stoi(this->weight.dicts["dim_model_base"]);
float rms_scale = 1.f / (dim_model / dim_model_base);

Embedding(inputIds, this->weight["model.embed_tokens.weight"], hiddenStates);
Mul(hiddenStates, scale_emb, hiddenStates);
Mul(hiddenStates, embed_scale, hiddenStates);
int seqlen = hiddenStates.dims[1];
for (int i = 0; i < block_cnt; i++) {
ApplyDeviceMap(this->deviceMap, i + 1, block_cnt);
Expand Down Expand Up @@ -594,23 +580,22 @@ namespace fastllm {
}

Linear(attenOutput, weight[oWeightName], Data(), attenLastOutput);
Mul(attenLastOutput, scale_depth / std::sqrt(num_hidden_layers), attenLastOutput);
Mul(attenLastOutput, this->attention_scale, attenLastOutput);
AddTo(hiddenStates, attenLastOutput);

// 2. mlp
RMSNorm(hiddenStates, this->weight["model.layers." + std::to_string(i) + ".post_attention_layernorm.weight"], 1e-5, attenInput);
Linear(attenInput, weight["model.layers." + std::to_string(i) + ".mlp.gate_proj.weight"], Data(), w1);
Linear(attenInput, weight["model.layers." + std::to_string(i) + ".mlp.up_proj.weight"], Data(), w3);
Silu(w1, w1);
MulTo(w1, w3);
Linear(w1, weight["model.layers." + std::to_string(i) + ".mlp.down_proj.weight"], Data(), w2);
Mul(w2, scale_depth / std::sqrt(num_hidden_layers), w2);
Mul(w2, this->attention_scale, w2);
AddTo(hiddenStates, w2);
}

Data logits, curLogit;
RMSNorm(hiddenStates, weight["model.norm.weight"], 1e-5, hiddenStates);
Mul(hiddenStates, rms_scale, hiddenStates);
Mul(hiddenStates, this->rms_scale, hiddenStates);
Linear(hiddenStates, weight["lm_head.weight"], Data(), logits);
std::vector <int> lastRet;
int total = 0;
Expand Down

0 comments on commit cc9552f

Please sign in to comment.