Decoding Strategies - Temperature Scaling and TopK Sampling

<h3>Temperature Scaling</h3>

- Previously, we always sampled the token with the highest probability as the next token using `torch.argmax`
- To add variety, we can sample the next token using The `torch.multinomial(probs, num_samples=1)`, sampling from a probability distribution
- Here, each index's chance of being picked corresponds to its probability in the input tensor

We can control the distribution and selection process via a concept called temperature scaling
- "Temperature scaling" is just a fancy word for dividing the logits by a number greater than 0
- Temperatures greater than 1 will result in more uniformly distributed token probabilities after applying the softmax
- Temperatures smaller than 1 will result in more confident (sharper or more peaky) distributions after applying the softmax

<pre>def softmax_with_temperature(logits, temperature):
    scaled_logits = logits / temperature
    return torch.softmax(scaled_logits, dim=0)
</pre>

lower temp like 0.1 gives sharp distributions, which can resemble argmax only, while high temp creates uniform distributions, increasing diversity, although that also increases the prob of non-sensical outputs, and to reduce that we can apply top-k sampling as well.


<h3>Top K Scaling</h3>

To be able to use higher temperatures to increase output diversity and to reduce the probability of nonsensical sentences, we can restrict the sampled tokens to the top-k most likely tokens:

![topk](../images/topk.png)




Creating a new generate_text_simple function with these 2 decoding strategies

In [16]:
import torch

def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None):

    for _ in range(max_new_tokens):
        # limiting num_tokens in a sequence to our model's context size
        idx_context_capped = idx[:, -context_size :]
        with torch.no_grad():
            logits = model(idx_context_capped)
        logits = logits[:, -1, :] # picking up the possibilites-probabilities for the last token in this sequence so to get the next word

        # top-k sampling
        if top_k is not None:
            # keep only top-k values
            top_logits, _ = torch.topk(logits, top_k)
            min_value = top_logits[:, -1]
            logits = torch.where(logits<min_value, torch.tensor(float("-inf")).to(logits.device), logits) # put -inf at places we want 0 prob after softmax to get more sensical outputs when we apply temp scaling

        # temperature scaling
        if temperature > 0.0:
            logits = logits/temperature
            logits = logits - logits.max(dim=-1, keepdim=True).values
            probs = torch.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1) #sample from the distribution
        else:
            idx_next = torch.argmax(logits, dim=-1, keepdim=True)

        if idx_next == eos_id:
            break # Stop generating early if end-of-sequence token is encountered and eos_id is specified
        
        # append the new token to the existing sequence
        idx = torch.cat((idx, idx_next), dim=1)
    
    return idx



In [None]:
torch.manual_seed(123)
from modules import GPTModel, GPT_CONFIG_124M, text_to_token_ids, token_ids_to_text
import tiktoken
tokenizer = tiktoken.get_encoding("gpt2")
inference_device = "cpu"

model = GPTModel(GPT_CONFIG_124M)

token_ids = generate(
    model=model,
    idx=text_to_token_ids("Every effort moves you", tokenizer).to(inference_device),
    max_new_tokens=15,
    context_size=GPT_CONFIG_124M["context_length"],
    top_k=25,
    temperature=1.4
)

print("Output text:\n", token_ids_to_text(token_ids, tokenizer))

Output text:
 Every effort moves you Samoa Cubsdebug Saga reimb ShannonBow eight disable elemental Added scrapsassium Newsashes


**Saving the model and optimizer weights**

In [23]:
model = GPTModel(GPT_CONFIG_124M)

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    # Use PyTorch 2.9 or newer for stable mps results
    major, minor = map(int, torch.__version__.split(".")[:2])
    if (major, minor) >= (2, 9):
        device = torch.device("mps")
else:
    device = torch.device("cpu")

print("Device:", device)

Device: mps


In [None]:
torch.save({
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    },
    "model_and_optimizer.pth"
)

In [None]:
checkpoint = torch.load("model_and_optimizer.pth", weights_only=True)

model = GPTModel(GPT_CONFIG_124M)
model.load_state_dict(checkpoint["model_state_dict"])

optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005, weight_decay=0.1)
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
model.train();