# Chapter 5: Text Generation

### 1. The Challenge with Generating Coherent Text

In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
device

'cuda'

In [4]:
model_name = "gpt2-xl"

In [5]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [6]:
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

In [8]:
input_txt = "persistence is all you need"

In [9]:
input_ids = tokenizer(input_txt, return_tensors="pt")["input_ids"].to(device)

In [10]:
iterations = []

In [17]:
n_steps = 8
choices_per_step = 5

In [18]:
input_ids

tensor([[19276, 13274,   318,   477,   345,   761]], device='cuda:0')

In [19]:
input_ids[0]

tensor([19276, 13274,   318,   477,   345,   761], device='cuda:0')

In [13]:
with torch.no_grad():
    for _ in range(n_steps):
        iteration = dict()
        iteration["input"] = tokenizer.decode(input_ids[0])
        output = model(input_ids=input_ids)
        
        next_token_logits = output.logits[0, -1, :]
        next_token_probs = torch.softmax(next_token_logits, dim=-1)
            
        # sorted_ids = torch.argsort(next_token_probs, dim=-1, desc)

In [20]:
tokenizer.decode(input_ids[0])

'persistence is all you need'

In [59]:
len(input_ids[0])

6

In [21]:
output = model(input_ids=input_ids)

In [24]:
output.keys()

odict_keys(['logits', 'past_key_values'])

In [29]:
output.logits[0]

tensor([[ 1.8732,  3.6249,  0.2828,  ..., -6.4631, -2.5632,  1.3299],
        [ 2.4064,  6.1909,  1.8162,  ..., -5.2652, -4.9478,  2.1214],
        [ 0.5427,  1.6651, -2.6166,  ..., -7.3001, -4.4530, -0.3493],
        [ 3.0592,  1.4659, -2.7849,  ..., -6.8109, -5.2147, -0.3743],
        [ 3.4758,  2.9291, -0.9446,  ..., -5.3347, -2.7920,  0.8841],
        [ 7.5406,  5.7096,  0.6126,  ..., -7.7698, -7.9736,  4.4805]],
       device='cuda:0', grad_fn=<SelectBackward0>)

In [55]:
len(output.logits[0])

6

In [33]:
output.logits[0][-1][:]

tensor([ 7.5406,  5.7096,  0.6126,  ..., -7.7698, -7.9736,  4.4805],
       device='cuda:0', grad_fn=<SliceBackward0>)

In [34]:
output.logits[0, -1, :]

tensor([ 7.5406,  5.7096,  0.6126,  ..., -7.7698, -7.9736,  4.4805],
       device='cuda:0', grad_fn=<SliceBackward0>)

In [35]:
next_token_logits = output.logits[0, -1, :]

In [38]:
torch.softmax(next_token_logits, dim=-1)

tensor([2.1147e-02, 3.3888e-03, 2.0722e-05,  ..., 4.7428e-09, 3.8681e-09,
        9.9148e-04], device='cuda:0', grad_fn=<SoftmaxBackward0>)

In [40]:
input_txt2 = "persistence is all you need tores the five most probable tokens"

In [49]:
len(input_txt2.split())

11

In [42]:
input_ids2 = tokenizer(input_txt2, return_tensors="pt")["input_ids"].to(device)

In [47]:
len(input_ids2[0])

13

In [50]:
output2 = model(input_ids=input_ids2)

In [54]:
len(output2.logits[0])

13

# Flashcards

### 2. Greedy Search Decoding

##### Example 1

In [7]:
text = "The world is going to"

In [8]:
input_ids = tokenizer(text, return_tensors="pt")["input_ids"].to(device)

In [9]:
input_ids

tensor([[ 464,  995,  318, 1016,  284]], device='cuda:0')

In [10]:
type(model)

transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel

Given a tokenized text `input_ids`, generate a prediction for the next word in the sequence and convert the predicted tokens for the first word to their corresponding ids

**Hint**: `next_token_logits = output.logits[0, -1, :]`

In [11]:
output = model(input_ids=input_ids)

In [40]:
output.logits.shape

torch.Size([1, 5, 50257])

In [41]:
output.logits

tensor([[[-1.2333, -0.1872, -3.2762,  ..., -2.9148, -5.0910, -1.0786],
         [ 2.4011,  3.5135, -1.9742,  ..., -7.9800, -1.3352,  0.3985],
         [ 0.3600,  0.7970, -3.2943,  ..., -5.8384, -3.4494, -0.4427],
         [ 1.2610,  2.0624, -2.2482,  ..., -6.9114, -8.1211, -1.2510],
         [ 0.2554,  0.5884, -2.6440,  ..., -6.0750, -7.4630, -0.4967]]],
       device='cuda:0', grad_fn=<UnsafeViewBackward0>)

Given `output` is the output of a pre-trained `gpt2-xl` model generated using `transformers`

Extract the logit of the 5th vocabulary and explain

In [42]:
output.logits[0, -1, :].shape

torch.Size([50257])

In [43]:
logit = output.logits[0, -1, :][4]

**Explain**

`output.logits[0, -1, :]`
- Returns the logits for all tokens in the vocabulary for the last timestep of the sequence
- This is because `output.logits` has a shape of `[batch_size, sequence_length, vocabulary_size]`, and `[0, -1, :]` indexes the logits for the first item in the batch and the last timestep of the sequence

`output.logits[0, -1, :][4]`
- Returns the logit for the 5th token in the vocabulary
- This is because `output.logits[0, -1, :]` is a 1-dimensional tensor with a size equal to the `vocab_size`, and `[4]` indexes the 5th element in that tensor.

In [44]:
logit

tensor(-2.4703, device='cuda:0', grad_fn=<SelectBackward0>)

##### Example 2 

In [None]:
text = "The world is going to"

In [None]:
input_ids = tokenizer(text, return_tensors="pt")["input_ids"].to(device)

In [9]:
input_ids

tensor([[ 464,  995,  318, 1016,  284]], device='cuda:0')

In [10]:
type(model)

transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel

Given a tokenized text `input_ids`, generate a prediction for the next word in the sequence and convert the predicted tokens for the first word to their corresponding ids

**Hint**: `next_token_logits = output.logits[0, -1, :]`

In [11]:
output = model(input_ids=input_ids)

Retrieve the `ids` of the 5th word with the highest probability from `output`, which is the output of a pre-trained `gpt2-xl` model using `transformers`

In [46]:
output.logits.shape

torch.Size([1, 5, 50257])

In [48]:
output.logits

tensor([[[-1.2333, -0.1872, -3.2762,  ..., -2.9148, -5.0910, -1.0786],
         [ 2.4011,  3.5135, -1.9742,  ..., -7.9800, -1.3352,  0.3985],
         [ 0.3600,  0.7970, -3.2943,  ..., -5.8384, -3.4494, -0.4427],
         [ 1.2610,  2.0624, -2.2482,  ..., -6.9114, -8.1211, -1.2510],
         [ 0.2554,  0.5884, -2.6440,  ..., -6.0750, -7.4630, -0.4967]]],
       device='cuda:0', grad_fn=<UnsafeViewBackward0>)

In [49]:
next_token_logits = output.logits[0, -1, :]

In [49]:
next_token_logits

NameError: name 'next_token_logits' is not defined

In [51]:
len(next_token_logits)

50257

In [52]:
next_token_probs = torch.softmax(next_token_logits, dim=-1)

In [53]:
sorted_ids = torch.argsort(next_token_probs, dim=-1, descending=True)

In [54]:
sorted_ids[:5]

tensor([5968,  307,  886, 1487,  423], device='cuda:0')

##### Example 3

In [42]:
input_ids

tensor([[ 464,  995,  318, 1016,  284]], device='cuda:0')

In [43]:
type(model), type(tokenizer)

(transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel,
 transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast)

Predict the next `10` words in the sequence given the tokenized text `input_ids` of sequence "The world is going to" using the `model`

**Hint** Convert `output` to text using `tokenizer`

In [44]:
max_new_tokens = 20

In [45]:
output = model.generate(input_ids, max_new_tokens=max_new_tokens)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


In [46]:
output

tensor([[  464,   995,   318,  1016,   284,  5968,   287,   257,  1021,    65,
         11715,    13,   198,   198,   464,   995,   318,  1016,   284,  5968,
           287,   257,  1021,    65, 11715]], device='cuda:0')

In [47]:
decoded_text = tokenizer.decode(output[0])

In [48]:
decoded_text

'The world is going to hell in a handbasket.\n\nThe world is going to hell in a handbasket'

### 3. Beam Search Decoding

##### Example 1

In [137]:
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

In [162]:
indices = torch.tensor([[0], [1], [2]])

In [167]:
x

tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])

In [168]:
indices

tensor([[0],
        [1],
        [2]])

Write a line of code using `torch.gather` that selects elements from the second dimension of tensor `x` based on the values in tensor `indices`

In [169]:
selected_elements = torch.gather(x, dim=1, index=indices)

**Explain**

- `torch.gather` is a function that allows you to select specific elements from a tensor along a given dimension

- `dim=1`: means that we are indexing along the second dimension of `x`. In this case, `x` has 2 dimensions and the size of each dimension is 3. The first dimension corresponds to the number of rows in the tensor and the second dimension corresponds to the number of columns

- The `index` parameter specifies which indices to select. In this case, `indices` contains the values `0`, `1`, and `2`, so `torch.gather` will select the elements at indices `0`, `1`, and `2` along the second dimension of `x`s

In [170]:
selected_elements

tensor([[1],
        [5],
        [9]])

##### Example 2

In [56]:
import torch.nn.functional as F

In [65]:
output.logits

tensor([[[-1.2333, -0.1872, -3.2762,  ..., -2.9148, -5.0910, -1.0786],
         [ 2.4011,  3.5135, -1.9742,  ..., -7.9800, -1.3352,  0.3985],
         [ 0.3600,  0.7970, -3.2943,  ..., -5.8384, -3.4494, -0.4427],
         [ 1.2610,  2.0624, -2.2482,  ..., -6.9114, -8.1211, -1.2510],
         [ 0.2554,  0.5884, -2.6440,  ..., -6.0750, -7.4630, -0.4967]]],
       device='cuda:0', grad_fn=<UnsafeViewBackward0>)

Explain
- `[1]`: First applying the softmax function to the logits tensor, which converts the logits into probabilities that sum to 1. The softmax is taken along the last dimension of the logits tensor, which is specified by the dim argument
- `[2]`:

In [73]:
def log_probs_from_logits(logits, labels=None):
    indices = labels.unsqueeze(2
                              )
    probs = F.softmax(logits, dim=-1) # [1]
    log_probs = probs.log() # [2]
    
    log_prob_labels = torch.gather(lop_probs, indices).squeeze(-1)
    return log_probs

In [74]:
log_probs_from_logits(output.logits)

tensor([[[-11.1238, -10.0776, -13.1667,  ..., -12.8053, -14.9814, -10.9691],
         [ -8.3411,  -7.2288, -12.7165,  ..., -18.7223, -12.0775, -10.3438],
         [-11.1479, -10.7109, -14.8022,  ..., -17.3463, -14.9573, -11.9506],
         [-11.4963, -10.6949, -15.0055,  ..., -19.6687, -20.8785, -14.0083],
         [-11.5677, -11.2347, -14.4671,  ..., -17.8981, -19.2861, -12.3198]]],
       device='cuda:0', grad_fn=<LogBackward0>)