<a href="https://colab.research.google.com/github/rishabhd786/bert_pytorch/blob/master/bert_pretraining.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install pytorch_pretrained_bert

Collecting pytorch_pretrained_bert
[?25l  Downloading https://files.pythonhosted.org/packages/d7/e0/c08d5553b89973d9a240605b9c12404bcf8227590de62bae27acbcfe076b/pytorch_pretrained_bert-0.6.2-py3-none-any.whl (123kB)
[K     |██▋                             | 10kB 20.0MB/s eta 0:00:01[K     |█████▎                          | 20kB 6.7MB/s eta 0:00:01[K     |████████                        | 30kB 7.9MB/s eta 0:00:01[K     |██████████▋                     | 40kB 8.4MB/s eta 0:00:01[K     |█████████████▎                  | 51kB 6.9MB/s eta 0:00:01[K     |███████████████▉                | 61kB 7.4MB/s eta 0:00:01[K     |██████████████████▌             | 71kB 8.0MB/s eta 0:00:01[K     |█████████████████████▏          | 81kB 7.8MB/s eta 0:00:01[K     |███████████████████████▉        | 92kB 7.9MB/s eta 0:00:01[K     |██████████████████████████▌     | 102kB 8.1MB/s eta 0:00:01[K     |█████████████████████████████▏  | 112kB 8.1MB/s eta 0:00:01[K     |██████████████████████

In [0]:
import torch
import torch.nn as nn
from random import randint, shuffle
from random import random as rand
from pytorch_pretrained_bert.tokenization import BertTokenizer
import random
import math

In [0]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [0]:
def seek_random_offset(f, back_margin=2000):
    """ seek random offset of file pointer """
    f.seek(0, 2)
    max_offset = f.tell() - back_margin
    f.seek(randint(0, max_offset), 0)
    f.readline() 

In [0]:
class DataLoader():
    """ Load sentence pair from corpus """
    def __init__(self, file, batch_size, max_len, short_sampling_prob=0.1):
        super().__init__()
        self.f_pos = open(file, "r", encoding='utf-8', errors='ignore')
        self.f_neg = open(file, "r", encoding='utf-8', errors='ignore') 
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.max_len = max_len 
        self.short_sampling_prob = short_sampling_prob
        self.batch_size = batch_size
        self.preproc= Preprocess4Pretrain(max_len*0.15,0.15)

    def read_tokens(self, f, length, discard_last_and_restart=True):
        """ Read tokens from file pointer with limited length """
        tokens = []
        while len(tokens) < length:
            line = f.readline()
            if not line: # end of file
                return None
            if not line.strip(): 
                if discard_last_and_restart:
                    continue
                else:
                    return tokens 
            tokens.extend(self.tokenizer.tokenize(line.strip()))
            
        return tokens

    def __iter__(self): # iterator to load data
        while True:
            batch = []
            for i in range(self.batch_size):
             
                len_tokens = randint(1, int(self.max_len / 2)) \
                    if rand() < self.short_sampling_prob \
                    else int(self.max_len / 2)

                is_next = rand() < 0.5 # whether token_b is next to token_a or not

                tokens_a = self.read_tokens(self.f_pos, len_tokens, True)
                seek_random_offset(self.f_neg)
                f_next = self.f_pos if is_next else self.f_neg
                tokens_b = self.read_tokens(f_next, len_tokens, False)

                if tokens_a is None or tokens_b is None: 
                    self.f_pos.seek(0, 0)
                    return

                data = (is_next, tokens_a, tokens_b)
                data=self.preproc(data)
                
                batch.append(instance)

            batch_tensors = [torch.tensor(x, dtype=torch.long) for x in zip(*batch)]
            yield batch_tensors



In [0]:
data_loader=DataLoader("/content/drive/My Drive/bert_data.txt",2,512)


In [0]:
def truncate_tokens_pair(tokens_a, tokens_b, max_len):
    while True:
        if len(tokens_a) + len(tokens_b) <= max_len:
            break
        if len(tokens_a) > len(tokens_b):
            tokens_a.pop()
        else:
            tokens_b.pop()

def get_random_word(vocab_words):
    i = random.randint(0,30000)
    return list(vocab_words)[i]

In [0]:
class Preprocess():
    """ Pre-processing steps for pretraining transformer """
    def __init__(self, max_pred, mask_prob, max_len=512):
        super().__init__()
        self.max_pred = max_pred 
        self.mask_prob = mask_prob 
        self.indexer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.max_len = max_len

    def __call__(self,data):
        is_next, tokens_a, tokens_b = data
        truncate_tokens_pair(tokens_a, tokens_b, self.max_len - 3)

        # Add Special Tokens
        tokens = ['[CLS]'] + tokens_a + ['[SEP]'] + tokens_b + ['[SEP]']
        segment_ids = [0]*(len(tokens_a)+2) + [1]*(len(tokens_b)+1)
        input_mask = [1]*len(tokens)

        # For masked Language Models
        masked_tokens, masked_pos = [], []
        n_pred = min(self.max_pred, max(1, int(round(len(tokens)*self.mask_prob))))
        cand_pos = [i for i, token in enumerate(tokens)
                    if token != '[CLS]' and token != '[SEP]']
        shuffle(cand_pos)
        for pos in cand_pos[:int(n_pred)]:
            masked_tokens.append(tokens[pos])
            masked_pos.append(pos)
            if rand() < 0.8: # 80%
                tokens[pos] = '[MASK]'
            elif rand() < 0.5: # 10%
                tokens[pos] = get_random_word(self.indexer.vocab)
        masked_weights = [1]*len(masked_tokens)

        # Token Indexing
        input_ids = self.indexer.convert_tokens_to_ids(tokens)
        masked_ids = self.indexer.convert_tokens_to_ids(masked_tokens)

        # Zero Padding
        n_pad = self.max_len - len(input_ids)
        input_ids.extend([0]*int(n_pad))
        segment_ids.extend([0]*int(n_pad))
        input_mask.extend([0]*int(n_pad))

        # Zero Padding for masked target
        if self.max_pred > n_pred:
            n_pad = self.max_pred - n_pred
            masked_ids.extend([0]*int(n_pad))
            masked_pos.extend([0]*int(n_pad))
            masked_weights.extend([0]*int(n_pad))

        return (input_ids, segment_ids, input_mask, masked_ids, masked_pos, masked_weights, is_next)


In [0]:
def gelu(x):
    "Implementation of the gelu activation function by Hugging Face"
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))


class LayerNorm(nn.Module):
    def __init__(self, dim, variance_epsilon=1e-12):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.beta  = nn.Parameter(torch.zeros(dim))
        self.variance_epsilon = variance_epsilon

    def forward(self, x):
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.variance_epsilon)
        return self.gamma * x + self.beta

In [0]:
class embedding(nn.Module):
  def __init__(self,dim,vocab_size,max_len,n_segs):
    super().__init__()
    self.embed=nn.Embedding(vocab_size,dim)
    self.embedpos=nn.Embedding(max_len,dim)
    self.segembed=nn.Embedding(n_segs,dim)
    self.norm = LayerNorm(dim)
    self.drop = nn.Dropout(0.1)
  def forward(self,x,seg):
    seq_len = x.size(1)
    pos = torch.arange(seq_len, dtype=torch.long, device=x.device)
    pos = pos.unsqueeze(0).expand_as(x)
    return self.norm(self.drop(self.embed(x)+self.embedpos(pos)+self.segembed(seg)))

In [0]:
class Attention(nn.Module):
  def __init__(self,dim,heads,max_len):
    super().__init__()
    self.q_mat=nn.Linear(dim,dim)
    self.k_mat=nn.Linear(dim,dim)
    self.v_mat=nn.Linear(dim,dim)
    self.dim=dim
    self.heads=heads
    self.max_len=max_len
    self.dk=dim//heads
    self.drop=nn.Dropout(0.1)
    self.softmax=nn.Softmax(-1)
    self.out = nn.Linear(dim,dim)
  def forward(self,x,mask=None):
    bs=x.size(0)
    q=self.q_mat(x).view(bs,-1,self.heads,self.dk)
    k=self.k_mat(x).view(bs,-1,self.heads,self.dk)
    v=self.v_mat(x).view(bs,-1,self.heads,self.dk)

    q=q.transpose(1,2)
    k=k.transpose(1,2)
    v=v.transpose(1,2)

    scores=torch.matmul(q,k.transpose(2,3))/math.sqrt(self.dk)

    if mask is not None:
            mask = mask[:, None, None, :].float()
            scores -= 10000.0 * (1.0 - mask)

    scores = self.drop(self.softmax(scores))
    output = torch.matmul(scores, v)

    concat = output.transpose(1,2).contiguous()\
    .view(bs, -1, self.dim)

    output=self.out(concat)   
  
    
    return output











In [0]:
class feedforward(nn.Module):
  def __init__(self,dim,heads,max_len):
    super().__init__()
    self.fc1=nn.Linear(dim,dim*4)
    self.fc2=nn.Linear(dim*4,dim)
  def forward(self,x):
    out=self.fc2(gelu(self.fc1(x)))
    return out

In [0]:
class Encoder(nn.Module):
  def __init__(self,dim,heads,max_len):
    super().__init__()
    self.attention=Attention(dim,heads,max_len)
    self.norm1=LayerNorm(dim)
    self.ff=feedforward(dim,heads,max_len)
    self.norm2=LayerNorm(dim)
    self.drop = nn.Dropout(0.1)
  def forward(self,x,mask):
    out=self.attention(x,mask)
    out=x+out
    out=self.norm1(x)
    f=out
    out=self.ff(out)
    out=self.norm2(out+f)
    return out


In [0]:
class AllEncode(nn.Module):
  def __init__(self,dim,heads,max_len,n_segs):
    super().__init__()
    self.embed=embedding(dim,len(tokenizer1.vocab),max_len,n_segs)
    self.encoder1=Encoder(dim,heads,max_len)
    self.encoder2=Encoder(dim,heads,max_len)
    self.encoder3=Encoder(dim,heads,max_len)
    self.encoder4=Encoder(dim,heads,max_len)
    self.encoder5=Encoder(dim,heads,max_len)
    self.encoder6=Encoder(dim,heads,max_len)

  def forward(self,x,mask,seg):
    out=self.embed(x,seg)
    out=self.encoder1(out,mask)
    out=self.encoder2(out,mask)
    out=self.encoder3(out,mask)
    out=self.encoder4(out,mask)
    out=self.encoder5(out,mask)
    out=self.encoder6(out,mask)

    return out
    
    
    
    
    




In [0]:
class BertPreTrain(nn.Module):
  def __init__(self,dim,heads,max_len,n_seg):
    super().__init__()
    self.allenc=AllEncode(dim,heads,max_len,n_seg)
    self.fc1=nn.Linear(dim,dim)
    self.tanh=nn.Tanh()
    self.fc2=nn.Linear(dim,2)
    self.norm=LayerNorm(dim)
    embed_weight = self.allenc.embed.embed.weight
    n_vocab, n_dim = embed_weight.size()
    self.decoder = nn.Linear(n_dim, n_vocab, bias=False)
    self.decoder.weight = embed_weight
    self.linear = nn.Linear(dim,dim)

  def forward(self,batch):
    input_ids, segment_ids, input_mask, masked_ids, masked_pos, masked_weights, is_next=batch

    out=self.allenc(input_ids,input_mask,segment_ids)

    out1=self.fc1(out[:,0])
    out1=self.tanh(out1)
    out1=self.fc2(out1)

    masked_pos1 = masked_pos[:, :, None].expand(-1, -1, out.size(-1))
    h_masked = torch.gather(out, 1, masked_pos1)
    h_masked = self.norm(gelu(self.linear(h_masked)))
    out2 = self.decoder(h_masked)

    return out1,out2




In [0]:
x=BertPreTrain(768,12,512,2).to(device)


In [0]:
criterion1=nn.CrossEntropyLoss().to(device)
criterion2=nn.CrossEntropyLoss().to(device)


In [22]:
!pip install -U pytorch_warmup

Collecting pytorch_warmup
  Downloading https://files.pythonhosted.org/packages/7a/22/2fb600a06a1d1b493d54ac8fa6c41e96870985992fc504104e0620bc2ea4/pytorch_warmup-0.0.4-py3-none-any.whl
Installing collected packages: pytorch-warmup
Successfully installed pytorch-warmup-0.0.4


In [0]:
import pytorch_warmup as warmup
optimizer = torch.optim.AdamW(x.parameters(), lr=0.0001, betas=(0.9, 0.999), weight_decay=0.01) 
warmup_scheduler = warmup.UntunedLinearWarmup(optimizer)

In [0]:
def loss_func(model,batch):
  input_ids, segment_ids, input_mask, masked_ids, masked_pos, masked_weights, is_next = batch
  clsf,mlm=model(batch)

  lossclf=criterion1(clsf,is_next)

  losslm=criterion2(mlm.transpose(1,2),masked_ids)

  return lossclf+losslm



In [0]:
epochs=3
step=0

In [0]:
for epoch in range(epochs):
  for i,batch in enumerate(data_loader.__iter__()):
    batch = [t.to(device) for t in batch]
    optimizer.zero_grad()
    loss=loss_func(x,batch)
    loss.backward()
    optimizer.step()
    lr_scheduler.step()
    if step <10 :
      warmup_scheduler.dampen()
    step=step+1

    print("LOSS:",loss," ","epoch[%d/%d]"%(epoch,epochs))
