# BERT Pre-training: MLM + NSP From Scratch

Notebook này implement BERT pre-training với **2 objectives kết hợp**:
1. **MLM (Masked Language Modeling)**: Dự đoán từ bị mask
2. **NSP (Next Sentence Prediction)**: Dự đoán câu tiếp theo có phải là câu kế tiếp không


## 1. Setup


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import numpy as np
import random
import math

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")


Device: cuda


## 2. Tokenizer đơn giản


In [2]:
class SimpleTokenizer:
    def __init__(self, vocab_size=10000):
        self.vocab_size = vocab_size
        self.word_to_id = {}
        self.id_to_word = {}
        
        # Special tokens
        self.special_tokens = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"]
        self.pad_id, self.unk_id, self.cls_id, self.sep_id, self.mask_id = 0, 1, 2, 3, 4
        
        for idx, token in enumerate(self.special_tokens):
            self.word_to_id[token] = idx
            self.id_to_word[idx] = token
    
    def build_vocab(self, texts):
        word_counter = Counter()
        for text in texts:
            word_counter.update(text.lower().split())
        
        vocab_count = len(self.special_tokens)
        for word, _ in word_counter.most_common():
            if vocab_count >= self.vocab_size:
                break
            if word not in self.word_to_id:
                self.word_to_id[word] = vocab_count
                self.id_to_word[vocab_count] = word
                vocab_count += 1
        print(f"Vocabulary size: {len(self.word_to_id)}")
    
    def encode(self, text, max_length=128):
        words = text.lower().split()
        token_ids = [self.cls_id] + [self.word_to_id.get(w, self.unk_id) for w in words] + [self.sep_id]
        
        if len(token_ids) > max_length:
            token_ids = token_ids[:max_length]
            token_ids[-1] = self.sep_id
        else:
            token_ids.extend([self.pad_id] * (max_length - len(token_ids)))
        return token_ids


## 3. BERT Model Components


In [3]:
# Embedding: Token + Position + Segment
class BERTEmbedding(nn.Module):
    def __init__(self, vocab_size, embed_dim, max_seq_length, dropout=0.1):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, embed_dim)
        self.pos_emb = nn.Embedding(max_seq_length, embed_dim)
        self.seg_emb = nn.Embedding(2, embed_dim)  # 0: sentence 1, 1: sentence 2
        self.norm = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, token_ids, segment_ids=None):
        batch_size, seq_len = token_ids.size()
        if segment_ids is None:
            segment_ids = torch.zeros_like(token_ids)
        
        pos = torch.arange(seq_len, device=token_ids.device).unsqueeze(0).expand(batch_size, -1)
        emb = self.token_emb(token_ids) + self.pos_emb(pos) + self.seg_emb(segment_ids)
        return self.dropout(self.norm(emb))

# Multi-Head Attention
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = math.sqrt(self.head_dim)
        
        self.q = nn.Linear(embed_dim, embed_dim)
        self.k = nn.Linear(embed_dim, embed_dim)
        self.v = nn.Linear(embed_dim, embed_dim)
        self.out = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        batch_size, seq_len, embed_dim = x.size()
        Q = self.q(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.k(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.v(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(2)
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        context = torch.matmul(attn, V).transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)
        return self.out(context)

# Transformer Block
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(ff_dim, embed_dim)
        )
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        x = self.norm1(x + self.dropout1(self.attn(x, mask)))
        x = self.norm2(x + self.dropout2(self.ff(x)))
        return x

# BERT Encoder
class BERT(nn.Module):
    def __init__(self, vocab_size, embed_dim=256, num_layers=4, num_heads=8, 
                 ff_dim=1024, max_seq_length=128, dropout=0.1):
        super().__init__()
        self.embeddings = BERTEmbedding(vocab_size, embed_dim, max_seq_length, dropout)
        self.layers = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, ff_dim, dropout)
            for _ in range(num_layers)
        ])
    
    def forward(self, token_ids, segment_ids=None, mask=None):
        x = self.embeddings(token_ids, segment_ids)
        for layer in self.layers:
            x = layer(x, mask)
        return x


## 4. BERT Pre-training Model (MLM + NSP)


In [4]:
class BERTForPreTraining(nn.Module):
    """
    BERT model cho pre-training với 2 objectives:
    - MLM: Dự đoán từ bị mask
    - NSP: Dự đoán câu tiếp theo
    """
    def __init__(self, vocab_size, embed_dim=256, num_layers=4, num_heads=8,
                 ff_dim=1024, max_seq_length=128, dropout=0.1):
        super().__init__()
        self.bert = BERT(vocab_size, embed_dim, num_layers, num_heads, ff_dim, max_seq_length, dropout)
        
        # MLM Head: Dự đoán từ bị mask
        self.mlm_head = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.GELU(),
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, vocab_size)
        )
        
        # NSP Head: Dự đoán câu tiếp theo (IsNext/NotNext)
        self.nsp_head = nn.Linear(embed_dim, 2)
    
    def forward(self, token_ids, segment_ids=None, mask=None):
        # BERT encoder output
        bert_output = self.bert(token_ids, segment_ids, mask)  # (batch, seq_len, embed_dim)
        
        # MLM: Dự đoán cho tất cả vị trí
        mlm_logits = self.mlm_head(bert_output)  # (batch, seq_len, vocab_size)
        
        # NSP: Dùng [CLS] token (vị trí 0) để dự đoán
        cls_output = bert_output[:, 0, :]  # (batch, embed_dim)
        nsp_logits = self.nsp_head(cls_output)  # (batch, 2)
        
        return mlm_logits, nsp_logits


## 5. Tạo dữ liệu cho MLM và NSP


In [5]:
def create_mlm_data(token_ids, tokenizer, mask_prob=0.15):
    """
    Tạo dữ liệu cho MLM:
    - 15% tokens được chọn để mask
    - Trong số đó: 80% thay bằng [MASK], 10% random token, 10% giữ nguyên
    """
    labels = token_ids.clone()
    vocab_size = len(tokenizer.word_to_id)
    
    for i, token_id in enumerate(token_ids):
        # Bỏ qua special tokens
        if token_id in [tokenizer.pad_id, tokenizer.cls_id, tokenizer.sep_id]:
            labels[i] = -100  # Ignore trong loss
            continue
        
        # 15% chance để mask
        if random.random() < mask_prob:
            rand = random.random()
            if rand < 0.8:  # 80%: thay bằng [MASK]
                token_ids[i] = tokenizer.mask_id
            elif rand < 0.9:  # 10%: thay bằng random token
                token_ids[i] = random.randint(5, vocab_size - 1)
            # 10% còn lại: giữ nguyên
    
    return token_ids, labels

def create_nsp_data(sentences, tokenizer, max_length=128):
    """
    Tạo cặp câu cho NSP:
    - 50% là câu tiếp theo (is_next=1)
    - 50% là câu ngẫu nhiên (is_next=0)
    """
    is_next = random.random() < 0.5
    
    if is_next:
        # Câu tiếp theo: lấy 2 câu liên tiếp
        idx = random.randint(0, len(sentences) - 2)
        sent1, sent2 = sentences[idx], sentences[idx + 1]
    else:
        # Không phải câu tiếp theo: lấy 2 câu ngẫu nhiên
        idx1, idx2 = random.sample(range(len(sentences)), 2)
        sent1, sent2 = sentences[idx1], sentences[idx2]
    
    # Tokenize: [CLS] sent1 [SEP] sent2 [SEP]
    words1 = sent1.lower().split()
    words2 = sent2.lower().split()
    
    token_ids = [tokenizer.cls_id]
    segment_ids = [0]
    
    # Câu 1
    for word in words1:
        token_ids.append(tokenizer.word_to_id.get(word, tokenizer.unk_id))
        segment_ids.append(0)
    token_ids.append(tokenizer.sep_id)
    segment_ids.append(0)
    
    # Câu 2
    for word in words2:
        token_ids.append(tokenizer.word_to_id.get(word, tokenizer.unk_id))
        segment_ids.append(1)
    token_ids.append(tokenizer.sep_id)
    segment_ids.append(1)
    
    # Pad hoặc truncate
    if len(token_ids) > max_length:
        token_ids = token_ids[:max_length]
        segment_ids = segment_ids[:max_length]
        token_ids[-1] = tokenizer.sep_id
    else:
        padding = max_length - len(token_ids)
        token_ids.extend([tokenizer.pad_id] * padding)
        segment_ids.extend([0] * padding)
    
    return token_ids, segment_ids, is_next


## 6. Dataset cho Pre-training


In [6]:
class PreTrainingDataset(Dataset):
    def __init__(self, sentences, tokenizer, max_length=128):
        self.sentences = sentences
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.sentences) - 1
    
    def __getitem__(self, idx):
        # Tạo cặp câu cho NSP
        token_ids, segment_ids, is_next = create_nsp_data(self.sentences, self.tokenizer, self.max_length)
        
        # Tạo mask cho valid tokens
        mask = [1 if tid != self.tokenizer.pad_id else 0 for tid in token_ids]
        
        # Áp dụng MLM masking
        token_ids_tensor = torch.tensor(token_ids, dtype=torch.long)
        token_ids_masked, mlm_labels = create_mlm_data(token_ids_tensor.clone(), self.tokenizer)
        
        return {
            'token_ids': token_ids_masked,
            'segment_ids': torch.tensor(segment_ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
            'mlm_labels': mlm_labels,  # Ground truth cho MLM
            'nsp_label': torch.tensor(1 if is_next else 0, dtype=torch.long)  # Ground truth cho NSP
        }


## 7. Loss Function: Kết hợp MLM + NSP


In [7]:
def compute_pretraining_loss(mlm_logits, nsp_logits, mlm_labels, nsp_labels):
    """
    Tính loss kết hợp cho MLM và NSP
    
    Args:
        mlm_logits: (batch, seq_len, vocab_size) - dự đoán từ bị mask
        nsp_logits: (batch, 2) - dự đoán IsNext/NotNext
        mlm_labels: (batch, seq_len) - ground truth, -100 cho vị trí không tính
        nsp_labels: (batch,) - 0 hoặc 1
    """
    # MLM Loss: chỉ tính cho vị trí bị mask (labels != -100)
    mlm_loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
    mlm_loss = mlm_loss_fn(
        mlm_logits.view(-1, mlm_logits.size(-1)),  # (batch*seq_len, vocab_size)
        mlm_labels.view(-1)  # (batch*seq_len,)
    )
    
    # NSP Loss: binary classification
    nsp_loss_fn = nn.CrossEntropyLoss()
    nsp_loss = nsp_loss_fn(nsp_logits, nsp_labels)
    
    # Tổng loss: kết hợp cả 2
    total_loss = mlm_loss + nsp_loss
    
    return total_loss, mlm_loss, nsp_loss


## 8. Training: Pre-training với MLM + NSP


In [8]:
# Load dataset
from datasets import load_dataset
ds = load_dataset("uitnlp/vietnamese_students_feedback")

# Build tokenizer
tokenizer = SimpleTokenizer(vocab_size=5000)
tokenizer.build_vocab(ds["train"]["sentence"][:2000])

# Tạo dataset
sentences = ds["train"]["sentence"][:2000]
train_dataset = PreTrainingDataset(sentences, tokenizer, max_length=128)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

print(f"Training samples: {len(train_dataset)}")

# Tạo model
vocab_size = len(tokenizer.word_to_id)
model = BERTForPreTraining(
    vocab_size=vocab_size,
    embed_dim=256,
    num_layers=4,
    num_heads=8,
    ff_dim=1024,
    max_seq_length=128,
    dropout=0.1
).to(device)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-4)
num_epochs = 3

# Training loop
print("\n=== Pre-training với MLM + NSP ===\n")
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    total_mlm_loss = 0
    total_nsp_loss = 0
    nsp_correct = 0
    nsp_total = 0
    
    for batch_idx, batch in enumerate(train_loader):
        token_ids = batch['token_ids'].to(device)
        segment_ids = batch['segment_ids'].to(device)
        mask = batch['mask'].to(device)
        mlm_labels = batch['mlm_labels'].to(device)
        nsp_labels = batch['nsp_label'].to(device)
        
        optimizer.zero_grad()
        
        # Forward: nhận cả MLM và NSP logits
        mlm_logits, nsp_logits = model(token_ids, segment_ids, mask)
        
        # Tính loss kết hợp
        loss, mlm_loss, nsp_loss = compute_pretraining_loss(
            mlm_logits, nsp_logits, mlm_labels, nsp_labels
        )
        
        # Backward
        loss.backward()
        optimizer.step()
        
        # Track metrics
        total_loss += loss.item()
        total_mlm_loss += mlm_loss.item()
        total_nsp_loss += nsp_loss.item()
        
        # NSP accuracy
        nsp_preds = nsp_logits.argmax(1)
        nsp_correct += (nsp_preds == nsp_labels).sum().item()
        nsp_total += nsp_labels.size(0)
        
        if (batch_idx + 1) % 20 == 0:
            print(f"  Batch {batch_idx+1}/{len(train_loader)}: "
                  f"Total={loss.item():.4f}, MLM={mlm_loss.item():.4f}, NSP={nsp_loss.item():.4f}")
    
    # Epoch summary
    avg_loss = total_loss / len(train_loader)
    avg_mlm = total_mlm_loss / len(train_loader)
    avg_nsp = total_nsp_loss / len(train_loader)
    nsp_acc = nsp_correct / nsp_total
    
    print(f"\nEpoch {epoch+1}/{num_epochs}:")
    print(f"  Total Loss: {avg_loss:.4f}")
    print(f"  MLM Loss: {avg_mlm:.4f}")
    print(f"  NSP Loss: {avg_nsp:.4f}")
    print(f"  NSP Accuracy: {nsp_acc:.4f}\n")


Vocabulary size: 1333
Training samples: 1999
Model parameters: 3,943,479

=== Pre-training với MLM + NSP ===

  Batch 20/125: Total=6.9605, MLM=6.2669, NSP=0.6936
  Batch 40/125: Total=6.1158, MLM=5.4096, NSP=0.7063
  Batch 60/125: Total=5.4566, MLM=4.7008, NSP=0.7558
  Batch 80/125: Total=4.9400, MLM=4.2496, NSP=0.6904
  Batch 100/125: Total=4.7449, MLM=4.1102, NSP=0.6347
  Batch 120/125: Total=4.3986, MLM=3.7173, NSP=0.6813

Epoch 1/3:
  Total Loss: 5.7477
  MLM Loss: 5.0302
  NSP Loss: 0.7175
  NSP Accuracy: 0.4987

  Batch 20/125: Total=4.4226, MLM=3.7060, NSP=0.7166
  Batch 40/125: Total=3.8275, MLM=3.0327, NSP=0.7949
  Batch 60/125: Total=3.2926, MLM=2.4737, NSP=0.8189
  Batch 80/125: Total=3.6135, MLM=2.9332, NSP=0.6803
  Batch 100/125: Total=3.8721, MLM=3.1804, NSP=0.6917
  Batch 120/125: Total=3.5815, MLM=2.8687, NSP=0.7128

Epoch 2/3:
  Total Loss: 3.7457
  MLM Loss: 3.0369
  NSP Loss: 0.7087
  NSP Accuracy: 0.4997

  Batch 20/125: Total=3.1548, MLM=2.4283, NSP=0.7265
  Batch

## Tóm tắt

**BERT Pre-training sử dụng 2 objectives kết hợp:**

1. **MLM (Masked Language Modeling)**:
   - Mask 15% tokens
   - Dự đoán từ bị mask
   - Loss: CrossEntropy (chỉ tính cho vị trí bị mask)

2. **NSP (Next Sentence Prediction)**:
   - Tạo cặp câu: 50% là câu tiếp theo, 50% là câu ngẫu nhiên
   - Dự đoán IsNext/NotNext từ [CLS] token
   - Loss: CrossEntropy (binary classification)

**Total Loss = MLM Loss + NSP Loss**

Cả 2 objectives được train **cùng lúc** trong một forward pass, giúp BERT học được:
- Hiểu ngữ nghĩa từ (từ MLM)
- Hiểu quan hệ giữa các câu (từ NSP)
