## Train with Pytorch Lighning

In [None]:
import os
import os.path as osp
if osp.exists("imagecaption"):
    os.chdir("imagecaption")

In [None]:
!pip install -r train/requirements.txt

In [None]:
import os.path as osp
import sys
train_dir = osp.abspath("train")
if train_dir not in sys.path:
    sys.path.append(train_dir)
    
import pl_module
import pytorch_lightning as pl
from huggingface_hub import HfApi, notebook_login
from fine_tune import get_input_model_name

pl.seed_everything(42, workers=True)


### Login to HuggingFace

In [None]:

notebook_login()

### Load the model

In [None]:
from datasets import load_dataset
from torch.utils.data import DataLoader
from pl_module import ImageCaptioningModule
from dataset import ImageCaptioningDataset
from transformers import GitProcessor, GitForCausalLM

from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint

out_model = dataset_desc_only = "soul11zz/image-caption-desc-only"
model_name = get_input_model_name(out_model)

dt_train = load_dataset(dataset_desc_only, split="train")
dt_val = load_dataset(dataset_desc_only, split="validation")

processor = GitProcessor.from_pretrained(out_model)
train_dataset = ImageCaptioningDataset(dt_train, processor)
val_dataset = ImageCaptioningDataset(dt_val, processor)


### Prelims for PL Module

In [None]:
batch_size = 4
model_dir = "tmp/model"
epochs = 10

if not osp.exists(model_dir):
  os.makedirs(model_dir)
  
num_workers = os.cpu_count() if os.name != "nt" else 0
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

model = GitForCausalLM.from_pretrained(out_model)
callbacks = []

pl_train_module = ImageCaptioningModule(processor, model, train_loader, val_loader, learning_rate=1e-2)

### Trainer
logger = TensorBoardLogger("tb_logs", name="image-captioning")

checkpoint = ModelCheckpoint(dirpath=model_dir, 
                              save_top_k=2, monitor="val_loss", 
                              mode="min", 
                              filename="imcap-{epoch:02d}-{val_loss:.2f}")

callbacks += [checkpoint]


### Create PL Trainer

In [None]:
trainer = pl.Trainer( 
                      logger=logger, 
                      gpus=1,
                      callbacks=callbacks,
                      max_epochs=epochs,
                      check_val_every_n_epoch=1,
                      # val_check_interval=50,
                      precision=16,
                      num_sanity_val_steps=1,
                      )


### (optional) Tune Learning Rate

In [None]:
tuner = pl.Trainer(auto_lr_find=True, accelerator="cuda", devices=1, max_epochs=1)
tuner.tune(pl_train_module)

### Run Training

In [None]:
trainer.fit(pl_train_module)