# Training a Transformer Encoder with MLM Objective on BERT-style tokens / wikitext dataset

### 1. Setup (retrieve data, look at `mask_dataset_for_mlm` for important MLM preprocessing steps)

In [1]:
from modules.encoder import EncoderModel
from preprocess.mlm_preprocess import get_dataset_example, mask_dataset_for_mlm

input_ids, tokenizer = get_dataset_example()
mlm_input_ids, mlm_labels = mask_dataset_for_mlm(input_ids)

  from .autonotebook import tqdm as notebook_tqdm
Found cached dataset wikitext (/Users/rishubtamirisa/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)
100%|██████████| 3/3 [00:00<00:00, 488.83it/s]
Loading cached processed dataset at /Users/rishubtamirisa/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-52483841b366ee3a.arrow
  0%|          | 0/37 [00:00<?, ?ba/s]Token indices sequence length is longer than the specified maximum sequence length for this model (647 > 512). Running this sequence through the model will result in indexing errors
100%|██████████| 37/37 [00:01<00:00, 20.12ba/s]
Loading cached processed dataset at /Users/rishubtamirisa/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-4d856a7a0dbc146c.arrow
Loading cached processed data

### 2. Define the model

In [2]:
import torch
from modules.encoder import EncoderModel

vocab_size = tokenizer.vocab_size
embed_dim = 512
model_dim = 512
n_layers = 6
num_heads = 8
encoder = EncoderModel(vocab_size=vocab_size, embed_dim=embed_dim, model_dim=model_dim, n_layers=n_layers, num_heads=num_heads)


### 3. Prepare DataLoaders

In [None]:
from torch.utils.data import TensorDataset, DataLoader

dataset = TensorDataset( mlm_input_ids, mlm_labels )
loader = DataLoader(dataset, batch_size = 32, shuffle=True)

### 4. Train

In [None]:
from tqdm import tqdm
def train_mlm(epochs, model, tokenizer, loader, optimizer=torch.optim.Adam, device=torch.device('cpu')):
    criterion = torch.nn.CrossEntropyLoss(ignore_index=-100)
    model.train()
    model.to(device)
    with tqdm(total=epochs) as pbar:
        for _ in range(epochs):
            cur_batch = 0
            total_batches = len(loader) 
            for batch in loader:
                input_ids, labels = batch
                input_ids = input_ids.to(device, dtype=torch.int64)
                labels = labels.to(device, dtype=torch.int64)
                optimizer.zero_grad()
                output = model(input_ids)
                loss = criterion(output.view(-1, tokenizer.vocab_size), labels.view(-1))
                loss.backward()
                optimizer.step()
                cur_batch += 1
                pbar.set_postfix(**{"batch: ": f"{cur_batch} / {total_batches}", "loss:": loss.item()})


train_mlm(epochs=4, tokenizer=tokenizer, model=encoder, loader=loader, optimizer=torch.optim.Adam(encoder.parameters(), lr=1e-4))