Skip to content

Commit

Permalink
优化llama
Browse files Browse the repository at this point in the history
  • Loading branch information
ztxz16 committed Jul 6, 2024
1 parent 1376d59 commit 860bace
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions src/models/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ namespace fastllm {
// 1.2 Attention
// 1.2.0 q * k^T
if (alibiData.dims.size() == 0) {
Attention(q, pastKey, pastValue, attentionMask, attenOutput, q.dims[0] / pastKey.dims[0], 1.0 / sqrt(head_dim), 1);
Attention(q, pastKey, pastValue, attentionMask, qkv, q.dims[0] / pastKey.dims[0], 1.0 / sqrt(head_dim), 1);
} else {
MatMulTransB(q, pastKey, attenWeights, 1.0 / sqrt(head_dim), q.dims[0] / pastKey.dims[0]);
attenWeights.Reshape({1, attenWeights.dims[0], attenWeights.dims[1], attenWeights.dims[2]});
Expand All @@ -389,39 +389,39 @@ namespace fastllm {
}

Softmax(attenWeights, attenWeights, -1);
MatMul(attenWeights, pastValue, attenOutput, 1.f, attenWeights.dims[1] / pastValue.dims[0]);
attenOutput.Reshape({attenOutput.dims[1], attenOutput.dims[2], attenOutput.dims[3]});
MatMul(attenWeights, pastValue, qkv, 1.f, attenWeights.dims[1] / pastValue.dims[0]);
qkv.Reshape({qkv.dims[1], qkv.dims[2], qkv.dims[3]});
}

PermuteSelf(attenOutput, {1, 0, 2});
attenOutput.Reshape({seqlen, bsz, -1});
PermuteSelf(attenOutput, {1, 0, 2});
PermuteSelf(qkv, {1, 0, 2});
qkv.Reshape({seqlen, bsz, -1});
PermuteSelf(qkv, {1, 0, 2});

Data oBias = (weight.weight.find(oBiasName) != weight.weight.end()) ? weight[oBiasName] : Data();
Linear(attenOutput, weight[oWeightName], oBias, attenLastOutput);
AddTo(hiddenStates, attenLastOutput);
Linear(qkv, weight[oWeightName], oBias, attenInput);
AddTo(hiddenStates, attenInput);
// 2. mlp
RMSNorm(hiddenStates, this->weight["model.layers." + std::to_string(i) + ".post_attention_layernorm.weight"], rms_norm_eps, attenInput);
if (this->mergeSwiglu) {
std::string swigluWeightName = "model.layers." + std::to_string(i) + ".mlp.gateup_proj.weight";
if (CanRunLinearEx(LinearExType::ExSwiglu)) {
LinearEx(attenInput, weight[swigluWeightName], Data(), w1, LinearExType::ExSwiglu);
LinearEx(attenInput, weight[swigluWeightName], Data(), q, LinearExType::ExSwiglu);
} else {
Linear(attenInput, weight[swigluWeightName], Data(), w3);
Swiglu(w3, w1);
Linear(attenInput, weight[swigluWeightName], Data(), v);
Swiglu(v, q);
}
} else {
if (CanRunLinearEx(LinearExType::ExSilu)) {
LinearEx(attenInput, weight["model.layers." + std::to_string(i) + ".mlp.gate_proj.weight"], Data(), w1, LinearExType::ExSilu);
LinearEx(attenInput, weight["model.layers." + std::to_string(i) + ".mlp.gate_proj.weight"], Data(), q, LinearExType::ExSilu);
} else {
Linear(attenInput, weight["model.layers." + std::to_string(i) + ".mlp.gate_proj.weight"], Data(), w1);
Silu(w1, w1);
Linear(attenInput, weight["model.layers." + std::to_string(i) + ".mlp.gate_proj.weight"], Data(), q);
Silu(q, q);
}
Linear(attenInput, weight["model.layers." + std::to_string(i) + ".mlp.up_proj.weight"], Data(), w3);
MulTo(w1, w3);
Linear(attenInput, weight["model.layers." + std::to_string(i) + ".mlp.up_proj.weight"], Data(), v);
MulTo(v, q);
}
Linear(w1, weight["model.layers." + std::to_string(i) + ".mlp.down_proj.weight"], Data(), w2);
AddTo(hiddenStates, w2);
Linear(q, weight["model.layers." + std::to_string(i) + ".mlp.down_proj.weight"], Data(), k);
AddTo(hiddenStates, k);
}

Data logits, topk;
Expand Down

0 comments on commit 860bace

Please sign in to comment.