In [1]:

from model import *
from preprocessing import *
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.optim as optim
import lightning as L

from math import floor
import torch.optim.lr_scheduler  as lr_sc


model=AlbertModel()


In [None]:


class ALBERT(L.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.loss_metrics=[]
    def forward(self,x):
      return self.model(x)


    def training_step(self, batch):
        
        input_ids,attention_mask,token_type_ids=batch['input_ids'],batch['attention_mask'],batch['token_type_ids']

        y_mlm=torch.where(batch['labels']==-100,batch['input_ids'],batch['labels'])
        y_nsp=batch['label'].squeeze()
        
        y_hat_mlm,y_hat_nsp = self.model((input_ids,attention_mask,token_type_ids))

        y_hat_mlm=y_hat_mlm.transpose(1,2)
        
        # print(y_hat_nsp,y_nsp)
        # print(y_hat_mlm.shape,y_mlm.shape)
        # print(y_hat_nsp.shape,y_nsp.shape)
        # print(y_hat_mlm.dtype,y_mlm.dtype)
        # print(y_hat_nsp.dtype,y_nsp.dtype)
        loss_mlm = nn.functional.nll_loss( y_hat_mlm,y_mlm.squeeze(-1))
        loss_nsp = nn.functional.nll_loss( y_hat_nsp,y_nsp)
        lr=self.trainer.optimizers[0].param_groups[0]['lr']
        print(f'step : {self.global_step}, lr = {lr}, NSP_loss = {loss_nsp}, MLM_loss = {loss_mlm}')

        loss=loss_mlm+loss_nsp
        self.loss_metrics.append((loss_mlm,loss_nsp))

        return loss

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=6e-5)
        lr_scheduler = lr_sc.ExponentialLR(optimizer,0.5,total_iters=4)
        return [optimizer],[{"scheduler": lr_scheduler, "interval": "step"}]
    
    def lr_scheduler_step(self, scheduler, metric):
      scheduler.step(epoch=floor(self.global_step/200))  # timm's scheduler need the epoch value    
    
    def predict_step(self, batch ):
        
        input_ids,attention_mask,token_type_ids=batch['input_ids'],batch['attention_mask'],batch['token_type_ids']
        
        model=self.model.to('cuda')
        y_hat_mlm,y_hat_nsp=model((input_ids,attention_mask,token_type_ids))
        
        y_hat_mlm=y_hat_mlm.argmax(-1)
        y_hat_nsp=y_hat_nsp.argmax(-1)
        
        return y_hat_mlm,y_hat_nsp




albert = ALBERT(model)
trainer = L.Trainer(limit_train_batches=1000, max_epochs=1)
trainer.fit(model=albert, train_dataloaders=text_dataloader);


In [None]:
import matplotlib.pyplot as plt
losses=[(x.detach().cpu().numpy(),y.detach().cpu().numpy()) for x,y in albert.loss_metrics]
plt.plot(losses)


In [None]:
data=(next(iter(text_dataloader))).to('cuda')
output=trainer.model.predict_step(data)
(data['label'].squeeze()==output[1]).sum()
(data['label'].squeeze(),output[1])

In [None]:
tokenizer.convert_ids_to_tokens(output[0][0])[:20]