## importance score


In [None]:
import torch
from transformers import AutoTokenizer, GPTNeoXForCausalLM
from datasets import load_dataset, Dataset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# ========== 0. 配置 ==========
# 你可以换成别的，比如 "EleutherAI/gpt-neox-20b" 或你自己的 GPT-NeoX-GQA
MODEL_NAME = "EleutherAI/pythia-410m-deduped"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ========== 1. 加载现成 GPT-NeoX 权重 ==========
tok = AutoTokenizer.from_pretrained(MODEL_NAME)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

model = GPTNeoXForCausalLM.from_pretrained(MODEL_NAME)
model.to(DEVICE)
model.eval()
model.config.use_cache = False  # 为了 backward 稳定一点

if getattr(model, "gradient_checkpointing", False):
    model.gradient_checkpointing_disable()

print("num_layers:", model.config.num_hidden_layers)
print("num_heads:", model.config.num_attention_heads)
print("num_kv_heads:",
      getattr(model.config, "num_key_value_heads",
              model.config.num_attention_heads))

# ========== 2. 做一个“正常一点”的 eval loss（让你安心） ==========
#   用 wikitext-2 validation，拼接+切块，标准 causal LM loss

raw_val = load_dataset("wikitext", "wikitext-2-raw-v1", split="validation")

def make_packed(texts, tokenizer, block_size=1024):
    ids = []
    for t in texts:
        ids += tokenizer(t, add_special_tokens=False)["input_ids"] \
               + [tokenizer.eos_token_id]
    full = (len(ids) // block_size) * block_size
    blocks = [ids[i:i+block_size] for i in range(0, full, block_size)]
    return Dataset.from_dict({"input_ids": blocks})

val_texts = [ex["text"] for ex in raw_val]
val_packed = make_packed(val_texts, tok, block_size=1024)

def collate_lm(examples):
    input_ids = [torch.tensor(e["input_ids"], dtype=torch.long)
                 for e in examples]
    input_ids = torch.stack(input_ids, dim=0)  # 已等长，无需 pad
    return {"input_ids": input_ids, "labels": input_ids.clone()}

val_loader = DataLoader(
    val_packed,
    batch_size=4,
    shuffle=False,
    collate_fn=collate_lm,
)

@torch.no_grad()
def eval_loss(model, loader, device="cuda", max_batches=50):
    model.eval()
    total_loss = 0.0
    total_tokens = 0
    for i, batch in enumerate(loader):
        if i >= max_batches:
            break
        inp = batch["input_ids"].to(device)
        labels = batch["labels"].to(device)
        out = model(input_ids=inp, labels=labels)
        loss = out.loss  # batch mean
        total_loss += loss.item() * inp.numel()
        total_tokens += inp.numel()
    return total_loss / total_tokens

avg_loss = eval_loss(model, val_loader, device=DEVICE, max_batches=50)
print(f"[Eval] avg loss ≈ {avg_loss:.4f}, ppl ≈ {torch.exp(torch.tensor(avg_loss)).item():.2f}")

# ========== 3. 准备一个简单的“梯度用校准集” ==========
#   仍然用 wikitext-2 train 的前一小段即可

raw_train = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:40%]")

def encode_fn(ex):
    ids = tok(
        ex["text"],
        truncation=True,
        max_length=512,
        add_special_tokens=False,
    )["input_ids"]
    if len(ids) == 0:
        ids = [tok.eos_token_id]
    return {"input_ids": ids}

encoded = raw_train.map(encode_fn, remove_columns=raw_train.column_names)

def collate_batch(examples):
    input_ids = [torch.tensor(e["input_ids"], dtype=torch.long)
                 for e in examples]
    input_ids = torch.nn.utils.rnn.pad_sequence(
        input_ids, batch_first=True, padding_value=tok.pad_token_id
    )
    attention_mask = (input_ids != tok.pad_token_id).long()
    return {"input_ids": input_ids, "attention_mask": attention_mask}

calib_loader = DataLoader(
    encoded,
    batch_size=2,
    shuffle=False,
    collate_fn=collate_batch,
)

# ========== 4. 通用版：从 grad 里算每层 KV head importance ==========
def compute_kv_importance_from_grad(W_grad, config):
    """
    W_grad: [out_features, in_features] = query_key_value.weight.grad

    支持两种情况：
      1) 普通 MHA: num_key_value_heads == num_attention_heads,
         out_features = 3 * hidden_size
      2) GQA: num_key_value_heads < num_attention_heads,
         out_features = (n_heads + 2 * n_kv_heads) * head_dim
    """
    hidden_size = config.hidden_size
    n_heads = config.num_attention_heads
    n_kv_heads = getattr(config, "num_key_value_heads", n_heads)
    head_dim = hidden_size // n_heads

    out_features, in_features = W_grad.shape
    assert in_features == hidden_size, f"unexpected in_features: {in_features}"

    # 情况 1：无 GQA
    if n_kv_heads == n_heads:
        assert out_features == 3 * hidden_size, \
            f"expected 3*hidden_size={3*hidden_size}, got {out_features}"
        grad_q = W_grad[0:hidden_size, :]
        grad_k = W_grad[hidden_size:2*hidden_size, :]
        grad_v = W_grad[2*hidden_size:3*hidden_size, :]

        grad_k = grad_k.view(n_heads, head_dim, hidden_size)
        grad_v = grad_v.view(n_heads, head_dim, hidden_size)

        imp = (grad_k.pow(2).sum(dim=(1, 2)).sqrt() +
               grad_v.pow(2).sum(dim=(1, 2)).sqrt())
        return imp  # [n_heads] == [n_kv_heads]

    # 情况 2：GQA


    expected_out = (n_heads + 2 * n_kv_heads) * head_dim
    assert out_features == expected_out, \
        f"expected out_features={expected_out}, got {out_features}"

    grad_3d = W_grad.view(n_heads + 2 * n_kv_heads, head_dim, hidden_size)
    grad_q = grad_3d[0:n_heads]
    grad_k = grad_3d[n_heads:n_heads + n_kv_heads]
    grad_v = grad_3d[n_heads + n_kv_heads:]

    imp = (grad_k.pow(2).sum(dim=(1, 2)).sqrt() +
           grad_v.pow(2).sum(dim=(1, 2)).sqrt())
    return imp  # [n_kv_heads]


def compute_all_layer_importances(
    model,
    dataloader,
    max_batches=10,
    device="cuda",
):
    model.zero_grad()
    model.train()  # 为了 backward
    batches_run = 0

    for batch in dataloader:
        if batches_run >= max_batches:
            break

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=input_ids,
        )
        loss = outputs.loss
        loss.backward()  # 累积梯度

        batches_run += 1
        print(f"[Calib] batch {batches_run}, loss={loss.item():.4f}")

    # 回到 eval 状态
    model.eval()

    all_layer_importances = []

    for layer_id, layer in enumerate(model.gpt_neox.layers):
        attn = layer.attention
        W_grad = attn.query_key_value.weight.grad
        if W_grad is None:
            raise RuntimeError(
                f"Layer {layer_id} has no grad on query_key_value.weight; "
                "check that backward has been called."
            )
        kv_imp = compute_kv_importance_from_grad(W_grad, model.config)
        all_layer_importances.append(kv_imp.detach().cpu())

    return all_layer_importances

importance_scores = compute_all_layer_importances(
    model,
    calib_loader,
    max_batches=128,
    device=DEVICE,
)

# 打印前几层看一下
for layer_id, imp in enumerate(importance_scores):
    imp_list = imp.tolist()
    print(f"Layer {layer_id}:")
    print("  KV head importance:", [round(x, 4) for x in imp_list])

torch.save(
    {"importance_scores": importance_scores},
    "kv_head_importance_neox.pt",
)
print("Saved importance scores to kv_head_importance_neox.pt")

# ========== 5. 可视化：各层×head 的 importance heatmap ==========
# 形状: [num_layers, num_kv_heads]
imp_mat = torch.stack(importance_scores, dim=0)  # [L, H_kv]

plt.figure(figsize=(8, 6))
plt.imshow(imp_mat, aspect="auto")
plt.colorbar(label="KV head importance")
plt.xlabel("KV head index")
plt.ylabel("Layer index")
plt.title(f"KV head importance heatmap ({MODEL_NAME})")
plt.tight_layout()
plt.savefig("kv_head_importance_heatmap.png", dpi=150)
print("Saved heatmap to kv_head_importance_heatmap.png")
plt.close()


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/396 [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/911M [00:00<?, ?B/s]

num_layers: 24
num_heads: 16
num_kv_heads: 16


README.md: 0.00B [00:00, ?B/s]

wikitext-2-raw-v1/test-00000-of-00001.pa(…):   0%|          | 0.00/733k [00:00<?, ?B/s]

wikitext-2-raw-v1/train-00000-of-00001.p(…):   0%|          | 0.00/6.36M [00:00<?, ?B/s]

wikitext-2-raw-v1/validation-00000-of-00(…):   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

[Eval] avg loss ≈ 3.1183, ppl ≈ 22.61


Map:   0%|          | 0/14687 [00:00<?, ? examples/s]

[Calib] batch 1, loss=10.3757
[Calib] batch 2, loss=9.7406
[Calib] batch 3, loss=3.6335
[Calib] batch 4, loss=11.7969
[Calib] batch 5, loss=9.7697
[Calib] batch 6, loss=8.3560
[Calib] batch 7, loss=12.0315
[Calib] batch 8, loss=9.9184
[Calib] batch 9, loss=4.0302
[Calib] batch 10, loss=12.0062
[Calib] batch 11, loss=9.3731
[Calib] batch 12, loss=9.6553
[Calib] batch 13, loss=11.7543
[Calib] batch 14, loss=9.7040
[Calib] batch 15, loss=12.0853
[Calib] batch 16, loss=7.5993
[Calib] batch 17, loss=11.6534
[Calib] batch 18, loss=9.9084
[Calib] batch 19, loss=4.7830
[Calib] batch 20, loss=9.6581
[Calib] batch 21, loss=12.1289
[Calib] batch 22, loss=10.0633
[Calib] batch 23, loss=11.5391
[Calib] batch 24, loss=3.5122
[Calib] batch 25, loss=9.5135
[Calib] batch 26, loss=11.5576
[Calib] batch 27, loss=9.3584
[Calib] batch 28, loss=9.2812
[Calib] batch 29, loss=12.0060
[Calib] batch 30, loss=9.5818
[Calib] batch 31, loss=11.7375
[Calib] batch 32, loss=7.8212
[Calib] batch 33, loss=9.5825
[Calib

## Importance-aware GQA

In [None]:
# ======================= exp_importance_gqa_24L_train.py =======================
import os, math, torch
from datasets import load_dataset, interleave_datasets, IterableDataset
from transformers import (AutoTokenizer, AutoConfig, GPTNeoXForCausalLM,
                          DataCollatorForLanguageModeling, Trainer, TrainingArguments)
from packaging import version
import transformers as tf
from transformers.models.gpt_neox.modeling_gpt_neox import (
    GPTNeoXAttention,
    apply_rotary_pos_emb,
)



os.environ["ATTENTION_BACKEND"] = "EAGER"


# ============================================================
# (0) load importance (24 layers × 16 heads)
# ============================================================
imp_obj = torch.load("kv_head_importance_neox.pt")
importance_scores = imp_obj["importance_scores"]    # list of 24 tensors [16]
print("Loaded importance scores from kv_head_importance_neox.pt")
print(f"num_layers={len(importance_scores)}, head_dim={importance_scores[0].shape}")


# ============================================================
# (1) helper: importance → head 分组 → kv_idx
# ============================================================

def assign_kv_groups_from_importance(
    head_imp: torch.Tensor,
    num_q_heads: int = 16,
    num_kv_heads: int = 8,
):
    """
    给定一层的 16 个 head importance，输出：
      - kv_idx: [num_q_heads]，第 q 个 Q head 使用哪个 KV 组 (0..num_kv_heads-1)
      - group_sizes: [num_kv_heads]，第 j 个 KV 组包含多少个 Q head

    策略（你刚才说的那个）：
      1. importance 从大到小排序。
      2. 先让前 num_kv_heads 个最重要的 head 各自占一个 KV 组（初始每组1个）。
      3. 对剩下的较不重要的 head，依次塞到“当前总 importance 最小”的组里，
         这样不重要的 head 会挤在一起共享 KV，重要的 head 更可能独享或少量共享。
    """
    head_imp = torch.as_tensor(head_imp, dtype=torch.float32)
    assert head_imp.numel() == num_q_heads

    # 1) importance 降序
    sorted_idx = torch.argsort(head_imp, descending=True)  # [16]

    # 2) 初始化每个 KV 组先放一个最重要的 head
    groups = [[] for _ in range(num_kv_heads)]
    group_imp_sum = torch.zeros(num_kv_heads, dtype=torch.float32)

    for g in range(num_kv_heads):
        h = sorted_idx[g].item()
        groups[g].append(h)
        group_imp_sum[g] += head_imp[h]

    # 3) 剩下的 head（importance 较低）塞到当前“总importance 最小”的组里
    for idx in sorted_idx[num_kv_heads:]:
        h = idx.item()
        g = torch.argmin(group_imp_sum).item()
        groups[g].append(h)
        group_imp_sum[g] += head_imp[h]

    # 4) 构造 kv_idx / group_sizes
    kv_idx = torch.empty(num_q_heads, dtype=torch.long)
    group_sizes = torch.zeros(num_kv_heads, dtype=torch.long)

    for g, heads in enumerate(groups):
        group_sizes[g] = len(heads)
        for h in heads:
            kv_idx[h] = g  # 第 h 个 Q head → 使用第 g 个 KV head

    assert kv_idx.numel() == num_q_heads
    assert group_sizes.sum().item() == num_q_heads
    return kv_idx, group_sizes, groups


from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXAttention

def repeat_kv_nonuniform(k, kv_idx):
    """
    k: [B, S, H_kv, D]
    kv_idx: [H_q]，第 q 个 Q head 使用的是第 kv_idx[q] 个 KV head

    返回:
        k_expanded: [B, S, H_q, D]
    """
    # 原始维度: [B, S, H_kv, D]
    B, S, H_kv, D = k.shape
    H_q = kv_idx.numel()

    # 将 kv_idx 移动到 k 所在的设备
    kv_idx = kv_idx.to(k.device)

    # 为了方便索引，将 K 矩阵重塑为 [B*S, H_kv, D]
    k_flat = k.view(B * S, H_kv, D)

    # 使用 kv_idx 来选择 K Head。
    # torch.index_select 是最清晰的方式，但对于这个场景，使用高级索引更简洁。
    #
    # [B*S, H_kv, D] -> 沿 H_kv 轴进行索引 (kv_idx: [H_q])
    # 结果是 [B*S, H_q, D]
    k_sel = k_flat[:, kv_idx, :]

    # 变回 [B, S, H_q, D]
    k_expanded = k_sel.view(B, S, H_q, D)

    return k_expanded


class CaGQAAttention(GPTNeoXAttention):
    """
    基于重要性的非均匀分组查询注意力 (Ca-GQA)
    它覆盖了 GPTNeoXAttention 的 forward 方法，使用非均匀的 kv_idx 来扩展 K/V 矩阵。
    """

    # 启用 Ca-GQA 模式，并禁用 Flash Attention/SDPA，确保使用 EAGER 模式。
    # 注意：self.attention_op 必须为 None
    def forward(
        self,
        hidden_states: torch.FloatTensor,
        attention_mask: torch.FloatTensor,
        head_mask: torch.FloatTensor | None = None,
        layer_past: tuple[torch.Tensor] | None = None,
        use_cache: bool = False,
        output_attentions: bool = False,
        position_ids: torch.LongTensor | None = None,
        cache_position: torch.LongTensor | None = None,
        **kwargs,
    ) -> tuple[torch.Tensor, tuple[torch.Tensor] | None]:

        # 0. 检查 kv_idx 是否存在
        if not hasattr(self, "kv_idx") or self.kv_idx is None:
            # 如果没有 kv_idx，则退化到父类的标准 forward 行为 (通常是 uniform GQA 或 MHA)
            return super().forward(
                hidden_states,
                attention_mask,
                head_mask,
                layer_past,
                use_cache,
                output_attentions,
            )

        # 1. QKV 投影 (与标准 GQA 相同)
        # qkv: [batch_size, seq_len, 3*hidden_size] 或 [batch_size, seq_len, (H_q + 2*H_kv)*head_dim]
        qkv = self.query_key_value(hidden_states)

        # 2. 分割 Q, K, V (与标准 GQA 相同)
        # q_proj: [B, S, H_q, D], k_proj: [B, S, H_kv, D], v_proj: [B, S, H_kv, D]
        query_layer, key_layer, value_layer = self._split_heads(qkv)

        # 3. 旋转位置编码 (RoPE) (与标准 GQA 相同)

        # 如果 self.rotary_emb 不存在，或者其 forward 签名不符合预期，
        # 则使用 apply_rotary_pos_emb 函数
        if version.parse(tf.__version__) < version.parse("4.37.0"):
            # 老版本 RoPE
            query_layer, key_layer = apply_rotary_pos_emb(
                query_layer, key_layer, self.rotary_emb
            )
        else:
            # 新版本 RoPE (需要计算位置 ID)
            position_ids = self.bias.long().cumsum(-1).to(query_layer.device)
            k_seq_len = key_layer.shape[-2]
            position_ids = position_ids[:, -k_seq_len:]

            cos, sin = self.rotary_emb(value_layer, seq_len=k_seq_len)
            query_layer, key_layer = apply_rotary_pos_emb(
                query_layer, key_layer, cos, sin, position_ids
            )

        # 4. 关键步骤：非均匀 KV 扩展
        # key_layer: [B, S, H_kv, D] -> [B, S, H_q, D]
        # value_layer: [B, S, H_kv, D] -> [B, S, H_q, D]
        print("here!")
        key_layer = repeat_kv_nonuniform(key_layer, self.kv_idx)
        value_layer = repeat_kv_nonuniform(value_layer, self.kv_idx)

        # 5. KV 缓存 (与标准 GQA 相同)
        if layer_past is not None:
            past_key, past_value = layer_past
            key_layer = torch.cat((past_key, key_layer), dim=-2)
            value_layer = torch.cat((past_value, value_layer), dim=-2)

        if use_cache:
            present = (key_layer, value_layer)
        else:
            present = None

        # 6. 计算注意力得分 (与标准 GQA 相同)
        # query_layer, key_layer, value_layer 此时维度均为 [B, S, H_q, D]
        attn_output, attn_weights = self._attn(
            query_layer, key_layer, value_layer, attention_mask, head_mask
        )

        # 7. 合并头并进行输出投影 (与标准 GQA 相同)
        # attn_output: [B, S, H_q, D] -> [B, S, hidden_size]
        attn_output = self._merge_heads(attn_output)
        attn_output = self.dense(attn_output)

        return (attn_output, present, attn_weights)



# ============================================================
# (3) tokenizer  & dataset (same as baseline)
# ============================================================
tok = AutoTokenizer.from_pretrained("gpt2")
tok.model_max_length = 2048
tok.pad_token = tok.eos_token

fw_edu = load_dataset("HuggingFaceFW/fineweb-edu", split="train", streaming=True)
fw     = load_dataset("HuggingFaceFW/fineweb",     split="train", streaming=True)

train_stream = interleave_datasets([fw_edu, fw], probabilities=[0.2, 0.8], seed=42)

def stream_packer(stream, tokenizer, block_size=2048, buffer_chars=3_000_000):
    buf = ""
    for ex in stream:
        text = ex.get("text") or ex.get("content") or ""
        if not text:
            continue
        buf += text + "\n"
        if len(buf) >= buffer_chars:
            ids = tokenizer(buf, return_attention_mask=False, add_special_tokens=False)["input_ids"]
            ids += [tokenizer.eos_token_id]
            full = (len(ids)//block_size)*block_size
            for j in range(0, full, block_size):
                yield {"input_ids": ids[j:j+block_size]}
            buf = ""
    return

train_packed = IterableDataset.from_generator(
    lambda: stream_packer(train_stream, tok, 2048, 3_000_000)
)


# ============================================================
# (4) model: 24-layer GPT-NeoX + GQA = 8 KV heads
# ============================================================
config = AutoConfig.for_model(
    "gpt_neox",
    vocab_size=len(tok),
    hidden_size=1024,
    num_hidden_layers=24,       # 和 Pythia 对齐
    num_attention_heads=16,
    num_key_value_heads=8,      # GQA 模式
    intermediate_size=4096,
    max_position_embeddings=2048,
    rotary_emb_base=10000,
    layer_norm_eps=1e-5,
    use_cache=True,
    tie_word_embeddings=False,
)

model = GPTNeoXForCausalLM(config)
model.resize_token_embeddings(len(tok))

# 先让它用标准实现（flash_attention_2 保留）
model.config.attn_implementation = "eager"

H_q = config.num_attention_heads      # 16
H_kv = config.num_key_value_heads     # 8

# --- patch each layer with Ca-GQA + kv_idx ----
for layer_id, layer in enumerate(model.gpt_neox.layers):
    head_imp_16 = importance_scores[layer_id]          # [16]

    kv_idx, group_sizes, groups = assign_kv_groups_from_importance(
        head_imp_16,
        num_q_heads=H_q,
        num_kv_heads=H_kv,
    )

    print(f"[Layer {layer_id}] group_sizes={group_sizes.tolist()}, kv_idx={kv_idx.tolist()}")

    # 替换成 CaGQAAttention，并挂上 kv_idx
    new_attn = CaGQAAttention(config)
    new_attn.load_state_dict(layer.attention.state_dict())
    new_attn.register_buffer("kv_idx", kv_idx, persistent=False)
    new_attn.attention_op = None
    layer.attention = new_attn

print("layer.attention_op:", model.gpt_neox.layers[0].attention.attention_op)


# print("[Ca-GQA] patched all 24 layers (kv_idx 已挂载，但目前 forward 仍然是 uniform GQA 实现)。")


# ============================================================
# (5) training (same as baseline, 但把 dataloader_num_workers 改为 0)
# ============================================================
collator = DataCollatorForLanguageModeling(tok, mlm=False)

eval_kwargs = {}
if version.parse(tf.__version__) >= version.parse("4.46.0"):
    eval_kwargs["eval_strategy"] = "steps"
else:
    eval_kwargs["evaluation_strategy"] = "steps"

args = TrainingArguments(
    output_dir="out_ca_gqa_24L",
    per_device_train_batch_size=32,
    gradient_accumulation_steps=1,
    learning_rate=2e-4,
    weight_decay=0.1,
    warmup_ratio=0.02,
    lr_scheduler_type="cosine",
    max_steps=15000,
    logging_steps=50,
    logging_strategy="steps",      # ✅ 明确按 steps 打 log
    logging_first_step=True,       # 可选：第一步也打一次
    save_steps=1000,
    save_strategy="steps",         # 建议和上面统一
    bf16=True,
    gradient_checkpointing=True,
    eval_steps=500,
    eval_strategy="steps" if version.parse(tf.__version__) >= version.parse("4.46.0") else "steps",
    dataloader_num_workers=0,
    report_to="none",              # ✅ 这样会在控制台打印 loss
    ignore_data_skip=True,
)



# ============================================================
# (6) validation dataset same as baseline
# ============================================================
val_stream = load_dataset("wikimedia/wikipedia", "20231101.en", split="train", streaming=True)
def take_first_n(stream, n_docs=256):
    i = 0
    for ex in stream:
        text = ex.get("text", "")
        if text:
            yield {"text": text}
            i += 1
            if i >= n_docs:
                break

from datasets import Dataset
def make_val_packed(texts, block_size=2048):
    ids = []
    for t in texts:
        ids += tok(t, add_special_tokens=False)["input_ids"] + [tok.eos_token_id]
    full = (len(ids)//block_size)*block_size
    blocks = [ids[i:i+block_size] for i in range(0, full, block_size)]
    return Dataset.from_dict({"input_ids": blocks})

val_texts = [ex["text"] for ex in take_first_n(val_stream, 128)]
val_packed = make_val_packed(val_texts, 2048)


# ============================================================
# (7) train Ca-GQA (当前实际上仍是 uniform GQA 的实现，只是 kv_idx 已经准备好)
# ============================================================
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_packed,
    eval_dataset=val_packed,
    data_collator=collator,
)
trainer.train()

trainer.save_model("out_ca_gqa_24L/final")
tok.save_pretrained("out_ca_gqa_24L/final")



Loaded importance scores from kv_head_importance_neox.pt
num_layers=24, head_dim=torch.Size([16])


Resolving data files:   0%|          | 0/2410 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/2410 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/27468 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/27468 [00:00<?, ?it/s]

[Layer 0] group_sizes=[1, 3, 2, 2, 2, 2, 2, 2], kv_idx=[4, 0, 4, 3, 6, 1, 7, 5, 1, 6, 5, 3, 2, 1, 2, 7]
[Layer 1] group_sizes=[2, 2, 2, 2, 2, 2, 2, 2], kv_idx=[6, 7, 3, 1, 4, 4, 0, 5, 0, 2, 7, 1, 5, 6, 2, 3]
[Layer 2] group_sizes=[1, 2, 2, 2, 3, 2, 2, 2], kv_idx=[1, 5, 3, 4, 5, 1, 2, 6, 2, 0, 7, 4, 6, 3, 4, 7]
[Layer 3] group_sizes=[1, 2, 3, 2, 2, 2, 2, 2], kv_idx=[7, 7, 2, 5, 5, 3, 4, 0, 4, 3, 6, 2, 1, 6, 1, 2]
[Layer 4] group_sizes=[1, 1, 2, 3, 3, 2, 2, 2], kv_idx=[5, 4, 2, 7, 0, 4, 6, 5, 4, 1, 3, 3, 2, 6, 3, 7]
[Layer 5] group_sizes=[2, 2, 2, 2, 2, 2, 2, 2], kv_idx=[7, 7, 1, 4, 3, 4, 0, 6, 3, 2, 6, 2, 1, 5, 0, 5]
[Layer 6] group_sizes=[2, 2, 2, 2, 2, 2, 2, 2], kv_idx=[6, 3, 2, 5, 7, 0, 7, 1, 4, 6, 0, 1, 5, 4, 3, 2]
[Layer 7] group_sizes=[2, 2, 2, 2, 2, 2, 2, 2], kv_idx=[7, 1, 4, 4, 2, 2, 7, 5, 3, 6, 6, 1, 3, 5, 0, 0]
[Layer 8] group_sizes=[2, 2, 2, 2, 2, 2, 2, 2], kv_idx=[1, 0, 3, 4, 3, 4, 5, 6, 0, 7, 5, 1, 6, 7, 2, 2]
[Layer 9] group_sizes=[1, 1, 2, 3, 3, 2, 2, 2], kv_idx=[6, 6, 3,

Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (8668 > 2048). Running this sequence through the model will result in indexing errors


AttributeError: 'CaGQAAttention' object has no attribute '_split_heads'

## on my own implementation

In [None]:
from dataclasses import dataclass
import math, torch
import torch.nn as nn
import torch.nn.functional as F

@dataclass
class GPTconfig:
  vocal_size: int = 50257
  n_layers: int = 16
  d_model: int = 1024
  d_ff: int = 4096
  n_head: int = 16
  max_seq_len: int = 2048
  dropout: float = 0.0
  attn_type: str = "full"
  tie_weights: bool = False
  norm_eps: float = 1e-5
  rope_base: int = 10000
  n_kvhead: int = 8
  activation_fn: str = "gelu"

def count_parameters(model: nn.Module):
  return sum(p.numel() for p in model.parameters() if p.requires_grad)

class TokenEmbedding(nn.Module):
  def __init__(self, config: GPTconfig):
    super().__init__()
    self.token_embedding = nn.Embedding(config.vocal_size, config.d_model) ## 会把(B, T)中每个token_id转换到对应的D维embed，所以是输出(B,T,D)
    nn.init.normal_(self.token_embedding.weight, mean=0.0, std=0.02)

  def forward(self, x):  #(B,T)
    return self.token_embedding(x) #(B, T, D)

class SwiGLU(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)  # 门控分支
        self.w2 = nn.Linear(dim, hidden_dim, bias=False)  # 值分支
        self.w3 = nn.Linear(hidden_dim, dim, bias=False)  # 输出投影

        # 添加初始化
        self._init_weights()

    def _init_weights(self):
        # 使用GPT/LLaMA风格的初始化
        nn.init.normal_(self.w1.weight, mean=0.0, std=0.02)
        nn.init.normal_(self.w2.weight, mean=0.0, std=0.02)
        nn.init.normal_(self.w3.weight, mean=0.0, std=0.02)

    def forward(self, x):
        gate = F.silu(self.w1(x))  # 修正：使用SiLU而不是sigmoid
        value = self.w2(x)
        return self.w3(gate * value)

class RMSNorm(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.eps = config.norm_eps
    self.gamma = nn.Parameter(torch.ones(config.d_model))

  # (B, T, C)
  def forward(self, x):
    rms = (x.pow(2).mean(dim=-1, keepdim=True)+self.eps).sqrt() #(B, T, 1)
    return self.gamma * x / rms

class MLP(nn.Module):
  def __init__(self, config):
    super().__init__()
    if config.activation_fn == "gelu":
        self.fc1 = nn.Linear(config.d_model, config.d_ff)
        self.fc2 = nn.Linear(config.d_ff, config.d_model)
        self.dropout = nn.Dropout(config.dropout)
        nn.init.normal_(self.fc1.weight, mean=0.0, std=0.02)
        nn.init.normal_(self.fc2.weight, mean=0.0, std=0.02)
        nn.init.zeros_(self.fc1.bias)
        nn.init.zeros_(self.fc2.bias)
        self.activation = "gelu"

    elif config.activation_fn == "swish":
        self.ffn = SwiGLU(config.d_model, int(config.d_ff*2/3))
        self.dropout = nn.Dropout(config.dropout)
        self.activation = "swish"

  def forward(self, x):
    if self.activation == "gelu":
      x = self.fc1(x)
      x = F.gelu(x)
      x = self.fc2(x)
    elif self.activation == "swish":
      x = self.ffn(x)
    return self.dropout(x)


class RotaryEmbedding(nn.Module):
  def __init__(self,config: GPTconfig):
    super().__init__()
    self.config = config
    head_dim = config.d_model // config.n_head
    assert head_dim % 2 == 0, "head_dim must be divisible by 2"
    inv_freq = 1.0 / (config.rope_base ** (torch.arange(0, head_dim, 2)/head_dim))
    pos = torch.arange(config.max_seq_len, dtype=torch.float32)
    freqs = pos[:, None] * inv_freq[None, :]
    self.register_buffer("sin_cached", torch.sin(freqs), persistent=False)
    self.register_buffer("cos_cached", torch.cos(freqs), persistent=False)

  def forward(self, T, device):
    return self.sin_cached[:T, :].to(device), self.cos_cached[:T, :].to(device)

def apply_RoPE(x, cos, sin):
  # x　(B, H, T, D)
  x1 = x[..., ::2] ## D纬度上的0,2,4,...D-2号元素 (B, H, T, D/2)
  x2 = x[...,1::2] ## D纬度上的1,3,5,...D-1号元素 (B, H, T, D/2)
  cos = cos[None, None, :, :] #(1, 1, T, D/2)
  sin = sin[None, None, :, :] #(1, 1, T, D/2)
  y1 = x1*cos-x2*sin
  y2 = x1*sin+x2*cos
  y = torch.stack([y1, y2], dim=-1).flatten(-2)
  return y



def _split_heads(x, n_head):
  B, T, D = x.shape
  x = x.view(B, T, n_head, D//n_head).transpose(1,2)
  return x

def _merge_heads(x):
  B, H, T, D =x.shape
  x = x.transpose(1,2).contiguous().view(B, T, D*H)
  return x

class CausalMHA_RoPE_GQA(nn.Module):
  def __init__(self,config: GPTconfig, kv_idx: torch.Tensor = None):
    super().__init__()
    assert config.d_model % config.n_head == 0
    self.n_head = config.n_head
    self.n_kvheads = config.n_kvhead if config.n_kvhead else config.n_head
    assert self.n_head % self.n_kvheads == 0, "n_head must be divisible by n_kvheads"
    self.head_dim = config.d_model // config.n_head
    self.kv_proj_dim = self.head_dim * self.n_kvheads
    self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False)
    self.k_proj = nn.Linear(config.d_model, self.kv_proj_dim, bias=False)
    self.v_proj = nn.Linear(config.d_model, self.kv_proj_dim, bias=False)
    self.o_proj = nn.Linear(config.d_model, config.d_model, bias=False)
    self.attn_dropout = nn.Dropout(config.dropout)
    self.rope_emb = RotaryEmbedding(config)

    for mod in [self.q_proj, self.k_proj, self.v_proj, self.o_proj]:
      nn.init.normal_(mod.weight, mean=0.0, std=0.02)


    # <-- 2. 注册 kv_idx
    if kv_idx is not None:
        assert kv_idx.numel() == config.n_head
        self.register_buffer("kv_idx", kv_idx, persistent=False)
    else:
        # 如果没有提供，则 kv_idx 保持 None，使用标准均匀 GQA
        self.kv_idx = None


  def forward(self, x):
    q = self.q_proj(x)
    k = self.k_proj(x)
    v = self.v_proj(x)
    q = _split_heads(q, self.n_head)
    k = _split_heads(k, self.n_kvheads)
    v = _split_heads(v, self.n_kvheads)

    q = apply_RoPE(q, *self.rope_emb(q.shape[2], q.device))
    k = apply_RoPE(k, *self.rope_emb(k.shape[2], k.device))

    if self.n_kvheads != self.n_head:
        if self.kv_idx is not None:
            # 🚀 使用非均匀 GQA
            k = repeat_kv_nonuniform(k, self.kv_idx)
            v = repeat_kv_nonuniform(v, self.kv_idx)
        else:
            # 🤖 使用标准均匀 GQA (作为备用)
            repeat_times = self.n_head // self.n_kvheads
            k = torch.repeat_interleave(k, repeat_times, dim=1)
            v = torch.repeat_interleave(v, repeat_times, dim=1)

    y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.attn_dropout.p, is_causal=True)
    y = _merge_heads(y)
    y = self.attn_dropout(self.o_proj(y))
    return y


class TransformerBlock(nn.Module):
  def __init__(self, config: GPTconfig, kv_idx: torch.Tensor = None):
    super().__init__()
    self.attn = CausalMHA_RoPE_GQA(config, kv_idx)
    self.norm1 = RMSNorm(config)
    self.ffn = MLP(config)
    self.norm2 = RMSNorm(config)

  def forward(self, x):
    x = self.attn(self.norm1(x)) + x
    x = self.ffn(self.norm2(x)) + x
    return x


class GPTNeoX(nn.Module):
  def __init__(self, config: GPTconfig, all_kv_indices: list[torch.Tensor] = None):
    super().__init__()
    self.config = config
    self.token_embedder = TokenEmbedding(config)
    self.transformers = nn.ModuleList()
    for i in range(config.n_layers):
        kv_idx = all_kv_indices[i] if all_kv_indices and i < len(all_kv_indices) else None
        self.transformers.append(TransformerBlock(config, kv_idx))
    self.norm = RMSNorm(config)
    self.proj_to_vocab = nn.Linear(config.d_model, config.vocal_size, bias=False)
    nn.init.normal_(self.proj_to_vocab.weight, mean=0.0, std=0.02)

    if config.tie_weights:
      self.proj_to_vocab.weight = self.token_embedder.token_embedding.weight

  def forward(self, input_ids, labels):
    B, T = input_ids.shape
    x = self.token_embedder(input_ids)
    for transformer in self.transformers:
      x = transformer(x)
    x = self.norm(x)
    logits = self.proj_to_vocab(x)

    loss = None
    if labels is not None:
      loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100)

    return logits, loss



In [None]:
# ======================== 核心函数：非均匀扩展 K/V ========================
def repeat_kv_nonuniform(k_or_v, kv_idx):
    """
    k_or_v: [B, H_kv, T, D] (注意维度顺序是 B, H, T, D)
    kv_idx: [H_q]

    返回:
        expanded_k_or_v: [B, H_q, T, D]
    """
    B, H_kv, T, D = k_or_v.shape
    H_q = kv_idx.numel()

    # 确保索引在正确的设备上
    kv_idx = kv_idx.to(k_or_v.device)

    # 通过高级索引选择对应的 KV head
    # 结果: [B, H_q, T, D]
    expanded_k_or_v = k_or_v[:, kv_idx, :, :]

    return expanded_k_or_v
# =========================================================================

# ============================================================
# (0) load importance (24 layers × 16 heads)
# ============================================================
imp_obj = torch.load("kv_head_importance_neox.pt")
importance_scores = imp_obj["importance_scores"]    # list of 24 tensors [16]
print("Loaded importance scores from kv_head_importance_neox.pt")
print(f"num_layers={len(importance_scores)}, head_dim={importance_scores[0].shape}")


# ============================================================
# (1) helper: importance → head 分组 → kv_idx
# ============================================================

def assign_kv_groups_from_importance(
    head_imp: torch.Tensor,
    num_q_heads: int = 16,
    num_kv_heads: int = 8,
):
    """
    给定一层的 16 个 head importance，输出：
      - kv_idx: [num_q_heads]，第 q 个 Q head 使用哪个 KV 组 (0..num_kv_heads-1)
      - group_sizes: [num_kv_heads]，第 j 个 KV 组包含多少个 Q head

    策略（你刚才说的那个）：
      1. importance 从大到小排序。
      2. 先让前 num_kv_heads 个最重要的 head 各自占一个 KV 组（初始每组1个）。
      3. 对剩下的较不重要的 head，依次塞到“当前总 importance 最小”的组里，
         这样不重要的 head 会挤在一起共享 KV，重要的 head 更可能独享或少量共享。
    """
    head_imp = torch.as_tensor(head_imp, dtype=torch.float32)
    assert head_imp.numel() == num_q_heads

    # 1) importance 降序
    sorted_idx = torch.argsort(head_imp, descending=True)  # [16]

    # 2) 初始化每个 KV 组先放一个最重要的 head
    groups = [[] for _ in range(num_kv_heads)]
    group_imp_sum = torch.zeros(num_kv_heads, dtype=torch.float32)

    for g in range(num_kv_heads):
        h = sorted_idx[g].item()
        groups[g].append(h)
        group_imp_sum[g] += head_imp[h]

    # 3) 剩下的 head（importance 较低）塞到当前“总importance 最小”的组里
    for idx in sorted_idx[num_kv_heads:]:
        h = idx.item()
        g = torch.argmin(group_imp_sum).item()
        groups[g].append(h)
        group_imp_sum[g] += head_imp[h]

    # 4) 构造 kv_idx / group_sizes
    kv_idx = torch.empty(num_q_heads, dtype=torch.long)
    group_sizes = torch.zeros(num_kv_heads, dtype=torch.long)

    for g, heads in enumerate(groups):
        group_sizes[g] = len(heads)
        for h in heads:
            kv_idx[h] = g  # 第 h 个 Q head → 使用第 g 个 KV head

    assert kv_idx.numel() == num_q_heads
    assert group_sizes.sum().item() == num_q_heads
    return kv_idx, group_sizes, groups





Loaded importance scores from kv_head_importance_neox.pt
num_layers=24, head_dim=torch.Size([16])


In [None]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os, math, time, json, random
from typing import Iterable, List

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, IterableDataset as TorchIterable

from datasets import load_dataset, interleave_datasets, IterableDataset
from transformers import AutoTokenizer
from contextlib import nullcontext

# ================== 新增导入 ==================
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import numpy as np

# ================== CONFIG (edit here) ==================
args = {
    # Model
    'n_layers': 24,
    'n_head': 16,
    'n_kvhead': 8,
    'd_model': 768,
    'd_ff': 3072,
    'max_seq_len': 2048,
    'dropout': 0.0,
    'tie_weights': False,

    # Train
    'batch_size': 8,
    'grad_accum': 4,
    'lr': 2e-4,
    'weight_decay': 0.1,
    'warmup_ratio': 0.1,
    'max_steps': 40000,            # try 300 for smoke test
    'clip_grad': 1.0,
    'seed': 42,

    # IO
    'output_dir': 'ckpt-colab-300m-imp',
    'save_every': 2000,
    'log_every': 50,
    'eval_every': 500,
    'resume': '',                  # '', 'latest', or path/to/checkpoint.pt

    # Data
    'buffer_chars': 3_000_000,     # increase for more throughput
    'val_docs': 128,
}

# ================== 新增: 训练跟踪器 ==================
class TrainingTracker:
    def __init__(self, output_dir):
        self.output_dir = output_dir
        self.train_losses = []
        self.eval_losses = []
        self.eval_ppls = []
        self.steps = []

        # 创建绘图目录
        self.plot_dir = os.path.join(output_dir, 'plots')
        os.makedirs(self.plot_dir, exist_ok=True)

    def add_train_loss(self, step, loss):
        self.train_losses.append(loss)
        self.steps.append(step)

    def add_eval_loss(self, step, loss, ppl):
        self.eval_losses.append(loss)
        self.eval_ppls.append(ppl)

    def plot_losses(self):
        if len(self.train_losses) < 2:
            return

        plt.figure(figsize=(12, 4))

        # 训练损失
        plt.subplot(1, 2, 1)
        plt.plot(self.steps[:len(self.train_losses)], self.train_losses, 'b-', label='Train Loss')
        if self.eval_losses:
            eval_steps = self.steps[:len(self.eval_losses)]
            plt.plot(eval_steps, self.eval_losses, 'r-', label='Eval Loss')
        plt.xlabel('Step')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.title('Training and Evaluation Loss')

        # PPL
        plt.subplot(1, 2, 2)
        if self.eval_ppls:
            eval_steps = self.steps[:len(self.eval_ppls)]
            plt.plot(eval_steps, self.eval_ppls, 'g-', label='Eval PPL')
            plt.xlabel('Step')
            plt.ylabel('Perplexity')
            plt.legend()
            plt.grid(True, alpha=0.3)
            plt.title('Evaluation Perplexity')
            plt.yscale('log')

        plt.tight_layout()
        plot_path = os.path.join(self.plot_dir, f'training_plot_step_{self.steps[-1]}.png')
        plt.savefig(plot_path, dpi=150, bbox_inches='tight')
        plt.close()

    def save(self):
        data = {
            'train_losses': self.train_losses,
            'eval_losses': self.eval_losses,
            'eval_ppls': self.eval_ppls,
            'steps': self.steps
        }
        with open(os.path.join(self.output_dir, 'training_history.json'), 'w') as f:
            json.dump(data, f, indent=2)

    def load(self):
        history_path = os.path.join(self.output_dir, 'training_history.json')
        if os.path.exists(history_path):
            with open(history_path, 'r') as f:
                data = json.load(f)
                self.train_losses = data['train_losses']
                self.eval_losses = data['eval_losses']
                self.eval_ppls = data['eval_ppls']
                self.steps = data['steps']

# ================== Utilities ==================

def set_seed(seed: int):
    import numpy as np
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

class SpeedEMA:
    def __init__(self, alpha=0.1):
        self.alpha = alpha
        self.v = None
    def update(self, x):
        self.v = x if self.v is None else (1 - self.alpha) * self.v + self.alpha * x
        return self.v

# ================== Data (stream packer) ==================

def stream_packer(stream: Iterable, tokenizer, block_size=2048, buffer_chars=3_000_000):
    """Concatenate raw texts into a large buffer, tokenize once, cut into fixed blocks."""
    buf = ""
    for ex in stream:
        text = ex.get("text") or ex.get("content") or ""
        if not text:
            continue
        buf += text + "\n"
        if len(buf) >= buffer_chars:
            ids = tokenizer(buf, return_attention_mask=False, add_special_tokens=False)["input_ids"]
            ids += [tokenizer.eos_token_id]
            full = (len(ids) // block_size) * block_size
            for j in range(0, full, block_size):
                yield {"input_ids": ids[j:j+block_size]}
            buf = ""

class PackedIterable(TorchIterable):
    def __init__(self, hf_iterable, tokenizer, block_size=2048, buffer_chars=3_000_000):
        super().__init__()
        self.hf_iterable = hf_iterable
        self.tokenizer = tokenizer
        self.block_size = block_size
        self.buffer_chars = buffer_chars
    def __iter__(self):
        for ex in stream_packer(self.hf_iterable, self.tokenizer, self.block_size, self.buffer_chars):
            yield ex

def lm_collate(batch, pad_id: int, block_size: int):
    import torch.nn.functional as F
    ids = [torch.tensor(x["input_ids"], dtype=torch.long) for x in batch]
    x = torch.stack([F.pad(t, (0, block_size - t.size(0)), value=pad_id)[:block_size] for t in ids], dim=0)

    # 目标：labels[:, t] = x[:, t+1]；最后一个位置和 padding 置为 -100
    labels = torch.full_like(x, -100)
    labels[:, :-1] = x[:, 1:]

    # 若下一个 token 是 pad，也不训练该位置
    labels[:, :-1][x[:, 1:] == pad_id] = -100

    return {"input_ids": x, "labels": labels}


# ================== Scheduler (cosine + warmup) ==================

def build_scheduler(optimizer, total_steps: int, warmup_ratio: float):
    warmup = int(warmup_ratio * total_steps)
    def lr_lambda(step):
        if warmup > 0 and step < warmup:
            return step / max(1, warmup)
        if total_steps <= warmup:
            return 1.0
        prog = (step - warmup) / max(1, total_steps - warmup)
        return 0.5 * (1.0 + math.cos(math.pi * prog))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

# ================== Eval (PPL on small wiki sample) ==================
@torch.no_grad()
def eval_ppl(model, tokenizer, texts: List[str], block_size=2048, device="cuda", use_bf16=True):
    model.eval()
    ids: List[int] = []
    for t in texts:
        ids += tokenizer(t, add_special_tokens=False)["input_ids"] + [tokenizer.eos_token_id]
    full = (len(ids) // block_size) * block_size
    if full == 0:
        return float("nan"), float("nan")
    ids = torch.tensor(ids[:full], dtype=torch.long, device=device).view(-1, block_size)
    total_loss = 0.0
    amp_dtype = torch.bfloat16 if use_bf16 else torch.float16

    for i in range(ids.size(0)):
        inp = ids[i].unsqueeze(0)

        # 关键修复：创建正确的 shifted labels
        labels = torch.full_like(inp, -100)
        labels[:, :-1] = inp[:, 1:]  # 与训练时保持一致：labels = input_ids 右移一位

        with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=(device=="cuda")):
            _, loss = model(inp, labels=labels)  # 现在 labels 是正确的移位版本

        total_loss += float(loss)

    model.train()
    mean_loss = total_loss / ids.size(0)
    return math.exp(mean_loss), mean_loss


# ================== Main (no argparse) ==================

def pick_amp_dtype():
    # Colab: T4/L4 → fp16; A100/H100/L4(bf16-capable) → bf16
    cc = torch.cuda.get_device_capability() if torch.cuda.is_available() else (0,0)
    # Very rough rule: Ampere(8.x)+ supports bf16 well
    use_bf16 = torch.cuda.is_available() and (cc[0] >= 8)
    return (torch.bfloat16 if use_bf16 else torch.float16), use_bf16


def run_training():
    set_seed(args['seed'])
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # 初始化训练跟踪器
    tracker = TrainingTracker(args['output_dir'])

    # Tokenizer
    tok = AutoTokenizer.from_pretrained('gpt2')
    tok.model_max_length = args['max_seq_len']
    tok.pad_token = tok.eos_token

    # Import your GPT (RoPE + GQA)
    # 注意：这里需要你提供 GPTconfig 和 GPTNeoX 的定义
    cfg = GPTconfig(
        vocal_size=len(tok),
        n_layers=args['n_layers'],
        n_head=args['n_head'],
        n_kvhead=args['n_kvhead'],
        d_model=args['d_model'],
        d_ff=args['d_ff'],
        max_seq_len=args['max_seq_len'],
        dropout=args['dropout'],
        tie_weights=args['tie_weights'],
    )

    # now_method = "importance_score"
    now_method = "GQA"

    kv_indices = []


    if now_method == "importance_score":
        # ======================== 核心函数：非均匀扩展 K/V ========================

        for layer_id in range(args['n_layers']):
            head_imp_16 = importance_scores[layer_id]          # [16]

            kv_idx, group_sizes, groups = assign_kv_groups_from_importance(
                head_imp_16,
                num_q_heads=args['n_head'],
                num_kv_heads=args['n_kvhead'],
            )

            print(f"[Layer {layer_id}] group_sizes={group_sizes.tolist()}, kv_idx={kv_idx.tolist()}")


            kv_indices.append(torch.as_tensor(kv_idx.tolist(), dtype=torch.long))

    elif now_method == "GQA":
        for layer_id in range(args['n_layers']):
            kv_indices.append(None)
    else:
        raise ValueError(f"Unknown method: {now_method}")


    model = GPTNeoX(cfg, kv_indices).to(device)


    # Data streams
    fw_edu = load_dataset('HuggingFaceFW/fineweb-edu', split='train', streaming=True)
    fw     = load_dataset('HuggingFaceFW/fineweb',     split='train', streaming=True)
    train_stream = interleave_datasets([fw_edu, fw], probabilities=[0.2, 0.8], seed=args['seed'])

    train_packed = PackedIterable(train_stream, tok, block_size=args['max_seq_len'], buffer_chars=args['buffer_chars'])
    collate_fn = lambda batch: lm_collate(batch, pad_id=tok.eos_token_id, block_size=args['max_seq_len'])
    loader = DataLoader(train_packed, batch_size=args['batch_size'], collate_fn=collate_fn, num_workers=2)

    # Optimizer
    decay, no_decay = set(), set()
    for n, p in model.named_parameters():
        if p.ndim >= 2 and "embedding" not in n:  # 线性/卷积权重
            decay.add(n)
        else:  # bias, norm, embedding 都不 decay
            no_decay.add(n)
    optim_groups = [
        {"params": [p for n, p in model.named_parameters() if n in decay], "weight_decay": args['weight_decay']},
        {"params": [p for n, p in model.named_parameters() if n in no_decay], "weight_decay": 0.0},
    ]
    optimizer = torch.optim.AdamW(optim_groups, lr=args['lr'], betas=(0.9, 0.95), eps=1e-8)
    scheduler = build_scheduler(optimizer, total_steps=args['max_steps'], warmup_ratio=args['warmup_ratio'])

    # AMP autocast
    amp_dtype, use_bf16 = pick_amp_dtype()
    autocast_ctx = torch.autocast(device_type='cuda', dtype=amp_dtype) if device == 'cuda' else nullcontext()

    # Eval set
    val_texts = []
    try:
        val_texts = load_dataset("wikimedia/wikipedia", "20231101.en", split="train", streaming=True)
        def _take(stream, n):
            c=0
            for ex in stream:
                t = ex.get('text')
                if t:
                    yield t
                    c+=1
                    if c>=n: break
        val_texts = list(_take(val_texts, args['val_docs']))
    except Exception:
        val_texts = []

    # Resume
    os.makedirs(args['output_dir'], exist_ok=True)
    global_step = 0
    if args['resume']:
        ckpt_path = None
        if args['resume'] == 'latest':
            cands = [x for x in os.listdir(args['output_dir']) if x.startswith('checkpoint-') and x.endswith('.pt')]
            if cands:
                cands.sort(key=lambda s: int(s.split('-')[-1].split('.')[0]))
                ckpt_path = os.path.join(args['output_dir'], cands[-1])
        else:
            ckpt_path = args['resume']
        if ckpt_path and os.path.isfile(ckpt_path):
            state = torch.load(ckpt_path, map_location='cpu')
            model.load_state_dict(state['model'])
            optimizer.load_state_dict(state['optimizer'])
            scheduler.load_state_dict(state['scheduler'])
            global_step = int(state.get('step', 0))
            # 加载训练历史
            if 'tracker' in state:
                tracker = state['tracker']
            else:
                tracker.load()  # 从文件加载历史
            print(f"[resume] loaded {ckpt_path} at step {global_step}")

    # 初始化进度条
    pbar = tqdm(total=args['max_steps'], initial=global_step, desc="Training",
                unit="step", dynamic_ncols=True, position=0)

    # Training
    model.train()
    tokens_per_step = args['batch_size'] * args['max_seq_len'] * args['grad_accum']
    ema = SpeedEMA(0.1)
    t0 = time.time()
    accum = 0

    accumulated_loss = 0.0

    for batch in loader:
        if global_step >= args['max_steps']:
            break

        input_ids = batch['input_ids'].to(device, non_blocking=True)
        labels = batch['labels'].to(device, non_blocking=True)

        with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=(device=='cuda')):
            _, loss = model(input_ids, labels=labels)
            loss = loss / args['grad_accum']  # 为梯度累积归一化损失

        loss.backward()
        accumulated_loss += loss.item()  # 累加归一化后的损失
        accum += 1

        if accum % args['grad_accum'] == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), args['clip_grad'])
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
            scheduler.step()

            # 正确的损失：所有微批次的平均损失
            current_loss = accumulated_loss  # 因为每个loss已经除以grad_accum，所以直接累加就是平均损失

            # 记录训练损失到跟踪器
            tracker.add_train_loss(global_step, current_loss)

            # 更新进度条
            lr = scheduler.get_last_lr()[0]
            pbar.set_postfix({
                'loss': f'{current_loss:.4f}',
                'lr': f'{lr:.2e}',
                'step': f'{global_step}/{args["max_steps"]}'
            })

            # 重置累加器
            accumulated_loss = 0.0

            global_step += 1
            pbar.update(1)

            # 记录训练损失
            tracker.add_train_loss(global_step, current_loss)

            # logging
            if global_step % args['log_every'] == 0:
                dt = time.time() - t0
                tps = tokens_per_step / max(dt, 1e-6)
                ema_tps = ema.update(tps)

                # 计算预计剩余时间
                remaining_steps = args['max_steps'] - global_step
                eta_seconds = remaining_steps * (dt / args['log_every'])
                eta_str = time.strftime("%H:%M:%S", time.gmtime(eta_seconds))

                print(f"\nstep {global_step:6d} | loss {current_loss:.4f} | lr {lr:.2e} | "
                      f"tokens/s {ema_tps:,.0f} | ETA {eta_str} | {(global_step/args['max_steps']):.1%}")
                t0 = time.time()

            # eval
            if args['eval_every'] and (global_step % args['eval_every'] == 0) and val_texts:
                ppl, eval_loss = eval_ppl(model, tok, val_texts, block_size=args['max_seq_len'], device=device, use_bf16=use_bf16)
                tracker.add_eval_loss(global_step, eval_loss, ppl)
                print(f"[eval] step {global_step} | ppl {ppl:.2f} | eval_loss {eval_loss:.4f}")

                # 更新损失曲线
                tracker.plot_losses()

            # save checkpoint (改进版)
            if args['save_every'] and (global_step % args['save_every'] == 0):
                ckpt_path = os.path.join(args['output_dir'], f"checkpoint-{global_step}.pt")

                # 保存完整的训练状态
                torch.save({
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'step': global_step,
                    'tracker': tracker,  # 保存跟踪器状态
                    'config': args,
                }, ckpt_path)

                # 同时保存一个最新的副本
                latest_path = os.path.join(args['output_dir'], "checkpoint-latest.pt")
                torch.save({
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'step': global_step,
                    'tracker': tracker,
                    'config': args,
                }, latest_path)

                # 保存训练历史到单独文件
                tracker.save()

                print(f"[save] {ckpt_path}")

    pbar.close()

    # final save (改进版)
    final_dir = os.path.join(args['output_dir'], 'final')
    os.makedirs(final_dir, exist_ok=True)

    # 保存最终模型
    torch.save(model.state_dict(), os.path.join(final_dir, 'model.pt'))

    # 保存最终检查点
    torch.save({
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scheduler': scheduler.state_dict(),
        'step': global_step,
        'tracker': tracker,
        'config': args,
    }, os.path.join(final_dir, 'checkpoint-final.pt'))

    # 保存tokenizer和配置
    tok.save_pretrained(final_dir)
    with open(os.path.join(final_dir, 'train_args.json'), 'w') as f:
        json.dump(args, f, indent=2)

    # 生成最终损失曲线
    tracker.plot_losses()
    tracker.save()

    print(f'Pretrain completed! Final model saved to {final_dir}')

# ================== Kick off ==================
if __name__ == "__main__":
    run_training()

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

Resolving data files:   0%|          | 0/2410 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/2410 [00:00<?, ?it/s]

README.md: 0.00B [00:00, ?B/s]

Resolving data files:   0%|          | 0/27468 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/27468 [00:00<?, ?it/s]

README.md: 0.00B [00:00, ?B/s]

Resolving data files:   0%|          | 0/41 [00:00<?, ?it/s]

Training:   0%|          | 0/40000 [00:00<?, ?step/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (679041 > 2048). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (675570 > 2048). Running this sequence through the model will result in indexing errors



step     50 | loss 10.0323 | lr 2.50e-06 | tokens/s 731 | ETA 19:54:21 | 0.1%

step    100 | loss 9.5164 | lr 5.00e-06 | tokens/s 742 | ETA 17:18:16 | 0.2%

step    150 | loss 9.1142 | lr 7.50e-06 | tokens/s 749 | ETA 17:50:59 | 0.4%

step    200 | loss 8.7174 | lr 1.00e-05 | tokens/s 758 | ETA 17:18:50 | 0.5%

step    250 | loss 8.3645 | lr 1.25e-05 | tokens/s 763 | ETA 17:45:07 | 0.6%

step    300 | loss 7.9967 | lr 1.50e-05 | tokens/s 771 | ETA 17:15:44 | 0.8%

step    350 | loss 7.5827 | lr 1.75e-05 | tokens/s 775 | ETA 17:46:25 | 0.9%

step    400 | loss 7.4166 | lr 2.00e-05 | tokens/s 781 | ETA 17:14:41 | 1.0%

step    450 | loss 7.2176 | lr 2.25e-05 | tokens/s 787 | ETA 17:09:27 | 1.1%


Token indices sequence length is longer than the specified maximum sequence length for this model (8668 > 2048). Running this sequence through the model will result in indexing errors



step    500 | loss 7.0144 | lr 2.50e-05 | tokens/s 789 | ETA 17:41:16 | 1.2%
[eval] step 500 | ppl 1546.89 | eval_loss 7.3440

step    550 | loss 6.8935 | lr 2.75e-05 | tokens/s 778 | ETA 21:16:05 | 1.4%

step    600 | loss 6.8243 | lr 3.00e-05 | tokens/s 782 | ETA 17:36:20 | 1.5%

step    650 | loss 6.6659 | lr 3.25e-05 | tokens/s 787 | ETA 17:10:56 | 1.6%

step    700 | loss 6.4458 | lr 3.50e-05 | tokens/s 789 | ETA 17:38:08 | 1.8%

step    750 | loss 6.4388 | lr 3.75e-05 | tokens/s 792 | ETA 17:34:54 | 1.9%

step    800 | loss 6.4409 | lr 4.00e-05 | tokens/s 796 | ETA 17:02:29 | 2.0%

step    850 | loss 6.2858 | lr 4.25e-05 | tokens/s 798 | ETA 17:32:55 | 2.1%

step    900 | loss 6.3099 | lr 4.50e-05 | tokens/s 802 | ETA 16:59:41 | 2.2%

step    950 | loss 6.2393 | lr 4.75e-05 | tokens/s 803 | ETA 17:31:05 | 2.4%

step   1000 | loss 6.2395 | lr 5.00e-05 | tokens/s 806 | ETA 16:57:13 | 2.5%
[eval] step 1000 | ppl 696.84 | eval_loss 6.5465

step   1050 | loss 6.0064 | lr 5.25e-05 | t

'(ProtocolError('Connection aborted.', RemoteDisconnected('Remote end closed connection without response')), '(Request ID: 89f60df5-78ee-4551-b272-7bf39665cfd4)')' thrown while requesting GET https://huggingface.co/datasets/HuggingFaceFW/fineweb/resolve/9bb295ddab0e05d785b879661af7260fed5140fc/data/CC-MAIN-2013-20/000_00000.parquet
Retrying in 1s [Retry 1/5].



step   7050 | loss 4.0368 | lr 1.96e-04 | tokens/s 755 | ETA 18:35:43 | 17.6%

step   7100 | loss 3.8954 | lr 1.96e-04 | tokens/s 755 | ETA 15:48:13 | 17.8%

step   7150 | loss 3.9317 | lr 1.96e-04 | tokens/s 760 | ETA 15:00:28 | 17.9%

step   7200 | loss 3.8445 | lr 1.96e-04 | tokens/s 761 | ETA 15:21:59 | 18.0%

step   7250 | loss 3.8883 | lr 1.96e-04 | tokens/s 763 | ETA 15:25:10 | 18.1%

step   7300 | loss 3.9081 | lr 1.96e-04 | tokens/s 763 | ETA 15:28:10 | 18.2%

step   7350 | loss 3.9325 | lr 1.96e-04 | tokens/s 764 | ETA 15:25:12 | 18.4%

step   7400 | loss 3.9444 | lr 1.96e-04 | tokens/s 767 | ETA 14:55:30 | 18.5%

step   7450 | loss 3.8539 | lr 1.96e-04 | tokens/s 768 | ETA 15:20:16 | 18.6%

step   7500 | loss 3.9597 | lr 1.95e-04 | tokens/s 768 | ETA 15:18:13 | 18.8%
[eval] step 7500 | ppl 62.12 | eval_loss 4.1291

step   7550 | loss 3.8597 | lr 1.95e-04 | tokens/s 756 | ETA 18:12:44 | 18.9%

step   7600 | loss 3.8071 | lr 1.95e-04 | tokens/s 756 | ETA 15:38:15 | 19.0%

ste

'(ProtocolError('Connection aborted.', RemoteDisconnected('Remote end closed connection without response')), '(Request ID: 34361d69-59d3-4d6d-a96b-b2420a314eb4)')' thrown while requesting GET https://huggingface.co/datasets/HuggingFaceFW/fineweb/resolve/9bb295ddab0e05d785b879661af7260fed5140fc/data/CC-MAIN-2013-20/000_00000.parquet
Retrying in 1s [Retry 1/5].



step  14050 | loss 3.6566 | lr 1.64e-04 | tokens/s 746 | ETA 17:37:56 | 35.1%

step  14100 | loss 3.5978 | lr 1.64e-04 | tokens/s 751 | ETA 11:53:23 | 35.2%

step  14150 | loss 3.5914 | lr 1.63e-04 | tokens/s 754 | ETA 12:06:12 | 35.4%

step  14200 | loss 3.6875 | lr 1.63e-04 | tokens/s 756 | ETA 12:08:26 | 35.5%

step  14250 | loss 3.6901 | lr 1.63e-04 | tokens/s 758 | ETA 12:06:35 | 35.6%

step  14300 | loss 3.7380 | lr 1.62e-04 | tokens/s 759 | ETA 12:05:02 | 35.8%

step  14350 | loss 3.6841 | lr 1.62e-04 | tokens/s 763 | ETA 11:42:36 | 35.9%

step  14400 | loss 3.6358 | lr 1.62e-04 | tokens/s 764 | ETA 12:03:08 | 36.0%

step  14450 | loss 3.6437 | lr 1.61e-04 | tokens/s 765 | ETA 12:01:29 | 36.1%

step  14500 | loss 3.5631 | lr 1.61e-04 | tokens/s 766 | ETA 11:58:38 | 36.2%


KeyboardInterrupt: 

TODO: 试试分析一下自己训得pretrained的权重？
现在的对比还不fair因为两者的model size不一样（大模型吃亏）

- train 一个MHA看看多少step的importance score最像最终的，越早越好（越informative）
- 看一下上一次pretrain（12层）的权重有没有什么特征
- 改进分组策略？
- 试试不同层不同n_kv?比如前几层可以少一些