In [14]:
import torch
import torch.nn as nn

In [15]:
class BERTEmbedding(nn.Module) :
    def __init__(self, vocab_size, n_segments, max_len, embed_dim, dropout) :
        super().__init__()
        self.tok_embed = nn.Embedding(vocab_size, embed_dim)
        self.seg_embed = nn.Embedding(n_segments, embed_dim)
        self.pos_embed = nn.Embedding(max_len, embed_dim)
        self.drop = nn.Dropout(dropout)
        self.pos_inp = torch.tensor([i for i in range(max_len)],)


    def forward(self, sequence, segment) : 
        embed_val = self.tok_embed(sequence) + self.seg_embed(segment) + self.pos_embed(self.pos_inp)
        return embed_val

In [22]:
class BERT(nn.Module) : 
    def __init__(self, vocab_size, n_segments, max_len, embed_dim, n_layers, attn_heads, dropout ) :
        super().__init__()
        self.embedding = BERTEmbedding(vocab_size, n_segments, max_len, embed_dim, dropout)
        self.encoder_layer = nn.TransformerEncoderLayer(embed_dim, attn_heads, embed_dim+4)
        self.encoder_block = nn.TransformerEncoder(self.encoder_layer, n_layers)

    def forward(self, sequence, segment) : 
        out = self.embedding(sequence, segment)
        out = self.encoder_block(out)

        return out

In [23]:
#parameters
VOCAB_SIZE = 30000
N_SEGMENTS = 3
MAX_LEN = 512
EMBED_DIM = 768
N_LAYERS = 12
ATTN_HEADS = 12
DROPOUT = 0.1

In [24]:
sample_seq = torch.randint(high=VOCAB_SIZE, size=[MAX_LEN,])
sample_seg = torch.randint(high=N_SEGMENTS, size=[MAX_LEN,])

# Embedding BERT

In [25]:
embedding = BERTEmbedding(VOCAB_SIZE, N_SEGMENTS, MAX_LEN, EMBED_DIM, DROPOUT)

embedding_tensor = embedding(sample_seq,sample_seg)
print(embedding_tensor.size())

torch.Size([512, 768])


In [26]:
embedding_tensor

tensor([[ 2.4407,  0.4583,  1.0858,  ..., -0.4173,  1.9546,  0.2517],
        [-1.4944,  1.6902,  2.2356,  ...,  0.8020,  3.8421,  1.3661],
        [ 1.5868, -1.2223,  2.7905,  ..., -1.4159, -1.8224, -2.4992],
        ...,
        [-0.2949,  1.0737,  0.1272,  ..., -0.2352,  0.1508,  0.1010],
        [-2.1031,  0.1447, -0.3413,  ...,  1.1199, -0.9428, -2.3064],
        [-0.0131,  2.3977,  2.0010,  ...,  1.4253,  0.2366, -2.0512]],
       grad_fn=<AddBackward0>)

# BERT

In [27]:
bert = BERT(VOCAB_SIZE, N_SEGMENTS, MAX_LEN, EMBED_DIM, N_LAYERS, ATTN_HEADS, DROPOUT)
out = bert(sample_seq, sample_seg)
print(out)

tensor([[-0.4997, -0.8414, -0.9764,  ...,  0.0699,  2.2798, -0.9598],
        [-0.9944, -0.9993, -1.7293,  ..., -0.1628,  1.7765, -0.7004],
        [-0.7878, -1.0090, -1.7070,  ..., -0.4104,  1.3155, -0.3411],
        ...,
        [-0.7560, -1.4658, -1.5455,  ..., -0.1756,  1.7452, -0.8167],
        [-0.5068, -1.0970, -1.4658,  ...,  0.0860,  2.2402, -0.6997],
        [-0.3310, -0.4516, -1.5201,  ..., -0.1194,  1.7818, -0.8136]],
       grad_fn=<NativeLayerNormBackward0>)
