In [1]:
!pip install tokenizers
!pip install torchtext
!pip install pytorch_lightning
!pip install datasets
!pip install tensorboard



In [2]:
!git clone https://github.com/shreyash-99/erav2.git

fatal: destination path 'erav2' already exists and is not an empty directory.


In [3]:
cd erav2/S18

/content/erav2/S18


In [4]:
from train import train_model, get_model, get_ds
from config_file import get_config, get_weights_file_path
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from pathlib import Path
import torch.nn as nn

config = get_config()
config["batch_size"] = 24
config["preload"] = None
config["num_epochs"] = 18


import torch
torch.cuda.amp.autocast(enabled=True)

train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)

#Tensorboard
writer = SummaryWriter(config["experiment_name"])

#Adam is used to train each feature with a different learning rate.
#If some feature is appearing less, adam takes care of it
optimizer = torch.optim.Adam(model.parameters(), lr = config["lr"], eps = 1e-9)

initial_epoch = 0
global_step = 0


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Max length of the source sentence : 309
Max length of the source target : 274


In [5]:
MAX_LR = 10**-3 ## 10 times the value it is decided in the paper
STEPS_PER_EPOCH = len(train_dataloader)
EPOCHS = 18

In [6]:
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=MAX_LR, steps_per_epoch=STEPS_PER_EPOCH, epochs=EPOCHS,
                                                div_factor=100, three_phase=False, pct_start=int(0.3*EPOCHS) / EPOCHS,
                                                anneal_strategy='linear', final_div_factor=100)

In [7]:
loss_fn = nn.CrossEntropyLoss(ignore_index = tokenizer_src.token_to_id("[PAD]"), label_smoothing=0.1)
Path(config["model_folder"]).mkdir(parents=True, exist_ok=True)


# if config["preload"]:
#     model_filename = get_weights_file_path(config, config["preload"])
#     print("Preloading model {model_filename}")
#     state = torch.load(model_filename)
#     model.load_state_dict(state["model_state_dict"])
#     initial_epoch = state["epoch"] + 1
#     optimizer.load_state_dict(state["optimizer_state_dict"])
#     global_step = state["global_step"]
#     print("preloaded")


scaler = torch.cuda.amp.GradScaler()
lr = [0.0]

for epoch in range(initial_epoch, config["num_epochs"]):
    torch.cuda.empty_cache()
    # print(epoch)
    model.train()
    batch_iterator = tqdm(train_dataloader, desc = f"Processing Epoch {epoch:02d}")

    loss_list = [0.0]

    for batch in batch_iterator:
        optimizer.zero_grad(set_to_none=True)
        encoder_input = batch["encoder_input"].to(device)
        decoder_input = batch["decoder_input"].to(device)
        encoder_mask = batch["encoder_mask"].to(device)
        decoder_mask = batch["decoder_mask"].to(device)

        with torch.autocast(device_type='cuda',dtype = torch.float16 ):

            encoder_output = model.encode(encoder_input, encoder_mask)
            decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask)
            proj_output = model.project(decoder_output)

            label = batch["label"].to(device)

        #Compute loss using cross entropy
            tgt_vocab_size = tokenizer_tgt.get_vocab_size()
            loss = loss_fn(proj_output.view(-1, tgt_vocab_size), label.view(-1))
            loss_list.append(loss.item())

        batch_iterator.set_postfix({"recent loss": f"{loss.item():6.3f}" , "avg loss": f"{sum(loss_list)/len(loss_list):6.3f}", 'lr' : f'{lr[-1]}'} )

        #Log the loss
        writer.add_scalar('train_loss', loss.item(), global_step)
        writer.flush()

        scaler.scale(loss).backward()

        #Backpropogate loss
        # loss.backward()
        scale = scaler.get_scale()
        scaler.step(optimizer)
        scaler.update()
        skip_lr_sched = (scale > scaler.get_scale())
        if not skip_lr_sched:
            scheduler.step()
        lr.append(scheduler.get_last_lr())

        # #Update weights
        # optimizer.step()
        # optimizer.zero_grad(set_to_none=True)
        global_step+=1

    #run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, writer, global_step)


    model_filename = get_weights_file_path(config, f"{epoch:02d}")
    torch.save(
        {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "global_step": global_step
        },
        model_filename
    )
    print("loss for this epoch is ", sum(loss_list)/len(loss_list))

Processing Epoch 00: 100%|██████████| 1213/1213 [04:30<00:00,  4.48it/s, recent loss=6.145, avg loss=6.899, lr=[0.00020786939313984168]]


loss for this epoch is  6.899453397637734


Processing Epoch 01: 100%|██████████| 1213/1213 [04:29<00:00,  4.50it/s, recent loss=5.415, avg loss=5.713, lr=[0.0004059020448548813]]


loss for this epoch is  5.712780703037143


Processing Epoch 02: 100%|██████████| 1213/1213 [04:25<00:00,  4.57it/s, recent loss=5.420, avg loss=5.291, lr=[0.0006039346965699209]]


loss for this epoch is  5.291489537897377


Processing Epoch 03: 100%|██████████| 1213/1213 [04:26<00:00,  4.55it/s, recent loss=4.896, avg loss=4.974, lr=[0.0008019673482849604]]


loss for this epoch is  4.974203303775441


Processing Epoch 04: 100%|██████████| 1213/1213 [04:26<00:00,  4.56it/s, recent loss=4.649, avg loss=4.687, lr=[0.000999836741424802]]


loss for this epoch is  4.687145909524632


Processing Epoch 05: 100%|██████████| 1213/1213 [04:27<00:00,  4.54it/s, recent loss=4.378, avg loss=4.381, lr=[0.0009231480246052381]]


loss for this epoch is  4.380527908366044


Processing Epoch 06: 100%|██████████| 1213/1213 [04:27<00:00,  4.54it/s, recent loss=3.993, avg loss=4.010, lr=[0.0008462960492104763]]


loss for this epoch is  4.009986962871457


Processing Epoch 07: 100%|██████████| 1213/1213 [04:25<00:00,  4.57it/s, recent loss=3.639, avg loss=3.652, lr=[0.0007694440738157144]]


loss for this epoch is  3.6523520765823037


Processing Epoch 08: 100%|██████████| 1213/1213 [04:26<00:00,  4.55it/s, recent loss=3.325, avg loss=3.309, lr=[0.0006925286892003298]]


loss for this epoch is  3.309226773323298


Processing Epoch 09: 100%|██████████| 1213/1213 [04:26<00:00,  4.55it/s, recent loss=3.249, avg loss=3.003, lr=[0.0006156767138055679]]


loss for this epoch is  3.003005013434459


Processing Epoch 10: 100%|██████████| 1213/1213 [04:27<00:00,  4.54it/s, recent loss=2.609, avg loss=2.746, lr=[0.0005387613291901834]]


loss for this epoch is  2.746154033764583


Processing Epoch 11: 100%|██████████| 1213/1213 [04:24<00:00,  4.59it/s, recent loss=2.449, avg loss=2.522, lr=[0.0004618459445747987]]


loss for this epoch is  2.521883590021086


Processing Epoch 12: 100%|██████████| 1213/1213 [04:25<00:00,  4.57it/s, recent loss=2.180, avg loss=2.338, lr=[0.00038493055995941403]]


loss for this epoch is  2.3379703037428503


Processing Epoch 13: 100%|██████████| 1213/1213 [04:24<00:00,  4.59it/s, recent loss=2.207, avg loss=2.181, lr=[0.00030807858456465215]]


loss for this epoch is  2.1806383894931347


Processing Epoch 14: 100%|██████████| 1213/1213 [04:26<00:00,  4.55it/s, recent loss=1.929, avg loss=2.052, lr=[0.00023122660916989027]]


loss for this epoch is  2.0522377175789885


Processing Epoch 15: 100%|██████████| 1213/1213 [04:25<00:00,  4.56it/s, recent loss=1.857, avg loss=1.952, lr=[0.0001543112245545057]]


loss for this epoch is  1.9515366337248286


Processing Epoch 16: 100%|██████████| 1213/1213 [04:25<00:00,  4.57it/s, recent loss=1.955, avg loss=1.870, lr=[7.745924915974394e-05]]


loss for this epoch is  1.869763157611231


Processing Epoch 17: 100%|██████████| 1213/1213 [04:24<00:00,  4.58it/s, recent loss=1.846, avg loss=1.813, lr=[5.438645443592641e-07]]


loss for this epoch is  1.8128717045807563


In [8]:
lr

[0.0,
 [1.016325857519789e-05],
 [1.0326517150395779e-05],
 [1.0489775725593668e-05],
 [1.0653034300791558e-05],
 [1.0816292875989446e-05],
 [1.0979551451187336e-05],
 [1.1142810026385226e-05],
 [1.1306068601583114e-05],
 [1.1469327176781003e-05],
 [1.1632585751978893e-05],
 [1.1795844327176783e-05],
 [1.195910290237467e-05],
 [1.212236147757256e-05],
 [1.2285620052770449e-05],
 [1.2448878627968338e-05],
 [1.2612137203166228e-05],
 [1.2775395778364118e-05],
 [1.2938654353562006e-05],
 [1.3101912928759895e-05],
 [1.3265171503957783e-05],
 [1.3428430079155673e-05],
 [1.3591688654353563e-05],
 [1.3754947229551453e-05],
 [1.3918205804749342e-05],
 [1.408146437994723e-05],
 [1.4244722955145118e-05],
 [1.4407981530343008e-05],
 [1.4571240105540898e-05],
 [1.4734498680738788e-05],
 [1.4897757255936677e-05],
 [1.5061015831134565e-05],
 [1.5224274406332453e-05],
 [1.5387532981530343e-05],
 [1.5550791556728233e-05],
 [1.5714050131926123e-05],
 [1.587730870712401e-05],
 [1.6040567282321902e-05],
