In [None]:
!pip install transformers==4.24 cache_decorator pytorch_lightning==1.6.3 torchmetrics==0.7.0

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import sys
sys.path.append('/content/drive/MyDrive/history')

In [None]:
import os
from event_models_utils import (
    get_rams_data_dict, load_rams_data, 
)

from event_models import (
    EventGenModelWrapper, EventGenModel,
    RAMSEventGenDataModule, collate_RAMS, collate_argument_RAMS
)
from transformers import (
    BartModel, BartTokenizer,
)

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

seed_everything(42, workers=True)


In [None]:
dm = RAMSEventGenDataModule(
    batch_size=1, 
    num_workers=0,
    data_dir="/content/drive/MyDrive/history/datasets/rams",
    pin_memory=False
)
dm.prepare_data()
dm.setup()

In [None]:
if not os.path.exists("/content/drive/MyDrive/history/checkpoints/"):
    os.makedirs("/content/drive/MyDrive/history/checkpoints/")

In [None]:
max_epochs = 10

In [None]:
event_checkpoint_callback = ModelCheckpoint(
    dirpath="/content/drive/MyDrive/history/checkpoints/event_gen",
    every_n_epochs=1,
    save_top_k=2,
    monitor="valid_loss",
    mode="min"
)

In [None]:
bart = BartModel.from_pretrained("facebook/bart-base")
bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
model = EventGenModelWrapper(
    bart=bart,
    bart_tokenizer=bart_tokenizer
)
logger = TensorBoardLogger(
    "logs", name="event_gen"
)

In [None]:
trainer = Trainer(
    max_epochs=max_epochs,
    deterministic=True,
    gpus=1,
    precision=16,
    gradient_clip_val=1, 
    logger=logger,
    callbacks=[event_checkpoint_callback],
    num_sanity_val_steps=0,
    #resume_from_checkpoint="/content/drive/MyDrive/history/checkpoints/event_gen/epoch=0-step=7328.ckpt"
)

In [None]:
trainer.fit(model, dm)

In [None]:
dm = RAMSEventGenDataModule(
    batch_size=1, 
    num_workers=0,
    data_dir="/content/drive/MyDrive/history/datasets/rams",
    pin_memory=False
)
bart = BartModel.from_pretrained("facebook/bart-base")
bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
model = EventGenModelWrapper(
    bart=bart,
    bart_tokenizer=bart_tokenizer
).load_from_checkpoint(
    "/content/drive/MyDrive/history/checkpoints/event_gen/epoch=1-step=14657.ckpt",
    bart=bart,
    bart_tokenizer=bart_tokenizer
)
logger = TensorBoardLogger(
    "logs", name="event_gen"
)
trainer = Trainer(
    max_epochs=max_epochs,
    deterministic=True,
    gpus=1,
    precision=16,
    gradient_clip_val=1, 
    logger=logger,
    callbacks=[event_checkpoint_callback],
    num_sanity_val_steps=0,
    resume_from_checkpoint="/content/drive/MyDrive/history/checkpoints/event_gen/epoch=1-step=14657.ckpt"
)
trainer.test(
    model, 
    dm,
    ckpt_path="/content/drive/MyDrive/history/checkpoints/event_gen/epoch=1-step=14657.ckpt"
)