# 5.2 Training an LLM

​ In this section, we finally implemented the code for pre-training LLM, our GPTModel. To this end, we focused on a simple training loop, as shown in Figure 5.11, to keep the code concise and easy to read. However, interested readers can learn more advanced techniques, including learning rate warm-up, cosine annealing, and gradient clipping, in Appendix D, Adding Bells and Whistles to the Training Loop.

Figure 5.11 A typical training loop for training a deep neural network in PyTorch consists of several steps, iterating over batches in the training set for multiple epochs. In each loop, we compute the loss for each training set batch to determine the loss gradient, which we use to update the model weights in order to minimize the training set loss.

![image-20240422143154243](../img/fig-5-11.png)

The flowchart in Figure 5.11 describes a typical PyTorch neural network training workflow, which we use to train the LLM. It outlines eight steps, starting with iterating through each epoch, processing the batch, resetting and computing gradients, updating the weights, and finally monitoring steps such as printing the loss and generating text samples. If you are relatively new to training deep neural networks with PyTorch and are not familiar with any of these steps, consider reading sections A.5 to A.8 in Appendix A, Introduction to PyTorch.

In the code, we can implement this training process through the following train_model_simple function:

**Listing 5.3 Main functions of pre-trained LLM**

In [None]:
def train_model_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
eval_freq, eval_iter, start_context):
    
	train_losses, val_losses, track_tokens_seen = [], [], [] #A
	tokens_seen, global_step = 0, -1
	for epoch in range(num_epochs): #B
        model.train()
        for input_batch, target_batch in train_loader:
            optimizer.zero_grad() #C
            loss = calc_loss_batch(input_batch, target_batch, model, device)
            loss.backward() #D
            optimizer.step() #E
            tokens_seen += input_batch.numel()
            global_step += 1
            if global_step % eval_freq == 0: #F
                train_loss, val_loss = evaluate_model(
                model, train_loader, val_loader, device, eval_iter)
                train_losses.append(train_loss)
                val_losses.append(val_loss)
                track_tokens_seen.append(tokens_seen)
                print(f"Ep {epoch+1} (Step {global_step:06d}): "
                f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")
        generate_and_print_sample( #G
        	model, train_loader.dataset.tokenizer, device, start_context
        )
	return train_losses, val_losses, track_tokens_seen

​ Note that the train_model_simple function we just created uses two functions that have not yet been defined: evaluate_model and generate_and_print_sample.

​ The evaluate_model function corresponds to step 7 in Figure 5.11. It prints the training and validation set losses after each model update so that we can evaluate whether training has improved the model.

​ More specifically, the evaluate_model function computes the losses for the training and validation sets while ensuring that the model is in evaluation mode, disabling gradient tracking and dropout when computing the losses for the training and validation sets:

In [None]:
def evaluate_model(model, train_loader, val_loader, device, eval_iter):
    model.eval() #A
    with torch.no_grad(): #B
        train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)
        val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
    model.train()
    return train_loss, val_loss

​ Similar to evaluate_model, the generate_and_print_sample function is a convenience function that we use to track whether the model is improving during training. Specifically, the generate_and_print_sample function takes as input a text snippet (start_context), converts it to a token ID, and feeds it to the LLM to generate a text sample using the generate_text_simple function we used earlier:

In [None]:
def generate_and_print_sample(model, tokenizer, device, start_context):
    model.eval()
    context_size = model.pos_emb.weight.shape[0]
    encoded = text_to_token_ids(start_context, tokenizer).to(device)
    with torch.no_grad():
        token_ids = generate_text_simple(
            model=model, idx=encoded,
            max_new_tokens=50, context_size=context_size
        )
        decoded_text = token_ids_to_text(token_ids, tokenizer)
        print(decoded_text.replace("\n", " ")) # Compact print format
    model.train()

While the evaluate_model function gives us a numerical estimate of the model's training progress, this generate_and_print_sample_text function provides concrete examples of text generated by the model to judge its ability during training.

Adam W

The Adam optimizer is a popular choice for training deep neural networks. However, in our training loop, we chose the AdamW optimizer. AdamW is a variant of Adam that improves the weight decay method, aiming to minimize model complexity and prevent overfitting by penalizing larger weights. This adjustment enables AdamW to achieve more effective regularization and better generalization, and is therefore often used for training LLMs.

Let’s see all this in action by training a GPTModel instance for 10 epochs using the AdamW optimizer and the train_model_simple function we defined earlier.

In [None]:
torch.manual_seed(123)
model = GPTModel(GPT_CONFIG_124M)
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0004, weight_decay=0.1) #A
num_epochs = 10
train_losses, val_losses, tokens_seen = train_model_simple(
    model, train_loader, val_loader, optimizer, device,
    num_epochs=num_epochs, eval_freq=5, eval_iter=1,
    start_context="Every effort moves you"
)

Executing the training_model_simple function will start the training process, which should take about 5 minutes to complete on a MacBook Air or similar laptop. The output printed during this execution is as follows:

In [None]:
Ep 1 (Step 000000): Train loss 9.781, Val loss 9.933
Ep 1 (Step 000005): Train loss 8.111, Val loss 8.339
Every effort moves you,,,,,,,,,,,,.
Ep 2 (Step 000010): Train loss 6.661, Val loss 7.048
Ep 2 (Step 000015): Train loss 5.961, Val loss 6.616
Every effort moves you, and, and, and, and, and, and, and, and, and, and, and, and, and, and, and, and, and, and, and, and, and, and,, and, and,
[...] #A
Ep 9 (Step 000080): Train loss 0.541, Val loss 6.393
Every effort moves you?" "Yes--quite insensible to the irony. She wanted him vindicated--and by me!" He laughed again, and threw back the window-curtains, I had the donkey. "There were days when I
Ep 10 (Step 000085): Train loss 0.391, Val loss 6.452
Every effort moves you know," was one of the axioms he laid down

As we can see, according to the results printed during training, the training loss improved dramatically, starting from a value of 9.558 and converging to 0.762. The language skills of the model have improved a lot. In the beginning, the model could only append commas to the start context ("Every effort moves you,,,,,,,,,,,") or repeat the word "and". At the end of training, it could generate grammatically correct text.

Similar to the training set loss, we can see that the validation loss starts high (9.856) and decreases during training. However, it never becomes as small as the training set loss and remains at 6.372 after epoch 10.

Before discussing validation loss in more detail, let’s create a simple plot showing the training and validation loss side by side:

In [None]:
import matplotlib.pyplot as plt
def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses):
    fig, ax1 = plt.subplots(figsize=(5, 3))
    ax1.plot(epochs_seen, train_losses, label="Training loss")
    ax1.plot(epochs_seen, val_losses, linestyle="-.", label="Validation loss")
    ax1.set_xlabel("Epochs")
    ax1.set_ylabel("Loss")
    ax1.legend(loc="upper right")
    ax2 = ax1.twiny() #A
    ax2.plot(tokens_seen, train_losses, alpha=0) #B
    ax2.set_xlabel("Tokens seen")
    fig.tight_layout()
    plt.show()
epochs_tensor = torch.linspace(0, num_epochs, len(train_losses))
plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses)

The resulting training and validation loss plots are shown in Figure 5.12.

Figure 5.12 At the beginning of training, we observe that both the training set and validation set losses decrease dramatically, indicating that the model is learning. However, the training set loss continues to decrease after the second epoch, while the validation loss stagnates. This indicates that the model is still learning, but it is overfitting to the training set after epoch 2.

![image-20240422144030197](..\img\fig-5-12.png)

As shown in Figure 5.12, during the first epoch, both training and validation losses begin to improve. However, the losses begin to diverge beyond the second epoch. This divergence, along with the fact that validation loss is much larger than training loss, suggests that the model is overfitting the training data. We can confirm that the model has memorized the training data verbatim by searching for generated text snippets, such as “very insensitive to sarcasm” in the “The Verdict” text file.

This memorization is expected since we are using a very very small training dataset and training the model for multiple epochs. Normally, models are trained on a much larger dataset for only one epoch.

As mentioned earlier, the interested reader can try training the model on Project Gutenberg’s 60,000 public domain books, where this kind of overfitting does not occur; see Appendix B for details.

In the next section, as shown in Figure 5.13, we will explore the sampling method adopted by LLM to alleviate the memory effect and thus generate more novel text.

Figure 5.13 Our model can generate coherent text after implementing the training function. However, it often memorizes paragraphs from the training set verbatim. The following sections describe strategies for generating more diverse output text.

![image-20240422144152449](../img/fig-5-13.png)

As shown in Figure 5.13, the next section introduces the text generation strategy of LLM to reduce training data memory and improve the originality of LLM-generated text. Then we introduce weight loading and saving and loading pre-trained weights of the GPT model from OpenAI.