In [None]:
!pip install tokenizers
!pip install torchtext
!pip install pytorch_lightning
!pip install datasets
!pip install tensorboard
!pip install lion_pytorch

# The loss didnt decrease much and was equal to nan many times which disturbed the training

In [None]:
from train_DB_AMP_OCP import train_model, get_model, get_ds
from config_file import get_config, get_weights_file_path
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from pathlib import Path
import torch.nn as nn

from lion_pytorch import Lion

config = get_config()
config["batch_size"] = 24
config["preload"] = None
config["num_epochs"] = 18


import torch
torch.cuda.amp.autocast(enabled=True)

train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)

#Tensorboard
writer = SummaryWriter(config["experiment_name"])

#Adam is used to train each feature with a different learning rate. 
#If some feature is appearing less, adam takes care of it
optimizer = Lion(model.parameters(), lr = 1e-4/10, weight_decay = 1e-2)

initial_epoch = 0
global_step = 0


In [None]:
MAX_LR = 10**-3 ## 10 times the value it is decided in the paper
STEPS_PER_EPOCH = len(train_dataloader)
EPOCHS = 18

In [None]:
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=MAX_LR, steps_per_epoch=STEPS_PER_EPOCH, epochs=EPOCHS,
                                                div_factor=100, three_phase=False, pct_start=int(0.3*EPOCHS) / EPOCHS,
                                                anneal_strategy='linear', final_div_factor=100)

In [None]:
loss_fn = nn.CrossEntropyLoss(ignore_index = tokenizer_src.token_to_id("[PAD]"), label_smoothing=0.1)
Path(config["model_folder"]).mkdir(parents=True, exist_ok=True)
    
    
# if config["preload"]:
#     model_filename = get_weights_file_path(config, config["preload"])
#     print("Preloading model {model_filename}")
#     state = torch.load(model_filename)
#     model.load_state_dict(state["model_state_dict"])
#     initial_epoch = state["epoch"] + 1
#     optimizer.load_state_dict(state["optimizer_state_dict"])
#     global_step = state["global_step"]
#     print("preloaded")
        

scaler = torch.cuda.amp.GradScaler()
lr = [0.0]

for epoch in range(initial_epoch, config["num_epochs"]):
    torch.cuda.empty_cache()
    # print(epoch)
    model.train()
    batch_iterator = tqdm(train_dataloader, desc = f"Processing Epoch {epoch:02d}")
    
    loss_list = []
    
    for batch in batch_iterator:
        optimizer.zero_grad(set_to_none=True)
        encoder_input = batch["encoder_input"].to(device)
        decoder_input = batch["decoder_input"].to(device)
        encoder_mask = batch["encoder_mask"].to(device)
        decoder_mask = batch["decoder_mask"].to(device)
        
        with torch.autocast(device_type='cuda',dtype = torch.float16 ):
        
            encoder_output = model.encode(encoder_input, encoder_mask)
            decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask)
            proj_output = model.project(decoder_output)
        
            label = batch["label"].to(device)
        
        #Compute loss using cross entropy
            tgt_vocab_size = tokenizer_tgt.get_vocab_size()
            loss = loss_fn(proj_output.view(-1, tgt_vocab_size), label.view(-1))
            loss_list.append(loss.item())
            
        batch_iterator.set_postfix({"recent loss": f"{loss.item():6.3f}" , "avg loss": f"{sum(loss_list)/len(loss_list):6.3f}", 'lr' : f'{lr[-1]}'} )

        #Log the loss
        writer.add_scalar('train_loss', loss.item(), global_step)
        writer.flush()
        
        scaler.scale(loss).backward()
        
        #Backpropogate loss
        # loss.backward()
        scale = scaler.get_scale()
        scaler.step(optimizer)
        scaler.update()
        skip_lr_sched = (scale > scaler.get_scale())    
        if not skip_lr_sched:
            scheduler.step()
        lr.append(scheduler.get_last_lr())
        
        
        
        # #Update weights
        # optimizer.step()
        # optimizer.zero_grad(set_to_none=True)
        global_step+=1
        
    #run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, writer, global_step)
    
    
    model_filename = get_weights_file_path(config, f"{epoch:02d}")
    torch.save(
        {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "global_step": global_step
        },
        model_filename
    )
    print("loss for this epoch is ", sum(loss_list)/len(loss_list))