# Mask Language Model

In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataset import BERTLanguageModelingDataset
from vocab import Vocab
    

In [2]:
data_dir = "ptb"
epochs = 10
batch_length = 32
batch_size = 16
lr = 0.001

n_layers = 1
d_emb = 200
d_hid = 250
p_drop = 0.2

interval_print = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Load Dataset


In [3]:
vocab = Vocab(data_dir, mask_token='<mask>')
trainset = BERTLanguageModelingDataset(data_dir, vocab, seq_len=batch_length, split='train')
validset = BERTLanguageModelingDataset(data_dir, vocab, seq_len=batch_length, split='valid')
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size)
validloader = torch.utils.data.DataLoader(validset, batch_size=batch_size)


building vocab...


100%|██████████| 42068/42068 [00:00<00:00, 117222.17it/s]


[('the', 50770), ('<unk>', 45020), ('N', 32481), ('of', 24400), ('to', 23638), ('a', 21196), ('in', 18000), ('and', 17474), ("'s", 9784), ('that', 8931)]
end building vocab ...
['<mask>', '<pad>', '<eos>', 'the', '<unk>', 'N', 'of', 'to', 'a', 'in']


# Model

In [4]:
class WordEmbedding(nn.Module):
    def __init__(self, num_embeddomgs, embedding_dim, p_drop=0.):
        super(WordEmbedding, self).__init__()
        self.emb = nn.Embedding(num_embeddomgs, embedding_dim)
        self.dropout = nn.Dropout(p_drop)

    def forward(self, input):
        output = self.emb(input)
        output = self.dropout(output)
        return output         

class MLM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, n_layers, p_drop):
        super(MLM, self).__init__()
        self.n_classes = vocab_size
        self.d_emb = embedding_dim


        self.word_embedding = WordEmbedding(self.n_classes, self.d_emb, p_drop=p_drop)
        self.layers = nn.GRU(self.d_emb, hidden_dim, n_layers, dropout=p_drop, batch_first=True, bidirectional=True)
        self.proj_layer = nn.Linear(hidden_dim*2, self.n_classes)
        
        self.drop = nn.Dropout(p_drop)
        # self.layer2 = nn.GRU(hidden_dim, self.n_classes)

    def forward(self, input):
        emb = self.word_embedding(input)
        
        # (bsz, len_step, h_dim)()
        output, h = self.layers(emb)

        output = self.drop(output)
        
        output = self.proj_layer(output)
        return output

class BidirectionalLM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, n_layers, p_drop):
        super(BidirectionalLM, self).__init__()
        self.n_classes = vocab_size
        self.d_emb = embedding_dim
        self.hidden_dim = hidden_dim


        self.word_embedding = WordEmbedding(self.n_classes, self.d_emb, p_drop=p_drop)
        # self.layers = nn.GRU(self.d_emb, hidden_dim, n_layers, dropout=p_drop, batch_first=True, bidirectional=True)
        
        self.forward_layers = nn.GRU(self.d_emb, hidden_dim, n_layers, dropout=p_drop, batch_first=True, bidirectional=False)
        self.backward_layers = nn.GRU(self.d_emb, hidden_dim, n_layers, dropout=p_drop, batch_first=True, bidirectional=False)


        self.proj_layer = nn.Linear(hidden_dim*2, self.n_classes)
        
        self.drop = nn.Dropout(p_drop)
        # self.layer2 = nn.GRU(hidden_dim, self.n_classes)

    def forward(self, input):
        # print(vocab.decode_tokids(input[0, :-2]), "@@",vocab.decode_tokids(input[0, 2:]))
        
        emb = self.word_embedding(input)
        # (bsz, len_step, h_dim)()
        # output, h = self.layers(emb)
        # forward_output, backward_output = output[:, :-2, :self.hidden_dim], output[:, 2:, self.hidden_dim:]
        forward_output, _ = self.forward_layers(emb[:, :-2])
        backward_output, _ = self.backward_layers(torch.flip(emb[:, 2:], dims =[-1]))
        # print(torch.flip(backward_output, dims =[-1]).shape)
        staggered_output = torch.cat((forward_output, torch.flip(backward_output, dims =[1])), dim=-1)
        # print(output.shape, staggered_output.shape)
        output = self.drop(staggered_output)
        # print(output.shape)

        output = self.proj_layer(output)
        return output

model = BidirectionalLM(vocab_size=vocab.size, embedding_dim=d_emb, hidden_dim=d_hid, n_layers=n_layers, p_drop=p_drop)
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(),
                  lr = lr, # config.learning_rate - default is 5e-5, our notebook had 2e-5
                  eps = 1e-8 # config.adam_epsilon  - default is 1e-8.
                  )




In [5]:

n_iter, train_loss, best_ppl = 0, 0., float('inf')
for ep in range(epochs):
    print(f"[{ep}/{epochs}] epochs training...")
    
    # train
    model.train()
    for (mlm_train, mlm_target) in trainloader:
        n_iter += 1
        mlm_train = mlm_train.to(device)
        # print(mlm_train[0], vocab.padding_idx)
        mlm_target = mlm_train[:, 1:-1].contiguous().to(device)
        # print(mlm_train.size(), vocab.decode_tokids(mlm_target[0]))
        logits = model(mlm_train)
        # print(logits.shape, mlm_target.shape)
        loss = F.cross_entropy(logits.reshape(-1, vocab.size), mlm_target.reshape(-1), ignore_index=vocab.padding_idx)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

        if n_iter % interval_print == 0:
            train_loss /= interval_print
            train_ppl = math.exp(train_loss)
            print(f"n_iter:{n_iter} loss: {train_loss:0.3f} ppl: {train_ppl:0.3f}")
            train_loss = 0
            
    
    model.eval()
    valid_loss = 0.
    for mlm_train, mlm_target in validloader:
        mlm_train = mlm_train.to(device)
        mlm_target = mlm_train[:, 1:-1].contiguous().to(device)
        
        # mlm_target = mlm_target.to(device)

        with torch.no_grad():

            logits = model(mlm_train)
            loss = F.cross_entropy(logits.reshape(-1, vocab.size), mlm_target.reshape(-1), ignore_index=vocab.padding_idx)
            valid_loss += loss.item()
            
        
    valid_loss = valid_loss/len(validloader)
    valid_ppl = math.exp(valid_loss)
        

    if valid_ppl < best_ppl:
        best_ppl = valid_ppl
        torch.save(model, "mlm-best.pth")
        print("### find best mode ###", best_ppl)

    print(f"validation vloss: {valid_loss:0.3f} vppl: {valid_ppl:0.3f}, best ppl: {best_ppl:0.3f}")


   



[0/10] epochs training...
n_iter:100 loss: 7.111 ppl: 1225.477
n_iter:200 loss: 6.447 ppl: 630.773
n_iter:300 loss: 6.027 ppl: 414.648
n_iter:400 loss: 6.041 ppl: 420.258
n_iter:500 loss: 5.733 ppl: 309.020
n_iter:600 loss: 5.846 ppl: 345.880
n_iter:700 loss: 5.675 ppl: 291.390
n_iter:800 loss: 5.539 ppl: 254.375
n_iter:900 loss: 5.467 ppl: 236.697
n_iter:1000 loss: 5.488 ppl: 241.832
n_iter:1100 loss: 5.348 ppl: 210.105
n_iter:1200 loss: 5.252 ppl: 191.003
n_iter:1300 loss: 5.200 ppl: 181.309
n_iter:1400 loss: 5.146 ppl: 171.700
n_iter:1500 loss: 5.095 ppl: 163.270
n_iter:1600 loss: 4.930 ppl: 138.397
n_iter:1700 loss: 5.064 ppl: 158.262
n_iter:1800 loss: 5.012 ppl: 150.222
n_iter:1900 loss: 4.964 ppl: 143.217
n_iter:2000 loss: 5.105 ppl: 164.836
n_iter:2100 loss: 4.833 ppl: 125.640
n_iter:2200 loss: 4.840 ppl: 126.460
n_iter:2300 loss: 4.907 ppl: 135.178
n_iter:2400 loss: 4.730 ppl: 113.307
n_iter:2500 loss: 4.704 ppl: 110.375
n_iter:2600 loss: 4.804 ppl: 121.986
### find best mode #

In [11]:
# input_text = "the u.s. is one of the few <mask> nations that does n't have a higher standard of regulation"
# input_text = "viewers can call a N number for additional advice"
input_text = "i <mask> you so much"
k = 5
model.eval()
mask_ind = input_text.split(" ").index("<mask>") -1 
# mask_ind = 1

mask_input = torch.tensor([vocab.encode_line(input_text, add_eos=True)]).to(device)
# mask_ind = torch.where(mask_input==vocab.mask_idx)
print(mask_ind)
mask_ind = torch.tensor([mask_ind]).to(device)

# masked_input, masked_labels = trainset.get_masked_input_and_labels(seq)

logits = F.softmax(model(mask_input), dim=-1)[:, mask_ind]
top_k = torch.topk(logits, k)

top_k_words = top_k.indices.flatten()
top_k_probs = top_k.values.flatten()
for i, (w, p) in enumerate(zip(top_k_words, top_k_probs)):
    print(f"{i}th 'predicted word (prob.)': {vocab.id2tok[w]} ({p:0.3f})")
    print(f"{i}th 'complete sentence': {input_text.replace('<mask>', vocab.id2tok[w])}"  )

(tensor([0], device='cuda:0'), tensor([1], device='cuda:0'))
tensor([[  2,   0, 113, 108, 124,   3]], device='cuda:0') (tensor([0], device='cuda:0'), tensor([1], device='cuda:0'))
0th 'predicted word (prob.)': you (0.087)
0th 'complete sentence': you you so much
1th 'predicted word (prob.)': officials (0.034)
1th 'complete sentence': officials you so much
2th 'predicted word (prob.)': is (0.028)
2th 'complete sentence': is you so much
3th 'predicted word (prob.)': so (0.026)
3th 'complete sentence': so you so much
4th 'predicted word (prob.)': of (0.025)
4th 'complete sentence': of you so much
