Shows how one can generate text given a prompt and some hyperparameters, using either minGPT or huggingface/transformers

In [1]:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel,GPT2Config
from src.model import GPT
from src.utils import set_seed
from src.bpe import BPETokenizer
set_seed(3407)

In [2]:
use_mingpt = True # use minGPT or huggingface/transformers model?
model_type = 'gpt2'
device = 'cuda'

In [3]:
# if use_mingpt:
#     model = GPT.from_pretrained(model_type)
# else:
#     model = GPT2LMHeadModel.from_pretrained(model_type)
#     model.config.pad_token_id = model.config.eos_token_id # suppress a warning

config = GPT2Config(
    vocab_size=50257,
    n_layer=4,
    n_head=4,
    n_embd=256
)
model = GPT2LMHeadModel(config)  # 随机初始化

# ship model to device and set to eval mode
model.to(device)
model.eval();

In [9]:
def generate(prompt='', num_samples=10, steps=20, do_sample=True):
    if use_mingpt:
        tokenizer = BPETokenizer()
        if prompt == '':
            x = torch.tensor([[tokenizer.encoder.encoder['<|endoftext|>']]], dtype=torch.long).to(device)
        else:
            x = tokenizer(prompt).to(device)

        # 扩展 batch
        x = x.expand(num_samples, -1)

        # minGPT 不支持 attention_mask 和 pad_token_id
        y = model.generate(x, max_new_tokens=steps, do_sample=do_sample, top_k=40)

        # decode
        for i in range(num_samples):
            out = tokenizer.decode(y[i].cpu().squeeze())
            print('-' * 80)
            print(out)

    else:
        from transformers import GPT2Tokenizer
        tokenizer = GPT2Tokenizer.from_pretrained(model_type)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        if prompt == '':
            prompt = '<|endoftext|>'
        encoded_input = tokenizer(prompt, return_tensors='pt', padding=True).to(device)
        x = encoded_input['input_ids']

        x = x.expand(num_samples, -1)

        y = model.generate(
            x,
            max_new_tokens=steps,
            do_sample=do_sample,
            top_k=40,
            pad_token_id=tokenizer.eos_token_id,
            attention_mask=(x != tokenizer.pad_token_id)
        )

        for i in range(num_samples):
            out = tokenizer.decode(y[i].cpu(), skip_special_tokens=True)
            print('-' * 80)
            print(out)


In [10]:
generate(prompt='zhixuduan is the', num_samples=10, steps=20)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


--------------------------------------------------------------------------------
zhixuduan is the Das Dewti DerSmall lenrast progressed Larry corrective Pamー�history robberies bunker horrendplom106gy delinqu
--------------------------------------------------------------------------------
zhixuduan is theurious Ahmed Staten Opp indisp democratically Bonnie cliffsriers Enlightenmentproductive disob Andrews quiteophone flamesカ impairment planners sed
--------------------------------------------------------------------------------
zhixuduan is theん smugglersbub rockedイ Der Labrador efforts Der cliffs Grill EDITIONalianocking hamstringdocspkg diagnosed Mods Hut
--------------------------------------------------------------------------------
zhixuduan is thearn Ms graffitiollow Nick correlation 281 Everton Everton Everton Dot Astros Sz limb Viz teamnesssilver moved Ghostbusters
--------------------------------------------------------------------------------
zhixuduan is thegenic]( assessedde