# *Lab: LLM Pretraining

Here we directly leverage the decoder architecture we made from previous sections. 
  

In [19]:
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.utils.data import DataLoader
from torch import nn
import numpy as np
from omegaconf import OmegaConf
from llm_lab.model.rotary_decoder import RotaryDecoderModel
from llm_lab.utils.collate_utils import default_data_collator
from llm_lab.utils.common_utils import move_to_device
from transformers import AutoTokenizer
from datasets import load_dataset
from itertools import chain
from functools import partial

%load_ext autoreload
%autoreload 2 

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Data

In [20]:
dataset_name =  "wikitext"
data_config = "wikitext-2-raw-v1"
text_column_name = "text"

# model parameters
model_name_or_path="openai-community/gpt2"

In [21]:
raw_datasets = load_dataset(dataset_name, data_config)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)

In [22]:
def tokenize(examples):
    return tokenizer(examples[text_column_name])

In [23]:
def group_and_chunk(tokenized_examples, chunk_size=1024, chunk_key='input_ids'):
    keys = list(tokenized_examples.keys())
    # use chain to flatten list
    concat_examples = {k: list(chain(*tokenized_examples[k])) for k in keys}
    total_length = len(concat_examples[chunk_key])
    total_length = (total_length // chunk_size) * chunk_size
    
    result_dict = {
        k: [v[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
        for k, v in concat_examples.items()
    }

    return result_dict

In [24]:
tokenized_dataset = raw_datasets.map(
                    tokenize, 
                    batched=True)

In [25]:
chunk_data = tokenized_dataset.map(
                                    partial(group_and_chunk, 
                                            chunk_size=256),
                                        #chunk_size=tokenizer.model_max_length),
                                    batched=True,
                                    remove_columns=['text'])

In [26]:
chunk_data

DatasetDict({
    test: Dataset({
        features: ['text', 'input_ids', 'attention_mask'],
        num_rows: 1104
    })
    train: Dataset({
        features: ['text', 'input_ids', 'attention_mask'],
        num_rows: 9327
    })
    validation: Dataset({
        features: ['text', 'input_ids', 'attention_mask'],
        num_rows: 964
    })
})

## Model

In [27]:
class DecoderCausalLM(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.decoder = RotaryDecoderModel(config)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        
    def forward(self, batch):
        hidden_states = self.decoder(input_ids=batch['input_ids'])
        logits = self.lm_head(hidden_states)
        return logits

## Training
### Loss computation routine

In [28]:
def compute_batch_loss(batch, model, device, mixed_precision_assert=False):
    assert model.training
    move_to_device(batch, device)
    model_input = {'input_ids':batch['input_ids'],'attention_mask': batch['attention_mask']}
    logits = model(model_input)[:,:-1,:].contiguous()
    labels = batch['input_ids'][:,1:].contiguous()
    flat_labels = labels.view(-1)
    flat_logits = logits.view(-1, logits.shape[-1])
    loss = F.cross_entropy(flat_logits, flat_labels)
    
    if mixed_precision_assert == True:
        # flat_logits is float16 because linear layers ``autocast`` to float16.
        # loss is float32 because ``cross entropy`` layers ``autocast`` to float32.
        assert flat_logits.dtype == torch.float16
        assert loss.dtype == torch.float32
    
    return loss

def compute_eval_loss(eval_dataloader, model, device):
    assert not model.training
    all_losses = []
    with torch.no_grad():
        for batch in eval_dataloader:
            move_to_device(batch, device)
            model_input = {'input_ids':batch['input_ids'],'attention_mask': batch['attention_mask']}
            logits = model(model_input)[:,:-1,:].contiguous()
            labels = batch['input_ids'][:,1:].contiguous()
            flat_labels = labels.view(-1)
            flat_logits = logits.view(-1, logits.shape[-1])
            losses = F.cross_entropy(flat_logits, flat_labels, reduction='none').tolist()
            all_losses.extend(losses)
    
    mean_loss = np.mean(all_losses)
    return mean_loss



### Timing utility

In [29]:
import time, gc
# Timing utilities
# https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html
start_time = None

def start_timer():
    global start_time
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.synchronize()
    start_time = time.time()

def end_timer_and_print(local_msg):
    torch.cuda.synchronize()
    end_time = time.time()
    print("\n" + local_msg)
    print("Total execution time = {:.3f} sec".format(end_time - start_time))
    print("Max memory used by tensors = {} bytes".format(torch.cuda.max_memory_allocated()))

### Full precision

In [30]:
def train_model_epoch(model, 
                train_loader, 
                val_loader, 
                optimizer,
                device,
                train_config):
    
    global_steps = 0
    record_list = []
    model = model.to(device)
    start_timer()
    for epoch in range(train_config.num_epochs):
        
        
        for batch in train_loader:
            model.train()
            loss = compute_batch_loss(batch, model, device)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            global_steps += 1
            if global_steps % train_config.log_freq == 0:
                model.eval()
                val_loss = compute_eval_loss(val_loader, model, device)
                record = {"epoch": epoch,
                          "step": global_steps,
                          "train_loss": loss.detach().item(),
                          "val_loss": val_loss}
                print(record)
                record_list.append(record)
    end_timer_and_print("Default precision:")
    return record_list

### Mixed precision

Instances of torch.autocast serve as context managers that allow regions of your script to run in mixed precision.

In these regions, CUDA ops run in a dtype chosen by autocast to improve performance while maintaining accuracy. See the Autocast Op Reference for details on what precision autocast chooses for each op, and under what circumstances.

In [31]:


def train_model_epoch_mixed_precision(model, 
                train_loader, 
                val_loader, 
                optimizer,
                device,
                train_config):
    
    global_steps = 0
    record_list = []
    model = model.to(device)
    # Gradient scaling helps prevent gradients with small magnitudes from flushing to zero
    # (“underflowing”) when training with mixed precision.
    scaler = torch.amp.GradScaler('cuda')
    start_timer()
    for epoch in range(train_config.num_epochs):
        
        for batch in train_loader:
            model.train()

            with torch.autocast(device_type='cuda', dtype=torch.float16):
                
                loss = compute_batch_loss(batch, model, device, mixed_precision_assert=True)
            
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            global_steps += 1
            if global_steps % train_config.log_freq == 0:
                model.eval()
                val_loss = compute_eval_loss(val_loader, model, device)
                record = {"epoch": epoch,
                          "step": global_steps,
                          "train_loss": loss.detach().item(),
                          "val_loss": val_loss}
                print(record)
                record_list.append(record)
    end_timer_and_print("Default precision:")
    return record_list
                

In [32]:
def train_main(model, train_settings, chunk_data, mixed_precision_training=False):
    
    torch.manual_seed(train_settings.seed)
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
            

    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(),
                                  lr=train_settings.learning_rate,
                                  weight_decay=train_settings.weight_decay)
    
    
    train_loader = DataLoader(chunk_data['train'],
                                      batch_size=train_settings.batch_size,
                                      shuffle=True,
                                        num_workers=0,
                                        collate_fn=default_data_collator
    )
    
    val_loader =  DataLoader(chunk_data['validation'],
                                      batch_size=train_settings.batch_size,
                                      shuffle=False,
                                        num_workers=0,
                                        collate_fn=default_data_collator
    )
    
    
    if mixed_precision_training:
      train_model_epoch_mixed_precision(
                model=model,
                train_loader=train_loader,
                val_loader=val_loader,
                optimizer=optimizer,
                train_config=train_settings,
                device=device)
    else:
      train_model_epoch(model=model,
                  train_loader=train_loader,
                  val_loader=val_loader,
                  optimizer=optimizer,
                  train_config=train_settings,
                  device=device)
    

    
    
    

## Training Entry

In [33]:
model_config = {
    "vocab_size": 50257,    # Vocabulary size
    "max_position_embeddings": 1024,
    "hidden_size": 128,         # model dimension
    "intermediate_size": 128*4,
    "num_key_value_heads": 2,
    "num_heads": 4,          # Number of attention heads
    "num_layers": 3,         # Number of layers
    "attention_dropout": 0.1,       # Dropout rate
    "qkv_bias": False,       # Query-key-value bias
    "o_bias": True,
    "mlp_bias": True,
    "rms_norm_eps": 1e-6,
    "dropout": 0.1,
    "pad_token_id": tokenizer.eos_token_id,
    "causal_attention": True
}

model_config = OmegaConf.create(model_config)
train_settings = {
    "learning_rate": 5e-4,
    "num_epochs": 1,
    "batch_size": 32,
    "weight_decay": 0.1,
    "seed": 1,
    "log_freq": 50
}

train_settings = OmegaConf.create(train_settings)


model = DecoderCausalLM(config=model_config)
# train model
train_main(model, train_settings=train_settings, chunk_data=chunk_data, mixed_precision_training=True)
    


{'epoch': 0, 'step': 50, 'train_loss': 7.689789772033691, 'val_loss': 7.668777453052184}
{'epoch': 0, 'step': 100, 'train_loss': 7.0601887702941895, 'val_loss': 7.037264205774467}
{'epoch': 0, 'step': 150, 'train_loss': 6.174334526062012, 'val_loss': 6.146691244994047}
{'epoch': 0, 'step': 200, 'train_loss': 5.481212139129639, 'val_loss': 5.386733918386623}
{'epoch': 0, 'step': 250, 'train_loss': 4.769058704376221, 'val_loss': 4.7812650817852775}

Default precision:
Total execution time = 24.310 sec
Max memory used by tensors = 5175275520 bytes


In [34]:
# train model
train_main(model, train_settings=train_settings, chunk_data=chunk_data, mixed_precision_training=False)

# On my 3090
# /home/yangyutu/miniconda3/envs/huggingface_lastest/lib/python3.9/site-packages/torch/cuda/memory.py:365: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
#   warnings.warn(
# {'epoch': 0, 'step': 50, 'train_loss': 3.1814770698547363, 'val_loss': 3.159256796090871}
# {'epoch': 0, 'step': 100, 'train_loss': 2.7032246589660645, 'val_loss': 2.6685547356172212}
# {'epoch': 0, 'step': 150, 'train_loss': 2.2208354473114014, 'val_loss': 2.2558248526470144}
# {'epoch': 0, 'step': 200, 'train_loss': 1.9303675889968872, 'val_loss': 1.9033248468423163}
# {'epoch': 0, 'step': 250, 'train_loss': 1.5353786945343018, 'val_loss': 1.6159611645910337}

# Default precision:
# Total execution time = 25.575 sec
# Max memory used by tensors = 6044132352 bytes

# On 4090

# {'epoch': 0, 'step': 50, 'train_loss': 3.354149103164673, 'val_loss': 3.3306671452994774}
# {'epoch': 0, 'step': 100, 'train_loss': 2.8767457008361816, 'val_loss': 2.844842069686138}
# {'epoch': 0, 'step': 150, 'train_loss': 2.39607834815979, 'val_loss': 2.4303341806986802}
# {'epoch': 0, 'step': 200, 'train_loss': 2.114605188369751, 'val_loss': 2.0717746545050604}
# {'epoch': 0, 'step': 250, 'train_loss': 1.7027183771133423, 'val_loss': 1.7752767190680252}

# Default precision:
# Total execution time = 18.650 sec
# Max memory used by tensors = 5682771968 bytes

{'epoch': 0, 'step': 50, 'train_loss': 3.673240900039673, 'val_loss': 3.660376475333091}
{'epoch': 0, 'step': 100, 'train_loss': 3.1750710010528564, 'val_loss': 3.129806725645031}
{'epoch': 0, 'step': 150, 'train_loss': 2.6211187839508057, 'val_loss': 2.6639581041889313}
{'epoch': 0, 'step': 200, 'train_loss': 2.2989678382873535, 'val_loss': 2.257827716357259}
{'epoch': 0, 'step': 250, 'train_loss': 1.8609188795089722, 'val_loss': 1.9229015331654424}

Default precision:
Total execution time = 32.994 sec
Max memory used by tensors = 5683595264 bytes
