In [29]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [30]:
def sample_input_ids():
    input_ids = torch.randint(0,30_000,(4,20))
    input_ids[:,0]=101 # CLS TOKEN
    input_ids[:,7]=103 # SEP TOKEN
    input_ids[:,-1]=103 # SEP TOKEN
    
    input_ids[:,16:]=0 # PAD TOKEN
    return input_ids

In [31]:
input_ids = sample_input_ids()
input_ids

tensor([[  101,  3413,  8401,  4797, 13763, 22724,  7058,   103, 24223,  9199,
         11515,   488, 27828,  3547, 18158, 10189,     0,     0,     0,     0],
        [  101,  4303, 15277,  4807, 26430, 21759, 24603,   103, 12689, 15363,
         14318,     5,  9876, 29946,  4058,   451,     0,     0,     0,     0],
        [  101, 12024, 27899,  8706, 19037, 28294, 16359,   103,   835, 13595,
         20214, 23694,  8276, 11985, 14855,  8413,     0,     0,     0,     0],
        [  101, 15900,   527, 15302, 26235, 16148, 10709,   103, 21784, 18223,
         13851, 11017, 22713, 25382, 20000, 24657,     0,     0,     0,     0]])

In [32]:
mlm_mask = torch.rand(input_ids.size()) < 0.15 * (input_ids!=101) * (input_ids!=103) * (input_ids!=0) 
mlm_mask

tensor([[False, False, False, False,  True, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False],
        [False, False, False,  True, False, False, False, False, False, False,
         False, False, False,  True, False, False, False, False, False, False],
        [False, False, False, False,  True, False, False, False, False, False,
         False,  True, False, False, False, False, False, False, False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False,  True, False,  True,  True, False, False, False, False]])

In [33]:
masked_tokens = input_ids * mlm_mask
masked_tokens

tensor([[    0,     0,     0,     0, 13763,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0],
        [    0,     0,     0,  4807,     0,     0,     0,     0,     0,     0,
             0,     0,     0, 29946,     0,     0,     0,     0,     0,     0],
        [    0,     0,     0,     0, 19037,     0,     0,     0,     0,     0,
             0, 23694,     0,     0,     0,     0,     0,     0,     0,     0],
        [    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0, 22713,     0, 20000, 24657,     0,     0,     0,     0]])

In [34]:
input_ids[masked_tokens!=0]=1

In [35]:
input_ids

tensor([[  101,  3413,  8401,  4797,     1, 22724,  7058,   103, 24223,  9199,
         11515,   488, 27828,  3547, 18158, 10189,     0,     0,     0,     0],
        [  101,  4303, 15277,     1, 26430, 21759, 24603,   103, 12689, 15363,
         14318,     5,  9876,     1,  4058,   451,     0,     0,     0,     0],
        [  101, 12024, 27899,  8706,     1, 28294, 16359,   103,   835, 13595,
         20214,     1,  8276, 11985, 14855,  8413,     0,     0,     0,     0],
        [  101, 15900,   527, 15302, 26235, 16148, 10709,   103, 21784, 18223,
         13851, 11017,     1, 25382,     1,     1,     0,     0,     0,     0]])

In [36]:
labels = input_ids.clone()
labels[masked_tokens==0]=-100
labels

tensor([[-100, -100, -100, -100,    1, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100],
        [-100, -100, -100,    1, -100, -100, -100, -100, -100, -100, -100, -100,
         -100,    1, -100, -100, -100, -100, -100, -100],
        [-100, -100, -100, -100,    1, -100, -100, -100, -100, -100, -100,    1,
         -100, -100, -100, -100, -100, -100, -100, -100],
        [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
            1, -100,    1,    1, -100, -100, -100, -100]])

In [37]:
input_ids.dtype, labels.dtype

(torch.int64, torch.int64)

In [38]:
head = nn.Linear(512,30_000)

In [39]:
out = torch.rand(4,20,512)

In [40]:
logits = head(out)
logits.shape

torch.Size([4, 20, 30000])

In [41]:
loss = F.cross_entropy(logits.view(-1,logits.size(-1)),labels.view(-1),ignore_index=-100)

In [42]:
loss

tensor(10.5416, grad_fn=<NllLossBackward0>)

In [43]:
class MultiheadAttention(nn.Module):
    def __init__(self, dim, n_heads, dropout=0.):
        super().__init__()
        self.dim = dim
        self.n_heads = n_heads
        assert dim % n_heads == 0, 'dim should be div by n_heads'
        self.head_dim = self.dim // self.n_heads
        self.in_proj = nn.Linear(dim,dim*3)
        self.attn_dropout = nn.Dropout(dropout)
        self.scale = self.head_dim ** -0.5
        self.out_proj = nn.Linear(dim,dim)
        
    def forward(self,x,mask=None):
        b,t,c = x.shape
        q,k,v = self.in_proj(x).chunk(3,dim=-1)
        q = q.view(b,t,self.n_heads,self.head_dim).permute(0,2,1,3)
        k = k.view(b,t,self.n_heads,self.head_dim).permute(0,2,1,3)
        v = v.view(b,t,self.n_heads,self.head_dim).permute(0,2,1,3)
        
        qkT = torch.matmul(q,k.transpose(-1,-2)) * self.scale
        qkT = self.attn_dropout(qkT)
        
        if mask is not None:
            mask = mask.to(dtype=qkT.dtype,device=qkT.device)
            qkT = qkT.masked_fill(mask==0,float('-inf'))
            
        qkT = F.softmax(qkT,dim=-1)
            
        attn = torch.matmul(qkT,v)
        attn = attn.permute(0,2,1,3).contiguous().view(b,t,c)
        out = self.out_proj(attn)
        
        return out

In [44]:
class FeedForward(nn.Module):
    def __init__(self,dim,dropout=0.):
        super().__init__()
        self.feed_forward = nn.Sequential(
            nn.Linear(dim,dim*4),
            nn.Dropout(dropout),
            nn.GELU(),
            nn.Linear(dim*4,dim)
        )
        
    def forward(self, x):
        return self.feed_forward(x)

In [45]:
class EncoderBlock(nn.Module):
    def __init__(self, dim, n_heads, attn_dropout=0., mlp_dropout=0.):
        super().__init__()
        self.attn = MultiheadAttention(dim,n_heads,attn_dropout)
        self.ffd = FeedForward(dim,mlp_dropout)
        self.ln_1 = nn.LayerNorm(dim)
        self.ln_2 = nn.LayerNorm(dim)
        
    def forward(self,x,mask=None):
        x = self.ln_1(x)
        x = x + self.attn(x,mask)
        x = self.ln_2(x)
        x = x + self.ffd(x)
        return x

In [46]:
class Embedding(nn.Module):
    def __init__(self,vocab_size,max_len,dim):
        super().__init__()
        self.max_len = max_len
        self.class_embedding = nn.Embedding(vocab_size,dim)
        self.pos_embedding = nn.Embedding(max_len,dim)
    def forward(self,x):
        x = self.class_embedding(x)
        pos = torch.arange(0,x.size(1),device=x.device)
        x = x + self.pos_embedding(pos)
        return x

In [47]:
class MLMBERT(nn.Module):
    def __init__(self, config):
        
        super().__init__()
        
        self.embedding = Embedding(config['vocab_size'],config['max_len'],config['dim'])
        
        self.depth = config['depth']
        self.encoders = nn.ModuleList([
            EncoderBlock(
                dim=config['dim'],
                n_heads=config['n_heads'],
                attn_dropout=config['attn_dropout'],
                mlp_dropout=config['mlp_dropout']
            ) for _ in range(self.depth)
        ])
        
        self.mlm_head = nn.Linear(config['dim'],config['vocab_size'])
        
        self.pad_token_id = config['pad_token_id']
        self.mask_token_id = config['mask_token_id']
        
    def create_src_mask(self,src):
        return (src != self.pad_token_id).unsqueeze(1).unsqueeze(2) # N, 1, 1, src_len
    
    def forward(self,input_ids,labels=None):
        
        src_mask = self.create_src_mask(input_ids)
        enc_out = self.embedding(input_ids)
        for layer in self.encoders:
            enc_out = layer(enc_out,mask=src_mask)
                
        print(enc_out.shape)
        logits = self.mlm_head(enc_out)
        
        if labels is not None:
            loss = F.cross_entropy(logits.view(-1,logits.size(-1)),labels.view(-1))
            return {'loss': loss, 'logits': logits}
        else:
            # assuming inference input_ids only have 1 [MASK] token
            mask_idx = (input_ids==self.mask_token_id).flatten().nonzero().item()
            mask_preds = F.softmax(logits[:,mask_idx,:],dim=-1).argmax(dim=-1)
            return {'mask_predictions':mask_preds}

In [48]:
config = {
    'dim': 512,
    'n_heads': 8,
    'attn_dropout': 0.1,
    'mlp_dropout': 0.1,
    'depth': 6,
    'vocab_size': 5_000,
    'max_len': 128,
    'pad_token_id': 0,
    'mask_token_id': 1
}

In [49]:
def sample():
    input_ids = torch.randint(0,config['vocab_size'],(4,20))
    input_ids[:,16:]=config['pad_token_id'] # PAD TOKEN
    mlm_mask = torch.rand(input_ids.size()) < 0.15 * (input_ids!=config['pad_token_id'])
    masked_tokens = input_ids * mlm_mask
    labels = input_ids.clone()
    labels[masked_tokens==config['pad_token_id']]=-100
    input_ids[masked_tokens!=0]=config['mask_token_id'] # MASK TOKEN
    return input_ids, labels

In [50]:
i,l = sample()
print(i)
print(l)

tensor([[2329,    1, 2274, 3112,    1,    1,    1, 4928,  321, 1492, 4794, 4964,
            1, 2607, 1656, 4623,    0,    0,    0,    0],
        [   1,    1, 3197, 2275, 2464, 2552, 4599,  269,  774,  721,  586,    1,
         1804,    1,    1, 3224,    0,    0,    0,    0],
        [2481,    1,  456, 1489, 1961,    1,    1, 3021,    1, 4425, 4870,    1,
         3808, 2035, 4949, 4294,    0,    0,    0,    0],
        [1893, 2362, 1262,  876, 1569,  718,    1, 3085, 1696,  775,    1, 4555,
         2769, 2454, 2614,    1,    0,    0,    0,    0]])
tensor([[-100, 3525, -100, -100, 4747, 4705, 3834, -100, -100, -100, -100, -100,
         4951, -100, -100, -100, -100, -100, -100, -100],
        [ 968,  424, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1195,
         -100, 1442, 4918, -100, -100, -100, -100, -100],
        [-100, 3001, -100, -100, -100, 4988, 2002, -100, 4265, -100, -100,  467,
         -100, -100, -100, -100, -100, -100, -100, -100],
        [-100, -100, -100,

In [51]:
i.shape, l.shape

(torch.Size([4, 20]), torch.Size([4, 20]))

In [52]:
model = MLMBERT(config)

In [53]:
out = model(i,l)

torch.Size([4, 20, 512])


In [54]:
out['logits'].shape

torch.Size([4, 20, 5000])