# Training a Transformer Encoder with MLM Objective on BERT-style tokens / wikitext dataset

### 1. Setup (retrieve data, look at `mask_dataset_for_mlm` for important MLM preprocessing steps)

In [None]:
from modules.encoder import EncoderModel
from preprocess.mlm_preprocess import get_dataset_example, mask_dataset_for_mlm

input_ids, tokenizer = get_dataset_example()
mlm_input_ids, mlm_labels = mask_dataset_for_mlm(input_ids)

### 2. Define the model

In [None]:
import torch
from modules.encoder import EncoderModel

vocab_size = 30522#tokenizer.vocab_size #30522
embed_dim = 512
model_dim = 512
n_layers = 6
num_heads = 8
encoder = EncoderModel(vocab_size=vocab_size, embed_dim=embed_dim, model_dim=model_dim, n_layers=n_layers, num_heads=num_heads)
encoder.to("mps")


### 3. Prepare DataLoaders

In [None]:
from torch.utils.data import TensorDataset, DataLoader

dataset = TensorDataset( mlm_input_ids, mlm_labels )
loader = DataLoader(dataset, batch_size = 32, shuffle=True)

### 4. Train

In [None]:
from tqdm import tqdm
def train_mlm(epochs, model, tokenizer, loader, optimizer=torch.optim.Adam, device=torch.device('cpu')):
    criterion = torch.nn.CrossEntropyLoss(ignore_index=-100)
    model.train()
    model.to(device)
    with tqdm(total=epochs) as pbar:
        for _ in range(epochs):
            cur_batch = 0
            total_batches = len(loader) 
            for batch in loader:
                input_ids, labels = batch
                input_ids = input_ids.to(device, dtype=torch.int64)
                labels = labels.to(device, dtype=torch.int64)
                optimizer.zero_grad()
                output = model(input_ids)
                loss = criterion(output.view(-1, tokenizer.vocab_size), labels.view(-1))
                loss.backward()
                optimizer.step()
                cur_batch += 1
                pbar.set_postfix(**{"batch: ": f"{cur_batch} / {total_batches}", "loss:": loss.item()})
        
        checkpoint = {'vocab_size': tokenizer.vocab_size,
                      'embed_dim': embed_dim,
                      'model_dim': model_dim,
                      'n_layers': n_layers,
                      'num_heads': num_heads,
                      'state_dict': model.state_dict()}
        torch.save(checkpoint, 'model_checkpoints/checkpoint.pth')

train_mlm(epochs=4, tokenizer=tokenizer, model=encoder, loader=loader, optimizer=torch.optim.Adam(encoder.parameters(), lr=1e-4))

In [None]:
# print number of parameters in model
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(encoder):,} trainable parameters')

# print model architecture
print(encoder)


In [None]:
def load_model_from_checkpoint(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    model = EncoderModel(vocab_size=checkpoint['vocab_size'], 
                         embed_dim=checkpoint['embed_dim'], 
                         model_dim=checkpoint['model_dim'], 
                         n_layers=checkpoint['n_layers'], 
                         num_heads=checkpoint['num_heads'])
    model.load_state_dict(checkpoint['state_dict'])
    return model

# load model from checkpoint
encoder = load_model_from_checkpoint('model_checkpoints/checkpoint2.pth')

# test model on input text with masked tokens
text = "I love [MASK] and [MASK] ."
input_ids = torch.tensor([tokenizer.encode(text)])

outputs = encoder(input_ids) # (batch_size, seq_len, vocab_size)
predicted_index = torch.argmax(torch.softmax(outputs, dim=-1), dim=-1)
predicted_token = tokenizer.convert_ids_to_tokens(predicted_index[0].tolist())

print(predicted_token)

