In [8]:
import matplotlib.pyplot as plt
import numpy as np
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

In [4]:
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)



## KV cache

In [None]:
prompt = "The quick brown fox jumped over the"
inputs = tokenizer(prompt, return_tensors="pt")
print(inputs)

def generate_token_with_past(inputs):
    with torch.no_grad(): # 不计算梯度
        outputs = model(**inputs)

    logits = outputs.logits
    last_logits = logits[0, -1, :]
    next_token_id = last_logits.argmax()
    return next_token_id, outputs.past_key_values


def generate(inputs, max_tokens):
    generated_tokens = []
    next_inputs = inputs
    for _ in range(max_tokens):
        next_token_id, past_key_values = \
        generate_token_with_past(next_inputs)
        next_inputs = {
            "input_ids": next_token_id.reshape((1, 1)), # 只传入下一个token的id
            "attention_mask": torch.cat(
                [next_inputs["attention_mask"], torch.tensor([[1]])],
                dim=1
            ),
            "past_key_values": past_key_values,
        }

        next_token = tokenizer.decode(next_token_id)
        generated_tokens.append(next_token)
    return "".join(generated_tokens)


tokens = generate(inputs, max_tokens=10)
print(tokens)

{'input_ids': tensor([[  464,  2068,  7586, 21831, 11687,   625,   262]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]])}
 fence and ran to the other side of the fence


## Batch decode

In [23]:
# Define PAD Token = EOS Token = 50256
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
# pad on the left so we can append new tokens on the right
tokenizer.padding_side = "left"
tokenizer.truncation_side = "left"

In [41]:
prompts = [
    "The quick brown fox jumped over the",
    "The rain in Spain falls",
    "What comes up must",
]
inputs = tokenizer(prompts, padding=True, return_tensors="pt")

In [32]:
def generate_batch_tokens_with_past(inputs):
    with torch.no_grad():
        outputs = model(**inputs)

    logits = outputs.logits
    last_logits = logits[:, -1, :]
    next_token_ids = last_logits.argmax(dim=1)
    return next_token_ids, outputs.past_key_values

In [39]:
def generate_batch(inputs, max_tokens):
    # 给定bs大小的空列表，存储输出
    generated_tokens = [
        [] for _ in range(inputs["input_ids"].shape[0])
    ]

    attention_mask = inputs["attention_mask"]
    position_ids = attention_mask.long().cumsum(-1) - 1
    position_ids.masked_fill_(attention_mask == 0, 1)

    next_inputs = {
        "position_ids": position_ids,
        **inputs
    }

    for _ in range(max_tokens):
        next_token_ids, past_key_values = \
            generate_batch_tokens_with_past(next_inputs)
        # print("next_position_ids:", next_inputs["position_ids"])
        # print("position_ids:", next_inputs["position_ids"][:, -1].unsqueeze(-1) + 1)
        next_inputs = {
            "input_ids": next_token_ids.reshape((-1, 1)), # 只传入下一个token的id
            "position_ids": next_inputs["position_ids"][:, -1].unsqueeze(-1) + 1, # 只传入下一个token的位置
            "attention_mask": torch.cat([
                next_inputs["attention_mask"],
                torch.ones((next_token_ids.shape[0], 1)),
            ], dim=1),
            "past_key_values": past_key_values,
        }

        next_tokens = tokenizer.batch_decode(next_token_ids)
        for i, token in enumerate(next_tokens):
            generated_tokens[i].append(token)
    return ["".join(tokens) for tokens in generated_tokens]

In [42]:
generated_tokens = generate_batch(inputs, max_tokens=10)

In [43]:
for prompt, generated in zip(prompts, generated_tokens):
    print(prompt, f"\x1b[31m{generated}\x1b[0m\n") # 红色字体显示生成的文本

The quick brown fox jumped over the [31m fence and ran to the other side of the fence[0m

The rain in Spain falls [31m on the first day of the month, and the[0m

What comes up must [31m be a good idea.

"I think[0m



## 吞吐量与延迟
- 探索批处理对延迟的影响（生成每个令牌需要多长时间）。 
- 观察吞吐量和延迟之间存在的基本权衡。

In [45]:
# constants
max_tokens = 10

# observations
durations = []
throughputs = []
latencies = []

batch_sizes = [2**p for p in range(8)] # [1, 2, 4, 8, 16, 32, 64, 128]
for batch_size in batch_sizes:
    print(f"bs= {batch_size}")

    # 记录不同batch size下的生成时间和吞吐量
    t0 = time.time()
    batch_prompts = [
        prompts[i % len(prompts)] for i in range(batch_size) # 重复使用prompts
    ]
    inputs = tokenizer(
        batch_prompts, padding=True, return_tensors="pt"
    )
    generated_tokens = generate_batch(inputs, max_tokens=max_tokens)
    duration_s = time.time() - t0

    ntokens = batch_size * max_tokens
    throughput = ntokens / duration_s
    avg_latency = duration_s / max_tokens
    print("duration", duration_s)
    print("throughput", throughput)
    print("avg latency", avg_latency)    
    print()

    durations.append(duration_s)
    throughputs.append(throughput)
    latencies.append(avg_latency)

bs= 1
duration 0.6177549362182617
throughput 16.18764887775313
avg latency 0.061775493621826175

bs= 2
duration 0.9665660858154297
throughput 20.691808137595977
avg latency 0.09665660858154297

bs= 4
duration 0.851855993270874
throughput 46.95629345332405
avg latency 0.0851855993270874

bs= 8
duration 1.0085225105285645
throughput 79.32396070968429
avg latency 0.10085225105285645

bs= 16
duration 1.458956241607666
throughput 109.66744268059156
avg latency 0.1458956241607666

bs= 32
duration 1.9382057189941406
throughput 165.10115353806125
avg latency 0.19382057189941407

bs= 64
duration 3.983337640762329
throughput 160.66928232514007
avg latency 0.3983337640762329

bs= 128
duration 7.6774022579193115
throughput 166.72306035282548
avg latency 0.7677402257919311



In [None]:
def render_plot(x, y1, y2, x_label, y1_label, y2_label):
    # Create a figure and a set of subplots
    fig, ax1 = plt.subplots()

    # Plot the first line (throughput)
    color = 'tab:red'
    ax1.set_xlabel(x_label)
    ax1.set_ylabel(y1_label, color=color)
    ax1.plot(x, y1, color=color)
    ax1.tick_params(axis='y', labelcolor=color)

    # Set the x-axis to be log-scaled
    ax1.set_xscale('log', base=2)

    # Instantiate a second axes that shares the same x-axis
    ax2 = ax1.twinx()  
    color = 'tab:blue'
    ax2.set_ylabel(y2_label, color=color)  # we already handled the x-label with ax1
    ax2.plot(x, y2, color=color)
    ax2.tick_params(axis='y', labelcolor=color)

    plt.show()

render_plot(
    batch_sizes,
    throughputs,
    latencies,
    "Batch Size",
    "Throughput",
    "Latency"
)