In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl



class BertLike(pl.LightningModule):
    def __init__(self, num_tokens, hidden_size, num_layers, num_heads, dropout_rate):
        super(BertLike, self).__init__()
        
        # tokenizer
        self.token_embedding = nn.Embedding(num_tokens, hidden_size)
        
        # positional embedding
        self.position_embedding = nn.Embedding(1000, hidden_size)
        
        # transformer encoder layers
        self.encoder_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(hidden_size, num_heads, dim_feedforward=4*hidden_size, dropout=dropout_rate)
            for _ in range(num_layers)
        ])
        
        # dropout
        self.dropout = nn.Dropout(dropout_rate)
        
        # get the out tokens
        self.fc = nn.Linear(hidden_size, num_tokens)

    def forward(self, x):
        
        # token_emb (batch_size, seq_len, hidden_size)
        # get the token embedding
        token_emb = self.token_embedding(x) 
        
        # position_emb (1, seq_len, hidden_size)
        position_emb = self.position_embedding(torch.arange(x.size(1), device=x.device))[None, :, :] 
        
        # embedding (batch_size, seq_len, hidden_size)
        # add position embedding to token embedding
        emb = self.dropout(token_emb + position_emb) 
        
        # bert like block
        for layer in self.encoder_layers:
            emb = layer(emb)
            
        # logits (batch_size, num_tokens)
        # get the out tokens
        logits = self.fc(emb[:, -1, :]) 
        
        return logits

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        # get the validation data
        x, y = batch
        
        # get the predictions
        logits = self(x)
        
        # loss
        loss = F.cross_entropy(logits, y)
        
        # log metrics
        self.log('val_loss', loss, on_step=True, on_epoch=True)
        
        preds = logits.argmax(dim=-1)
        acc = (preds == y).float().mean()
        
        # log metrics
        self.log('val_acc', acc, on_step=True, on_epoch=True)
        
        return acc

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        return optimizer

In [63]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# This code calculates the cross entropy loss between the logits and the true labels.

logits = torch.randn(10, 1000, 21)
y = torch.randint(0,21, (10, 1000, 21)).to(torch.float32)


loss = F.cross_entropy(logits, y)
print(loss)

tensor(74097.6406)
