In [None]:
%load_ext autoreload
%autoreload 2
# %env CUDA_VISIBLE_DEVICES=

In [None]:
import sys; sys.path.append("../")

In [None]:
from fof.encdec import EncoderDecoderModel
from fof.dataloader import ScicapDataModule
from pathlib import Path
import wandb

In [None]:
run = wandb.init(project="figuring-out-figures")

In [None]:
import pytorch_lightning as pl

datamodule = ScicapDataModule(
    "First-Sentence",
    batch_size=16,
    tokenizer=None,
    num_workers=32,
    root=Path("../scicap_data"),
    caption_type="orig")
trainer = pl.Trainer(gpus=1)

In [None]:
def test_model(artifact_str: str = None, ckpt_path: str = None, use_test = False, **kwargs):
    if artifact_str is not None:
        artifact = run.use_artifact(artifact_str, type="model")
        artifact_dir = artifact.download()
        ckpt_path = Path(artifact_dir) / "model.ckpt"
    model = EncoderDecoderModel.load_from_checkpoint(ckpt_path, **kwargs)
    if use_test:
        trainer.test(model, datamodule=datamodule)
    else:
        trainer.validate(model, datamodule=datamodule)

In [None]:
# Text features only, best model
test_model(ckpt_path="/data/kevin/arxiv/evaluation_checkpoints/text-features-only.ckpt", use_test=True)

In [None]:
# Image features (CLIP) only, original caption, best model
test_model(artifact_str="figuring-out-figures/figuring-out-figures/model-yw5qm3wp:v16", use_test=True)


In [None]:
# Image features (CLIP), DistilGPT, normalized captions, SCIBERT encoder
# for references, title, and abstract, best model
test_model(
    artifact_str='figuring-out-figures/figuring-out-figures/model-1d7mntmw:v2', use_test=True, use_top_p_sampling=True)


In [None]:
# Image features (CLIP), DistilGPT, original captions, SCIBERT encoder
# for references, title, and abstract, best model
test_model(
    artifact_str='figuring-out-figures/figuring-out-figures/model-27py12gz:v2', use_test=True, use_top_p_sampling=False)


In [None]:
# DistilGPT top p caption
test_model(artifact_str="figuring-out-figures/figuring-out-figures/model-27py12gz:v1", use_test=True)