Skip to content

Commit e2aa556

Browse files
committedJan 7, 2023
whisper : experiments with Flash Attention in the decoder
1 parent f30b5d3 commit e2aa556

File tree

1 file changed

+57
-10
lines changed

1 file changed

+57
-10
lines changed
 

‎whisper.cpp

+57-10
Original file line numberDiff line numberDiff line change
@@ -1457,7 +1457,7 @@ static bool whisper_encode(
14571457
layer.cross_attn_k_w,
14581458
cur);
14591459

1460-
Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1460+
//Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
14611461

14621462
struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
14631463
layer.cross_attn_v_w,
@@ -1579,14 +1579,14 @@ static bool whisper_decode(
15791579
Qcur),
15801580
Qcur);
15811581

1582-
Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1582+
//Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
15831583

15841584
// note: no bias for Key
15851585
struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
15861586
layer.attn_k_w,
15871587
cur);
15881588

1589-
Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1589+
//Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
15901590

15911591
struct ggml_tensor * Vcur = ggml_mul_mat(ctxL,
15921592
layer.attn_v_w,
@@ -1609,6 +1609,33 @@ static bool whisper_decode(
16091609

16101610
// ------
16111611

1612+
#ifdef USE_FLASH_ATTN
1613+
struct ggml_tensor * Q =
1614+
ggml_permute(ctxL,
1615+
ggml_cpy(ctxL,
1616+
Qcur,
1617+
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
1618+
0, 2, 1, 3);
1619+
1620+
struct ggml_tensor * K =
1621+
ggml_permute(ctxL,
1622+
ggml_reshape_3d(ctxL,
1623+
ggml_view_1d(ctxL, model.memory_k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_k)*n_state),
1624+
n_state/n_head, n_head, n_past + N),
1625+
0, 2, 1, 3);
1626+
1627+
struct ggml_tensor * V =
1628+
ggml_cpy(ctxL,
1629+
ggml_permute(ctxL,
1630+
ggml_reshape_3d(ctxL,
1631+
ggml_view_1d(ctxL, model.memory_v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_v)*n_state),
1632+
n_state/n_head, n_head, n_past + N),
1633+
1, 2, 0, 3),
1634+
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_past + N, n_state/n_head, n_head)
1635+
);
1636+
1637+
struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, true);
1638+
#else
16121639
struct ggml_tensor * Q =
16131640
ggml_permute(ctxL,
16141641
ggml_cpy(ctxL,
@@ -1626,13 +1653,13 @@ static bool whisper_decode(
16261653
// K * Q
16271654
struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
16281655

1629-
//struct ggml_tensor * KQ_scaled =
1630-
// ggml_scale(ctxL,
1631-
// KQ,
1632-
// ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
1633-
// );
1656+
struct ggml_tensor * KQ_scaled =
1657+
ggml_scale(ctxL,
1658+
KQ,
1659+
ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
1660+
);
16341661

1635-
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ, n_past);
1662+
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ_scaled, n_past);
16361663

16371664
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_masked);
16381665

@@ -1644,6 +1671,7 @@ static bool whisper_decode(
16441671
1, 2, 0, 3);
16451672

16461673
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
1674+
#endif
16471675

16481676
struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
16491677

@@ -1689,7 +1717,7 @@ static bool whisper_decode(
16891717
Qcur),
16901718
Qcur);
16911719

1692-
Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1720+
//Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
16931721

16941722
// Kcross is already scaled
16951723
struct ggml_tensor * Kcross =
@@ -1704,6 +1732,24 @@ static bool whisper_decode(
17041732

17051733
// ------
17061734

1735+
#ifdef USE_FLASH_ATTN
1736+
struct ggml_tensor * Q =
1737+
ggml_permute(ctxL,
1738+
ggml_cpy(ctxL,
1739+
Qcur,
1740+
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
1741+
0, 2, 1, 3);
1742+
1743+
struct ggml_tensor * K = ggml_permute(ctxL, Kcross, 0, 2, 1, 3);
1744+
1745+
struct ggml_tensor * V =
1746+
ggml_cpy(ctxL,
1747+
ggml_permute(ctxL, Vcross, 1, 2, 0, 3),
1748+
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, M, n_state/n_head, n_head)
1749+
);
1750+
1751+
struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
1752+
#else
17071753
struct ggml_tensor * Q =
17081754
ggml_permute(ctxL,
17091755
ggml_cpy(ctxL,
@@ -1730,6 +1776,7 @@ static bool whisper_decode(
17301776
struct ggml_tensor * V_trans = ggml_permute(ctxL, Vcross, 1, 2, 0, 3);
17311777

17321778
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
1779+
#endif
17331780

17341781
struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
17351782

0 commit comments

Comments
 (0)
Failed to load comments.