In [11]:
import json
import re
import torch.nn as nn
import torch

In [12]:
# def ConvertVocab(path):
#     index = 0
#     vocab = dict()
#     with open(path, "r", encoding="utf-8") as f:
#         for word in f:            
#             vocab[word.strip()] = index
#             index += 1
#     with open("vocabBert.txt", "w") as f:
#         json.dump(vocab, f)

In [13]:
a = torch.rand((10, 10)).reshape(-1, 5, 2)
a.shape

torch.Size([10, 5, 2])

In [69]:
class Attention(nn.Module):
    def __init__(self):
        super().__init__()
        self.query = nn.Linear(in_features=768, out_features=768)
        self.key = nn.Linear(in_features=768, out_features=768)
        self.value = nn.Linear(in_features=768, out_features=768)
        self.drpOut = nn.Dropout(p = 0.1)
        self.linear = nn.Linear(in_features=768, out_features=768)
        self.lNorm = nn.LayerNorm(768, eps=1e-12)

        self.softmax = nn.Softmax(dim=-1)

        self.qkvDim = 64
        self.numHead = 12


    def forward(self, x):
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        batch, seqLen, hiddenDim = x.shape
        q = q.reshape(batch, seqLen, 12, 64).permute(0, 2, 1, 3)
        k = k.reshape(batch, seqLen, 12, 64).permute(0, 2, 3, 1)
        v = v.reshape(batch, seqLen, 12, 64).permute(0, 2, 1, 3)

        score = (q @ k) / 8 # root(64) = 8
        sf = self.softmax(score)
        sf = self.drpOut(sf)
        attnOut = (sf @ v).permute(0, 2, 1, 3).contiguous(). view(batch, seqLen, hiddenDim)

        hiddenOut = self.linear(attnOut)
        hiddenOut = self.drpOut(hiddenOut)
        hiddenOut = self.lNorm(hiddenOut + x)

        return hiddenOut + attnOut

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(in_features=768, out_features=3072)
        self.gelu = nn.GELU()
        self.linear2 = nn.Linear(in_features=3072, out_features=768)
        self.lNorm = nn.LayerNorm(768, eps=1e-12)
        self.drpOut = nn.Dropout(p = 0.1)

    def forward(self, x):
        hiddenOut = self.linear1(x)
        hiddenOut = self.gelu(hiddenOut)
        hiddenOut = self.linear2(hiddenOut)
        hiddenOut = self.drpOut(hiddenOut)
        return self.lNorm(hiddenOut + x)

class Transformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.attn = Attention()
        self.mlp = MLP()
    
    def forward(self, x):
        attn = self.attn(x)
        return self.mlp(attn)

class BERT(nn.Module):
    def __init__(self):
        super().__init__()
        self.wordEmbd = nn.Embedding(num_embeddings=30522, embedding_dim=768)
        self.posEmbd = nn.Embedding(num_embeddings=512, embedding_dim=768)
        self.tokenEmbd = nn.Embedding(num_embeddings=2, embedding_dim=768)
        self.layerNorm = nn.LayerNorm(768, eps=1e-12)
        self.drpOut = nn.Dropout(p=0.1)
        self.layers = nn.ModuleList([
            Transformer() for _ in range(12)
        ])
    
    def forward(self, x):
        tokenEmbedInput = torch.zeros(x.shape, dtype=torch.int)
        # Only consider that there is only one sentence
        # That's why always take 0 index for the token embedding
        batchSize, seqLen = x.shape
        posIndex = torch.arange(0, seqLen).expand(batchSize, seqLen)
        x = self.wordEmbd(x) + self.posEmbd(posIndex) + self.tokenEmbd(tokenEmbedInput)
        x = self.layerNorm(x)
        x = self.drpOut(x)
        for layer in self.layers:
            x = layer(x)
        return x

In [70]:
class tokenizer:
    def __init__(self, path = r"C:\Users\shiva\Desktop\IISC\code\NeuroCpp\Projects\The Dream\embedding\vocabBert.txt"):
        self.path = path
        with open(path, "r") as f:
            self.vocab = json.load(f)
            
        self.reverseVocab = dict()
        for i in self.vocab.keys():
            self.reverseVocab[self.vocab[i]] = i    

    def EncodeWord(self, word):
        if(word in self.vocab):
            return [self.vocab[word]]
        res = []
        for i in range(len(word) - 1, -1, -1):
            if(word[:i] in self.vocab):
                res.append(self.vocab[word[:i]])
                res.extend(self.EncodeWord("##" + word[i:]))
                break
        return res
    
    def SplitWord(self, txt):
        res = []
        currentTxt = ""
        for ch in txt:
            if(ch == " "):
                if(currentTxt != ""):
                    res.append(currentTxt)
                currentTxt = ""
                continue
            elif(re.match(r"[^a-z0-9]", ch)):
                if(currentTxt != ""):
                    res.append(currentTxt)
                res.append(ch)
                currentTxt = ""
                continue
            currentTxt += ch
        return res
                

    def encode(self, txt):
        txt = txt.replace("\n", " ").lower()
        words = self.SplitWord(txt)
        res = []
        for word in words:
            res.extend(self.EncodeWord(word))
        return res
    
    def decode(self, li):
        pass

In [71]:
tkn = tokenizer()
encodeVec = tkn.encode("  Hello,,,   WORLD!! 42@@@openAI##   is---great??? \n\n   NLP---rocks :)    [unused10] ".lower())

In [74]:
len(encodeVec)

35

In [72]:
model = BERT()
model(torch.tensor(encodeVec).view(1,len(encodeVec))).shape

torch.Size([1, 35, 768])