Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3767,8 +3767,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
elif "conv1d" in name:
data_torch = data_torch.squeeze()
elif "q_proj.weight" in name:
q_proj, gate = data_torch.chunk(2, dim=0)
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_GATE, bid), gate)
data_torch = q_proj

return Qwen2MoeModel.modify_tensors(self, data_torch, name, bid)
yield from Qwen2MoeModel.modify_tensors(self, data_torch, name, bid)


@ModelBase.register("GPT2LMHeadModel")
Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@ class MODEL_TENSOR(IntEnum):
ATTN_NORM_2 = auto()
ATTN_OUT_NORM = auto()
ATTN_POST_NORM = auto()
ATTN_GATE = auto()
ATTN_ROT_EMBD = auto()
ATTN_SINKS = auto()
FFN_GATE_INP = auto()
Expand Down Expand Up @@ -776,6 +777,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm",
MODEL_TENSOR.ATTN_POST_NORM: "blk.{bid}.post_attention_norm",
MODEL_TENSOR.ATTN_GATE: "blk.{bid}.attn_gate",
MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp",
MODEL_TENSOR.FFN_GATE_INP_SHEXP: "blk.{bid}.ffn_gate_inp_shexp",
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
Expand Down Expand Up @@ -1478,6 +1480,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.ATTN_GATE,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_INP_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
Expand Down
2 changes: 2 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_GATE, "blk.%d.attn_gate" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
Expand Down Expand Up @@ -2245,6 +2246,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ enum llm_tensor {
LLM_TENSOR_ATTN_Q_A_NORM,
LLM_TENSOR_ATTN_KV_A_NORM,
LLM_TENSOR_ATTN_SUB_NORM,
LLM_TENSOR_ATTN_GATE,
LLM_TENSOR_FFN_SUB_NORM,
LLM_TENSOR_DEC_ATTN_NORM,
LLM_TENSOR_DEC_ATTN_Q,
Expand Down
61 changes: 31 additions & 30 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2434,9 +2434,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0);

if ((i + 1) % 4 == 0) { // TODO: magic 4
// Attention layers
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_ff }, 0);
if (!hparams.is_recurrent(i)) {
// Attention layers
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
Expand All @@ -2445,6 +2445,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0);
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0);

// attn gate
layer.wq_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, n_embd_head_k * n_head }, 0);

} else {
// Linear attention (gated delta net) specific tensors
// Create tensors with calculated dimensions
Expand All @@ -2454,7 +2457,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), { hparams.ssm_dt_rank }, 0);
layer.ssm_beta_alpha = create_tensor(tn(LLM_TENSOR_SSM_BETA_ALPHA, "weight", i), { n_embd, ba_projection_size }, 0);
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0);
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { n_ff, n_embd }, 0);
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0);
}

layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0);
Expand Down Expand Up @@ -19032,30 +19035,27 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
const int64_t n_embd_head,
const int il) {

// QKV projection with gating
ggml_tensor * qkv_g = build_lora_mm(model.layers[il].wq, cur);
cb(qkv_g, "qkv_g", il);

// Split into Q and gate
const int64_t n_embd_q = hparams.n_head(il) * n_embd_head;
ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv_g, n_embd_head, hparams.n_head(il), n_tokens,
n_embd_head * sizeof(float), qkv_g->nb[1], 0);
ggml_tensor * gate = ggml_view_3d(ctx0, qkv_g, n_embd_head, hparams.n_head(il), n_tokens,
n_embd_head * sizeof(float), qkv_g->nb[1], n_embd_q * ggml_element_size(qkv_g));

// K and V projections
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
ggml_tensor * gate = build_lora_mm(model.layers[il].wq_gate, cur);

// compute Q and K and RoPE them
struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);

struct ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);

struct ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);

Qcur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, Qcur), n_embd_head, hparams.n_head(il), n_tokens);
Kcur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, Kcur), n_embd_head, hparams.n_head_kv(il), n_tokens);
Vcur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, Vcur), n_embd_head, hparams.n_head_kv(il), n_tokens);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);

// Apply Q/K normalization
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
cb(Kcur, "Qcur_normed", il);
cb(Kcur, "Kcur_normed", il);

// Apply RoPE
Qcur = ggml_rope_ext(
Expand All @@ -19079,7 +19079,6 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);

// Apply gating
gate = ggml_reshape_2d(ctx0, ggml_cont(ctx0, gate), n_embd_q, n_tokens);
cur = ggml_cont(ctx0, ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)));
cb(cur, "attn_gated", il);

Expand Down Expand Up @@ -19182,16 +19181,10 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {

GGML_ASSERT(ggml_nelements(beta) + ggml_nelements(alpha) == ggml_nelements(mixed_ba));

// Softplus would be nice...
ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); // a + dt_bias
ggml_tensor * alpha_exp = ggml_exp(ctx0, alpha_biased); // exp(a + dt_bias)
ggml_tensor * one_tensor = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); // Create scalar tensor
ggml_exp(ctx0, one_tensor); // make it a 1
ggml_tensor * one_plus_exp = ggml_add1(ctx0, alpha_exp, one_tensor); // 1 + exp(a + dt_bias)
ggml_tensor * alpha_softplus = ggml_log(ctx0, one_plus_exp); // log(1 + exp(...))
ggml_tensor * alpha_softplus = softplus(alpha, model.layers[il].ssm_dt);
ggml_tensor * A_log_exp = ggml_exp(ctx0, model.layers[il].ssm_a); // A_log.exp()
ggml_tensor * gate_scaled = ggml_mul(ctx0, alpha_softplus, A_log_exp); // A_log.exp() * softplus
ggml_tensor * gate = ggml_neg(ctx0, gate_scaled); // - (A_log.exp() * softplus)
ggml_tensor * gate = ggml_scale(ctx0, gate_scaled, -1.0f); // - (A_log.exp() * softplus)

// Get convolution weights and bias
ggml_tensor * conv_weight = model.layers[il].ssm_conv1d;
Expand Down Expand Up @@ -19325,6 +19318,14 @@ struct llm_build_qwen3next : public llm_graph_context_mamba {

return cur;
}

ggml_tensor * softplus(ggml_tensor * alpha, ggml_tensor * dt_bias) {
ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, dt_bias); // a + dt_bias
ggml_tensor * alpha_exp = ggml_exp(ctx0, alpha_biased); // exp(a + dt_bias)
ggml_tensor * one_plus_exp = ggml_scale_bias(ctx0, alpha_exp, 1.0f, 1.0f); // 1 + exp(a + dt_bias)
ggml_tensor * alpha_softplus = ggml_log(ctx0, one_plus_exp); // log(1 + exp(...))
return alpha_softplus;
}
};


Expand Down
1 change: 1 addition & 0 deletions src/llama-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ struct llama_layer {
struct ggml_tensor * wk_enc = nullptr;
struct ggml_tensor * wv_enc = nullptr;
struct ggml_tensor * wo_enc = nullptr;
struct ggml_tensor * wq_gate = nullptr;

// attention bias
struct ggml_tensor * bq = nullptr;
Expand Down
Loading