Shows how one can generate text given a prompt and some hyperparameters, using either minGPT or huggingface/transformers

In [1]:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from mingpt.model import GPT
from mingpt.utils import set_seed
from mingpt.bpe import BPETokenizer
set_seed(3407)

In [2]:
use_mingpt = True # use minGPT or huggingface/transformers model?
model_type = 'gpt2-xl'
device = 'cuda'

In [3]:
if use_mingpt:
    model = GPT.from_pretrained(model_type)
else:
    model = GPT2LMHeadModel.from_pretrained(model_type)
    model.config.pad_token_id = model.config.eos_token_id # suppress a warning

# ship model to device and set to eval mode
model.to(device)
model.eval()

number of parameters: 1557.61M


In [4]:
def generate(prompt='', num_samples=10, steps=20, do_sample=True):
    """
    根据给定的提示生成文本样本。

    Args:
        prompt: 文本提示，用于开始生成。
        num_samples: 要生成的样本数量。
        steps: 生成文本的步数（生成的token数量）。
        do_sample: 是否在生成文本时进行采样。

    Returns:
        None，将生成的样本打印到控制台。
    """

    # 将输入提示标记化为整数输入序列
    if use_mingpt:  # 如果使用mingpt模型
        tokenizer = BPETokenizer() # 初始化BPE分词器
        if prompt == '':
            # 如果提示为空，则创建无条件样本
            # 手动创建一个只包含特殊标记 <|endoftext|> 的张量
            x = torch.tensor([[tokenizer.encoder.encoder['<|endoftext|>']]], dtype=torch.long)
        else:
            x = tokenizer(prompt).to(device) # 将提示分词并转换为张量
    else: # 如果使用Hugging Face Transformers模型
        tokenizer = GPT2Tokenizer.from_pretrained(model_type)  # 初始化GPT-2分词器
        if prompt == '':
            # 如果提示为空，则创建无条件样本
            prompt = '<|endoftext|>'  # 使用特殊标记作为提示
        encoded_input = tokenizer(prompt, return_tensors='pt').to(device)  # 将提示分词并转换为张量
        x = encoded_input['input_ids'] # 获取输入ID

    # 我们将批量处理所有所需的样本数量，因此扩展批次维度
    x = x.expand(num_samples, -1) 

    # 将模型运行 `steps` 次以获取样本，以批处理方式进行
    y = model.generate(x, max_new_tokens=steps, do_sample=do_sample, top_k=40) 

    # 遍历生成的样本
    for i in range(num_samples):
        out = tokenizer.decode(y[i].cpu().squeeze()) # 将生成的token解码为文本
        print('-'*80) # 打印分隔符
        print(out) # 打印生成的文本

In [5]:
generate(prompt='Andrej Karpathy, the', num_samples=10, steps=20)

--------------------------------------------------------------------------------
Andrej Karpathy, the chief of the criminal investigation department, said during a news conference, "We still have a lot of
--------------------------------------------------------------------------------
Andrej Karpathy, the man whom most of America believes is the architect of the current financial crisis. He runs the National Council
--------------------------------------------------------------------------------
Andrej Karpathy, the head of the Department for Regional Reform of Bulgaria and an MP in the centre-right GERB party
--------------------------------------------------------------------------------
Andrej Karpathy, the former head of the World Bank's IMF department, who worked closely with the IMF. The IMF had
--------------------------------------------------------------------------------
Andrej Karpathy, the vice president for innovation and research at Citi who oversaw the team's work to mak