### Training Pipeline

In [None]:
import lightning as pl
import pandas as pd
from lightning.pytorch import loggers as pl_loggers
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from sklearn.model_selection import train_test_split

from en_grammar_checker.config import Config
from en_grammar_checker.datasets import get_train_data_loader, get_val_data_loader
from en_grammar_checker.trainer import MyLightningClassifierModel

In [None]:
cnfg = Config()

### Read Data

In [None]:
df = pd.read_csv(
    f"{cnfg.train_path}",
    delimiter="\t",
    header=None,
    names=["sentence_source", "label", "label_notes", "sentence"],
)

#### Create Train Val DF

In [None]:
train_df, val_df = train_test_split(
    df, test_size=cnfg.train_val_split, random_state=cnfg.seed
)
train_df = train_df.reset_index(drop=True)
val_df = val_df.reset_index(drop=True)

In [None]:
train_df.shape, val_df.shape

((7268, 4), (1283, 4))

#### Get DataLoaders

In [None]:
train_dataloader = get_train_data_loader(cnfg, train_df)
val_dataloader = get_val_data_loader(cnfg, val_df)



#### PL Model

In [None]:
task = MyLightningClassifierModel(cnfg)

#### Training

In [None]:
checkpoint_callback = ModelCheckpoint(
    # dirpath=f"{cnfg.training_logs_path}/{cnfg.experiment_name}/",
    monitor="val_f1",
    filename="{epoch}-{val_f1:.4f}",
    every_n_train_steps=1,
    mode="max",
    save_top_k=1,
)

In [None]:
trainer = pl.Trainer(
    accelerator="gpu",
    default_root_dir=f"{cnfg.training_logs_path}/{cnfg.experiment_name}/",
    callbacks=[
        EarlyStopping(
            monitor="val_f1",
            mode="max",
            patience=cnfg.early_stopping_rounds,
        ),
        # RichProgressBar(),
        checkpoint_callback,
        # progress_bar,
    ],
    max_epochs=100,
    # enable_progress_bar=False,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/rohit/Desktop/rohit/virtualenvs/rohit_transformers/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


In [None]:
import warnings

warnings.filterwarnings("ignore")

In [None]:
trainer.fit(task, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Missing logger folder: ../training_logs/test_run1/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                | Params
----------------------------------------------
0 | model | EnDeBertaClassifier | 435 M 
----------------------------------------------
1.2 M     Trainable params
434 M     Non-trainable params
435 M     Total params
1,740.773 Total estimated model params size (MB)


Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

	 Epoch: 0, Val F1: 0.4434782608695652, Val Precision: 0.5, Val Recall 0.3984375


Training: |                                               | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

	 Epoch: 0, Val F1: 0.7267668000707201, Val Precision: 0.7637950652772787, Val Recall 0.7208686626586371


Validation: |                                             | 0/? [00:00<?, ?it/s]

	 Epoch: 1, Val F1: 0.7591337289215193, Val Precision: 0.7364923942987185, Val Recall 0.8086419270629797


Validation: |                                             | 0/? [00:00<?, ?it/s]

	 Epoch: 2, Val F1: 0.7775745882616112, Val Precision: 0.7619789795185052, Val Recall 0.8023461150353179


Validation: |                                             | 0/? [00:00<?, ?it/s]

	 Epoch: 3, Val F1: 0.7798964717101337, Val Precision: 0.7624850281470835, Val Recall 0.8089691804480537


Validation: |                                             | 0/? [00:00<?, ?it/s]

	 Epoch: 4, Val F1: 0.7792233635218176, Val Precision: 0.7594023236315727, Val Recall 0.8150273462500728


Validation: |                                             | 0/? [00:00<?, ?it/s]

	 Epoch: 5, Val F1: 0.7887168210989893, Val Precision: 0.7766244460414421, Val Recall 0.805603522422029


Validation: |                                             | 0/? [00:00<?, ?it/s]

	 Epoch: 6, Val F1: 0.7836742786887263, Val Precision: 0.7941460055096419, Val Recall 0.7758873454623071


Validation: |                                             | 0/? [00:00<?, ?it/s]

	 Epoch: 7, Val F1: 0.7511085245538189, Val Precision: 0.7248173433944185, Val Recall 0.8222096336732241


Validation: |                                             | 0/? [00:00<?, ?it/s]

	 Epoch: 8, Val F1: 0.7883486001883612, Val Precision: 0.7878683075817463, Val Recall 0.7888352431484007


Validation: |                                             | 0/? [00:00<?, ?it/s]

	 Epoch: 9, Val F1: 0.7914115084053196, Val Precision: 0.8026634926338484, Val Recall 0.7831302521008403


Validation: |                                             | 0/? [00:00<?, ?it/s]

	 Epoch: 10, Val F1: 0.7813160330618952, Val Precision: 0.7613232123607618, Val Recall 0.8173601413859313


Validation: |                                             | 0/? [00:00<?, ?it/s]

	 Epoch: 11, Val F1: 0.7840782431164819, Val Precision: 0.8131857108635765, Val Recall 0.77204639210117


Validation: |                                             | 0/? [00:00<?, ?it/s]

	 Epoch: 12, Val F1: 0.7926495452710582, Val Precision: 0.8029165169481376, Val Recall 0.7848702959033105


Validation: |                                             | 0/? [00:00<?, ?it/s]

	 Epoch: 13, Val F1: 0.7918695446896602, Val Precision: 0.7731764283147682, Val Recall 0.8231874508470254


Validation: |                                             | 0/? [00:00<?, ?it/s]

	 Epoch: 14, Val F1: 0.7937334885984427, Val Precision: 0.8042939274164571, Val Recall 0.785784189656627


Validation: |                                             | 0/? [00:00<?, ?it/s]

	 Epoch: 15, Val F1: 0.7982284812727337, Val Precision: 0.8086791831357049, Val Recall 0.7902896518320774


Validation: |                                             | 0/? [00:00<?, ?it/s]

	 Epoch: 16, Val F1: 0.7501993678904344, Val Precision: 0.7208632770391663, Val Recall 0.8433389191453707


Validation: |                                             | 0/? [00:00<?, ?it/s]

	 Epoch: 17, Val F1: 0.8029775077097241, Val Precision: 0.8057746436699006, Val Recall 0.8003726843272563


Validation: |                                             | 0/? [00:00<?, ?it/s]

	 Epoch: 18, Val F1: 0.7872118946714839, Val Precision: 0.7625883339322075, Val Recall 0.837528674580081


Validation: |                                             | 0/? [00:00<?, ?it/s]

	 Epoch: 19, Val F1: 0.8003338297041305, Val Precision: 0.7948960953407593, Val Recall 0.806590146378517


Validation: |                                             | 0/? [00:00<?, ?it/s]

	 Epoch: 20, Val F1: 0.7994412084045086, Val Precision: 0.8190142532039766, Val Recall 0.787854962038313


Validation: |                                             | 0/? [00:00<?, ?it/s]

	 Epoch: 21, Val F1: 0.7767716572189032, Val Precision: 0.7501542100850401, Val Recall 0.838098764832532


Validation: |                                             | 0/? [00:00<?, ?it/s]

	 Epoch: 22, Val F1: 0.7808002517068022, Val Precision: 0.7542864414899988, Val Recall 0.8403449077413874


Validation: |                                             | 0/? [00:00<?, ?it/s]

	 Epoch: 23, Val F1: 0.8037888780562406, Val Precision: 0.7872409869445443, Val Recall 0.8291257805530776


Validation: |                                             | 0/? [00:00<?, ?it/s]

	 Epoch: 24, Val F1: 0.7950479233226837, Val Precision: 0.7719771828961552, Val Recall 0.8385875643032878


In [None]:
# from lightning.pytorch.tuner.tuning import Tuner
# lr_finder = Tuner(trainer).lr_find(model = task,train_dataloaders=train_dataloader, val_dataloaders=val_dataloader,min_lr=1e-08, max_lr=0.1, num_training=1000)