In [83]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import kaiming_normal_
import math

In [84]:
class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :]+\
                 self.pos_embedding[:token_embedding.size(0), :])
        
                             

In [85]:
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz))) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

In [87]:
class GPT2Model(nn.Module):
    def __init__(self,embedding_dim=512,vocab_sz=12000,d_model=512):
        super().__init__()
        
        self.embedding_dim = embedding_dim
        self.d_model = d_model
        self.vocab_sz = vocab_sz
    
        
        self.embedding = nn.Embedding(self.vocab_sz,self.embedding_dim,padding_idx=0).to(torch.float)
        
        self.emb_weight = nn.Parameter(self.embedding.weight)
        
        self.position_emb = PositionalEncoding(emb_size=self.embedding_dim,
                                              dropout=0.02)
        
        self.fc1 = nn.Linear(self.embedding_dim,self.d_model)
        
        self.decoder_layer = nn.TransformerDecoderLayer(d_model=512,nhead=8,
                                                       batch_first=True)
        
        
        
        
        
    def forward(self,seq):
        
        s = self.embedding(seq[:,:-1])
        
        s = self.position_emb(s)
        
        sz = s.shape[1] #seq length
        
        mask = generate_square_subsequent_mask(sz)
        
        transformer = nn.TransformerDecoder(self.decoder_layer, num_layers=6)
                                             
        out = transformer(tgt=s,memory=s,tgt_mask=mask) 
        
        out = out @ self.embedding.weight.T
        
        return F.softmax(out,dim=1)
                                             

In [88]:
gpt = GPT2Model()

In [89]:
seq = torch.randint(12000,(20,15))

In [90]:
seq

tensor([[ 1444,  6866, 10298,  5349,  8742,  4656,  4974,  1553,  5392,  8350,
          7766,  5875,  6813,  5359,  1348],
        [ 5421,  7495,  5031, 10265,  4535,  5024,  2605,  9636,  5759,  5241,
          7620,   498,  6163,   155,  4595],
        [ 4922,  5422, 11667, 10973,  3140, 11864,  3417,  7335,  9431,  5849,
          7014,   473, 11552,  9431,  3023],
        [ 6025,  4873,   850,  6919,  7008,  9700,  2830,  9614,  5624,  7398,
         10037,  3465, 11286,  4638,  3101],
        [11191,  8705,  1452,  6506,  8122,  3990,  9915,  1362,  4087,   437,
          1608,   550,  1931,  9101,  8699],
        [  901, 11774,  5921, 10341, 10574,  5068,   124,  7600,  5595,  9150,
          9863, 11509,  8178,  9247,  8061],
        [10469,  6485,  2340,  2692,  3427,  2365,    53,  8785,  5364, 10366,
          5740,  1264,  5529,  5776,  7627],
        [ 2375,  1569,  6195,  7976,  5548,  7011, 11130,  8561,  4736,  9652,
          3656, 11413,  9062,  3189,  3810],
        

In [91]:
output = gpt(seq)

In [92]:
output.shape

torch.Size([20, 14, 12000])

In [93]:
nn.NLLLoss()(output.view(output.shape[0]*output.shape[1],-1),seq[:,1:].reshape(-1))

tensor(-0.0445, grad_fn=<NllLossBackward0>)