@@ -1457,7 +1457,7 @@ static bool whisper_encode(
1457
1457
layer.cross_attn_k_w ,
1458
1458
cur);
1459
1459
1460
- Kcross = ggml_scale (ctx0, Kcross, ggml_new_f32 (ctx0, pow (float (n_state)/n_head, -0.25 )));
1460
+ // Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
1461
1461
1462
1462
struct ggml_tensor * Vcross = ggml_mul_mat (ctx0,
1463
1463
layer.cross_attn_v_w ,
@@ -1579,14 +1579,14 @@ static bool whisper_decode(
1579
1579
Qcur),
1580
1580
Qcur);
1581
1581
1582
- Qcur = ggml_scale (ctxL, Qcur, ggml_new_f32 (ctxL, pow (float (n_state)/n_head, -0.25 )));
1582
+ // Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1583
1583
1584
1584
// note: no bias for Key
1585
1585
struct ggml_tensor * Kcur = ggml_mul_mat (ctxL,
1586
1586
layer.attn_k_w ,
1587
1587
cur);
1588
1588
1589
- Kcur = ggml_scale (ctxL, Kcur, ggml_new_f32 (ctxL, pow (float (n_state)/n_head, -0.25 )));
1589
+ // Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1590
1590
1591
1591
struct ggml_tensor * Vcur = ggml_mul_mat (ctxL,
1592
1592
layer.attn_v_w ,
@@ -1609,6 +1609,33 @@ static bool whisper_decode(
1609
1609
1610
1610
// ------
1611
1611
1612
+ #ifdef USE_FLASH_ATTN
1613
+ struct ggml_tensor * Q =
1614
+ ggml_permute (ctxL,
1615
+ ggml_cpy (ctxL,
1616
+ Qcur,
1617
+ ggml_new_tensor_3d (ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
1618
+ 0 , 2 , 1 , 3 );
1619
+
1620
+ struct ggml_tensor * K =
1621
+ ggml_permute (ctxL,
1622
+ ggml_reshape_3d (ctxL,
1623
+ ggml_view_1d (ctxL, model.memory_k , (n_past + N)*n_state, il*n_ctx*ggml_element_size (model.memory_k )*n_state),
1624
+ n_state/n_head, n_head, n_past + N),
1625
+ 0 , 2 , 1 , 3 );
1626
+
1627
+ struct ggml_tensor * V =
1628
+ ggml_cpy (ctxL,
1629
+ ggml_permute (ctxL,
1630
+ ggml_reshape_3d (ctxL,
1631
+ ggml_view_1d (ctxL, model.memory_v , (n_past + N)*n_state, il*n_ctx*ggml_element_size (model.memory_v )*n_state),
1632
+ n_state/n_head, n_head, n_past + N),
1633
+ 1 , 2 , 0 , 3 ),
1634
+ ggml_new_tensor_3d (ctxL, GGML_TYPE_F16, n_past + N, n_state/n_head, n_head)
1635
+ );
1636
+
1637
+ struct ggml_tensor * KQV = ggml_flash_attn (ctxL, Q, K, V, true );
1638
+ #else
1612
1639
struct ggml_tensor * Q =
1613
1640
ggml_permute (ctxL,
1614
1641
ggml_cpy (ctxL,
@@ -1626,13 +1653,13 @@ static bool whisper_decode(
1626
1653
// K * Q
1627
1654
struct ggml_tensor * KQ = ggml_mul_mat (ctxL, K, Q);
1628
1655
1629
- // struct ggml_tensor * KQ_scaled =
1630
- // ggml_scale(ctxL,
1631
- // KQ,
1632
- // ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
1633
- // );
1656
+ struct ggml_tensor * KQ_scaled =
1657
+ ggml_scale (ctxL,
1658
+ KQ,
1659
+ ggml_new_f32 (ctxL, 1 .0f /sqrt (float (n_state)/n_head))
1660
+ );
1634
1661
1635
- struct ggml_tensor * KQ_masked = ggml_diag_mask_inf (ctxL, KQ , n_past);
1662
+ struct ggml_tensor * KQ_masked = ggml_diag_mask_inf (ctxL, KQ_scaled , n_past);
1636
1663
1637
1664
struct ggml_tensor * KQ_soft_max = ggml_soft_max (ctxL, KQ_masked);
1638
1665
@@ -1644,6 +1671,7 @@ static bool whisper_decode(
1644
1671
1 , 2 , 0 , 3 );
1645
1672
1646
1673
struct ggml_tensor * KQV = ggml_mul_mat (ctxL, V_trans, KQ_soft_max);
1674
+ #endif
1647
1675
1648
1676
struct ggml_tensor * KQV_merged = ggml_permute (ctxL, KQV, 0 , 2 , 1 , 3 );
1649
1677
@@ -1689,7 +1717,7 @@ static bool whisper_decode(
1689
1717
Qcur),
1690
1718
Qcur);
1691
1719
1692
- Qcur = ggml_scale (ctxL, Qcur, ggml_new_f32 (ctxL, pow (float (n_state)/n_head, -0.25 )));
1720
+ // Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
1693
1721
1694
1722
// Kcross is already scaled
1695
1723
struct ggml_tensor * Kcross =
@@ -1704,6 +1732,24 @@ static bool whisper_decode(
1704
1732
1705
1733
// ------
1706
1734
1735
+ #ifdef USE_FLASH_ATTN
1736
+ struct ggml_tensor * Q =
1737
+ ggml_permute (ctxL,
1738
+ ggml_cpy (ctxL,
1739
+ Qcur,
1740
+ ggml_new_tensor_3d (ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
1741
+ 0 , 2 , 1 , 3 );
1742
+
1743
+ struct ggml_tensor * K = ggml_permute (ctxL, Kcross, 0 , 2 , 1 , 3 );
1744
+
1745
+ struct ggml_tensor * V =
1746
+ ggml_cpy (ctxL,
1747
+ ggml_permute (ctxL, Vcross, 1 , 2 , 0 , 3 ),
1748
+ ggml_new_tensor_3d (ctxL, GGML_TYPE_F16, M, n_state/n_head, n_head)
1749
+ );
1750
+
1751
+ struct ggml_tensor * KQV = ggml_flash_attn (ctxL, Q, K, V, false );
1752
+ #else
1707
1753
struct ggml_tensor * Q =
1708
1754
ggml_permute (ctxL,
1709
1755
ggml_cpy (ctxL,
@@ -1730,6 +1776,7 @@ static bool whisper_decode(
1730
1776
struct ggml_tensor * V_trans = ggml_permute (ctxL, Vcross, 1 , 2 , 0 , 3 );
1731
1777
1732
1778
struct ggml_tensor * KQV = ggml_mul_mat (ctxL, V_trans, KQ_soft_max);
1779
+ #endif
1733
1780
1734
1781
struct ggml_tensor * KQV_merged = ggml_permute (ctxL, KQV, 0 , 2 , 1 , 3 );
1735
1782
0 commit comments