# Pythia-2.8B Baseline & KVPress 压缩实验

本 Notebook 旨在复现与评估无训练 KV Cache 压缩方法在 **Pythia-2.8B** 模型上的效果。

**实验设置：**
*   **模型**: `EleutherAI/pythia-2.8B`
*   **方法**: Baseline (无压缩), StreamingLLM
*   **数据集**: 
    *   `wikitext-2`: 标准 PPL 测试
    *   `pg-19`: 超长文本测试 (取单一 sample)
*   **指标**: Perplexity (PPL) 和 推理加速比 (Speedup)

In [1]:
import torch
import time
from transformers import AutoModelForCausalLM, AutoTokenizer


# 设置设备
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on {device}")

  from .autonotebook import tqdm as notebook_tqdm


Running on cuda


### pythia的输入输出
首先我们使用70m模型了解pythia的格式（即GPTNeoXAttention）

- GPTNeoX 的forward函数显式接受 layer_past / cache_position / position_embeddings ，内部用 Cache 维护 KV。
- Pythia 使用的是 DynamicCache + DynamicLayer 。
- 每层 KV 的形状是 [batch, num_heads, seq_len, head_dim] 。
- DynamicLayer.keys / values 是可读可写的张量属性（kvpress 里也就是这样直接赋值）。

In [2]:
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXAttention
import inspect
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import Cache

print(inspect.signature(GPTNeoXAttention.forward))


model_id = "EleutherAI/pythia-70m"
print("loading model...")
model_test = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16)
model_test.eval()

text = "Hello world, this is a streaming KV cache test. " * 10
tok = AutoTokenizer.from_pretrained(model_id)
inputs = tok(text, return_tensors="pt")
with torch.no_grad():
    out = model_test(**inputs, use_cache=True)

cache = out.past_key_values
print("cache type:", type(cache))
print("num layers:", len(cache.layers))
layer0 = cache.layers[0]
print("layer0 type:", type(layer0))
print("keys shape:", layer0.keys.shape)
print("values shape:", layer0.values.shape)
print("dir(layer0):", [a for a in dir(layer0) if not a.startswith("_")])


(self, hidden_states: torch.FloatTensor, attention_mask: torch.FloatTensor, head_mask: Optional[torch.FloatTensor] = None, layer_past: Optional[transformers.cache_utils.Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: Unpack[transformers.modeling_flash_attention_utils.FlashAttentionKwargs])
loading model...


'(ProtocolError('Connection aborted.', ConnectionResetError(10054, '远程主机强迫关闭了一个现有的连接。', None, 10054, None)), '(Request ID: cfbc5b79-3396-4c46-b991-3eac81e0b2be)')' thrown while requesting HEAD https://huggingface.co/EleutherAI/pythia-70m/resolve/main/config.json
Retrying in 1s [Retry 1/5].
`torch_dtype` is deprecated! Use `dtype` instead!


cache type: <class 'transformers.cache_utils.DynamicCache'>
num layers: 6
layer0 type: <class 'transformers.cache_utils.DynamicLayer'>
keys shape: torch.Size([1, 8, 121, 64])
values shape: torch.Size([1, 8, 121, 64])
dir(layer0): ['batch_repeat_interleave', 'batch_select_indices', 'crop', 'device', 'dtype', 'get_mask_sizes', 'get_max_cache_shape', 'get_seq_length', 'is_compileable', 'is_initialized', 'is_sliding', 'keys', 'lazy_initialization', 'offload', 'prefetch', 'reorder_cache', 'reset', 'update', 'values']


### 将StreamingLLM适配于pythia模型

In [3]:
class StreamingLLM:
    def __init__(self, n_sink: int = 4, window_size: int = 256):
        self.n_sink = n_sink
        self.window_size = window_size

    def build_context(self, input_ids: torch.Tensor) -> torch.Tensor:
        if input_ids.size(1) <= self.n_sink + self.window_size:
            return input_ids
        sink = input_ids[:, : self.n_sink]
        tail = input_ids[:, -self.window_size :]
        return torch.cat([sink, tail], dim=1)

    def compress_cache(self, cache) -> None:
        for layer in cache.layers:
            keys = layer.keys
            values = layer.values
            bsz, num_heads, seq_len, head_dim = keys.shape
            if seq_len <= self.n_sink + self.window_size:
                continue
            device = keys.device
            sink_end = min(self.n_sink, seq_len)
            tail_len = min(self.window_size, seq_len - sink_end)
            if tail_len <= 0:
                keep_idx = torch.arange(0, sink_end, device=device)
            else:
                tail_start = seq_len - tail_len
                keep_prefix = torch.arange(0, sink_end, device=device)
                keep_tail = torch.arange(tail_start, seq_len, device=device)
                keep_idx = torch.cat([keep_prefix, keep_tail], dim=0)
            keys = keys.index_select(2, keep_idx)
            values = values.index_select(2, keep_idx)
            layer.keys = keys
            layer.values = values

    @torch.no_grad()
    def generate(
        self,
        model: AutoModelForCausalLM,
        input_ids: torch.Tensor,
        max_new_tokens: int = 50,
        temperature: float | None = None,
        top_k: int | None = None,
        top_p: float | None = None,
    ) -> torch.Tensor:
        device = next(model.parameters()).device
        input_ids = input_ids.to(device)
        outputs = model(input_ids, use_cache=True)
        cache = outputs.past_key_values
        self.compress_cache(cache)
        generated = input_ids
        last_token = generated[:, -1:]
        for _ in range(max_new_tokens):
            outputs = model(last_token, use_cache=True, past_key_values=cache)
            logits = outputs.logits[:, -1, :]
            cache = outputs.past_key_values
            self.compress_cache(cache)
            if temperature is not None and temperature > 0:
                logits = logits / temperature
            if top_k is not None and top_k > 0:
                v, _ = torch.topk(logits, top_k)
                min_values = v[:, -1].unsqueeze(-1)
                logits = torch.where(logits < min_values, torch.full_like(logits, -float("inf")), logits)
            if top_p is not None and 0 < top_p < 1:
                sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
                cumulative_probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                sorted_indices_to_remove[..., 0] = 0
                indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
                logits = logits.masked_fill(indices_to_remove, -float("inf"))
            probs = torch.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            generated = torch.cat([generated, next_token], dim=1)
            last_token = next_token
        return generated

def load_model_and_tokenizer(model_id: str, torch_dtype=torch.float16):
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="auto",
        torch_dtype=torch_dtype,
        trust_remote_code=True,
    )
    if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    return model, tokenizer

def streaming_generate_from_text(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    prompt: str,
    n_sink: int = 4,
    window_size: int = 256,
    max_new_tokens: int = 50,
) -> str:
    wrapper = StreamingLLM(n_sink=n_sink, window_size=window_size)
    encoded = tokenizer(prompt, return_tensors="pt")
    input_ids = encoded.input_ids
    generated_ids = wrapper.generate(model, input_ids, max_new_tokens=max_new_tokens)
    return tokenizer.decode(generated_ids[0], skip_special_tokens=True)



### 评估函数定义

In [4]:
def benchmark_speed_streaming_kv(
    model,
    tokenizer,
    prompt: str,
    n_sink: int = 4,
    window_size: int = 256,
    num_tokens: int = 50,
    batch_size: int = 1,
):
    device = next(model.parameters()).device
    enc = tokenizer(prompt, return_tensors="pt")
    input_ids = enc.input_ids.to(device)
    input_ids = input_ids.repeat(batch_size, 1)

    stream = StreamingLLM(n_sink=n_sink, window_size=window_size)

    start = time.time()
    with torch.no_grad():
        outputs = model(input_ids, use_cache=True)
    cache = outputs.past_key_values
    stream.compress_cache(cache)
    generated = input_ids
    last_token = generated[:, -1:]

    ttft = None

    for step in range(num_tokens):
        step_start = time.time() if step == 0 else None
        with torch.no_grad():
            outputs = model(last_token, use_cache=True, past_key_values=cache)
        cache = outputs.past_key_values
        stream.compress_cache(cache)
        logits = outputs.logits[:, -1, :]
        next_token = torch.argmax(logits, dim=-1, keepdim=True)
        generated = torch.cat([generated, next_token], dim=1)
        last_token = next_token
        if step == 0:
            ttft = time.time() - start

    end = time.time()
    total_time = end - start
    if ttft is None:
        ttft = total_time

    token_phase_time = max(total_time - ttft, 1e-6)
    total_tokens = num_tokens * batch_size
    tpot = token_phase_time / total_tokens
    throughput = total_tokens / token_phase_time

    num_params = sum(p.numel() for p in model.parameters())
    prompt_len = input_ids.shape[1]
    total_tokens_processed = (prompt_len + num_tokens) * batch_size
    total_flops = 2.0 * num_params * total_tokens_processed
    avg_flops_per_token = total_flops / total_tokens_processed

    return {
        "ttft": ttft,
        "tpot": tpot,
        "throughput": throughput,
        "total_time": total_time,
        "total_flops": total_flops,
        "avg_flops_per_token": avg_flops_per_token,
    }


def benchmark_speed_baseline(
    model,
    tokenizer,
    prompt: str,
    num_tokens: int = 50,
    batch_size: int = 1,
):
    device = next(model.parameters()).device
    enc = tokenizer(prompt, return_tensors="pt")
    input_ids = enc.input_ids.to(device)
    input_ids = input_ids.repeat(batch_size, 1)

    start = time.time()
    ttft = None

    generated = input_ids
    cache = None

    for step in range(num_tokens):
        with torch.no_grad():
            if cache is None:
                outputs = model(generated, use_cache=True)
            else:
                outputs = model(generated[:, -1:], use_cache=True, past_key_values=cache)
        cache = outputs.past_key_values
        logits = outputs.logits[:, -1, :]
        next_token = torch.argmax(logits, dim=-1, keepdim=True)
        generated = torch.cat([generated, next_token], dim=1)
        if step == 0:
            ttft = time.time() - start

    end = time.time()
    total_time = end - start
    if ttft is None:
        ttft = total_time

    token_phase_time = max(total_time - ttft, 1e-6)
    total_tokens = num_tokens * batch_size
    tpot = token_phase_time / total_tokens
    throughput = total_tokens / token_phase_time

    num_params = sum(p.numel() for p in model.parameters())
    prompt_len = input_ids.shape[1]
    total_tokens_processed = (prompt_len + num_tokens) * batch_size
    total_flops = 2.0 * num_params * total_tokens_processed
    avg_flops_per_token = total_flops / total_tokens_processed

    return {
        "ttft": ttft,
        "tpot": tpot,
        "throughput": throughput,
        "total_time": total_time,
        "total_flops": total_flops,
        "avg_flops_per_token": avg_flops_per_token,
    }


def evaluate_ppl_streaming_kv(
    model,
    tokenizer,
    text: str,
    n_sink: int = 4,
    window_size: int = 256,
    max_tokens: int = 2000,
):
    enc = tokenizer(text, return_tensors="pt")
    input_ids = enc.input_ids[:, :max_tokens].to(next(model.parameters()).device)
    seq_len = input_ids.size(1)
    if seq_len < 2:
        return float("inf")

    stream = StreamingLLM(n_sink=n_sink, window_size=window_size)
    cache = None
    losses = []
    ce = torch.nn.CrossEntropyLoss(reduction="none")

    for pos in range(seq_len - 1):
        cur = input_ids[:, pos:pos + 1]
        with torch.no_grad():
            if cache is None:
                outputs = model(cur, use_cache=True)
            else:
                outputs = model(cur, use_cache=True, past_key_values=cache)
            cache = outputs.past_key_values
            stream.compress_cache(cache)
            logits = outputs.logits[:, -1, :]
            target = input_ids[:, pos + 1]
            loss = ce(logits, target).mean()
            losses.append(loss)

    ppl = torch.exp(torch.stack(losses).mean())
    return ppl.item()


def evaluate_ppl_baseline(
    model,
    tokenizer,
    text: str,
    max_tokens: int = 2000,
):
    enc = tokenizer(text, return_tensors="pt")
    input_ids = enc.input_ids[:, :max_tokens].to(next(model.parameters()).device)
    seq_len = input_ids.size(1)
    if seq_len < 2:
        return float("inf")

    max_length = model.config.max_position_embeddings
    stride = 512
    nlls = []
    prev_end = 0
    ce = torch.nn.CrossEntropyLoss(reduction="none")

    for begin in range(0, seq_len, stride):
        end = min(begin + max_length, seq_len)
        trg_len = end - prev_end
        if trg_len <= 0:
            break
        cur = input_ids[:, begin:end]
        target = cur.clone()
        target[:, :-trg_len] = -100

        with torch.no_grad():
            outputs = model(cur, labels=target)
            loss = outputs.loss

        nlls.append(loss)
        prev_end = end
        if end == seq_len:
            break

    if not nlls:
        return float("inf")
    ppl = torch.exp(torch.stack(nlls).mean())
    return ppl.item()



###  模型加载

In [9]:
model_id = "EleutherAI/pythia-2.8b"
model, tokenizer = load_model_and_tokenizer(model_id)

### 数据集加载

In [6]:
from datasets import load_dataset

def load_long_text_from_dataset(
    dataset_name: str = "wikitext",
    split: str = "test",
    limit_samples: int = 1,
    max_chars: int | None = None,
) -> str:
    if dataset_name == "wikitext":
        ds = load_dataset("wikitext", "wikitext-2-raw-v1", split=split)
        
        # 正确获取文本的方法：
        # 方法1：使用列表推导
        texts = [item["text"] for item in ds.select(range(limit_samples))]
        # 或者方法2：直接切片
        # texts = [ds[i]["text"] for i in range(min(limit_samples, len(ds)))]
        
        text = "\n\n".join(texts)
    
    elif dataset_name == "pg19":
        ds = load_dataset("pg19", split=split, streaming=True)
        sample = next(iter(ds))
        text = sample["text"]
    
    else:
        raise ValueError(f"Unsupported dataset: {dataset_name}")

    if max_chars is not None and len(text) > max_chars:
        text = text[:max_chars]
    
    return text
# 固定一次抽样
text_wiki = load_long_text_from_dataset(
    dataset_name="wikitext",
    split="train",
    limit_samples=100,
    max_chars=100000,
)

text_pg19 = load_long_text_from_dataset(
    dataset_name="pg19",
    split="test",
    limit_samples=1,
    max_chars=100000,
)

print("Wiki text length:", len(text_wiki))
print("PG-19 text length:", len(text_pg19))



Wiki text length: 35786
PG-19 text length: 100000


### 实验主循环

In [34]:
prompt = "在一座海边小城里，工程师正在测试一种新的 KV 缓存压缩算法。"

# 生成效果看看（Streaming）
encoded = tokenizer(prompt, return_tensors="pt")
input_ids = encoded.input_ids
stream = StreamingLLM(n_sink=4, window_size=256)
out_ids = stream.generate(model, input_ids, max_new_tokens=50)
print("StreamingLLM sample:\n", repr(tokenizer.decode(out_ids[0], skip_special_tokens=False)))

with open("text/long_normal.txt", "r", encoding="utf-8") as f:
    text = f.read()

# 1) Baseline PPL
ppl_base = evaluate_ppl_baseline(
    model,
    tokenizer,
    text,
    max_tokens=2000,
)
print("Baseline PPL on custom_complex_dataset:", ppl_base)

# 2) StreamingLLM PPL
ppl_stream = evaluate_ppl_streaming_kv(
    model,
    tokenizer,
    text,
    n_sink=4,
    window_size=256,
    max_tokens=2000,
)
print("StreamingLLM PPL(256) on custom_complex_dataset:", ppl_stream)

# 2) StreamingLLM PPL
ppl_stream1 = evaluate_ppl_streaming_kv(
    model,
    tokenizer,
    text,
    n_sink=4,
    window_size=512,
    max_tokens=2000,
)
print("StreamingLLM PPL(512) on custom_complex_dataset:", ppl_stream1)

# 3) Baseline Speed
speed_base = benchmark_speed_baseline(
    model,
    tokenizer,
    prompt,
    num_tokens=50,
    batch_size=1,
)
print("Baseline speed (tok/s):", speed_base)

# 4) StreamingLLM Speed
speed_stream = benchmark_speed_streaming_kv(
    model,
    tokenizer,
    prompt,
    n_sink=4,
    window_size=256,
    num_tokens=50,
    batch_size=1,
)
print("StreamingLLM(256) KV-level speed (tok/s):", speed_stream)

speed_stream1 = benchmark_speed_streaming_kv(
    model,
    tokenizer,
    prompt,
    n_sink=4,
    window_size=512,
    num_tokens=50,
    batch_size=1,
)
print("StreamingLLM(512) KV-level speed (tok/s):", speed_stream1)

# 简单汇总
print("\n=== Summary ===")
print(f"Baseline   - PPL: {ppl_base:.3f},  Speed: {speed_base:.2f} tok/s")
print(f"Streaming(256)  - PPL: {ppl_stream:.3f},  Speed: {speed_stream:.2f} tok/s")
print(f"Streaming(512)  - PPL: {ppl_stream1:.3f},  Speed: {speed_stream1:.2f} tok/s")

StreamingLLM sample:
 '在一座海边小城里，工程师正在测试一种新的 KV 缓存压缩算法。そこ他们通过维护一个 KV 环形表（key-value table）来结合在一起，并乘以预期的数据量来评估响应'
Baseline PPL on custom_complex_dataset: 11.485358238220215
StreamingLLM PPL(256) on custom_complex_dataset: 247.625


: 

### Wiki/pg19 文本评估

In [10]:
# PPL测试
text_wiki_ppl = text_wiki[:5000]  
ppl_base_wiki = evaluate_ppl_baseline(model, tokenizer, text_wiki_ppl, max_tokens=1000)
print("Baseline PPL on WikiText (1000 tokens):", ppl_base_wiki)

ppl_stream_wiki_256 = evaluate_ppl_streaming_kv(model, tokenizer, text_wiki_ppl, n_sink=8, window_size=256, max_tokens=1000)
print("StreamingLLM(256) PPL on WikiText (1000 tokens):", ppl_stream_wiki_256)

ppl_stream_wiki_512 = evaluate_ppl_streaming_kv(model, tokenizer, text_wiki_ppl, n_sink=8, window_size=1024, max_tokens=1000)
print("StreamingLLM(512) PPL on WikiText (1000 tokens):", ppl_stream_wiki_512)



Baseline PPL on WikiText (1000 tokens): 13.31332778930664
StreamingLLM(256) PPL on WikiText (1000 tokens): 150.125
StreamingLLM(512) PPL on WikiText (1000 tokens): 13.3046875


In [27]:
# 加速测试
prompt_wiki = text_wiki[:2000]
speed_base_wiki = benchmark_speed_baseline(model, tokenizer, prompt_wiki, num_tokens=1000, batch_size=1)
print("Baseline speed on WikiText (500+1000 tokens):", speed_base_wiki)

speed_stream_wiki_256 = benchmark_speed_streaming_kv(model, tokenizer, prompt_wiki, n_sink=8, window_size=256, num_tokens=1000, batch_size=1)
print("StreamingLLM(256) speed on WikiText (500+1000 tokens):", speed_stream_wiki_256)

speed_stream_wiki_512 = benchmark_speed_streaming_kv(model, tokenizer, prompt_wiki, n_sink=8, window_size=512, num_tokens=1000, batch_size=1)
print("StreamingLLM(512) speed on WikiText (500+1000 tokens):", speed_stream_wiki_512)

Baseline speed on WikiText (500+1000 tokens): {'ttft': 0.5776915550231934, 'tpot': 0.047968844890594484, 'throughput': 20.846864298708088, 'total_time': 48.546536445617676, 'total_flops': 8114710999040.0, 'avg_flops_per_token': 5550417920.0}
StreamingLLM(256) speed on WikiText (500+1000 tokens): {'ttft': 0.28908491134643555, 'tpot': 0.08570580220222473, 'throughput': 11.667821481216382, 'total_time': 85.99488711357117, 'total_flops': 8114710999040.0, 'avg_flops_per_token': 5550417920.0}
StreamingLLM(512) speed on WikiText (500+1000 tokens): {'ttft': 0.2142009735107422, 'tpot': 0.059623798847198484, 'throughput': 16.771826340061967, 'total_time': 59.83799982070923, 'total_flops': 8114710999040.0, 'avg_flops_per_token': 5550417920.0}


In [28]:
# PPL测试  
text_pg19_ppl = text_pg19[:5000]  # 2048 tokens
ppl_base_pg19 = evaluate_ppl_baseline(model, tokenizer, text_pg19_ppl, max_tokens=1000)
print("Baseline PPL on PG-19 (1000 tokens):", ppl_base_pg19)

ppl_stream_pg19_256 = evaluate_ppl_streaming_kv(model, tokenizer, text_pg19_ppl, n_sink=8, window_size=256, max_tokens=1000)
print("StreamingLLM(256) PPL on PG-19 (1000 tokens):", ppl_stream_pg19_256)

ppl_stream_pg19_512 = evaluate_ppl_streaming_kv(model, tokenizer, text_pg19_ppl, n_sink=8, window_size=512, max_tokens=1000)
print("StreamingLLM(512) PPL on PG-19 (1000 tokens):", ppl_stream_pg19_512)


Baseline PPL on PG-19 (1000 tokens): 8.540501594543457
StreamingLLM(256) PPL on PG-19 (1000 tokens): 78.5
StreamingLLM(512) PPL on PG-19 (1000 tokens): 35.125


In [29]:
# 加速测试
prompt_pg19 = text_pg19[:2000]
speed_base_pg19 = benchmark_speed_baseline(model, tokenizer, prompt_pg19, num_tokens=1000, batch_size=1)
print("Baseline speed on PG-19 (500+1000 tokens):", speed_base_pg19)

speed_stream_pg19_256 = benchmark_speed_streaming_kv(model, tokenizer, prompt_pg19, n_sink=8, window_size=256, num_tokens=1000, batch_size=1)
print("StreamingLLM(256) speed on PG-19 (500+1000 tokens):", speed_stream_pg19_256)

speed_stream_pg19_512 = benchmark_speed_streaming_kv(model, tokenizer, prompt_pg19, n_sink=8, window_size=512, num_tokens=1000, batch_size=1)
print("StreamingLLM(512) speed on PG-19 (500+1000 tokens):", speed_stream_pg19_512)

Baseline speed on PG-19 (500+1000 tokens): {'ttft': 0.4057044982910156, 'tpot': 0.05524669170379639, 'throughput': 18.100631352940958, 'total_time': 55.6523962020874, 'total_flops': 8397782312960.0, 'avg_flops_per_token': 5550417920.0}
StreamingLLM(256) speed on PG-19 (500+1000 tokens): {'ttft': 0.3122568130493164, 'tpot': 0.08692020392417908, 'throughput': 11.504805037874796, 'total_time': 87.2324607372284, 'total_flops': 8397782312960.0, 'avg_flops_per_token': 5550417920.0}
StreamingLLM(512) speed on PG-19 (500+1000 tokens): {'ttft': 0.2873044013977051, 'tpot': 0.0594079430103302, 'throughput': 16.832765945559068, 'total_time': 59.695247411727905, 'total_flops': 8397782312960.0, 'avg_flops_per_token': 5550417920.0}


In [30]:
print("少token生成实验（token总量不超出pythia2.8B模型的训练长度2048）")
print("\n=== WikiText 汇总 ===")
print(f"Baseline - PPL: {ppl_base_wiki:.3f}, Speed: {speed_base_wiki['throughput']:.2f} tok/s")
print(f"Streaming(256) - PPL: {ppl_stream_wiki_256:.3f}, Speed: {speed_stream_wiki_256['throughput']:.2f} tok/s")
print(f"Streaming(512) - PPL: {ppl_stream_wiki_512:.3f}, Speed: {speed_stream_wiki_512['throughput']:.2f} tok/s")
print("\n=== PG-19 汇总 ===")
print(f"Baseline - PPL: {ppl_base_pg19:.3f}, Speed: {speed_base_pg19['throughput']:.2f} tok/s")
print(f"Streaming(256) - PPL: {ppl_stream_pg19_256:.3f}, Speed: {speed_stream_pg19_256['throughput']:.2f} tok/s")
print(f"Streaming(512) - PPL: {ppl_stream_pg19_512:.3f}, Speed: {speed_stream_pg19_512['throughput']:.2f} tok/s")


少token生成实验（token总量不超出pythia2.8B模型的训练长度2048）

=== WikiText 汇总 ===
Baseline - PPL: 13.313, Speed: 20.85 tok/s
Streaming(256) - PPL: 150.125, Speed: 11.67 tok/s
Streaming(512) - PPL: 42.688, Speed: 16.77 tok/s

=== PG-19 汇总 ===
Baseline - PPL: 8.541, Speed: 18.10 tok/s
Streaming(256) - PPL: 78.500, Speed: 11.50 tok/s
Streaming(512) - PPL: 35.125, Speed: 16.83 tok/s


In [None]:
# PPL测试
text_wiki_ppl = text_wiki[:15000]  # 3000 tokens
ppl_base_wiki_max = evaluate_ppl_baseline(model, tokenizer, text_wiki_ppl, max_tokens=3000)
print("Baseline PPL on WikiText (3000 tokens):", ppl_base_wiki_max)

ppl_stream_wiki_256_max = evaluate_ppl_streaming_kv(model, tokenizer, text_wiki_ppl, n_sink=8, window_size=256, max_tokens=3000)
print("StreamingLLM(256) PPL on WikiText (3000 tokens):", ppl_stream_wiki_256_max)

ppl_stream_wiki_512_max = evaluate_ppl_streaming_kv(model, tokenizer, text_wiki_ppl, n_sink=8, window_size=512, max_tokens=3000)
print("StreamingLLM(512) PPL on WikiText (3000 tokens):", ppl_stream_wiki_512_max)

In [None]:
# 加速测试
prompt_wiki = text_wiki[:5000]

speed_stream_wiki_256_max = benchmark_speed_streaming_kv(model, tokenizer, prompt_wiki, n_sink=8, window_size=256, num_tokens=2000, batch_size=1)
print("StreamingLLM(256) speed on WikiText (1000+2000 tokens):", speed_stream_wiki_256_max)

speed_stream_wiki_512_max = benchmark_speed_streaming_kv(model, tokenizer, prompt_wiki, n_sink=8, window_size=512, num_tokens=2000, batch_size=1)
print("StreamingLLM(512) speed on WikiText (1000+2000 tokens):", speed_stream_wiki_512_max)

speed_base_wiki_max = benchmark_speed_baseline(model, tokenizer, prompt_wiki, num_tokens=2000, batch_size=1)
print("Baseline speed on WikiText (1000+2000 tokens):", speed_base_wiki_max)

In [None]:
# PPL测试  
text_pg19_ppl = text_pg19[:15000]  # 3000 tokens
ppl_base_pg19_max = evaluate_ppl_baseline(model, tokenizer, text_pg19_ppl, max_tokens=3000)
print("Baseline PPL on PG-19 (3000 tokens):", ppl_base_pg19_max)

ppl_stream_pg19_256_max = evaluate_ppl_streaming_kv(model, tokenizer, text_pg19_ppl, n_sink=8, window_size=256, max_tokens=3000)
print("StreamingLLM(256) PPL on PG-19 (3000 tokens):", ppl_stream_pg19_256_max)

ppl_stream_pg19_512_max = evaluate_ppl_streaming_kv(model, tokenizer, text_pg19_ppl, n_sink=8, window_size=512, max_tokens=3000)
print("StreamingLLM(512) PPL on PG-19 (3000 tokens):", ppl_stream_pg19_512_max)


In [None]:
# 加速测试
prompt_pg19 = text_pg19[:5000]
speed_base_pg19_max = benchmark_speed_baseline(model, tokenizer, prompt_pg19, num_tokens=2000, batch_size=1)
print("Baseline speed on PG-19 (1000+2000 tokens):", speed_base_pg19_max)

speed_stream_pg19_256_max = benchmark_speed_streaming_kv(model, tokenizer, prompt_pg19, n_sink=8, window_size=256, num_tokens=2000, batch_size=1)
print("StreamingLLM(256) speed on PG-19 (1000+2000 tokens):", speed_stream_pg19_256_max)

speed_stream_pg19_512_max = benchmark_speed_streaming_kv(model, tokenizer, prompt_pg19, n_sink=8, window_size=512, num_tokens=2000, batch_size=1)
print("StreamingLLM(512) speed on PG-19 (1000+2000 tokens):", speed_stream_pg19_512_max)

In [None]:
print("大量token生成实验（token总量超出pythia2.8B模型的训练长度2048）")
print("\n=== WikiText 汇总 ===")
print(f"Baseline - PPL: {ppl_base_wiki_max:.3f}, Speed: {speed_base_wiki_max['throughput']:.2f} tok/s")
print(f"Streaming(256) - PPL: {ppl_stream_wiki_256_max:.3f}, Speed: {speed_stream_wiki_256_max['throughput']:.2f} tok/s")
print(f"Streaming(512) - PPL: {ppl_stream_wiki_512_max:.3f}, Speed: {speed_stream_wiki_512_max['throughput']:.2f} tok/s")
print("\n=== PG-19 汇总 ===")
print(f"Baseline - PPL: {ppl_base_pg19_max:.3f}, Speed: {speed_base_pg19_max['throughput']:.2f} tok/s")
print(f"Streaming(256) - PPL: {ppl_stream_pg19_256_max:.3f}, Speed: {speed_stream_pg19_256_max['throughput']:.2f} tok/s")
print(f"Streaming(512) - PPL: {ppl_stream_pg19_512_max:.3f}, Speed: {speed_stream_pg19_512_max['throughput']:.2f} tok/s")
