In [1]:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from torch.nn import functional as F

max_length = 30 # maximal length of each return sqequence
num_return_sequences = 5 # maximal number of return sequences

# Load pre-trained model and tokenizer
# gpt2, gpt2-medium, gpt2-large, and gpt2-xl
model = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

model.eval()
model.to('cuda')
torch.manual_seed(43)
torch.cuda.manual_seed(43)

input_text = "What is the full name of USA?"
tokens = tokenizer.encode(input_text, return_tensors='pt').to('cuda')
# tokens = [15496, 11, 314, 1101, 257, 3303, 2746, 11]  # "Hello, I'm a language model,"
# tokens = torch.tensor(tokens, dtype=torch.long)  # (8,)
tokens = tokens.repeat(num_return_sequences, 1)  # Repeat tokens for the number of sequences
x = tokens.to('cuda')

# Generate! Right now x is (B, T) where B = 5, T = 8
while x.size(1) < max_length:
    with torch.no_grad():
        logits = model(x).logits  # Forward the model to get the logits (B, T, vocab_size)
        logits = logits[:, -1, :]  # Take the logits at the last position (B, vocab_size)
        probs = F.softmax(logits, dim=-1)  # Get the probabilities
        topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)  # Top-k sampling of 50 (B, 50)
        ix = torch.multinomial(topk_probs, 1)  # Select a token from the top-k probabilities (B, 1)
        xcol = torch.gather(topk_indices, 1, ix)  # Gather the corresponding indices (B, 1)
        x = torch.cat((x, xcol), dim=1)  # Append to the sequence

# Print the generated text
for i in range(num_return_sequences):
    tokens = x[i, :max_length].tolist()
    decoded = tokenizer.decode(tokens)
    print(f"Generated Sequence {i + 1}: {decoded}")
    print(f"{tokens}")



Generated Sequence 1: What is the full name of USA?

What is your background?

What is your job (doors, elevators, elevator)?
[2061, 318, 262, 1336, 1438, 286, 4916, 30, 198, 198, 2061, 318, 534, 4469, 30, 198, 198, 2061, 318, 534, 1693, 357, 4598, 669, 11, 7662, 2024, 11, 20932, 19427]
Generated Sequence 2: What is the full name of USA? USA = USA

If we didn't already know our name, just what country was this name, or
[2061, 318, 262, 1336, 1438, 286, 4916, 30, 4916, 796, 4916, 198, 198, 1532, 356, 1422, 470, 1541, 760, 674, 1438, 11, 655, 644, 1499, 373, 428, 1438, 11, 393]
Generated Sequence 3: What is the full name of USA? USA is US/Toronto - (9) 887-2220, USA is Toronto - (9)
[2061, 318, 262, 1336, 1438, 286, 4916, 30, 4916, 318, 1294, 14, 31359, 532, 357, 24, 8, 807, 5774, 12, 1828, 1238, 11, 4916, 318, 6586, 532, 357, 24, 8]
Generated Sequence 4: What is the full name of USA?

USA is the name given to all Canadian citizens living in Canada. It makes sense as the nation's
[2061, 