In [None]:
import os
import warnings
from pathlib import Path
from dotenv import load_dotenv
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from modules import Img2TextModel, FlickerDataModule
from pytorch_lightning.callbacks import ModelCheckpoint

from jlclient import jarvisclient
from jlclient.jarvisclient import *

dotenv_path = Path('./project.env')
load_dotenv(dotenv_path=dotenv_path)
warnings.filterwarnings('ignore')

MODEL_ROOT_DIR = os.getenv('MODEL_ROOT_DIR')
ROOT_DIR = os.getenv('ROOT_DIR')
ANNOTATIONS = os.getenv('ANNOTATIONS')

jarvisclient.token = os.getenv('JARVISLAB_TOKEN')
jarvisclient.user_id = os.getenv('JARVISLAB_USERID')

In [None]:
instance = User.get_instance(64424)
instance.resume()
print(f'instance status: {instance.status}')

In [None]:
!nvidia-smi

In [None]:
# logger
logger = WandbLogger(project="img2text")

# callbacks
model_type = 'img2text_tiny_256x256_16p'
checkpoint_callback = ModelCheckpoint(
    dirpath=MODEL_ROOT_DIR,
    filename=f'{model_type}',
    monitor='val_loss',
    mode="min")

# data module
datamodule = FlickerDataModule({
    'root_dir': ROOT_DIR,
    'annotations': ANNOTATIONS,
    'batch_size': 32,
})

# model to be trained
model = Img2TextModel({
    'num_layers': 1,
})

# trainer object
trainer = Trainer(
    logger=logger,
    accelerator='gpu',
    callbacks=[checkpoint_callback],
    max_epochs=50,
    gradient_clip_val=0.5,
    gradient_clip_algorithm='value',
    devices=1,
)

# call fit()
try:
    trainer.fit(model, datamodule)
except KeyboardInterrupt:
    instance.pause()
    print(f'instance status: {instance.status}')