Skip to content

Commit

Permalink
优化chatglm
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Jul 10, 2024
1 parent 7ca2153 commit 55902fd
Showing 1 changed file with 20 additions and 89 deletions.
109 changes: 20 additions & 89 deletions src/models/chatglm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -580,110 +580,41 @@ namespace fastllm {
CatDirectBatch(keys, pointersK, 1);
CatDirectBatch(values, pointersV, 1);
if (all1 && batch > 1) {
contextLayer.ToDevice(q.dataDevice);
contextLayer.Resize({batch, 1, embed_dim});
contextLayer.Allocate();
for (int b = 0; b < batch; b++) {
qs[b] = (&curQs[b]);
keys[b] = (pastKeyValues[b * block_cnt + i].first);
values[b] = (pastKeyValues[b * block_cnt + i].second);
masks[b] = attentionMask[b];
curContextLayer[b].FakeFrom(contextLayer, b * embed_dim * contextLayer.unitSize);
contexts[b] = (&curContextLayer[b]);

outputSizes[b] = {1, qs[b]->dims[0], qs[b]->dims[1], keys[b]->dims[1]};
}
AttentionBatch(qs, keys, values, masks, contexts, qs[0]->dims[0] / values[0]->dims[0], 1.0 / scale_attn, 1);
} else {
contextLayer.ToDevice(curQs[0].dataDevice);
contextLayer.Resize({total, 1, embed_dim});
contextLayer.Allocate();
int curLen = 0;
for (int b = 0; b < batch; b++) {
auto &q = curQs[b];
Data &pastKey = *pastKeyValues[b * block_cnt + i].first;
outputSizes[b] = {1, q.dims[0], q.dims[1], pastKey.dims[1]};
q.Reshape({pastKey.dims[0], -1, q.dims[2]});
}

// 1.2 Attention
// 1.2.0 q * k^T
if (all1 && batch > 1) {
for (int b = 0; b < batch; b++) {
qs[b] = (&curQs[b]);
keys[b] = (pastKeyValues[b * block_cnt + i].first);
attns[b] = (&attnProbs[b]);
}
MatMulTransBBatch(qs, keys, attns, 1.0 / (scale_attn * (i + 1)));
} else {
for (int b = 0; b < batch; b++) {
auto &q = curQs[b];
Data &pastKey = *pastKeyValues[b * block_cnt + i].first;
MatMulTransB(q, pastKey, attnProbs[b], 1.0 / (scale_attn * (i + 1)));
}
}

for (int b = 0; b < batch; b++) {
attnProbs[b].Reshape(outputSizes[b]);
// 1.2.1 Mask
if (attentionMask[b] != nullptr) {
AttentionMask(attnProbs[b], *attentionMask[b], -10000);
}
}

// 1.2.2 softmax
for (int i = 0; i < attnProbs.size(); i++) {
attns[i] = (&attnProbs[i]);
}
MulBatch(attns, i + 1, attns);
SoftmaxBatch(attns, attns, -1);

for (int b = 0; b < batch; b++) {
Data &pastValue = *pastKeyValues[b * block_cnt + i].second;
outputSizes[b] = {1, num_attention_heads, -1, pastValue.dims[2]};
attnProbs[b].Reshape({pastValue.dims[0], -1, attnProbs[b].dims[3]});
}

// 1.2.3 prob * v
if (all1 && batch > 1) {
for (int b = 0; b < batch; b++) {
attns[b] = (&attnProbs[b]);
values[b] = (pastKeyValues[b * block_cnt + i].second);
contexts[b] = (&curContextLayer[b]);
}
MatMulBatch(attns, values, contexts);
} else {
for (int b = 0; b < batch; b++) {
Data &pastValue = *pastKeyValues[b * block_cnt + i].second;
MatMul(attnProbs[b], pastValue, curContextLayer[b]);
}
}
}
if (all1) {
for (int b = 0; b < batch; b++) {
curContextLayer[b].dims[0] = outputSizes[b][2];
curContextLayer[b].dims[1] = outputSizes[b][0];
curContextLayer[b].dims[2] = embed_dim;
curContextLayer[b].strides[0] = curContextLayer[b].dims[1] * curContextLayer[b].dims[2];
curContextLayer[b].strides[1] = curContextLayer[b].dims[2];
curContextLayer[b].strides[2] = 1;
}
} else {
for (int b = 0; b < batch; b++) {
curContextLayer[b].Reshape(outputSizes[b]);
PermuteSelf(curContextLayer[b], {2, 0, 1, 3});
curContextLayer[b].Reshape({curContextLayer[b].dims[0], curContextLayer[b].dims[1], embed_dim});
}
}

if (all1 && batch > 1) {
for (int b = 0; b < batch; b++) {
contexts[b] = (&curContextLayer[b]);
}
CatBatch(contexts, 0, contextLayer);
} else {
for (int b = 0; b < batch; b++) {
if (contextLayer.dims.size() == 0) {
std::vector<int> dims = curContextLayer[b].dims;
dims[0] = total;
contextLayer.Expansion(dims);
auto &q = curQs[b], &k = curKs[b], &v = curVs[b];
Data &pastKey = *pastKeyValues[b * block_cnt + i].first, &pastValue = *pastKeyValues[b * block_cnt + i].second;
curContextLayer[0].FakeFrom(contextLayer, curLen * embed_dim * contextLayer.unitSize);
curLen += seqLens[b];

// 1.2 Attention
if (attentionMask[b] == nullptr) {
Attention(q, pastKey, pastValue, Data(), curContextLayer[0], q.dims[0] / pastKey.dims[0], 1.0 / scale_attn, 1);
} else {
Attention(q, pastKey, pastValue, *attentionMask[b], curContextLayer[0], q.dims[0] / pastKey.dims[0], 1.0 / scale_attn, 1);
}
contextLayer.ToDevice(DataDevice::CUDA);
CatDirect(contextLayer, curContextLayer[b], 0);
PermuteSelf(curContextLayer[0], {1, 0, 2});
}
}

// 1.2.4 dense
std::string denseWeightName = weightPre + std::to_string(i) + weightMiddle + ".dense.weight";
std::string denseBiasName = weightPre + std::to_string(i) + weightMiddle + ".dense.bias";
Expand Down

0 comments on commit 55902fd

Please sign in to comment.