Skip to content

Commit

Permalink
Revert "Avoid the transposed X branch in the Z = X * Y matrix multipl…
Browse files Browse the repository at this point in the history
…ication (ggerganov#439)"

This reverts commit 483bab2.
  • Loading branch information
sw committed Apr 4, 2023
1 parent 53dbba7 commit 1cecdec
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -860,13 +860,11 @@ static bool llama_eval_internal(

// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
struct ggml_tensor * V_trans =
ggml_cpy(ctx0,
ggml_permute(ctx0,
ggml_reshape_3d(ctx0,
ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.v)*n_embd),
n_embd/n_head, n_head, n_past + N),
1, 2, 0, 3),
ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd/n_head, n_head));
ggml_permute(ctx0,
ggml_reshape_3d(ctx0,
ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.v)*n_embd),
n_embd/n_head, n_head, n_past + N),
1, 2, 0, 3);

// KQV = transpose(V) * KQ_soft_max
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
Expand Down

0 comments on commit 1cecdec

Please sign in to comment.