# Lesson 5. Model training

Pretraining is very expensive! Please check costs carefully before starting a pretraining project.

You can get a rough estimate your training job cost using [this calculator](https://huggingface.co/training-cluster) from Hugging Face. For training on other infrastructure, e.g. AWS or Google Cloud, please consult those providers for up to date cost estimates. 

In [1]:
import warnings
warnings.filterwarnings('ignore')

## 1. Load the model to be trained

Load the upscaled model from the previous lesson:

In [7]:
# pip install datasets

In [1]:
import torch
from transformers import AutoModelForCausalLM
model_path = "./models"
pretrained_model = AutoModelForCausalLM.from_pretrained(
    model_path,
    device_map="cuda", 
    torch_dtype=torch.bfloat16,
    use_cache=False,
)

In [2]:
pretrained_model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 1024)
    (layers): ModuleList(
      (0-11): 12 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (k_proj): Linear(in_features=1024, out_features=256, bias=False)
          (v_proj): Linear(in_features=1024, out_features=256, bias=False)
          (o_proj): Linear(in_features=1024, out_features=1024, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=1024, out_features=4096, bias=False)
          (up_proj): Linear(in_features=1024, out_features=4096, bias=False)
          (down_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((1024,), eps=1e-06)
        (post_attention_layernorm): LlamaRMSNorm((1024,), eps=1e-06)
      )
    )
    (norm): LlamaRMSNorm((1024,), eps=1e-06)
    (rotary_emb): 

## 2. Load dataset

Here you'll update two methods on the `Dataset` object to allow it to interface with the trainer. These will be applied when you specify the dataset you created in Lesson 3 as the training data in the next section.

Note that the code has additional comment strings that don't appear in the video. These are to help you understand what each part of the code is doing.

In [3]:
import datasets
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, args, split="train"):
        """Initializes the custom dataset object."""
        self.args = args
        self.dataset = datasets.load_dataset(
            "parquet",
            data_files=args.dataset_name,
            split=split
        )

    def __len__(self):
        """Returns the number of samples in the dataset."""
        return len(self.dataset)

    def __getitem__(self, idx):
        """
        Retrieves a single data sample from the dataset 
        at the specified index
        """
        # Convert the lists to a LongTensor for PyTorch
        input_ids = torch.LongTensor(self.dataset[idx]["input_ids"])
        labels = torch.LongTensor(self.dataset[idx]["input_ids"])

        # Return the sample as a dictionary
        return {"input_ids": input_ids, "labels": labels}

## 3. Configure Training Arguments

Here you set up the training run. The training dataset you created in Lesson 3 is specified in the Dataset configuration section.

Note: there are comment strings in the cell below that don't appear in the video. These have been included to help you understand what each parameter does.

In [35]:
# from dataclasses import dataclass, field
# import transformers

# @dataclass
# class CustomArguments(transformers.TrainingArguments):
#     dataset_name: str = field(                           # Dataset configuration
#         default="./packaged_pretrain_dataset.parquet")
#     num_proc: int = field(default=1)                     # Number of subprocesses for data preprocessing
#     max_seq_length: int = field(default=128)              # Maximum sequence length

#     # Core training configurations
#     seed: int = field(default=0)                         # Random seed for initialization, ensuring reproducibility
#     optim: str = field(default="adamw_torch")            # Optimizer, here it's AdamW implemented in PyTorch
#     # max_steps: int = field(default=15000) # Number of maximum training steps
#     num_train_epochs: int = field(default=5)  # Define training in terms of epochs instead of steps
#     per_device_train_batch_size: int = field(default=8)  # Batch size per device during training

#     # Other training configurations
#     learning_rate: float = field(default=5e-5)           # Initial learning rate for the optimizer
#     weight_decay: float = field(default=0)               # Weight decay
#     warmup_steps: int = field(default=10)                # Number of steps for the learning rate warmup phase
#     lr_scheduler_type: str = field(default="linear")     # Type of learning rate scheduler
#     gradient_checkpointing: bool = field(default=True)   # Enable gradient checkpointing to save memory
#     dataloader_num_workers: int = field(default=2)       # Number of subprocesses for data loading
#     bf16: bool = field(default=True)                     # Use bfloat16 precision for training on supported hardware
#     gradient_accumulation_steps: int = field(default=1)  # Number of steps to accumulate gradients before updating model weights
    
#     # Logging configuration
#     logging_steps: int = field(default=1000)                # Frequency of logging training information
#     report_to: str = field(default="none")               # Destination for logging (e.g., WandB, TensorBoard)

#     # Saving configuration
#     # Saving configuration
#     save_strategy: str = field(default="epoch")  # Change save strategy to epoch if needed
#     save_total_limit: int = field(default=2)
#     # save_strategy: str = field(default="steps")          # Can be replaced with "epoch"
#     # save_steps: int = field(default=3)                   # Frequency of saving training checkpoint
#     # save_total_limit: int = field(default=2)             # The total number of checkpoints to be saved

In [4]:
from dataclasses import dataclass, field
import transformers

@dataclass
class CustomArguments(transformers.TrainingArguments):
    dataset_name: str = field(default="./packaged_pretrain_dataset.parquet")
    num_proc: int = field(default=1)
    max_seq_length: int = field(default=128)

    # Core training configurations
    seed: int = field(default=0)
    optim: str = field(default="adamw_torch")
    num_train_epochs: int = field(default=20)  # Set training for 10 epochs
    per_device_train_batch_size: int = field(default=8)

    # Other training configurations
    learning_rate: float = field(default=5e-5)
    weight_decay: float = field(default=0)
    warmup_steps: int = field(default=10)
    lr_scheduler_type: str = field(default="linear")
    gradient_checkpointing: bool = field(default=True)
    dataloader_num_workers: int = field(default=2)
    bf16: bool = field(default=True)
    gradient_accumulation_steps: int = field(default=1)

    # Logging configuration
    logging_steps: int = field(default=1000)
    report_to: str = field(default="none")

    # Saving configuration
    save_strategy: str = field(default="epoch")  # Save after each epoch
    save_total_limit: int = field(default=2)


Parse the custom arguments and set the output directory where the model will be saved: 

In [5]:
parser = transformers.HfArgumentParser(CustomArguments)
args, = parser.parse_args_into_dataclasses(
    args=["--output_dir", "output"]
)

Setup the training dataset:

In [15]:
train_dataset = CustomDataset(args=args)

Check the shape of the dataset:

In [10]:
print("Input shape: ", train_dataset[0]['input_ids'].shape)

Input shape:  torch.Size([128])


## 4. Run the trainer and monitor the loss

First, set up a callback to log the loss values during training (note this cell is not shown in the video):

In [11]:
from transformers import Trainer, TrainingArguments, TrainerCallback

# Define a custom callback to log the loss values
class LossLoggingCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None:
            self.logs.append(logs)

    def __init__(self):
        self.logs = []

# Initialize the callback
loss_logging_callback = LossLoggingCallback()

Then, create an instance of the Hugging Face `Trainer` object from the `transformers` library. Call the `train()` method of the trainder to initialize the training run:

In [16]:
from transformers import Trainer

trainer = Trainer(
    model=pretrained_model, 
    args=args, 
    train_dataset=train_dataset, 
    eval_dataset=None,
    callbacks=[loss_logging_callback] 
)

trainer.train()

Step,Training Loss
1000,3.4987
2000,3.3111
3000,2.9909
4000,2.9841
5000,2.7372
6000,2.7411
7000,2.544
8000,2.5559
9000,2.3979
10000,2.4021


TrainOutput(global_step=39560, training_loss=2.2105191222095395, metrics={'train_runtime': 4119.4865, 'train_samples_per_second': 76.825, 'train_steps_per_second': 9.603, 'total_flos': 5.231692675547136e+16, 'train_loss': 2.2105191222095395, 'epoch': 20.0})

You can use the code below to save intermediate model checkpoints in your own training run:

In [11]:
# Saving configuration
    # save_strategy: str = field(default="steps")          # Can be replaced with "epoch"
    # save_steps: int = field(default=3)                   # Frequency of saving training checkpoint
    # save_total_limit: int = field(default=2)             # The total number of checkpoints to be saved

### Checking the performance of an intermediate checkpoint

Below, you can try generating text using an intermediate checkpoint of the model. This checkpoint was saved after 10,000 training steps. As you did in previous lessons, you'll use the Solar tokenizer and then set up a `TextStreater` object to display the text as it is generated: 

In [17]:
from transformers import AutoTokenizer, TextStreamer
model_path = "./models"
tokenizer = AutoTokenizer.from_pretrained(model_path)

In [19]:
from transformers import AutoTokenizer, TextStreamer, AutoModelForCausalLM
import torch

model_name_or_path = "./output/checkpoint-39560"
model2 = AutoModelForCausalLM.from_pretrained(
    model_name_or_path,
    device_map="auto",
    torch_dtype=torch.bfloat16,    
)


In [21]:
prompt = "Avara (in Kshatriya; the great Vamuka) would give food to their beloved,  "

inputs = tokenizer(prompt, return_tensors="pt").to(model2.device)

streamer = TextStreamer(
    tokenizer, 
    skip_prompt=True, 
    skip_special_tokens=True
)

outputs = model2.generate(
    **inputs, 
    streamer=streamer, 
    use_cache=True, 
    max_new_tokens=128,     
    do_sample=True,
    temperature=1.0,
)

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


and their utmost in his own mind would not leave them in a moment.
  When any one of Sunder Mâyâ is mentally dead, he will have an opportunity of getting rid of the past and future. Otherwise he does not eat.
5. Kshatriya-vegetation makes a man pure. 6. By this time the soul has attained to a certain goal and can reach it where it is not free. Therefore the Mahâ-animal sacrifice is not a must; it is only a way of attaining Vaishâdha (honour) and Brah
