# Training a Transformer Encoder with MLM Objective on BERT-style tokens / wikitext dataset

### 1. Setup (retrieve data, look at `mask_dataset_for_mlm` for important MLM preprocessing steps)

In [1]:
from modules.encoder import EncoderModel
from preprocess.mlm_preprocess import get_dataset_example, mask_dataset_for_mlm

input_ids, tokenizer = get_dataset_example()
mlm_input_ids, mlm_labels = mask_dataset_for_mlm(input_ids)

  from .autonotebook import tqdm as notebook_tqdm


Downloading and preparing dataset wikitext/wikitext-103-v1 to /Users/rishubtamirisa/.cache/huggingface/datasets/wikitext/wikitext-103-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126...


Downloading data: 100%|██████████| 190M/190M [00:48<00:00, 3.89MB/s] 
                                                                                            

Dataset wikitext downloaded and prepared to /Users/rishubtamirisa/.cache/huggingface/datasets/wikitext/wikitext-103-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126. Subsequent calls will reuse this data.


100%|██████████| 3/3 [00:00<00:00, 15.95it/s]
 40%|████      | 2/5 [00:00<00:00, 16.84ba/s]Token indices sequence length is longer than the specified maximum sequence length for this model (547 > 512). Running this sequence through the model will result in indexing errors
100%|██████████| 5/5 [00:00<00:00, 19.20ba/s]
100%|██████████| 1802/1802 [01:32<00:00, 19.39ba/s]
100%|██████████| 4/4 [00:00<00:00, 18.74ba/s]
100%|██████████| 5/5 [00:02<00:00,  2.43ba/s]
100%|██████████| 1802/1802 [14:47<00:00,  2.03ba/s]
100%|██████████| 4/4 [00:01<00:00,  2.27ba/s]


### 2. Define the model

In [37]:
import torch
from modules.encoder import EncoderModel

vocab_size = tokenizer.vocab_size
embed_dim = 512
model_dim = 512
n_layers = 6
num_heads = 8
encoder = EncoderModel(vocab_size=vocab_size, embed_dim=embed_dim, model_dim=model_dim, n_layers=n_layers, num_heads=num_heads)
encoder.to("mps")


EncoderModel(
  (embedding): Embedding(30522, 512)
  (pos_en): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (encoder): Encoder(
    (encoder_layers): ModuleList(
      (0): EncoderBlock(
        (layer_norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (multi_head_attention): MultiHeadAttention(
          (qkv_weights_list): ModuleList(
            (0): ModuleList(
              (0): Linear(in_features=512, out_features=512, bias=True)
              (1): Linear(in_features=512, out_features=512, bias=True)
              (2): Linear(in_features=512, out_features=512, bias=True)
            )
            (1): ModuleList(
              (0): Linear(in_features=512, out_features=512, bias=True)
              (1): Linear(in_features=512, out_features=512, bias=True)
              (2): Linear(in_features=512, out_features=512, bias=True)
            )
            (2): ModuleList(
              (0): Linear(

### 3. Prepare DataLoaders

In [None]:
from torch.utils.data import TensorDataset, DataLoader

dataset = TensorDataset( mlm_input_ids, mlm_labels )
loader = DataLoader(dataset, batch_size = 32, shuffle=True)

### 4. Train

In [None]:
from tqdm import tqdm
def train_mlm(epochs, model, tokenizer, loader, optimizer=torch.optim.Adam, device=torch.device('cpu')):
    criterion = torch.nn.CrossEntropyLoss(ignore_index=-100)
    model.train()
    model.to(device)
    with tqdm(total=epochs) as pbar:
        for _ in range(epochs):
            cur_batch = 0
            total_batches = len(loader) 
            for batch in loader:
                input_ids, labels = batch
                input_ids = input_ids.to(device, dtype=torch.int64)
                labels = labels.to(device, dtype=torch.int64)
                optimizer.zero_grad()
                output = model(input_ids)
                loss = criterion(output.view(-1, tokenizer.vocab_size), labels.view(-1))
                loss.backward()
                optimizer.step()
                cur_batch += 1
                pbar.set_postfix(**{"batch: ": f"{cur_batch} / {total_batches}", "loss:": loss.item()})
        
        checkpoint = {'vocab_size': tokenizer.vocab_size,
                      'embed_dim': embed_dim,
                      'model_dim': model_dim,
                      'n_layers': n_layers,
                      'num_heads': num_heads,
                      'state_dict': model.state_dict()}
        torch.save(checkpoint, 'model_checkpoints/checkpoint.pth')

train_mlm(epochs=4, tokenizer=tokenizer, model=encoder, loader=loader, optimizer=torch.optim.Adam(encoder.parameters(), lr=1e-4))

In [None]:
# print number of parameters in model
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(encoder):,} trainable parameters')

# print model architecture
print(encoder)


In [35]:
def load_model_from_checkpoint(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    model = EncoderModel(vocab_size=checkpoint['vocab_size'], 
                         embed_dim=checkpoint['embed_dim'], 
                         model_dim=checkpoint['model_dim'], 
                         n_layers=checkpoint['n_layers'], 
                         num_heads=checkpoint['num_heads'])
    model.load_state_dict(checkpoint['state_dict'])
    return model

# load model from checkpoint
encoder = load_model_from_checkpoint('model_checkpoints/checkpoint2.pth')

# test model on input text with masked tokens
text = "I love [MASK] and [MASK] ."
input_ids = torch.tensor([tokenizer.encode(text)])

outputs = encoder(input_ids) # (batch_size, seq_len, vocab_size)
predicted_index = torch.argmax(torch.softmax(outputs, dim=-1), dim=-1)
predicted_token = tokenizer.convert_ids_to_tokens(predicted_index[0].tolist())

print(predicted_token)



KeyboardInterrupt: 

In [29]:
# decode first row of mlm_input_ids
tokenizer.decode(mlm_input_ids[300].int())


'thought. families [MASK] [MASK] deities, with [MASK] father, [MASK], and child, represent the creation [MASK] new life and [MASK] succession of the father [MASK] the child, [MASK] pattern [MASK] connects divine families with royal [MASK]. osiri [MASK], isis, and horus formed the quintessential family of this type. the pattern they set grew [MASK] widespread over time, so that many deities in local cult centers, like ptah, sekhmet, and their child [MASK]fer [MASK] at memphis and amun [MASK] mut, and khon [MASK] at [MASK]bes, were assembled into family triads. genealogical connections like these are changeable, in'