# *Lab: LLM Pretraining

Here we directly leverage the decoder architecture we made from previous sections. 
  

In [13]:
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.utils.data import DataLoader
from torch import nn
import numpy as np
from omegaconf import OmegaConf
from llm_lab.model.rotary_decoder import RotaryDecoderModel
from llm_lab.utils.collate_utils import default_data_collator
from llm_lab.utils.common_utils import move_to_device
from transformers import AutoTokenizer
from datasets import load_dataset
from itertools import chain
from functools import partial

%load_ext autoreload
%autoreload 2 

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Data

In [14]:
dataset_name =  "wikitext"
data_config = "wikitext-2-raw-v1"
text_column_name = "text"

# model parameters
model_name_or_path="openai-community/gpt2"

In [15]:
raw_datasets = load_dataset(dataset_name, data_config)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)

In [16]:
def tokenize(examples):
    return tokenizer(examples[text_column_name])

In [17]:
def group_and_chunk(tokenized_examples, chunk_size=1024, chunk_key='input_ids'):
    keys = list(tokenized_examples.keys())
    # use chain to flatten list
    concat_examples = {k: list(chain(*tokenized_examples[k])) for k in keys}
    total_length = len(concat_examples[chunk_key])
    total_length = (total_length // chunk_size) * chunk_size
    
    result_dict = {
        k: [v[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
        for k, v in concat_examples.items()
    }

    return result_dict

In [18]:
tokenized_dataset = raw_datasets.map(
                    tokenize, 
                    batched=True)

In [19]:
chunk_data = tokenized_dataset.map(
                                    partial(group_and_chunk, 
                                            chunk_size=256),
                                        #chunk_size=tokenizer.model_max_length),
                                    batched=True,
                                    remove_columns=['text'])

In [20]:
chunk_data

DatasetDict({
    test: Dataset({
        features: ['text', 'input_ids', 'attention_mask'],
        num_rows: 1104
    })
    train: Dataset({
        features: ['text', 'input_ids', 'attention_mask'],
        num_rows: 9327
    })
    validation: Dataset({
        features: ['text', 'input_ids', 'attention_mask'],
        num_rows: 964
    })
})

## Model

In [21]:
class DecoderCausalLM(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.decoder = RotaryDecoderModel(config)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        
    def forward(self, batch):
        hidden_states = self.decoder(input_ids=batch['input_ids'])
        logits = self.lm_head(hidden_states)
        return logits

## Training

In [22]:
def compute_batch_loss(batch, model, device):
    assert model.training
    move_to_device(batch, device)
    model_input = {'input_ids':batch['input_ids'],'attention_mask': batch['attention_mask']}
    logits = model(model_input)[:,:-1,:].contiguous()
    labels = batch['input_ids'][:,1:].contiguous()
    flat_labels = labels.view(-1)
    flat_logits = logits.view(-1, logits.shape[-1])
    loss = F.cross_entropy(flat_logits, flat_labels)
    return loss

def compute_eval_loss(eval_dataloader, model, device):
    assert not model.training
    all_losses = []
    with torch.no_grad():
        for batch in eval_dataloader:
            move_to_device(batch, device)
            model_input = {'input_ids':batch['input_ids'],'attention_mask': batch['attention_mask']}
            logits = model(model_input)[:,:-1,:].contiguous()
            labels = batch['input_ids'][:,1:].contiguous()
            flat_labels = labels.view(-1)
            flat_logits = logits.view(-1, logits.shape[-1])
            losses = F.cross_entropy(flat_logits, flat_labels, reduction='none').tolist()
            all_losses.extend(losses)
    
    mean_loss = np.mean(all_losses)
    return mean_loss

def train_model_epoch(model, 
                train_loader, 
                val_loader, 
                optimizer,
                device,
                train_config):
    
    global_steps = 0
    record_list = []
    model = model.to(device)
    for epoch in range(train_config.num_epochs):
        
        
        for batch in train_loader:
            model.train()
            loss = compute_batch_loss(batch, model, device)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            global_steps += 1
            if global_steps % train_config.log_freq == 0:
                model.eval()
                val_loss = compute_eval_loss(val_loader, model, device)
                record = {"epoch": epoch,
                          "step": global_steps,
                          "train_loss": loss.detach().item(),
                          "val_loss": val_loss}
                print(record)
                record_list.append(record)
        
    return record_list

In [23]:
def train_main(model_config, train_settings, chunk_data):
    
    torch.manual_seed(train_settings.seed)
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
            
    model = DecoderCausalLM(config=model_config)
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(),
                                  lr=train_settings.learning_rate,
                                  weight_decay=train_settings.weight_decay)
    
    
    train_loader = DataLoader(chunk_data['train'],
                                      batch_size=train_settings.batch_size,
                                      shuffle=True,
                                        num_workers=0,
                                        collate_fn=default_data_collator
    )
    
    val_loader =  DataLoader(chunk_data['validation'],
                                      batch_size=train_settings.batch_size,
                                      shuffle=False,
                                        num_workers=0,
                                        collate_fn=default_data_collator
    )
    
    train_model_epoch(model=model,
                train_loader=train_loader,
                val_loader=val_loader,
                optimizer=optimizer,
                train_config=train_settings,
                device=device)
    
    
    

## Training Entry

In [24]:
model_config = {
    "vocab_size": 50257,    # Vocabulary size
    "max_position_embeddings": 1024,
    "hidden_size": 768,         # model dimension
    "intermediate_size": 768*4,
    "num_key_value_heads": 2,
    "num_heads": 4,          # Number of attention heads
    "num_layers": 6,         # Number of layers
    "attention_dropout": 0.1,       # Dropout rate
    "qkv_bias": False,       # Query-key-value bias
    "o_bias": True,
    "mlp_bias": True,
    "rms_norm_eps": 1e-6,
    "dropout": 0.1,
    "pad_token_id": tokenizer.eos_token_id,
    "causal_attention": True
}

model_config = OmegaConf.create(model_config)
train_settings = {
    "learning_rate": 5e-4,
    "num_epochs": 1,
    "batch_size": 2,
    "weight_decay": 0.1,
    "seed": 1,
    "log_freq": 50
}

train_settings = OmegaConf.create(train_settings)

# train model
train_main(model_config=model_config, train_settings=train_settings, chunk_data=chunk_data)
    

# save model
#torch.save(model.state_dict(), "model.pth")

# training process
# {'epoch': 0, 'step': 950, 'train_loss': 0.3691011965274811, 'val_loss': 0.4959029606489849}
# {'epoch': 0, 'step': 1000, 'train_loss': 0.4940517842769623, 'val_loss': 0.4663128215367203}
# {'epoch': 0, 'step': 1050, 'train_loss': 0.8797768950462341, 'val_loss': 0.4402653791785611}
# {'epoch': 0, 'step': 1100, 'train_loss': 0.34599336981773376, 'val_loss': 0.41212919295760314}
# {'epoch': 0, 'step': 1150, 'train_loss': 0.3531911373138428, 'val_loss': 0.4092062050130844}
# {'epoch': 0, 'step': 1200, 'train_loss': 0.4641529619693756, 'val_loss': 0.38234950190919664}
# {'epoch': 0, 'step': 1250, 'train_loss': 0.22967249155044556, 'val_loss': 0.3607293127420803}
# {'epoch': 0, 'step': 1300, 'train_loss': 0.3634558618068695, 'val_loss': 0.3436481123064947}
# {'epoch': 0, 'step': 1350, 'train_loss': 0.35325485467910767, 'val_loss': 0.3274566013152589}
# {'epoch': 0, 'step': 1400, 'train_loss': 0.09018289297819138, 'val_loss': 0.3139857236895701}


{'epoch': 0, 'step': 50, 'train_loss': 5.934116363525391, 'val_loss': 5.95978886500486}
{'epoch': 0, 'step': 100, 'train_loss': 4.6610188484191895, 'val_loss': 4.4738071170653555}
{'epoch': 0, 'step': 150, 'train_loss': 2.971130132675171, 'val_loss': 3.394472318179944}
{'epoch': 0, 'step': 200, 'train_loss': 2.0693352222442627, 'val_loss': 2.6074926472101367}
{'epoch': 0, 'step': 250, 'train_loss': 1.3895280361175537, 'val_loss': 2.1285733113386835}
{'epoch': 0, 'step': 300, 'train_loss': 1.6931740045547485, 'val_loss': 1.758039969718746}
{'epoch': 0, 'step': 350, 'train_loss': 1.4765475988388062, 'val_loss': 1.500211113480706}
{'epoch': 0, 'step': 400, 'train_loss': 1.3626762628555298, 'val_loss': 1.3146632623914765}
{'epoch': 0, 'step': 450, 'train_loss': 1.3208198547363281, 'val_loss': 1.1729037596728389}
{'epoch': 0, 'step': 500, 'train_loss': 1.2643789052963257, 'val_loss': 1.0483412505480632}
{'epoch': 0, 'step': 550, 'train_loss': 1.347705364227295, 'val_loss': 0.942060433071201