Largely follows the tutorial from https://github.com/karpathy/build-nanogpt/

In [1]:
BATCH_SIZE = 5
MAX_LENGTH = 30
SEED = 0

In [2]:
import torch
from chesslm.gpt import GPT

device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = GPT.from_pretrained("gpt2")
model.eval()
model = model.to(device)

loading weights from pretrained gpt: gpt2


In [3]:
import tiktoken

enc = tiktoken.get_encoding("gpt2")
tokens = enc.encode("Hello, I'm a language model,")
tokens = torch.tensor(tokens, dtype=torch.long)
tokens = tokens.unsqueeze(0).repeat(BATCH_SIZE, 1)
x = tokens.to(device)

In [4]:
# generate! right now x is (B, T) where B = 5, T = 8
import torch.nn.functional as F
from tqdm.notebook import trange

torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

for _ in trange(MAX_LENGTH - x.size(1), desc="Generating tokens"):
    # forward the model to get the logits
    with torch.no_grad():
        logits = model(x) # (B, T, vocab_size)

    # take the logits at the last position
    logits = logits[:, -1, :] # (B, vocab_size)

    # get the probabilities
    probs = F.softmax(logits, dim=1)

    # do top-k sampling of 50 (huggingface pipeline default)
    # topk_probs here becomes (5, 50), topk_indices is (5, 50)
    topk_probs, topk_indices = torch.topk(probs, 50)

    # select a token from the top-k probabilities
    ix = torch.multinomial(topk_probs, 1) # (B, 1)

    # gather the corresponding indices
    xcol = torch.gather(topk_indices, -1, ix) # (B, 1)

    # append to the sequence
    x = torch.cat((x, xcol), dim=1)

Generating tokens:   0%|          | 0/22 [00:00<?, ?it/s]

In [5]:
# print the generated text
for i in range(BATCH_SIZE):
    tokens = x[i, :MAX_LENGTH].tolist()
    decoded = enc.decode(tokens)
    display(decoded.strip())

"Hello, I'm a language model, and we're a model, we're a language the language, language the, language the, language the,"

"Hello, I'm a language model, we have a model we want to a language to, we want a model we, we have a language,"

"Hello, I'm a language model, I'm language I believe it an interpreter, interpreters I, interpreter a language I, a interpreter a,"

"Hello, I'm a language model, and a thing a thing a model and a and model a and model an object a and object an object,"

"Hello, I'm a language model, I think a language I think a language there a language I think a language a language a, some language a"