# Training portion for [Simplifying Transformer Blocks -- He et al.](https://arxiv.org/abs/2311.01906)

In [1]:
# Imports
import import_ipynb
import simplified_transformer_block # the other notebook, rip devops
import torch
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda
Initial loss: 7.078291416168213
Step 2, Loss: 7.033586502075195
Step 4, Loss: 6.944496154785156
Step 6, Loss: 6.855834007263184
Step 8, Loss: 6.767575740814209
Step 10, Loss: 6.679686069488525
Using device: cuda


## Training setup

### Training settings

In [None]:
# Model hyper-parameters
ctx_len = 128 # Context length, for the A100 training run we do length = 512 (if we have enough vram)
d_model = 768
num_layers = 18
num_heads = 12
dropout = 0.1
mlp_expansion_factor = 4
use_norm = True

# Training runs settings
batch_size = 16
stride = 32 # Overlap between chunks
train_epoches = 1
learning_rate = 1e-3
checkpoint_path = "checkpoints/pre-train"

### Tokenization, Dataset and Dataloader

In [None]:
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup

from transformers import AutoTokenizer
# GPT-2 tokenizer is commonly used for code
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
vocab_size = tokenizer.vocab_size


dataset = load_dataset("codeparrot/codeparrot-clean-train")
train_dataset = dataset["train"]

train_dataset.column_names

Repo card metadata block was not found. Setting CardData to empty.


Resolving data files:   0%|          | 0/53 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/107 [00:00<?, ?it/s]

['repo_name',
 'path',
 'copies',
 'size',
 'content',
 'license',
 'hash',
 'line_mean',
 'line_max',
 'alpha_frac',
 'autogenerated']

In [4]:
print(train_dataset.features)

{'repo_name': Value(dtype='string', id=None), 'path': Value(dtype='string', id=None), 'copies': Value(dtype='string', id=None), 'size': Value(dtype='string', id=None), 'content': Value(dtype='string', id=None), 'license': Value(dtype='string', id=None), 'hash': Value(dtype='int64', id=None), 'line_mean': Value(dtype='float64', id=None), 'line_max': Value(dtype='int64', id=None), 'alpha_frac': Value(dtype='float64', id=None), 'autogenerated': Value(dtype='bool', id=None)}


In [None]:
def collate_fn(batch):
    texts = [ex["content"] for ex in batch]
    
    # Tokenize WITHOUT truncation
    tokenized = tokenizer(
        texts,
        padding=True,
        truncation=False,  # Disable truncation
        return_tensors="pt"
    )

    # Now split long sequences into chunks
    input_ids_chunks = []
    attention_mask_chunks = []
    
    for seq in tokenized["input_ids"]:
        chunk_size = ctx_len
        chunk_overlap = stride
        chunks = [seq[i:i+chunk_size] 
                for i in range(0, len(seq), ctx_len - chunk_overlap)]
        
        # Add chunks to final list
        input_ids_chunks.extend(chunks)

    # Repeat for attention_mask
    for mask in tokenized["attention_mask"]:
        chunks = [mask[i:i+chunk_size]
                for i in range(0, len(mask), chunk_size)]
        attention_mask_chunks.extend(chunks)

    # Finally, pad the chunks (requires same-length sequences)
    return {
        "input_ids": torch.nn.utils.rnn.pad_sequence(
            input_ids_chunks,
            batch_first=True
        ),
        "attention_mask": torch.nn.utils.rnn.pad_sequence(
            attention_mask_chunks,
            batch_first=True
        )
    }

In [None]:
# Tokenization function
def tokenize_function(examples):
    return tokenizer(
        examples["content"],
        truncation=True,
        max_length=ctx_len,
        padding="max_length",
        return_tensors="pt"
    )

dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn
)

Map:   0%|          | 0/5300000 [00:00<?, ? examples/s]

KeyboardInterrupt: 

### Model

In [None]:
model = SimpleLanguageModel(
    vocab_size=tokenizer.vocab_size,
    d_model=d_model,
    num_layers=num_layers,
    num_heads=num_heads,
    mlp_dim = d_model * mlp_expansion_factor,
    context_length=ctx_len,
    dropout=dropout,
    use_norm=use_norm
).to(device)

### Optimizer and training scheduler

In [None]:
scaler = torch.amp.GradScaler()

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

# Learning rate schedule with warmup
num_training_steps = len(dataloader) * train_epoches
num_warmup_steps = len(dataloader) // 20  # 5% warmup for first epoch
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps
)

### Training loop

In [None]:
for epoch in range(train_epoches):
    for batch in dataloader:
        optimizer.zero_grad()
        
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch.get("attention_mask", torch.ones_like(input_ids)).to(device)
        labels = input_ids.clone()
        
        # Forward pass with autocast
        with torch.autocast(device_type='cuda', dtype=torch.float16):
            outputs = model(input_ids)
            # Compute loss (with masking if needed)
            loss = F.cross_entropy(
                outputs.view(-1, model.vocab_size),
                labels.view(-1)
            )
        
        # Backward pass with scaler
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        scheduler.step()